Skip to content

Commit b867488

Browse files
raylimCopilot
andcommitted
fix: address review comments on aggregate_sample_features refactor
- validate per-slide shapes/lengths in aggregate_sample_features, raising informative ValueError with slide index and sample_id - tighten return/variable type annotations to dict[str, tuple[np.ndarray, np.ndarray]] - fix CLI memory regression: group by sample_id first, read H5 files per-sample so peak memory scales with the largest sample, not the full input set - test_multi_slide: assert concatenated content, not just shape - test_with_subsampling: assert reproducibility (same seed → same result) and features/coords alignment after subsampling - add test_aggregate_sample_features_invalid_shapes for the new per-slide validation Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 844ea67 commit b867488

3 files changed

Lines changed: 112 additions & 34 deletions

File tree

mussel/cli/aggregate_sample_features.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import collections
12
import logging
23
import os
34
from dataclasses import dataclass, field
@@ -110,24 +111,32 @@ def main(cfg: AggregateSampleFeaturesConfig):
110111
f"sample_ids ({len(sample_ids)}) must have the same length."
111112
)
112113

113-
features_list = []
114-
coords_list = []
115-
for h5_path in patch_features_h5_paths:
116-
with h5py.File(h5_path, "r") as h5:
117-
features_list.append(np.array(h5["features"]))
118-
coords_list.append(h5["coords"][:])
119-
120-
results = aggregate_sample_features(
121-
features_list=features_list,
122-
coords_list=coords_list,
123-
sample_ids=sample_ids,
124-
max_tiles=cfg.max_tiles,
125-
subsampling_strategy=cfg.subsampling_strategy,
126-
seed=cfg.seed,
127-
)
114+
# Group indices by sample_id first so we only hold one sample's slides in
115+
# memory at a time, keeping peak memory proportional to the largest sample
116+
# rather than the entire input set.
117+
groups: dict = collections.OrderedDict()
118+
for idx, sid in enumerate(sample_ids):
119+
groups.setdefault(sid, []).append(idx)
128120

