Skip to content

Commit 6006d7c

Browse files
add option to subset and return complete_traj (#332)
1 parent 94a5efc commit 6006d7c

File tree

5 files changed

+55
-6
lines changed

5 files changed

+55
-6
lines changed

.codecov.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@ coverage:
88
status:
99
project:
1010
default:
11-
informational: true
11+
threshold: 10%
1212
patch:
1313
default:
14-
informational: true
14+
threshold: 10%
1515

1616
comment: off
1717

.github/workflows/ci.yml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,16 @@ jobs:
2626
with:
2727
environment-file: environment.yml
2828
environment-name: clouddrift
29-
- name: Testing
29+
- name: Run unit tests
3030
shell: bash -l {0}
3131
run: |
32-
pip install coverage
33-
pip install matplotlib-base cartopy
32+
micromamba install coverage matplotlib-base cartopy
3433
coverage run -m unittest discover -s tests -p "*.py"
34+
- name: Create coverage report
35+
shell: bash -l {0}
36+
run: |
3537
coverage xml
36-
- name: Upload coverage reports to Codecov
38+
- name: Upload coverage report to Codecov
3739
if: matrix.os == 'ubuntu-latest' && matrix.python-version == '3.9'
3840
uses: codecov/codecov-action@v3
3941
with:

clouddrift/plotting.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,4 +212,8 @@ def plot_ragged(
212212
start = end
213213
lines.append(line)
214214

215+
# set axis limits
216+
ax.set_xlim([np.min(longitude), np.max(longitude)])
217+
ax.set_ylim([np.min(latitude), np.max(latitude)])
218+
215219
return lines

clouddrift/ragged.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,7 @@ def subset(
543543
rowsize_var_name: str = "rowsize",
544544
traj_dim_name: str = "traj",
545545
obs_dim_name: str = "obs",
546+
full_trajectories=False,
546547
) -> xr.Dataset:
547548
"""Subset the dataset as a function of one or many criteria. The criteria are
548549
passed as a dictionary, where a variable to subset is assigned to either a
@@ -567,6 +568,10 @@ def subset(
567568
Name of the trajectory dimension (default is "traj")
568569
obs_dim_name : str, optional
569570
Name of the observation dimension (default is "obs")
571+
full_trajectories : bool, optional
572+
If True, it returns the complete trajectories where at least one observation
573+
matches the criteria, rather than just the segments where the criteria are satisfied.
574+
Default is False.
570575
571576
Returns
572577
-------
@@ -582,6 +587,10 @@ def subset(
582587
583588
>>> subset(ds, {"lat": (21, 31), "lon": (-98, -78)})
584589
590+
The parameter `full_trajectories` can be used to retrieve trajectories passing through a region, for example all trajectories passing through the Gulf of Mexico:
591+
592+
>>> subset(ds, {"lat": (21, 31), "lon": (-98, -78)}, full_trajectories=True)
593+
585594
Retrieve drogued trajectory segments:
586595
587596
>>> subset(ds, {"drogue_status": True})
@@ -641,6 +650,14 @@ def subset(
641650
mask_traj, np.in1d(ds[id_var_name], np.unique(ids_with_mask_obs))
642651
)
643652

653+
# reset mask_obs to True to keep complete trajectories
654+
if full_trajectories:
655+
for i in np.where(mask_traj)[0]:
656+
mask_obs[slice(traj_idx[i], traj_idx[i + 1])] = True
657+
ids_with_mask_obs = np.repeat(
658+
ds[id_var_name].values, ds[rowsize_var_name].values
659+
)[mask_obs]
660+
644661
if not any(mask_traj):
645662
warnings.warn("No data matches the criteria; returning an empty dataset.")
646663
return xr.Dataset()

tests/ragged_tests.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -667,6 +667,32 @@ def test_arraylike_criterion(self):
667667
ds_sub = subset(self.ds, {"ID": self.ds["ID"][:2].values})
668668
self.assertTrue(ds_sub["ID"].size == 2)
669669

670+
def test_full_trajectories(self):
671+
ds_id_rowsize = {
672+
i: j for i, j in zip(self.ds.ID.values, self.ds.rowsize.values)
673+
}
674+
675+
ds_sub = subset(self.ds, {"lon": (-125, -111)}, full_trajectories=True)
676+
self.assertTrue(all(ds_sub.lon == [-121, -111, 51, 61, 71]))
677+
678+
ds_sub_id_rowsize = {
679+
i: j for i, j in zip(ds_sub.ID.values, ds_sub.rowsize.values)
680+
}
681+
for k, v in ds_sub_id_rowsize.items():
682+
self.assertTrue(ds_id_rowsize[k] == v)
683+
684+
ds_sub = subset(self.ds, {"lat": (30, 40)}, full_trajectories=True)
685+
self.assertTrue(all(ds_sub.lat == [10, 20, 30, 40]))
686+
687+
ds_sub_id_rowsize = {
688+
i: j for i, j in zip(ds_sub.ID.values, ds_sub.rowsize.values)
689+
}
690+
for k, v in ds_sub_id_rowsize.items():
691+
self.assertTrue(ds_id_rowsize[k] == v)
692+
693+
ds_sub = subset(self.ds, {"time": (4, 5)}, full_trajectories=True)
694+
xr.testing.assert_equal(self.ds, ds_sub)
695+
670696

671697
class unpack_tests(unittest.TestCase):
672698
def test_unpack(self):

0 commit comments

Comments
 (0)