Skip to content

Commit f9826c0

Browse files
authored
Add optional rows parameter to unpack_ragged and apply_ragged (#272)
* Add optional rows parameter to unpack_ragged and apply_ragged * Fix conflict resolve from rowsize_to_index
1 parent afbf318 commit f9826c0

File tree

2 files changed

+71
-10
lines changed

2 files changed

+71
-10
lines changed

clouddrift/analysis.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def apply_ragged(
1818
arrays: list[np.ndarray],
1919
rowsize: list[int],
2020
*args: tuple,
21+
rows: Union[int, Iterable[int]] = None,
2122
executor: futures.Executor = futures.ThreadPoolExecutor(max_workers=None),
2223
**kwargs: dict,
2324
) -> Union[tuple[np.ndarray], np.ndarray]:
@@ -45,6 +46,9 @@ def apply_ragged(
4546
List of integers specifying the number of data points in each row.
4647
*args : tuple
4748
Additional arguments to pass to ``func``.
49+
rows : int or Iterable[int], optional
50+
The row(s) of the ragged array to apply ``func`` to. If ``rows`` is
51+
``None`` (default), then ``func`` will be applied to all rows.
4852
executor : concurrent.futures.Executor, optional
4953
Executor to use for concurrent execution. Default is ``ThreadPoolExecutor``
5054
with the default number of ``max_workers``.
@@ -72,6 +76,15 @@ def apply_ragged(
7276
array([1., 1., 2., 2., 2., 3., 3., 3., 3.]),
7377
array([1., 1., 1., 1., 1., 1., 1., 1., 1.]))
7478
79+
To apply ``func`` to only a subset of rows, use the ``rows`` argument:
80+
81+
>>> u1, v1 = apply_ragged(velocity_from_position, [x, y, t], rowsize, rows=0, coord_system="cartesian")
82+
array([1., 1.]),
83+
array([1., 1.]))
84+
>>> u1, v1 = apply_ragged(velocity_from_position, [x, y, t], rowsize, rows=[0, 1], coord_system="cartesian")
85+
array([1., 1., 2., 2., 2.]),
86+
array([1., 1., 1., 1., 1.]))
87+
7588
Raises
7689
------
7790
ValueError
@@ -88,7 +101,7 @@ def apply_ragged(
88101
raise ValueError("The sum of rowsize must equal the length of arr.")
89102

90103
# split the array(s) into trajectories
91-
arrays = [unpack_ragged(arr, rowsize) for arr in arrays]
104+
arrays = [unpack_ragged(arr, rowsize, rows) for arr in arrays]
92105
iter = [[arrays[i][j] for i in range(len(arrays))] for j in range(len(arrays[0]))]
93106

94107
# parallel execution
@@ -1088,7 +1101,9 @@ def subset(
10881101

10891102

10901103
def unpack_ragged(
1091-
ragged_array: np.ndarray, rowsize: np.ndarray[int]
1104+
ragged_array: np.ndarray,
1105+
rowsize: np.ndarray[int],
1106+
rows: Union[int, Iterable[int]] = None,
10921107
) -> list[np.ndarray]:
10931108
"""Unpack a ragged array into a list of regular arrays.
10941109
@@ -1103,6 +1118,8 @@ def unpack_ragged(
11031118
rowsize : array-like
11041119
An array of integers whose values is the size of each row in the ragged
11051120
array
1121+
rows : int or Iterable[int], optional
1122+
A row or list of rows to unpack. Default is None, which unpacks all rows.
11061123
11071124
Returns
11081125
-------
@@ -1119,6 +1136,8 @@ def unpack_ragged(
11191136
11201137
lon = unpack_ragged(ds.lon, ds["rowsize"]) # return a list[xr.DataArray] (slower)
11211138
lon = unpack_ragged(ds.lon.values, ds["rowsize"]) # return a list[np.ndarray] (faster)
1139+
first_lon = unpack_ragged(ds.lon.values, ds["rowsize"], rows=0) # return only the first row
1140+
first_two_lons = unpack_ragged(ds.lon.values, ds["rowsize"], rows=[0, 1]) # return first two rows
11221141
11231142
Looping over trajectories in a ragged Xarray Dataset to compute velocities
11241143
for each:
@@ -1133,4 +1152,10 @@ def unpack_ragged(
11331152
u, v = velocity_from_position(lon, lat, time)
11341153
"""
11351154
indices = rowsize_to_index(rowsize)
1136-
return [ragged_array[indices[n] : indices[n + 1]] for n in range(indices.size - 1)]
1155+
1156+
if rows is None:
1157+
rows = range(indices.size - 1)
1158+
if isinstance(rows, int):
1159+
rows = [rows]
1160+
1161+
return [ragged_array[indices[n] : indices[n + 1]] for n in rows]

tests/analysis_tests.py

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -739,6 +739,23 @@ def test_velocity_ndarray(self):
739739
)
740740
)
741741

742+
def test_with_rows(self):
743+
y = apply_ragged(
744+
lambda x: x**2,
745+
np.array([1, 2, 3, 4]),
746+
[2, 2],
747+
rows=0,
748+
)
749+
self.assertTrue(np.all(y == np.array([1, 4])))
750+
751+
y = apply_ragged(
752+
lambda x: x**2,
753+
np.array([1, 2, 3, 4]),
754+
[2, 2],
755+
rows=[0, 1],
756+
)
757+
self.assertTrue(np.all(y == np.array([1, 4, 9, 16])))
758+
742759
def test_velocity_dataarray(self):
743760
for executor in [futures.ThreadPoolExecutor(), futures.ProcessPoolExecutor()]:
744761
u, v = apply_ragged(
@@ -891,13 +908,32 @@ def test_unpack_ragged(self):
891908
np.all([lon[n].size == ds["rowsize"][n] for n in range(len(lon))])
892909
)
893910

911+
def test_unpack_ragged_rows(self):
912+
ds = sample_ragged_array().to_xarray()
913+
x = ds.lon.values
914+
rowsize = ds.rowsize.values
894915

895-
class rowsize_to_index_tests(unittest.TestCase):
896-
def test_rowsize_to_index(self):
897-
rowsize = [2, 3, 4]
898-
expected = np.array([0, 2, 5, 9])
899-
self.assertTrue(np.all(rowsize_to_index(rowsize) == expected))
900-
self.assertTrue(np.all(rowsize_to_index(np.array(rowsize)) == expected))
901916
self.assertTrue(
902-
np.all(rowsize_to_index(xr.DataArray(data=rowsize)) == expected)
917+
all(
918+
np.array_equal(a, b)
919+
for a, b in zip(
920+
unpack_ragged(x, rowsize, None), unpack_ragged(x, rowsize)
921+
)
922+
)
923+
)
924+
self.assertTrue(
925+
all(
926+
np.array_equal(a, b)
927+
for a, b in zip(
928+
unpack_ragged(x, rowsize, 0), unpack_ragged(x, rowsize)[:1]
929+
)
930+
)
931+
)
932+
self.assertTrue(
933+
all(
934+
np.array_equal(a, b)
935+
for a, b in zip(
936+
unpack_ragged(x, rowsize, [0, 1]), unpack_ragged(x, rowsize)[:2]
937+
)
938+
)
903939
)

0 commit comments

Comments
 (0)