129121
os.makedirs(cfg.output_dir, exist_ok=True)
130-
for sample_id, (features, coords) in results.items():
122+
for sample_id, indices in groups.items():
123+
features_list = []
124+
coords_list = []
125+
for i in indices:
126+
with h5py.File(patch_features_h5_paths[i], "r") as h5:
127+
features_list.append(np.array(h5["features"]))
128+
coords_list.append(h5["coords"][:])
129+
130+
result = aggregate_sample_features(
131+
features_list=features_list,
132+
coords_list=coords_list,
133+
sample_ids=[sample_id] * len(indices),
134+
max_tiles=cfg.max_tiles,
135+
subsampling_strategy=cfg.subsampling_strategy,
136+
seed=cfg.seed,
137+
)
138+
139+
features, coords = result[sample_id]
131140
out_path = os.path.join(cfg.output_dir, f"{sample_id}.{cfg.output_h5_suffix}")
132141
save_hdf5(out_path, {"features": features, "coords": coords}, mode="w")
133142
logger.info(

mussel/utils/feature_extract.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1937,7 +1937,7 @@ def aggregate_sample_features(
19371937
max_tiles: Optional[int] = None,
19381938
subsampling_strategy: str = "random",
19391939
seed: int = 42,
1940-
) -> dict:
1940+
) -> dict[str, tuple[np.ndarray, np.ndarray]]:
19411941
"""Concatenate per-slide patch features into one array per sample.
19421942
19431943
Groups slides by ``sample_id``, concatenates their features and coordinates
@@ -1968,11 +1968,29 @@ def aggregate_sample_features(
19681968
f"and sample_ids ({len(sample_ids)}) must all have the same length."
19691969
)
19701970

1971-
groups: dict = collections.OrderedDict()
1971+
for i, (feats, coords) in enumerate(zip(features_list, coords_list)):
1972+
sid = sample_ids[i]
1973+
if feats.ndim != 2:
1974+
raise ValueError(
1975+
f"features_list[{i}] (sample_id={sid!r}) must be 2-D, "
1976+
f"got shape {feats.shape}."
1977+
)
1978+
if coords.ndim != 2 or coords.shape[1] != 2:
1979+
raise ValueError(
1980+
f"coords_list[{i}] (sample_id={sid!r}) must have shape (N, 2), "
1981+
f"got shape {coords.shape}."
1982+
)
1983+
if len(feats) != len(coords):
1984+
raise ValueError(
1985+
f"features_list[{i}] and coords_list[{i}] (sample_id={sid!r}) "
1986+
f"have different lengths: {len(feats)} vs {len(coords)}."
1987+
)
1988+
1989+
groups: dict[str, list[int]] = collections.OrderedDict()
19721990
for idx, sid in enumerate(sample_ids):
19731991
groups.setdefault(sid, []).append(idx)
19741992

1975-
results: dict = collections.OrderedDict()
1993+
results: dict[str, tuple[np.ndarray, np.ndarray]] = collections.OrderedDict()
19761994

19771995
for sample_id, indices in groups.items():
19781996
logger.info("Aggregating sample %s from %d slide(s)", sample_id, len(indices))

tests/mussel/cli/test_aggregate_sample_features.py

Lines changed: 66 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,35 @@ def test_subsample_tiles_invalid_strategy():
123123
aggregate_sample_features as _aggregate_sample_features
124124

125125

126+
def test_aggregate_sample_features_invalid_shapes():
127+
"""Per-slide validation raises informative ValueError on bad input."""
128+
feats, coords = _make_data(10)
129+
130+
# 1-D features array
131+
with pytest.raises(ValueError, match="2-D"):
132+
_aggregate_sample_features(
133+
features_list=[feats.ravel()],
134+
coords_list=[coords],
135+
sample_ids=["s"],
136+
)
137+
138+
# coords wrong second dim
139+
with pytest.raises(ValueError, match=r"\(N, 2\)"):
140+
_aggregate_sample_features(
141+
features_list=[feats],
142+
coords_list=[coords[:, :1]],
143+
sample_ids=["s"],
144+
)
145+
146+
# mismatched lengths
147+
with pytest.raises(ValueError, match="different lengths"):
148+
_aggregate_sample_features(
149+
features_list=[feats],
150+
coords_list=[coords[:5]],
151+
sample_ids=["s"],
152+
)
153+
154+
126155
def test_aggregate_sample_features_single_slide(tmp_path):
127156
"""One slide per sample — output equals input."""
128157
feats_a, coords_a = _make_data(30)
@@ -158,6 +187,12 @@ def test_aggregate_sample_features_multi_slide(tmp_path):
158187

159188
assert results["sampleX"][0].shape == (35, 4)
160189
assert results["sampleX"][1].shape == (35, 2)
190+
np.testing.assert_array_equal(
191+
results["sampleX"][0], np.concatenate([feats_a, feats_b], axis=0)
192+
)
193+
np.testing.assert_array_equal(
194+
results["sampleX"][1], np.concatenate([coords_a, coords_b], axis=0)
195+
)
161196

162197

163198
def test_aggregate_sample_features_two_samples(tmp_path):
@@ -179,23 +214,39 @@ def test_aggregate_sample_features_two_samples(tmp_path):
179214

180215

181216
def test_aggregate_sample_features_with_subsampling(tmp_path):
182-
"""Subsampling reduces output to max_tiles."""
183-
rng = np.random.default_rng(0)
184-
feats_a = rng.random((80, 4)).astype(np.float32)
185-
coords_a = rng.integers(0, 1000, (80, 2))
186-
feats_b = rng.random((60, 4)).astype(np.float32)
187-
coords_b = rng.integers(0, 1000, (60, 2))
217+
"""Subsampling reduces output to max_tiles, is reproducible, and keeps features/coords aligned."""
218+
# Use identifiable rows: feature row i has value i in all dims, coord row i
219+
# is (i, i). After subsampling, each selected feature row must equal its
220+
# corresponding coord row, proving the two arrays stay in sync.
221+
n_a, n_b = 80, 60
222+
feats_a = np.tile(np.arange(n_a, dtype=np.float32)[:, None], (1, 4))
223+
coords_a = np.tile(np.arange(n_a)[:, None], (1, 2))
224+
feats_b = np.tile(np.arange(n_a, n_a + n_b, dtype=np.float32)[:, None], (1, 4))
225+
coords_b = np.tile(np.arange(n_a, n_a + n_b)[:, None], (1, 2))
226+
227+
def run():
228+
return _aggregate_sample_features(
229+
features_list=[feats_a, feats_b],
230+
coords_list=[coords_a, coords_b],
231+
sample_ids=["big", "big"],
232+
max_tiles=50,
233+
subsampling_strategy="random",
234+
seed=99,
235+
)
188236

189-
results = _aggregate_sample_features(
190-
features_list=[feats_a, feats_b],
191-
coords_list=[coords_a, coords_b],
192-
sample_ids=["big", "big"],
193-
max_tiles=50,
194-
subsampling_strategy="random",
195-
seed=99,
196-
)
237+
r1 = run()
238+
r2 = run()
239+
f_out, c_out = r1["big"]
240+
241+
assert f_out.shape[0] == 50
242+
assert c_out.shape[0] == 50
243+
244+
# Reproducible with same seed
245+
np.testing.assert_array_equal(r1["big"][0], r2["big"][0])
246+
np.testing.assert_array_equal(r1["big"][1], r2["big"][1])
197247

198-
assert results["big"][0].shape[0] == 50
248+
# features and coords remain aligned: feature value == coord value for each row
249+
np.testing.assert_array_equal(f_out[:, 0].astype(np.int64), c_out[:, 0])
199250

200251

201252
# =============================================================================

0 commit comments

Comments
 (0)