Skip to content

Commit 3c902e8

Browse files
philippemironPhilippe Mironmilancurcic
authored
apply ragged (#128)
* initial commit * initial commit * lint * lint * apply changes * fix merge in the comment * missing spaces * lint * added max_workers * forgot max_workers in the func.. * Docstring and type hints * Expand tests * Reorder arguments * Test with additional args * Test for args and kwargs each * Fix and test for passing arrs as a scalar DataArray * Rename arrs -> arrays (too similar to args) --------- Co-authored-by: Philippe Miron <[email protected]> Co-authored-by: milancurcic <[email protected]>
1 parent f095843 commit 3c902e8

File tree

2 files changed

+153
-2
lines changed

2 files changed

+153
-2
lines changed

clouddrift/analysis.py

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,93 @@
11
import numpy as np
2-
from typing import Optional, Tuple
2+
from typing import Optional, Tuple, Union
33
import xarray as xr
4+
from concurrent import futures
45
from clouddrift.haversine import distance, bearing
6+
from clouddrift.dataformat import unpack_ragged
7+
8+
9+
def apply_ragged(
10+
func: callable,
11+
arrays: list[np.ndarray],
12+
rowsize: list[int],
13+
*args: tuple,
14+
max_workers: int = None,
15+
**kwargs: dict,
16+
) -> Union[tuple[np.ndarray], np.ndarray]:
17+
"""Apply a function to a ragged array.
18+
19+
The function ``func`` will be applied to each contiguous row of ``arrays`` as
20+
indicated by row sizes ``rowsize``. The output of ``func`` will be
21+
concatenated into a single ragged array.
22+
23+
This function uses ``concurrent.futures.ThreadPoolExecutor`` to run ``func``
24+
in multiple threads. The number of threads can be controlled by the
25+
``max_workers`` argument, which is passed down to ``ThreadPoolExecutor``.
26+
27+
Parameters
28+
----------
29+
func : callable
30+
Function to apply to each row of each ragged array in ``arrays``.
31+
arrays : list[np.ndarray] or np.ndarray
32+
An array or a list of arrays to apply ``func`` to.
33+
rowsize : list
34+
List of integers specifying the number of data points in each row.
35+
*args : tuple
36+
Additional arguments to pass to ``func``.
37+
max_workers : int, optional
38+
Number of threads to use. If None, the number of threads will be equal
39+
to the ``max_workers`` default value of ``concurrent.futures.ThreadPoolExecutor``.
40+
**kwargs : dict
41+
Additional keyword arguments to pass to ``func``.
42+
43+
Returns
44+
-------
45+
out : tuple[np.ndarray] or np.ndarray
46+
Output array(s) from ``func``.
47+
48+
Examples
49+
--------
50+
>>> def func(x, y):
51+
... return x + y
52+
>>> x = np.arange(10)
53+
>>> y = np.arange(10, 20)
54+
>>> apply_ragged(func, [x, y], [5, 5])
55+
array([10, 12, 14, 16, 18, 20, 22, 24, 26, 28])
56+
57+
Raises
58+
------
59+
ValueError
60+
If the sum of ``rowsize`` does not equal the length of ``arrays``.
61+
"""
62+
# make sure the arrays is iterable
63+
if type(arrays) not in [list, tuple]:
64+
arrays = [arrays]
65+
# validate rowsize
66+
for arr in arrays:
67+
if not sum(rowsize) == len(arr):
68+
raise ValueError("The sum of rowsize must equal the length of arr.")
69+
70+
# split the array(s) into trajectories
71+
arrays = [unpack_ragged(arr, rowsize) for arr in arrays]
72+
iter = [[arrays[i][j] for i in range(len(arrays))] for j in range(len(arrays[0]))]
73+
74+
# combine other arguments
75+
for arg in iter:
76+
if args:
77+
arg.append(*args)
78+
79+
# parallel execution
80+
with futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
81+
res = executor.map(lambda x: func(*x, **kwargs), iter)
82+
# concatenate the outputs
83+
res = list(res)
84+
if isinstance(res[0], tuple): # more than 1 parameter
85+
outputs = []
86+
for i in range(len(res[0])):
87+
outputs.append(np.concatenate([r[i] for r in res]))
88+
return tuple(outputs)
89+
else:
90+
return np.concatenate(res)
591

692

793
def segment(
@@ -60,6 +146,7 @@ def segment(
60146
>>> segment(x, 0.5, rowsize=segment(x, -0.5))
61147
array([2, 2, 2, 2])
62148
"""
149+
63150
if rowsize is None:
64151
if tolerance >= 0:
65152
exceeds_tolerance = np.diff(x) > tolerance

tests/analysis_tests.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from clouddrift.analysis import segment, velocity_from_position
1+
from clouddrift.analysis import segment, velocity_from_position, apply_ragged
22
from clouddrift.haversine import EARTH_RADIUS_METERS
33
import unittest
44
import numpy as np
@@ -132,3 +132,67 @@ def test_time_axis(self):
132132
self.assertTrue(np.all(vf == expected_vf))
133133
self.assertTrue(np.all(uf.shape == expected_uf.shape))
134134
self.assertTrue(np.all(vf.shape == expected_vf.shape))
135+
136+
137+
class apply_ragged_tests(unittest.TestCase):
138+
def setUp(self):
139+
self.rowsize = [2, 3, 4]
140+
self.x = np.array([1, 2, 10, 12, 14, 30, 33, 36, 39])
141+
self.y = np.arange(0, len(self.x))
142+
self.t = np.array([1, 2, 1, 2, 3, 1, 2, 3, 4])
143+
144+
def test_simple(self):
145+
y = apply_ragged(lambda x: x**2, np.array([1, 2, 3, 4]), [2, 2])
146+
self.assertTrue(np.all(y == np.array([1, 4, 9, 16])))
147+
148+
def test_simple_dataarray(self):
149+
y = apply_ragged(
150+
lambda x: x**2,
151+
xr.DataArray(data=[1, 2, 3, 4], coords={"obs": [1, 2, 3, 4]}),
152+
[2, 2],
153+
)
154+
self.assertTrue(np.all(y == np.array([1, 4, 9, 16])))
155+
156+
def test_simple_with_args(self):
157+
y = apply_ragged(lambda x, p: x**p, np.array([1, 2, 3, 4]), [2, 2], 2)
158+
self.assertTrue(np.all(y == np.array([1, 4, 9, 16])))
159+
160+
def test_simple_with_kwargs(self):
161+
y = apply_ragged(lambda x, p: x**p, np.array([1, 2, 3, 4]), [2, 2], p=2)
162+
self.assertTrue(np.all(y == np.array([1, 4, 9, 16])))
163+
164+
def test_velocity_ndarray(self):
165+
u, v = apply_ragged(
166+
velocity_from_position,
167+
[self.x, self.y, self.t],
168+
self.rowsize,
169+
coord_system="cartesian",
170+
)
171+
self.assertIsNone(
172+
np.testing.assert_allclose(u, [1.0, 1.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0, 3.0])
173+
)
174+
self.assertIsNone(
175+
np.testing.assert_allclose(v, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
176+
)
177+
178+
def test_velocity_dataarray(self):
179+
u, v = apply_ragged(
180+
velocity_from_position,
181+
[
182+
xr.DataArray(data=self.x),
183+
xr.DataArray(data=self.y),
184+
xr.DataArray(data=self.t),
185+
],
186+
xr.DataArray(data=self.rowsize),
187+
coord_system="cartesian",
188+
)
189+
self.assertIsNone(
190+
np.testing.assert_allclose(u, [1.0, 1.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0, 3.0])
191+
)
192+
self.assertIsNone(
193+
np.testing.assert_allclose(v, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
194+
)
195+
196+
def test_bad_rowsize_raises(self):
197+
with self.assertRaises(ValueError):
198+
y = apply_ragged(lambda x: x**2, np.array([1, 2, 3, 4]), [2])

0 commit comments

Comments
 (0)