Skip to content

Commit 33103dc

Browse files
vadmbertrPhilippe Mironselipotkevinsantana11
authored
Subset callable using multiple variables (#361)
* sevaral variables can be passed to the callable returning the mask * add some examples. update typings. * add raise TypeError * consistent typing. add unit tests * minor changes * docstring edits * refactor and add tests. revert np.all. small updates to the docstring * docstring edits * Added one Exception to validate all variable share the same dimension * bump version * Update pyproject.toml Add myself as an author for conda forge invite * black v24 --------- Co-authored-by: Vadim Bertrand <vadim.bertrand@univ-grenoble-alpes.fr> Co-authored-by: Philippe Miron <philippe.miron@dtn.com> Co-authored-by: Shane Elipot <selipot@miami.edu> Co-authored-by: Kevin Santana <kevinsantana11@gmail.com>
1 parent 2dd835c commit 33103dc

File tree

3 files changed

+104
-25
lines changed

3 files changed

+104
-25
lines changed

clouddrift/ragged.py

Lines changed: 72 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -545,20 +545,21 @@ def subset(
545545
obs_dim_name: str = "obs",
546546
full_trajectories=False,
547547
) -> xr.Dataset:
548-
"""Subset a ragged array dataset as a function of one or more criteria.
548+
"""Subset a ragged array xarray dataset as a function of one or more criteria.
549549
The criteria are passed with a dictionary, where a dictionary key
550550
is a variable to subset and the associated dictionary value is either a range
551551
(valuemin, valuemax), a list [value1, value2, valueN], a single value, or a
552-
masking function applied to every row of the ragged array using ``apply_ragged``.
552+
masking function applied to any variable of the dataset.
553553
554554
This function needs to know the names of the dimensions of the ragged array dataset
555555
(`traj_dim_name` and `obs_dim_name`), and the name of the rowsize variable (`rowsize_var_name`).
556-
Default values are provided for these arguments (see below), but they can be changed if needed.
556+
Default values corresponds to the clouddrift convention ("traj", "obs", and "rowsize") but should
557+
be changed as needed.
557558
558559
Parameters
559560
----------
560561
ds : xr.Dataset
561-
Dataset stored as ragged arrays.
562+
Xarray dataset composed of ragged arrays.
562563
criteria : dict
563564
Dictionary containing the variables (as keys) and the ranges/values/functions (as values) to subset.
564565
id_var_name : str, optional
@@ -570,19 +571,19 @@ def subset(
570571
obs_dim_name : str, optional
571572
Name of the observation dimension (default is "obs").
572573
full_trajectories : bool, optional
573-
If True, it returns the complete trajectories (rows) where at least one observation
574-
matches the criteria, rather than just the segments where the criteria are satisfied.
575-
Default is False.
574+
If True, the function returns complete rows (trajectories) for which the criteria
575+
are matched at least once. Default is False which means that only segments matching the criteria
576+
are returned when filtering along the observation dimension.
576577
577578
Returns
578579
-------
579580
xr.Dataset
580-
subset Dataset matching the criterion(a)
581+
Subset xarray dataset matching the criterion(a).
581582
582583
Examples
583584
--------
584-
Criteria are combined on any data or metadata variables part of the Dataset.
585-
The following examples are based on NOAA GDP datasets which can be accessed with the
585+
Criteria are combined on any data (with dimension "obs") or metadata (with dimension "traj") variables
586+
part of the Dataset. The following examples are based on NOAA GDP datasets which can be accessed with the
586587
``clouddrift.datasets`` module.
587588
588589
Retrieve a region, like the Gulf of Mexico, using ranges of latitude and longitude:
@@ -607,7 +608,7 @@ def subset(
607608
608609
>>> subset(ds, {"rowsize": (0, 1000)})
609610
610-
Retrieve specific drifters from their IDs:
611+
Retrieve specific drifters using their IDs:
611612
612613
>>> subset(ds, {"id": [2578, 2582, 2583]})
613614
@@ -637,10 +638,32 @@ def subset(
637638
>>> func = (lambda arr: ((arr - arr[0]) % 2) == 0)
638639
>>> subset(ds, {"time": func})
639640
641+
The filtering function can accept several input variables passed as a tuple. For example, retrieve
642+
drifters released in the Mediterranean Sea, but exclude those released in the Bay of Biscay and the Black Sea:
643+
644+
>>> def mediterranean_mask(lon: xr.DataArray, lat: xr.DataArray) -> xr.DataArray:
645+
>>> # Mediterranean Sea bounding box
646+
>>> in_med = np.logical_and(-6.0327 <= lon, np.logical_and(lon <= 36.2173,
647+
>>> np.logical_and(30.2639 <= lat, lat <= 45.7833)))
648+
>>> # Bay of Biscay
649+
>>> in_biscay = np.logical_and(lon <= -0.1462, lat >= 43.2744)
650+
>>> # Black Sea
651+
>>> in_blacksea = np.logical_and(lon >= 27.4437, lat >= 40.9088)
652+
>>> return np.logical_and(in_med, np.logical_not(np.logical_or(in_biscay, in_blacksea)))
653+
>>> subset(ds, {("start_lon", "start_lat"): mediterranean_mask})
654+
640655
Raises
641656
------
642657
ValueError
643-
If one of the variable in a criterion is not found in the Dataset
658+
If one of the variable in a criterion is not found in the Dataset.
659+
TypeError
660+
If one of the `criteria` key is a tuple while its associated value is not a `Callable` criterion.
661+
TypeError
662+
If variables of a `criterion` key associated to a `Callable` do not share the same dimension.
663+
664+
See Also
665+
--------
666+
:func:`apply_ragged`
644667
"""
645668
mask_traj = xr.DataArray(
646669
data=np.ones(ds.sizes[traj_dim_name], dtype="bool"), dims=[traj_dim_name]
@@ -650,19 +673,30 @@ def subset(
650673
)
651674

652675
for key in criteria.keys():
653-
if key in ds or key in ds.dims:
654-
if ds[key].dims == (traj_dim_name,):
676+
if np.all(np.isin(key, ds.variables) | np.isin(key, ds.dims)):
677+
if isinstance(key, tuple):
678+
criterion = [ds[k] for k in key]
679+
if not all(c.dims == criterion[0].dims for c in criterion):
680+
raise TypeError(
681+
"Variables passed to the Callable must share the same dimension."
682+
)
683+
criterion_dims = criterion[0].dims
684+
else:
685+
criterion = ds[key]
686+
criterion_dims = criterion.dims
687+
688+
if criterion_dims == (traj_dim_name,):
655689
mask_traj = np.logical_and(
656690
mask_traj,
657691
_mask_var(
658-
ds[key], criteria[key], ds[rowsize_var_name], traj_dim_name
692+
criterion, criteria[key], ds[rowsize_var_name], traj_dim_name
659693
),
660694
)
661-
elif ds[key].dims == (obs_dim_name,):
695+
elif criterion_dims == (obs_dim_name,):
662696
mask_obs = np.logical_and(
663697
mask_obs,
664698
_mask_var(
665-
ds[key], criteria[key], ds[rowsize_var_name], obs_dim_name
699+
criterion, criteria[key], ds[rowsize_var_name], obs_dim_name
666700
),
667701
)
668702
else:
@@ -769,7 +803,7 @@ def unpack(
769803

770804

771805
def _mask_var(
772-
var: xr.DataArray,
806+
var: Union[xr.DataArray, list[xr.DataArray]],
773807
criterion: Union[tuple, list, np.ndarray, xr.DataArray, bool, float, int, Callable],
774808
rowsize: xr.DataArray = None,
775809
dim_name: str = "dim_0",
@@ -778,8 +812,8 @@ def _mask_var(
778812
779813
Parameters
780814
----------
781-
var : xr.DataArray
782-
DataArray to be subset by the criterion
815+
var : xr.DataArray or list[xr.DataArray]
816+
DataArray or list of DataArray (only applicable if the criterion is a Callable) to be used by the criterion
783817
criterion : array-like or scalar or Callable
784818
The criterion can take four forms:
785819
- tuple: (min, max) defining a range
@@ -815,26 +849,41 @@ def _mask_var(
815849
array([False, True, False, True, False])
816850
Dimensions without coordinates: dim_0
817851
852+
>>> y = xr.DataArray(data=np.arange(0, 5)+2)
853+
>>> rowsize = xr.DataArray(data=[2, 3])
854+
>>> _mask_var([x, y], lambda var1, var2: ((var1 * var2) % 2) == 0, rowsize, "dim_0")
855+
<xarray.DataArray (dim_0: 5)>
856+
array([True, False, True, False, True])
857+
Dimensions without coordinates: dim_0
858+
818859
Returns
819860
-------
820861
mask : xr.DataArray
821862
The mask of the subset of the data matching the criteria
822863
"""
864+
if not callable(criterion) and isinstance(var, list):
865+
raise TypeError(
866+
"The `var` parameter can be a `list` only if the `criterion` is a `Callable`."
867+
)
868+
823869
if isinstance(criterion, tuple): # min/max defining range
824870
mask = np.logical_and(var >= criterion[0], var <= criterion[1])
825871
elif isinstance(criterion, (list, np.ndarray, xr.DataArray)):
826872
# select multiple values
827873
mask = np.isin(var, criterion)
828874
elif callable(criterion):
829875
# mask directly created by applying `criterion` function
830-
if len(var) == len(rowsize):
831-
mask = criterion(var)
876+
if not isinstance(var, list):
877+
var = [var]
878+
879+
if len(var[0]) == len(rowsize):
880+
mask = criterion(*var)
832881
else:
833882
mask = apply_ragged(criterion, var, rowsize)
834883

835884
mask = xr.DataArray(data=mask, dims=[dim_name]).astype(bool)
836885

837-
if not len(var) == len(mask):
886+
if not len(var[0]) == len(mask):
838887
raise ValueError(
839888
"The `Callable` function must return a masked array that matches the length of the variable to filter."
840889
)

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@ build-backend = "hatchling.build"
44

55
[project]
66
name = "clouddrift"
7-
version = "0.30.0"
7+
version = "0.31.0"
88
authors = [
99
{ name="Shane Elipot", email="selipot@miami.edu" },
1010
{ name="Philippe Miron", email="philippemiron@gmail.com" },
11-
{ name="Milan Curcic", email="mcurcic@miami.edu" }
11+
{ name="Milan Curcic", email="mcurcic@miami.edu" },
12+
{ name="Kevin Santana", email="kevinsantana11@gmail.com" }
1213
]
1314
description = "Accelerating the use of Lagrangian data for atmospheric, oceanic, and climate sciences"
1415
readme = "README.md"

tests/ragged_tests.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -714,13 +714,42 @@ def test_subset_callable(self):
714714
self.assertTrue(all(ds_sub["id"] == [1, 2]))
715715
self.assertTrue(all(ds_sub["rowsize"] == [5, 4]))
716716

717+
def test_subset_callable_tuple(self):
718+
func = lambda arr1, arr2: np.logical_and(
719+
arr1 >= 0, arr2 >= 30
720+
) # keep positive longitude and latitude larger or equal than 30
721+
ds_sub = subset(self.ds, {("lon", "lat"): func})
722+
self.assertTrue(all(ds_sub["id"] == [1, 2]))
723+
self.assertTrue(all(ds_sub["rowsize"] == [2, 2]))
724+
self.assertTrue(all(ds_sub["lon"] >= 0))
725+
self.assertTrue(all(ds_sub["lat"] >= 30))
726+
717727
def test_subset_callable_wrong_dim(self):
718728
func = lambda arr: [arr, arr] # returns 2 values per element
719729
with self.assertRaises(ValueError):
720730
subset(self.ds, {"time": func})
721731
with self.assertRaises(ValueError):
722732
subset(self.ds, {"id": func})
723733

734+
def test_subset_callable_wrong_type(self):
735+
rows = [0, 2] # test extracting first and third rows
736+
with self.assertRaises(TypeError): # passing a tuple when a string is expected
737+
subset(self.ds, {("traj",): rows})
738+
739+
def test_subset_callable_tuple_unknown_var(self):
740+
func = lambda arr1, arr2: np.logical_and(
741+
arr1 >= 0, arr2 >= 30
742+
) # keep positive longitude and latitude larger or equal than 30
743+
with self.assertRaises(ValueError):
744+
subset(self.ds, {("a", "lat"): func})
745+
746+
def test_subset_callable_tuple_not_same_dimension(self):
747+
func = lambda arr1, arr2: np.logical_and(
748+
arr1 >= 0, arr2 >= 30
749+
) # keep positive longitude and latitude larger or equal than 30
750+
with self.assertRaises(TypeError):
751+
subset(self.ds, {("id", "lat"): func})
752+
724753

725754
class unpack_tests(unittest.TestCase):
726755
def test_unpack(self):

0 commit comments

Comments
 (0)