11"""Unit tests for SAM3 v3 block class_mapping feature."""
22
3+ from unittest .mock import MagicMock , patch
4+
35import numpy as np
6+ import pytest
47import supervision as sv
58
9+ from inference .core .workflows .core_steps .common .entities import StepExecutionMode
610from inference .core .workflows .core_steps .models .foundation .segment_anything3 .v3 import (
11+ BlockManifest ,
712 SegmentAnything3BlockV3 ,
813)
14+ from inference .core .workflows .execution_engine .entities .base import (
15+ ImageParentMetadata ,
16+ WorkflowImageData ,
17+ )
918
1019
1120def _make_detections (class_names : list [str ]) -> sv .Detections :
@@ -21,49 +30,168 @@ def _make_result(class_names: list[str]) -> list[dict]:
2130 return [{"predictions" : _make_detections (class_names )}]
2231
2332
24- class TestApplyClassMapping :
25- def test_no_mapping_returns_unchanged (self ):
26- result = _make_result (["cat" , "dog" ])
27- mapped = SegmentAnything3BlockV3 ._apply_class_mapping (result , {})
28- assert list (mapped [0 ]["predictions" ].data ["class_name" ]) == ["cat" , "dog" ]
29-
30- def test_full_mapping (self ):
31- result = _make_result (["cat" , "dog" ])
32- mapped = SegmentAnything3BlockV3 ._apply_class_mapping (
33- result , {"cat" : "gato" , "dog" : "perro" }
34- )
35- assert list (mapped [0 ]["predictions" ].data ["class_name" ]) == ["gato" , "perro" ]
36-
37- def test_partial_mapping (self ):
38- result = _make_result (["cat" , "dog" , "bird" ])
39- mapped = SegmentAnything3BlockV3 ._apply_class_mapping (
40- result , {"cat" : "gato" }
41- )
42- assert list (mapped [0 ]["predictions" ].data ["class_name" ]) == [
43- "gato" , "dog" , "bird"
44- ]
45-
46- def test_mapping_with_no_matching_keys (self ):
47- result = _make_result (["cat" , "dog" ])
48- mapped = SegmentAnything3BlockV3 ._apply_class_mapping (
49- result , {"fish" : "pez" }
50- )
51- assert list (mapped [0 ]["predictions" ].data ["class_name" ]) == ["cat" , "dog" ]
52-
53- def test_multiple_images (self ):
54- result = [
55- {"predictions" : _make_detections (["cat" ])},
56- {"predictions" : _make_detections (["dog" ])},
57- ]
58- mapped = SegmentAnything3BlockV3 ._apply_class_mapping (
59- result , {"cat" : "gato" , "dog" : "perro" }
60- )
61- assert list (mapped [0 ]["predictions" ].data ["class_name" ]) == ["gato" ]
62- assert list (mapped [1 ]["predictions" ].data ["class_name" ]) == ["perro" ]
63-
64- def test_empty_result (self ):
65- result = []
66- mapped = SegmentAnything3BlockV3 ._apply_class_mapping (
67- result , {"cat" : "gato" }
68- )
69- assert mapped == []
33+ @pytest .fixture
34+ def mock_workflow_image_data ():
35+ img = np .random .randint (0 , 255 , (100 , 100 , 3 ), dtype = np .uint8 )
36+ return WorkflowImageData (
37+ parent_metadata = ImageParentMetadata (parent_id = "test" ),
38+ numpy_image = img ,
39+ )
40+
41+
42+ # --- Manifest tests ---
43+
44+
45+ def test_manifest_parsing_with_class_mapping ():
46+ """Test that BlockManifest accepts the class_mapping field."""
47+ data = {
48+ "type" : "roboflow_core/sam3@v3" ,
49+ "name" : "my_sam3_step" ,
50+ "images" : "$inputs.image" ,
51+ "class_names" : ["cat" , "dog" ],
52+ "class_mapping" : {"cat" : "gato" , "dog" : "perro" },
53+ }
54+ result = BlockManifest .model_validate (data )
55+ assert result .class_mapping == {"cat" : "gato" , "dog" : "perro" }
56+
57+
58+ def test_manifest_parsing_without_class_mapping ():
59+ """Test that class_mapping is optional and defaults to None."""
60+ data = {
61+ "type" : "roboflow_core/sam3@v3" ,
62+ "name" : "my_sam3_step" ,
63+ "images" : "$inputs.image" ,
64+ "class_names" : ["cat" , "dog" ],
65+ }
66+ result = BlockManifest .model_validate (data )
67+ assert result .class_mapping is None
68+
69+
70+ # --- _apply_class_mapping unit tests ---
71+
72+
73+ def test_apply_class_mapping_full ():
74+ """Test remapping all class names."""
75+ result = _make_result (["cat" , "dog" ])
76+ mapped = SegmentAnything3BlockV3 ._apply_class_mapping (
77+ result , {"cat" : "gato" , "dog" : "perro" }
78+ )
79+ assert list (mapped [0 ]["predictions" ].data ["class_name" ]) == ["gato" , "perro" ]
80+
81+
82+ def test_apply_class_mapping_partial ():
83+ """Test remapping only some class names, leaving others unchanged."""
84+ result = _make_result (["cat" , "dog" , "bird" ])
85+ mapped = SegmentAnything3BlockV3 ._apply_class_mapping (result , {"cat" : "gato" })
86+ assert list (mapped [0 ]["predictions" ].data ["class_name" ]) == [
87+ "gato" ,
88+ "dog" ,
89+ "bird" ,
90+ ]
91+
92+
93+ def test_apply_class_mapping_no_matching_keys ():
94+ """Test that unmatched mapping keys leave predictions unchanged."""
95+ result = _make_result (["cat" , "dog" ])
96+ mapped = SegmentAnything3BlockV3 ._apply_class_mapping (result , {"fish" : "pez" })
97+ assert list (mapped [0 ]["predictions" ].data ["class_name" ]) == ["cat" , "dog" ]
98+
99+
100+ def test_apply_class_mapping_multiple_images ():
101+ """Test remapping across multiple images in a batch."""
102+ result = [
103+ {"predictions" : _make_detections (["cat" ])},
104+ {"predictions" : _make_detections (["dog" ])},
105+ ]
106+ mapped = SegmentAnything3BlockV3 ._apply_class_mapping (
107+ result , {"cat" : "gato" , "dog" : "perro" }
108+ )
109+ assert list (mapped [0 ]["predictions" ].data ["class_name" ]) == ["gato" ]
110+ assert list (mapped [1 ]["predictions" ].data ["class_name" ]) == ["perro" ]
111+
112+
113+ def test_apply_class_mapping_empty_result ():
114+ """Test that an empty result list is handled gracefully."""
115+ result = []
116+ mapped = SegmentAnything3BlockV3 ._apply_class_mapping (result , {"cat" : "gato" })
117+ assert mapped == []
118+
119+
120+ def test_apply_class_mapping_empty_mapping ():
121+ """Test that an empty mapping leaves predictions unchanged."""
122+ result = _make_result (["cat" , "dog" ])
123+ mapped = SegmentAnything3BlockV3 ._apply_class_mapping (result , {})
124+ assert list (mapped [0 ]["predictions" ].data ["class_name" ]) == ["cat" , "dog" ]
125+
126+
127+ # --- Block-level run() tests ---
128+
129+
130+ @patch .object (SegmentAnything3BlockV3 , "run_locally" )
131+ def test_run_with_class_mapping_remaps_predictions (
132+ mock_run_locally , mock_workflow_image_data
133+ ):
134+ """Test that block.run() applies class_mapping to predictions from run_locally."""
135+ mock_run_locally .return_value = _make_result (["cat" , "dog" ])
136+ block = SegmentAnything3BlockV3 (
137+ model_manager = MagicMock (),
138+ api_key = "test_key" ,
139+ step_execution_mode = StepExecutionMode .LOCAL ,
140+ )
141+
142+ result = block .run (
143+ images = [mock_workflow_image_data ],
144+ model_id = "sam3/sam3_final" ,
145+ class_names = ["cat" , "dog" ],
146+ confidence = 0.5 ,
147+ class_mapping = {"cat" : "gato" , "dog" : "perro" },
148+ )
149+
150+ assert list (result [0 ]["predictions" ].data ["class_name" ]) == ["gato" , "perro" ]
151+
152+
153+ @patch .object (SegmentAnything3BlockV3 , "run_locally" )
154+ def test_run_without_class_mapping_leaves_predictions_unchanged (
155+ mock_run_locally , mock_workflow_image_data
156+ ):
157+ """Test that block.run() without class_mapping does not alter predictions."""
158+ mock_run_locally .return_value = _make_result (["cat" , "dog" ])
159+ block = SegmentAnything3BlockV3 (
160+ model_manager = MagicMock (),
161+ api_key = "test_key" ,
162+ step_execution_mode = StepExecutionMode .LOCAL ,
163+ )
164+
165+ result = block .run (
166+ images = [mock_workflow_image_data ],
167+ model_id = "sam3/sam3_final" ,
168+ class_names = ["cat" , "dog" ],
169+ confidence = 0.5 ,
170+ )
171+
172+ assert list (result [0 ]["predictions" ].data ["class_name" ]) == ["cat" , "dog" ]
173+
174+
175+ @patch .object (SegmentAnything3BlockV3 , "run_locally" )
176+ def test_run_with_partial_class_mapping (mock_run_locally , mock_workflow_image_data ):
177+ """Test that block.run() with partial class_mapping only remaps matched classes."""
178+ mock_run_locally .return_value = _make_result (["cat" , "dog" , "bird" ])
179+ block = SegmentAnything3BlockV3 (
180+ model_manager = MagicMock (),
181+ api_key = "test_key" ,
182+ step_execution_mode = StepExecutionMode .LOCAL ,
183+ )
184+
185+ result = block .run (
186+ images = [mock_workflow_image_data ],
187+ model_id = "sam3/sam3_final" ,
188+ class_names = ["cat" , "dog" , "bird" ],
189+ confidence = 0.5 ,
190+ class_mapping = {"cat" : "gato" },
191+ )
192+
193+ assert list (result [0 ]["predictions" ].data ["class_name" ]) == [
194+ "gato" ,
195+ "dog" ,
196+ "bird" ,
197+ ]
0 commit comments