@@ -545,20 +545,21 @@ def subset(
545545 obs_dim_name : str = "obs" ,
546546 full_trajectories = False ,
547547) -> xr .Dataset :
548- """Subset a ragged array dataset as a function of one or more criteria.
548+ """Subset a ragged array xarray dataset as a function of one or more criteria.
549549 The criteria are passed with a dictionary, where a dictionary key
550550 is a variable to subset and the associated dictionary value is either a range
551551 (valuemin, valuemax), a list [value1, value2, valueN], a single value, or a
552- masking function applied to every row of the ragged array using ``apply_ragged`` .
552+ masking function applied to any variable of the dataset .
553553
554554 This function needs to know the names of the dimensions of the ragged array dataset
555555 (`traj_dim_name` and `obs_dim_name`), and the name of the rowsize variable (`rowsize_var_name`).
556- Default values are provided for these arguments (see below), but they can be changed if needed.
556+ Default values corresponds to the clouddrift convention ("traj", "obs", and "rowsize") but should
557+ be changed as needed.
557558
558559 Parameters
559560 ----------
560561 ds : xr.Dataset
561- Dataset stored as ragged arrays.
562+ Xarray dataset composed of ragged arrays.
562563 criteria : dict
563564 Dictionary containing the variables (as keys) and the ranges/values/functions (as values) to subset.
564565 id_var_name : str, optional
@@ -570,19 +571,19 @@ def subset(
570571 obs_dim_name : str, optional
571572 Name of the observation dimension (default is "obs").
572573 full_trajectories : bool, optional
573- If True, it returns the complete trajectories (rows) where at least one observation
574- matches the criteria, rather than just the segments where the criteria are satisfied.
575- Default is False .
574+ If True, the function returns complete rows (trajectories) for which the criteria
575+ are matched at least once. Default is False which means that only segments matching the criteria
576+ are returned when filtering along the observation dimension .
576577
577578 Returns
578579 -------
579580 xr.Dataset
580- subset Dataset matching the criterion(a)
581+ Subset xarray dataset matching the criterion(a).
581582
582583 Examples
583584 --------
584- Criteria are combined on any data or metadata variables part of the Dataset.
585- The following examples are based on NOAA GDP datasets which can be accessed with the
585+ Criteria are combined on any data (with dimension "obs") or metadata (with dimension "traj") variables
586+ part of the Dataset. The following examples are based on NOAA GDP datasets which can be accessed with the
586587 ``clouddrift.datasets`` module.
587588
588589 Retrieve a region, like the Gulf of Mexico, using ranges of latitude and longitude:
@@ -607,7 +608,7 @@ def subset(
607608
608609 >>> subset(ds, {"rowsize": (0, 1000)})
609610
610- Retrieve specific drifters from their IDs:
611+ Retrieve specific drifters using their IDs:
611612
612613 >>> subset(ds, {"id": [2578, 2582, 2583]})
613614
@@ -637,10 +638,32 @@ def subset(
637638 >>> func = (lambda arr: ((arr - arr[0]) % 2) == 0)
638639 >>> subset(ds, {"time": func})
639640
641+ The filtering function can accept several input variables passed as a tuple. For example, retrieve
642+ drifters released in the Mediterranean Sea, but exclude those released in the Bay of Biscay and the Black Sea:
643+
644+ >>> def mediterranean_mask(lon: xr.DataArray, lat: xr.DataArray) -> xr.DataArray:
645+ >>> # Mediterranean Sea bounding box
646+ >>> in_med = np.logical_and(-6.0327 <= lon, np.logical_and(lon <= 36.2173,
647+ >>> np.logical_and(30.2639 <= lat, lat <= 45.7833)))
648+ >>> # Bay of Biscay
649+ >>> in_biscay = np.logical_and(lon <= -0.1462, lat >= 43.2744)
650+ >>> # Black Sea
651+ >>> in_blacksea = np.logical_and(lon >= 27.4437, lat >= 40.9088)
652+ >>> return np.logical_and(in_med, np.logical_not(np.logical_or(in_biscay, in_blacksea)))
653+ >>> subset(ds, {("start_lon", "start_lat"): mediterranean_mask})
654+
640655 Raises
641656 ------
642657 ValueError
643- If one of the variable in a criterion is not found in the Dataset
658+ If one of the variable in a criterion is not found in the Dataset.
659+ TypeError
660+ If one of the `criteria` key is a tuple while its associated value is not a `Callable` criterion.
661+ TypeError
662+ If variables of a `criterion` key associated to a `Callable` do not share the same dimension.
663+
664+ See Also
665+ --------
666+ :func:`apply_ragged`
644667 """
645668 mask_traj = xr .DataArray (
646669 data = np .ones (ds .sizes [traj_dim_name ], dtype = "bool" ), dims = [traj_dim_name ]
@@ -650,19 +673,30 @@ def subset(
650673 )
651674
652675 for key in criteria .keys ():
653- if key in ds or key in ds .dims :
654- if ds [key ].dims == (traj_dim_name ,):
676+ if np .all (np .isin (key , ds .variables ) | np .isin (key , ds .dims )):
677+ if isinstance (key , tuple ):
678+ criterion = [ds [k ] for k in key ]
679+ if not all (c .dims == criterion [0 ].dims for c in criterion ):
680+ raise TypeError (
681+ "Variables passed to the Callable must share the same dimension."
682+ )
683+ criterion_dims = criterion [0 ].dims
684+ else :
685+ criterion = ds [key ]
686+ criterion_dims = criterion .dims
687+
688+ if criterion_dims == (traj_dim_name ,):
655689 mask_traj = np .logical_and (
656690 mask_traj ,
657691 _mask_var (
658- ds [ key ] , criteria [key ], ds [rowsize_var_name ], traj_dim_name
692+ criterion , criteria [key ], ds [rowsize_var_name ], traj_dim_name
659693 ),
660694 )
661- elif ds [ key ]. dims == (obs_dim_name ,):
695+ elif criterion_dims == (obs_dim_name ,):
662696 mask_obs = np .logical_and (
663697 mask_obs ,
664698 _mask_var (
665- ds [ key ] , criteria [key ], ds [rowsize_var_name ], obs_dim_name
699+ criterion , criteria [key ], ds [rowsize_var_name ], obs_dim_name
666700 ),
667701 )
668702 else :
@@ -769,7 +803,7 @@ def unpack(
769803
770804
771805def _mask_var (
772- var : xr .DataArray ,
806+ var : Union [ xr .DataArray , list [ xr . DataArray ]] ,
773807 criterion : Union [tuple , list , np .ndarray , xr .DataArray , bool , float , int , Callable ],
774808 rowsize : xr .DataArray = None ,
775809 dim_name : str = "dim_0" ,
@@ -778,8 +812,8 @@ def _mask_var(
778812
779813 Parameters
780814 ----------
781- var : xr.DataArray
782- DataArray to be subset by the criterion
815+ var : xr.DataArray or list[xr.DataArray]
816+ DataArray or list of DataArray (only applicable if the criterion is a Callable) to be used by the criterion
783817 criterion : array-like or scalar or Callable
784818 The criterion can take four forms:
785819 - tuple: (min, max) defining a range
@@ -815,26 +849,41 @@ def _mask_var(
815849 array([False, True, False, True, False])
816850 Dimensions without coordinates: dim_0
817851
852+ >>> y = xr.DataArray(data=np.arange(0, 5)+2)
853+ >>> rowsize = xr.DataArray(data=[2, 3])
854+ >>> _mask_var([x, y], lambda var1, var2: ((var1 * var2) % 2) == 0, rowsize, "dim_0")
855+ <xarray.DataArray (dim_0: 5)>
856+ array([True, False, True, False, True])
857+ Dimensions without coordinates: dim_0
858+
818859 Returns
819860 -------
820861 mask : xr.DataArray
821862 The mask of the subset of the data matching the criteria
822863 """
864+ if not callable (criterion ) and isinstance (var , list ):
865+ raise TypeError (
866+ "The `var` parameter can be a `list` only if the `criterion` is a `Callable`."
867+ )
868+
823869 if isinstance (criterion , tuple ): # min/max defining range
824870 mask = np .logical_and (var >= criterion [0 ], var <= criterion [1 ])
825871 elif isinstance (criterion , (list , np .ndarray , xr .DataArray )):
826872 # select multiple values
827873 mask = np .isin (var , criterion )
828874 elif callable (criterion ):
829875 # mask directly created by applying `criterion` function
830- if len (var ) == len (rowsize ):
831- mask = criterion (var )
876+ if not isinstance (var , list ):
877+ var = [var ]
878+
879+ if len (var [0 ]) == len (rowsize ):
880+ mask = criterion (* var )
832881 else :
833882 mask = apply_ragged (criterion , var , rowsize )
834883
835884 mask = xr .DataArray (data = mask , dims = [dim_name ]).astype (bool )
836885
837- if not len (var ) == len (mask ):
886+ if not len (var [ 0 ] ) == len (mask ):
838887 raise ValueError (
839888 "The `Callable` function must return a masked array that matches the length of the variable to filter."
840889 )
0 commit comments