Skip to content

Commit f702928

Browse files
Merge pull request #350 from vadmbertr/subset-callable
Allow `Callable` criterion in `ragged.subset`
2 parents 22ea170 + c438051 commit f702928

3 files changed

Lines changed: 83 additions & 18 deletions

File tree

clouddrift/ragged.py

Lines changed: 61 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"""
44

55
import numpy as np
6-
from typing import Tuple, Union, Iterable
6+
from typing import Tuple, Union, Iterable, Callable
77
import xarray as xr
88
import pandas as pd
99
from concurrent import futures
@@ -545,21 +545,22 @@ def subset(
545545
obs_dim_name: str = "obs",
546546
full_trajectories=False,
547547
) -> xr.Dataset:
548-
"""Subset the dataset as a function of one or many criteria. The criteria are
549-
passed as a dictionary, where a variable to subset is assigned to either a
550-
range (valuemin, valuemax), a list [value1, value2, valueN], or a single value.
548+
"""Subset a ragged array dataset as a function of one or more criteria.
549+
The criteria are passed with a dictionary, where a dictionary key
550+
is a variable to subset and the associated dictionary value is either a range
551+
(valuemin, valuemax), a list [value1, value2, valueN], a single value, or a
552+
masking function applied to every row of the ragged array using ``apply_ragged``.
551553
552-
This function relies on specific names of the dataset dimensions and the
553-
rowsize variables. The default expected values are listed in the Parameters
554-
section, however, if your dataset uses different names for these dimensions
555-
and variables, you can specify them using the optional arguments.
554+
This function needs to know the names of the dimensions of the ragged array dataset
555+
(`traj_dim_name` and `obs_dim_name`), and the name of the rowsize variable (`rowsize_var_name`).
556+
Default values are provided for these arguments (see below), but they can be changed if needed.
556557
557558
Parameters
558559
----------
559560
ds : xr.Dataset
560-
Lagrangian dataset stored in two-dimensional or ragged array format
561+
Dataset stored as ragged arrays
561562
criteria : dict
562-
dictionary containing the variables and the ranges/values to subset
563+
dictionary containing the variables (as keys) and the ranges/values/functions (as values) to subset
563564
id_var_name : str, optional
564565
Name of the variable containing the ID of the trajectories (default is "id")
565566
rowsize_var_name : str, optional
@@ -569,7 +570,7 @@ def subset(
569570
obs_dim_name : str, optional
570571
Name of the observation dimension (default is "obs")
571572
full_trajectories : bool, optional
572-
If True, it returns the complete trajectories where at least one observation
573+
If True, it returns the complete trajectories (rows) where at least one observation
573574
matches the criteria, rather than just the segments where the criteria are satisfied.
574575
Default is False.
575576
@@ -581,7 +582,8 @@ def subset(
581582
Examples
582583
--------
583584
Criteria are combined on any data or metadata variables part of the Dataset.
584-
The following examples are based on the GDP dataset.
585+
The following examples are based on NOAA GDP datasets which can be accessed with the
586+
``clouddrift.datasets`` module.
585587
586588
Retrieve a region, like the Gulf of Mexico, using ranges of latitude and longitude:
587589
@@ -629,6 +631,12 @@ def subset(
629631
630632
>>> subset(ds, {"lat": (21, 31), "lon": (-98, -78), "drogue_status": True, "sst": (303.15, np.inf), "time": (np.datetime64("2000-01-01"), np.datetime64("2020-01-31"))})
631633
634+
You can also use a function to filter the data. For example, retrieve every other observation
635+
of each trajectory (row):
636+
637+
>>> func = (lambda arr: ((arr - arr[0]) % 2) == 0)
638+
>>> subset(ds, {"time": func})
639+
632640
Raises
633641
------
634642
ValueError
@@ -644,9 +652,19 @@ def subset(
644652
for key in criteria.keys():
645653
if key in ds or key in ds.dims:
646654
if ds[key].dims == (traj_dim_name,):
647-
mask_traj = np.logical_and(mask_traj, _mask_var(ds[key], criteria[key]))
655+
mask_traj = np.logical_and(
656+
mask_traj,
657+
_mask_var(
658+
ds[key], criteria[key], ds[rowsize_var_name], traj_dim_name
659+
),
660+
)
648661
elif ds[key].dims == (obs_dim_name,):
649-
mask_obs = np.logical_and(mask_obs, _mask_var(ds[key], criteria[key]))
662+
mask_obs = np.logical_and(
663+
mask_obs,
664+
_mask_var(
665+
ds[key], criteria[key], ds[rowsize_var_name], obs_dim_name
666+
),
667+
)
650668
else:
651669
raise ValueError(f"Unknown variable '{key}'.")
652670

@@ -752,19 +770,26 @@ def unpack(
752770

753771
def _mask_var(
754772
var: xr.DataArray,
755-
criterion: Union[tuple, list, np.ndarray, xr.DataArray, bool, float, int],
773+
criterion: Union[tuple, list, np.ndarray, xr.DataArray, bool, float, int, Callable],
774+
rowsize: xr.DataArray = None,
775+
dim_name: str = "dim_0",
756776
) -> xr.DataArray:
757777
"""Return the mask of a subset of the data matching a test criterion.
758778
759779
Parameters
760780
----------
761781
var : xr.DataArray
762782
DataArray to be subset by the criterion
763-
criterion : array-like
764-
The criterion can take three forms:
783+
criterion : array-like or scalar or Callable
784+
The criterion can take four forms:
765785
- tuple: (min, max) defining a range
766786
- list, np.ndarray, or xr.DataArray: An array-like defining multiples values
767787
- scalar: value defining a single value
788+
- function: a function applied against each trajectory using ``apply_ragged`` and returning a mask
789+
rowsize : xr.DataArray, optional
790+
List of integers specifying the number of data points in each row
791+
dim_name : str, optional
792+
Name of the masked dimension (default is "dim_0")
768793
769794
Examples
770795
--------
@@ -784,6 +809,12 @@ def _mask_var(
784809
array([False, False, False, True, False])
785810
Dimensions without coordinates: dim_0
786811
812+
>>> rowsize = xr.DataArray(data=[2, 3])
813+
>>> _mask_var(x, lambda arr: arr==arr[0]+1, rowsize, "dim_0")
814+
<xarray.DataArray (dim_0: 5)>
815+
array([False, True, False, True, False])
816+
Dimensions without coordinates: dim_0
817+
787818
Returns
788819
-------
789820
mask : xr.DataArray
@@ -794,6 +825,19 @@ def _mask_var(
794825
elif isinstance(criterion, (list, np.ndarray, xr.DataArray)):
795826
# select multiple values
796827
mask = np.isin(var, criterion)
828+
elif callable(criterion):
829+
# mask directly created by applying `criterion` function
830+
if len(var) == len(rowsize):
831+
mask = criterion(var)
832+
else:
833+
mask = apply_ragged(criterion, var, rowsize)
834+
835+
mask = xr.DataArray(data=mask, dims=[dim_name]).astype(bool)
836+
837+
if not len(var) == len(mask):
838+
raise ValueError(
839+
"The `Callable` function must return a masked array that matches the length of the variable to filter."
840+
)
797841
else: # select one specific value
798842
mask = var == criterion
799843
return mask

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
44

55
[project]
66
name = "clouddrift"
7-
version = "0.29.0"
7+
version = "0.30.0"
88
authors = [
99
{ name="Shane Elipot", email="[email protected]" },
1010
{ name="Philippe Miron", email="[email protected]" },

tests/ragged_tests.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -700,6 +700,27 @@ def test_subset_by_rows(self):
700700
self.assertTrue(all(ds_sub["id"] == [1, 2]))
701701
self.assertTrue(all(ds_sub["rowsize"] == [5, 4]))
702702

703+
def test_subset_callable(self):
704+
func = (
705+
lambda arr: ((arr - arr[0]) % 2) == 0
706+
) # test keeping obs every two time intervals
707+
ds_sub = subset(self.ds, {"time": func})
708+
self.assertTrue(all(ds_sub["id"] == [1, 3, 2]))
709+
self.assertTrue(all(ds_sub["rowsize"] == [3, 1, 2]))
710+
self.assertTrue(all(ds_sub["time"] == [1, 3, 5, 4, 2, 4]))
711+
712+
func = lambda arr: arr <= 2 # keep id larger or equal to 2
713+
ds_sub = subset(self.ds, {"id": func})
714+
self.assertTrue(all(ds_sub["id"] == [1, 2]))
715+
self.assertTrue(all(ds_sub["rowsize"] == [5, 4]))
716+
717+
def test_subset_callable_wrong_dim(self):
718+
func = lambda arr: [arr, arr] # returns 2 values per element
719+
with self.assertRaises(ValueError):
720+
subset(self.ds, {"time": func})
721+
with self.assertRaises(ValueError):
722+
subset(self.ds, {"id": func})
723+
703724

704725
class unpack_tests(unittest.TestCase):
705726
def test_unpack(self):

0 commit comments

Comments
 (0)