33"""
44
55import numpy as np
6- from typing import Tuple , Union , Iterable
6+ from typing import Tuple , Union , Iterable , Callable
77import xarray as xr
88import pandas as pd
99from concurrent import futures
@@ -545,21 +545,22 @@ def subset(
545545 obs_dim_name : str = "obs" ,
546546 full_trajectories = False ,
547547) -> xr .Dataset :
548- """Subset the dataset as a function of one or many criteria. The criteria are
549- passed as a dictionary, where a variable to subset is assigned to either a
550- range (valuemin, valuemax), a list [value1, value2, valueN], or a single value.
548+ """Subset a ragged array dataset as a function of one or more criteria.
549+ The criteria are passed with a dictionary, where a dictionary key
550+ is a variable to subset and the associated dictionary value is either a range
551+ (valuemin, valuemax), a list [value1, value2, valueN], a single value, or a
552+ masking function applied to every row of the ragged array using ``apply_ragged``.
551553
552- This function relies on specific names of the dataset dimensions and the
553- rowsize variables. The default expected values are listed in the Parameters
554- section, however, if your dataset uses different names for these dimensions
555- and variables, you can specify them using the optional arguments.
554+ This function needs to know the names of the dimensions of the ragged array dataset
555+ (`traj_dim_name` and `obs_dim_name`), and the name of the rowsize variable (`rowsize_var_name`).
556+ Default values are provided for these arguments (see below), but they can be changed if needed.
556557
557558 Parameters
558559 ----------
559560 ds : xr.Dataset
560- Lagrangian dataset stored in two-dimensional or ragged array format
561+ Dataset stored as ragged arrays
561562 criteria : dict
562- dictionary containing the variables and the ranges/values to subset
563+ dictionary containing the variables (as keys) and the ranges/values/functions (as values) to subset
563564 id_var_name : str, optional
564565 Name of the variable containing the ID of the trajectories (default is "id")
565566 rowsize_var_name : str, optional
@@ -569,7 +570,7 @@ def subset(
569570 obs_dim_name : str, optional
570571 Name of the observation dimension (default is "obs")
571572 full_trajectories : bool, optional
572- If True, it returns the complete trajectories where at least one observation
573+ If True, it returns the complete trajectories (rows) where at least one observation
573574 matches the criteria, rather than just the segments where the criteria are satisfied.
574575 Default is False.
575576
@@ -581,7 +582,8 @@ def subset(
581582 Examples
582583 --------
583584 Criteria are combined on any data or metadata variables part of the Dataset.
584- The following examples are based on the GDP dataset.
585+ The following examples are based on NOAA GDP datasets which can be accessed with the
586+ ``clouddrift.datasets`` module.
585587
586588 Retrieve a region, like the Gulf of Mexico, using ranges of latitude and longitude:
587589
@@ -629,6 +631,12 @@ def subset(
629631
630632 >>> subset(ds, {"lat": (21, 31), "lon": (-98, -78), "drogue_status": True, "sst": (303.15, np.inf), "time": (np.datetime64("2000-01-01"), np.datetime64("2020-01-31"))})
631633
634+ You can also use a function to filter the data. For example, retrieve every other observation
635+ of each trajectory (row):
636+
637+ >>> func = (lambda arr: ((arr - arr[0]) % 2) == 0)
638+ >>> subset(ds, {"time": func})
639+
632640 Raises
633641 ------
634642 ValueError
@@ -644,9 +652,19 @@ def subset(
644652 for key in criteria .keys ():
645653 if key in ds or key in ds .dims :
646654 if ds [key ].dims == (traj_dim_name ,):
647- mask_traj = np .logical_and (mask_traj , _mask_var (ds [key ], criteria [key ]))
655+ mask_traj = np .logical_and (
656+ mask_traj ,
657+ _mask_var (
658+ ds [key ], criteria [key ], ds [rowsize_var_name ], traj_dim_name
659+ ),
660+ )
648661 elif ds [key ].dims == (obs_dim_name ,):
649- mask_obs = np .logical_and (mask_obs , _mask_var (ds [key ], criteria [key ]))
662+ mask_obs = np .logical_and (
663+ mask_obs ,
664+ _mask_var (
665+ ds [key ], criteria [key ], ds [rowsize_var_name ], obs_dim_name
666+ ),
667+ )
650668 else :
651669 raise ValueError (f"Unknown variable '{ key } '." )
652670
@@ -752,19 +770,26 @@ def unpack(
752770
753771def _mask_var (
754772 var : xr .DataArray ,
755- criterion : Union [tuple , list , np .ndarray , xr .DataArray , bool , float , int ],
773+ criterion : Union [tuple , list , np .ndarray , xr .DataArray , bool , float , int , Callable ],
774+ rowsize : xr .DataArray = None ,
775+ dim_name : str = "dim_0" ,
756776) -> xr .DataArray :
757777 """Return the mask of a subset of the data matching a test criterion.
758778
759779 Parameters
760780 ----------
761781 var : xr.DataArray
762782 DataArray to be subset by the criterion
763- criterion : array-like
764- The criterion can take three forms:
783+ criterion : array-like or scalar or Callable
784+ The criterion can take four forms:
765785 - tuple: (min, max) defining a range
766786 - list, np.ndarray, or xr.DataArray: An array-like defining multiples values
767787 - scalar: value defining a single value
788+ - function: a function applied against each trajectory using ``apply_ragged`` and returning a mask
789+ rowsize : xr.DataArray, optional
790+ List of integers specifying the number of data points in each row
791+ dim_name : str, optional
792+ Name of the masked dimension (default is "dim_0")
768793
769794 Examples
770795 --------
@@ -784,6 +809,12 @@ def _mask_var(
784809 array([False, False, False, True, False])
785810 Dimensions without coordinates: dim_0
786811
812+ >>> rowsize = xr.DataArray(data=[2, 3])
813+ >>> _mask_var(x, lambda arr: arr==arr[0]+1, rowsize, "dim_0")
814+ <xarray.DataArray (dim_0: 5)>
815+ array([False, True, False, True, False])
816+ Dimensions without coordinates: dim_0
817+
787818 Returns
788819 -------
789820 mask : xr.DataArray
@@ -794,6 +825,19 @@ def _mask_var(
794825 elif isinstance (criterion , (list , np .ndarray , xr .DataArray )):
795826 # select multiple values
796827 mask = np .isin (var , criterion )
828+ elif callable (criterion ):
829+ # mask directly created by applying `criterion` function
830+ if len (var ) == len (rowsize ):
831+ mask = criterion (var )
832+ else :
833+ mask = apply_ragged (criterion , var , rowsize )
834+
835+ mask = xr .DataArray (data = mask , dims = [dim_name ]).astype (bool )
836+
837+ if not len (var ) == len (mask ):
838+ raise ValueError (
839+ "The `Callable` function must return a masked array that matches the length of the variable to filter."
840+ )
797841 else : # select one specific value
798842 mask = var == criterion
799843 return mask
0 commit comments