Skip to content

Commit aef0360

Browse files
philippemironPhilippe Miron
andauthored
Subset_fix (#261)
* make test fail * fix subset * cleanup sample_ragged_array creation * lint --------- Co-authored-by: Philippe Miron <[email protected]>
1 parent 5de2e50 commit aef0360

File tree

2 files changed

+15
-11
lines changed

2 files changed

+15
-11
lines changed

clouddrift/analysis.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1045,9 +1045,10 @@ def subset(
10451045
else:
10461046
# apply the filtering for both dimensions
10471047
ds_sub = ds.isel({traj_dim_name: mask_traj, obs_dim_name: mask_obs})
1048-
_, ds_sub[rowsize_var_name].values = np.unique(
1049-
ids_with_mask_obs, return_counts=True
1048+
_, unique_idx, sorted_rowsize = np.unique(
1049+
ids_with_mask_obs, return_index=True, return_counts=True
10501050
)
1051+
ds_sub[rowsize_var_name].values = sorted_rowsize[np.argsort(unique_idx)]
10511052
return ds_sub
10521053

10531054

tests/analysis_tests.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,17 +25,17 @@
2525

2626

2727
def sample_ragged_array() -> RaggedArray:
28-
drifter_id = [1, 2, 3]
29-
rowsize = [5, 4, 2]
30-
longitude = [[-121, -111, 51, 61, 71], [12, 22, 32, 42], [103, 113]]
31-
latitude = [[-90, -45, 45, 90, 0], [10, 20, 30, 40], [10, 20]]
32-
t = [[1, 2, 3, 4, 5], [2, 3, 4, 5], [4, 5]]
33-
ids = [[1, 1, 1, 1, 1], [2, 2, 2, 2], [3, 3]]
28+
drifter_id = [1, 3, 2]
29+
longitude = [[-121, -111, 51, 61, 71], [103, 113], [12, 22, 32, 42]]
30+
latitude = [[-90, -45, 45, 90, 0], [10, 20], [10, 20, 30, 40]]
31+
t = [[1, 2, 3, 4, 5], [4, 5], [2, 3, 4, 5]]
3432
test = [
3533
[True, True, True, False, False],
36-
[True, False, False, False],
3734
[False, False],
35+
[True, False, False, False],
3836
]
37+
rowsize = [len(x) for x in longitude]
38+
ids = [[d] * rowsize[i] for i, d in enumerate(drifter_id)]
3939
nb_obs = np.sum(rowsize)
4040
nb_traj = len(drifter_id)
4141
attrs_global = {
@@ -796,14 +796,17 @@ def test_select(self):
796796
def test_range(self):
797797
# positive
798798
ds_sub = subset(self.ds, {"lon": (0, 180)})
799+
print(ds_sub)
799800
traj_idx = np.insert(np.cumsum(ds_sub["rowsize"].values), 0, 0)
800801
self.assertTrue(
801802
all(ds_sub.lon[slice(traj_idx[0], traj_idx[1])] == [51, 61, 71])
802803
)
804+
805+
self.assertTrue(all(ds_sub.lon[slice(traj_idx[1], traj_idx[2])] == [103, 113]))
806+
803807
self.assertTrue(
804-
all(ds_sub.lon[slice(traj_idx[1], traj_idx[2])] == [12, 22, 32, 42])
808+
all(ds_sub.lon[slice(traj_idx[2], traj_idx[3])] == [12, 22, 32, 42])
805809
)
806-
self.assertTrue(all(ds_sub.lon[slice(traj_idx[2], traj_idx[3])] == [103, 113]))
807810

808811
# negative range
809812
ds_sub = subset(self.ds, {"lon": (-180, 0)})

0 commit comments

Comments
 (0)