Skip to content

Commit e197fcb

Browse files
test(sam3): add block-level and manifest tests for class_mapping (#2202)
Adds block-level tests that exercise the full run() method with class_mapping (mocking run_locally), plus manifest validation tests. Refactored to match repo patterns: pytest fixtures, top-level functions, docstrings.
1 parent b8874e7 commit e197fcb

1 file changed

Lines changed: 174 additions & 46 deletions

File tree

Lines changed: 174 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,20 @@
11
"""Unit tests for SAM3 v3 block class_mapping feature."""
22

3+
from unittest.mock import MagicMock, patch
4+
35
import numpy as np
6+
import pytest
47
import supervision as sv
58

9+
from inference.core.workflows.core_steps.common.entities import StepExecutionMode
610
from inference.core.workflows.core_steps.models.foundation.segment_anything3.v3 import (
11+
BlockManifest,
712
SegmentAnything3BlockV3,
813
)
14+
from inference.core.workflows.execution_engine.entities.base import (
15+
ImageParentMetadata,
16+
WorkflowImageData,
17+
)
918

1019

1120
def _make_detections(class_names: list[str]) -> sv.Detections:
@@ -21,49 +30,168 @@ def _make_result(class_names: list[str]) -> list[dict]:
2130
return [{"predictions": _make_detections(class_names)}]
2231

2332

24-
class TestApplyClassMapping:
25-
def test_no_mapping_returns_unchanged(self):
26-
result = _make_result(["cat", "dog"])
27-
mapped = SegmentAnything3BlockV3._apply_class_mapping(result, {})
28-
assert list(mapped[0]["predictions"].data["class_name"]) == ["cat", "dog"]
29-
30-
def test_full_mapping(self):
31-
result = _make_result(["cat", "dog"])
32-
mapped = SegmentAnything3BlockV3._apply_class_mapping(
33-
result, {"cat": "gato", "dog": "perro"}
34-
)
35-
assert list(mapped[0]["predictions"].data["class_name"]) == ["gato", "perro"]
36-
37-
def test_partial_mapping(self):
38-
result = _make_result(["cat", "dog", "bird"])
39-
mapped = SegmentAnything3BlockV3._apply_class_mapping(
40-
result, {"cat": "gato"}
41-
)
42-
assert list(mapped[0]["predictions"].data["class_name"]) == [
43-
"gato", "dog", "bird"
44-
]
45-
46-
def test_mapping_with_no_matching_keys(self):
47-
result = _make_result(["cat", "dog"])
48-
mapped = SegmentAnything3BlockV3._apply_class_mapping(
49-
result, {"fish": "pez"}
50-
)
51-
assert list(mapped[0]["predictions"].data["class_name"]) == ["cat", "dog"]
52-
53-
def test_multiple_images(self):
54-
result = [
55-
{"predictions": _make_detections(["cat"])},
56-
{"predictions": _make_detections(["dog"])},
57-
]
58-
mapped = SegmentAnything3BlockV3._apply_class_mapping(
59-
result, {"cat": "gato", "dog": "perro"}
60-
)
61-
assert list(mapped[0]["predictions"].data["class_name"]) == ["gato"]
62-
assert list(mapped[1]["predictions"].data["class_name"]) == ["perro"]
63-
64-
def test_empty_result(self):
65-
result = []
66-
mapped = SegmentAnything3BlockV3._apply_class_mapping(
67-
result, {"cat": "gato"}
68-
)
69-
assert mapped == []
33+
@pytest.fixture
34+
def mock_workflow_image_data():
35+
img = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
36+
return WorkflowImageData(
37+
parent_metadata=ImageParentMetadata(parent_id="test"),
38+
numpy_image=img,
39+
)
40+
41+
42+
# --- Manifest tests ---
43+
44+
45+
def test_manifest_parsing_with_class_mapping():
46+
"""Test that BlockManifest accepts the class_mapping field."""
47+
data = {
48+
"type": "roboflow_core/sam3@v3",
49+
"name": "my_sam3_step",
50+
"images": "$inputs.image",
51+
"class_names": ["cat", "dog"],
52+
"class_mapping": {"cat": "gato", "dog": "perro"},
53+
}
54+
result = BlockManifest.model_validate(data)
55+
assert result.class_mapping == {"cat": "gato", "dog": "perro"}
56+
57+
58+
def test_manifest_parsing_without_class_mapping():
59+
"""Test that class_mapping is optional and defaults to None."""
60+
data = {
61+
"type": "roboflow_core/sam3@v3",
62+
"name": "my_sam3_step",
63+
"images": "$inputs.image",
64+
"class_names": ["cat", "dog"],
65+
}
66+
result = BlockManifest.model_validate(data)
67+
assert result.class_mapping is None
68+
69+
70+
# --- _apply_class_mapping unit tests ---
71+
72+
73+
def test_apply_class_mapping_full():
74+
"""Test remapping all class names."""
75+
result = _make_result(["cat", "dog"])
76+
mapped = SegmentAnything3BlockV3._apply_class_mapping(
77+
result, {"cat": "gato", "dog": "perro"}
78+
)
79+
assert list(mapped[0]["predictions"].data["class_name"]) == ["gato", "perro"]
80+
81+
82+
def test_apply_class_mapping_partial():
83+
"""Test remapping only some class names, leaving others unchanged."""
84+
result = _make_result(["cat", "dog", "bird"])
85+
mapped = SegmentAnything3BlockV3._apply_class_mapping(result, {"cat": "gato"})
86+
assert list(mapped[0]["predictions"].data["class_name"]) == [
87+
"gato",
88+
"dog",
89+
"bird",
90+
]
91+
92+
93+
def test_apply_class_mapping_no_matching_keys():
94+
"""Test that unmatched mapping keys leave predictions unchanged."""
95+
result = _make_result(["cat", "dog"])
96+
mapped = SegmentAnything3BlockV3._apply_class_mapping(result, {"fish": "pez"})
97+
assert list(mapped[0]["predictions"].data["class_name"]) == ["cat", "dog"]
98+
99+
100+
def test_apply_class_mapping_multiple_images():
101+
"""Test remapping across multiple images in a batch."""
102+
result = [
103+
{"predictions": _make_detections(["cat"])},
104+
{"predictions": _make_detections(["dog"])},
105+
]
106+
mapped = SegmentAnything3BlockV3._apply_class_mapping(
107+
result, {"cat": "gato", "dog": "perro"}
108+
)
109+
assert list(mapped[0]["predictions"].data["class_name"]) == ["gato"]
110+
assert list(mapped[1]["predictions"].data["class_name"]) == ["perro"]
111+
112+
113+
def test_apply_class_mapping_empty_result():
114+
"""Test that an empty result list is handled gracefully."""
115+
result = []
116+
mapped = SegmentAnything3BlockV3._apply_class_mapping(result, {"cat": "gato"})
117+
assert mapped == []
118+
119+
120+
def test_apply_class_mapping_empty_mapping():
121+
"""Test that an empty mapping leaves predictions unchanged."""
122+
result = _make_result(["cat", "dog"])
123+
mapped = SegmentAnything3BlockV3._apply_class_mapping(result, {})
124+
assert list(mapped[0]["predictions"].data["class_name"]) == ["cat", "dog"]
125+
126+
127+
# --- Block-level run() tests ---
128+
129+
130+
@patch.object(SegmentAnything3BlockV3, "run_locally")
131+
def test_run_with_class_mapping_remaps_predictions(
132+
mock_run_locally, mock_workflow_image_data
133+
):
134+
"""Test that block.run() applies class_mapping to predictions from run_locally."""
135+
mock_run_locally.return_value = _make_result(["cat", "dog"])
136+
block = SegmentAnything3BlockV3(
137+
model_manager=MagicMock(),
138+
api_key="test_key",
139+
step_execution_mode=StepExecutionMode.LOCAL,
140+
)
141+
142+
result = block.run(
143+
images=[mock_workflow_image_data],
144+
model_id="sam3/sam3_final",
145+
class_names=["cat", "dog"],
146+
confidence=0.5,
147+
class_mapping={"cat": "gato", "dog": "perro"},
148+
)
149+
150+
assert list(result[0]["predictions"].data["class_name"]) == ["gato", "perro"]
151+
152+
153+
@patch.object(SegmentAnything3BlockV3, "run_locally")
154+
def test_run_without_class_mapping_leaves_predictions_unchanged(
155+
mock_run_locally, mock_workflow_image_data
156+
):
157+
"""Test that block.run() without class_mapping does not alter predictions."""
158+
mock_run_locally.return_value = _make_result(["cat", "dog"])
159+
block = SegmentAnything3BlockV3(
160+
model_manager=MagicMock(),
161+
api_key="test_key",
162+
step_execution_mode=StepExecutionMode.LOCAL,
163+
)
164+
165+
result = block.run(
166+
images=[mock_workflow_image_data],
167+
model_id="sam3/sam3_final",
168+
class_names=["cat", "dog"],
169+
confidence=0.5,
170+
)
171+
172+
assert list(result[0]["predictions"].data["class_name"]) == ["cat", "dog"]
173+
174+
175+
@patch.object(SegmentAnything3BlockV3, "run_locally")
176+
def test_run_with_partial_class_mapping(mock_run_locally, mock_workflow_image_data):
177+
"""Test that block.run() with partial class_mapping only remaps matched classes."""
178+
mock_run_locally.return_value = _make_result(["cat", "dog", "bird"])
179+
block = SegmentAnything3BlockV3(
180+
model_manager=MagicMock(),
181+
api_key="test_key",
182+
step_execution_mode=StepExecutionMode.LOCAL,
183+
)
184+
185+
result = block.run(
186+
images=[mock_workflow_image_data],
187+
model_id="sam3/sam3_final",
188+
class_names=["cat", "dog", "bird"],
189+
confidence=0.5,
190+
class_mapping={"cat": "gato"},
191+
)
192+
193+
assert list(result[0]["predictions"].data["class_name"]) == [
194+
"gato",
195+
"dog",
196+
"bird",
197+
]

0 commit comments

Comments
 (0)