@@ -949,18 +949,28 @@ def subset(ds: xr.Dataset, criteria: dict) -> xr.Dataset:
949949
950950 Examples
951951 --------
952- Criteria are combined on any data or metadata variables part of the Dataset.
952+ Criteria are combined on any data or metadata variables part of the Dataset. The following examples are based on the GDP dataset.
953953
954- To subset between a range of values:
955- >>> subset(ds, {"lon": (min_lon, max_lon), "lat": (min_lat, max_lat)})
956- >>> subset(ds, {"time": (min_time, max_time)})
954+ Retrieve a region, like the Gulf of Mexico, using ranges of latitude and longitude:
955+ >>> subset(ds, {"lat": (21, 31), "lon": (-98, -78)})
957956
958- To select multiples values:
959- >>> subset(ds, {"ID": [1, 2, 3]})
960-
961- To select a specific value:
957+ Retrieve drogued trajectory segments:
962958 >>> subset(ds, {"drogue_status": True})
963959
960+ Retrieve trajectory segments with temperature higher than 25°C (303.15K):
961+ >>> subset(ds, {"sst": (303.15, np.inf)})
962+
963+ Retrieve specific drifters from their IDs:
964+ >>> subset(ds, {"ID": [2578, 2582, 2583]})
965+
966+ Retrieve a specific time period:
967+ >>> subset(ds, {"time": (np.datetime64("2000-01-01"), np.datetime64("2020-01-31"))})
968+
969+ Note: To subset time variable, the range has to be defined as a function type of the variable. By default, `xarray` uses `np.datetime64` to represent datetime data. If the datetime data is a `datetime.datetime`, or `pd.Timestamp`, the range would have to be define accordingly.
970+
971+ Those criteria can also be combined:
972+ >>> subset(ds, {"lat": (21, 31), "lon": (-98, -78), "drogue_status": True, "sst": (303.15, np.inf), "time": (np.datetime64("2000-01-01"), np.datetime64("2020-01-31"))})
973+
964974 Raises
965975 ------
966976 ValueError
@@ -992,11 +1002,12 @@ def subset(ds: xr.Dataset, criteria: dict) -> xr.Dataset:
9921002 warnings .warn ("No data matches the criteria; returning an empty dataset." )
9931003 return xr .Dataset ()
9941004 else :
995- # update rowsize
996- id_count = np .bincount (ds .ids [mask_obs ])
997- ds ["rowsize" ].values [mask_traj ] = [id_count [i ] for i in ds .ID [mask_traj ]]
9981005 # apply the filtering for both dimensions
999- return ds .isel ({"traj" : mask_traj , "obs" : mask_obs })
1006+ ds_sub = ds .isel ({"traj" : mask_traj , "obs" : mask_obs })
1007+ # update the rowsize
1008+ id_count = np .bincount (ds_sub .ids )
1009+ ds_sub ["rowsize" ].values = np .take (id_count , ds_sub .ID )
1010+ return ds_sub
10001011
10011012
10021013def unpack_ragged (
0 commit comments