@@ -18,6 +18,7 @@ def apply_ragged(
1818 arrays : list [np .ndarray ],
1919 rowsize : list [int ],
2020 * args : tuple ,
21+ rows : Union [int , Iterable [int ]] = None ,
2122 executor : futures .Executor = futures .ThreadPoolExecutor (max_workers = None ),
2223 ** kwargs : dict ,
2324) -> Union [tuple [np .ndarray ], np .ndarray ]:
@@ -45,6 +46,9 @@ def apply_ragged(
4546 List of integers specifying the number of data points in each row.
4647 *args : tuple
4748 Additional arguments to pass to ``func``.
49+ rows : int or Iterable[int], optional
50+ The row(s) of the ragged array to apply ``func`` to. If ``rows`` is
51+ ``None`` (default), then ``func`` will be applied to all rows.
4852 executor : concurrent.futures.Executor, optional
4953 Executor to use for concurrent execution. Default is ``ThreadPoolExecutor``
5054 with the default number of ``max_workers``.
@@ -72,6 +76,15 @@ def apply_ragged(
7276 array([1., 1., 2., 2., 2., 3., 3., 3., 3.]),
7377 array([1., 1., 1., 1., 1., 1., 1., 1., 1.]))
7478
79+ To apply ``func`` to only a subset of rows, use the ``rows`` argument:
80+
81+ >>> u1, v1 = apply_ragged(velocity_from_position, [x, y, t], rowsize, rows=0, coord_system="cartesian")
82+ array([1., 1.]),
83+ array([1., 1.]))
84+ >>> u1, v1 = apply_ragged(velocity_from_position, [x, y, t], rowsize, rows=[0, 1], coord_system="cartesian")
85+ array([1., 1., 2., 2., 2.]),
86+ array([1., 1., 1., 1., 1.]))
87+
7588 Raises
7689 ------
7790 ValueError
@@ -88,7 +101,7 @@ def apply_ragged(
88101 raise ValueError ("The sum of rowsize must equal the length of arr." )
89102
90103 # split the array(s) into trajectories
91- arrays = [unpack_ragged (arr , rowsize ) for arr in arrays ]
104+ arrays = [unpack_ragged (arr , rowsize , rows ) for arr in arrays ]
92105 iter = [[arrays [i ][j ] for i in range (len (arrays ))] for j in range (len (arrays [0 ]))]
93106
94107 # parallel execution
@@ -1088,7 +1101,9 @@ def subset(
10881101
10891102
10901103def unpack_ragged (
1091- ragged_array : np .ndarray , rowsize : np .ndarray [int ]
1104+ ragged_array : np .ndarray ,
1105+ rowsize : np .ndarray [int ],
1106+ rows : Union [int , Iterable [int ]] = None ,
10921107) -> list [np .ndarray ]:
10931108 """Unpack a ragged array into a list of regular arrays.
10941109
@@ -1103,6 +1118,8 @@ def unpack_ragged(
11031118 rowsize : array-like
11041119 An array of integers whose values is the size of each row in the ragged
11051120 array
1121+ rows : int or Iterable[int], optional
1122+ A row or list of rows to unpack. Default is None, which unpacks all rows.
11061123
11071124 Returns
11081125 -------
@@ -1119,6 +1136,8 @@ def unpack_ragged(
11191136
11201137 lon = unpack_ragged(ds.lon, ds["rowsize"]) # return a list[xr.DataArray] (slower)
11211138 lon = unpack_ragged(ds.lon.values, ds["rowsize"]) # return a list[np.ndarray] (faster)
1139+ first_lon = unpack_ragged(ds.lon.values, ds["rowsize"], rows=0) # return only the first row
1140+ first_two_lons = unpack_ragged(ds.lon.values, ds["rowsize"], rows=[0, 1]) # return first two rows
11221141
11231142 Looping over trajectories in a ragged Xarray Dataset to compute velocities
11241143 for each:
@@ -1133,4 +1152,10 @@ def unpack_ragged(
11331152 u, v = velocity_from_position(lon, lat, time)
11341153 """
11351154 indices = rowsize_to_index (rowsize )
1136- return [ragged_array [indices [n ] : indices [n + 1 ]] for n in range (indices .size - 1 )]
1155+
1156+ if rows is None :
1157+ rows = range (indices .size - 1 )
1158+ if isinstance (rows , int ):
1159+ rows = [rows ]
1160+
1161+ return [ragged_array [indices [n ] : indices [n + 1 ]] for n in rows ]
0 commit comments