Skip to content

Commit de44948

Browse files
committed
SQUASHME: fix PR
1 parent 330371b commit de44948

3 files changed

Lines changed: 97 additions & 92 deletions

File tree

examples/viewer_lib/logic/segmentation/islands_effect_logic.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,21 +26,8 @@ def _on_apply_clicked(self):
2626
if not self.is_active():
2727
return
2828
if self._typed_state.data.mode == IslandsSegmentationMode.KEEP_LARGEST_ISLAND.value:
29-
self.effect.apply(
30-
segment_id=self.segmentation_editor.get_active_segment_id(),
31-
max_number_of_segments=1,
32-
split=False,
33-
)
29+
self.effect.keep_largest_island()
3430
elif self._typed_state.data.mode == IslandsSegmentationMode.REMOVE_SMALL_ISLANDS.value:
35-
self.effect.apply(
36-
self.segmentation_editor.get_active_segment_id(),
37-
max_number_of_segments=None,
38-
minimum_size=self._typed_state.data.minimum_size,
39-
split=False,
40-
)
31+
self.effect.remove_small_islands(self._typed_state.data.minimum_size)
4132
elif self._typed_state.data.mode == IslandsSegmentationMode.SPLIT_TO_SEGMENTS.value:
42-
self.effect.apply(
43-
segment_id=self.segmentation_editor.get_active_segment_id(),
44-
max_number_of_segments=None,
45-
split=True,
46-
)
33+
self.effect.split_islands_to_segments()

tests/test_segmentation_islands_effect.py

Lines changed: 21 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,6 @@
77
from trame_slicer.segmentation import SegmentationEffectIslands
88

99

10-
@pytest.fixture
11-
def undo_stack(a_segmentation_editor):
12-
undo_stack = UndoStack()
13-
a_segmentation_editor.set_undo_stack(undo_stack)
14-
return undo_stack
15-
16-
1710
@pytest.fixture
1811
def a_segmentation_spheres_file_path(a_data_folder) -> Path:
1912
return a_data_folder.joinpath("segmentation_spheres.nii.gz")
@@ -35,6 +28,8 @@ def set_up(a_slicer_app, a_volume_node, a_segmentation_editor, a_segmentation_sp
3528
a_slicer_app.display_manager.show_volume(a_volume_node, vr_preset="MR-Default")
3629
segmentation_node = a_slicer_app.io_manager.load_segmentation(a_segmentation_spheres_file_path)
3730
a_segmentation_editor.set_active_segmentation(segmentation_node, a_volume_node)
31+
undo_stack = UndoStack()
32+
a_segmentation_editor.set_undo_stack(undo_stack)
3833

3934

4035
def get_segment_array(
@@ -44,7 +39,7 @@ def get_segment_array(
4439
return a_segmentation_editor.get_segment_labelmap(segment_id, as_numpy_array=True)
4540

4641

47-
def test_do_nothing(
42+
def test_keep_biggest_island(
4843
a_segmentation_editor,
4944
effect,
5045
segment_id,
@@ -54,60 +49,49 @@ def test_do_nothing(
5449
a_segmentation_editor,
5550
segment_id,
5651
)
57-
effect.apply(segment_id, None, minimum_size=0, split=False)
58-
assert np.array_equal(
59-
get_segment_array(
60-
a_segmentation_editor,
61-
segment_id,
62-
),
63-
source_array,
64-
)
65-
66-
67-
def test_keep_zero_islands(
68-
a_segmentation_editor,
69-
effect,
70-
segment_id,
71-
):
72-
assert effect.is_active
73-
effect.apply(segment_id, 0, minimum_size=0, split=False)
52+
effect.keep_largest_island()
7453
segment_array = get_segment_array(
7554
a_segmentation_editor,
7655
segment_id,
7756
)
78-
assert np.array_equal(
79-
segment_array,
80-
np.zeros_like(segment_array),
81-
)
57+
# Assert that application created new zeros
58+
assert np.count_nonzero(source_array) > np.count_nonzero(segment_array)
8259

8360

84-
def test_split(
61+
def test_split_islands_to_segments(
8562
a_segmentation_editor,
8663
effect,
87-
segment_id,
8864
):
8965
assert effect.is_active
90-
effect.apply(segment_id, None, minimum_size=0, split=True)
66+
effect.split_islands_to_segments()
9167
assert len(a_segmentation_editor.get_segment_ids()) == 3
9268

9369

94-
def test_pixel_islands(
70+
def test_with_0_min_voxel_size_remove_small_islands_does_nothing(
9571
a_segmentation_editor,
9672
effect,
9773
segment_id,
9874
):
9975
assert effect.is_active
100-
effect.apply(segment_id, None, minimum_size=1, split=True)
101-
assert len(a_segmentation_editor.get_segment_ids()) == 3
76+
source_array = get_segment_array(
77+
a_segmentation_editor,
78+
segment_id,
79+
)
80+
effect.remove_small_islands(1)
81+
segment_array = get_segment_array(
82+
a_segmentation_editor,
83+
segment_id,
84+
)
85+
assert np.array_equal(source_array, segment_array)
10286

10387

104-
def test_remove_all_islands(
88+
def test_with_max_min_voxel_size_remove_small_islands_removes_all_islands(
10589
a_segmentation_editor,
10690
effect,
10791
segment_id,
10892
):
10993
assert effect.is_active
110-
effect.apply(segment_id, None, minimum_size=int(1e15), split=True)
94+
effect.remove_small_islands(int(1e15))
11195
segment_array = get_segment_array(
11296
a_segmentation_editor,
11397
segment_id,

trame_slicer/segmentation/segmentation_effect_islands.py

Lines changed: 73 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from . import ModificationMode
1717
from .segmentation_effect import SegmentationEffect
18+
from .segmentation_undo_command import SegmentationLabelMapUndoCommand
1819

1920

2021
class SegmentationEffectIslands(SegmentationEffect):
@@ -26,62 +27,86 @@ def _create_pipeline(self, _view_node: vtkMRMLAbstractViewNode, _parameter: vtkM
2627
# Islands effect does not require a pipeline
2728
return None
2829

29-
def apply(
30-
self, segment_id: str, max_number_of_segments: int | None = None, minimum_size: int = 0, split: bool = True
31-
):
30+
def _remove_small_islands(self, minimum_size: int) -> None:
3231
if not self.is_active:
3332
return
3433

35-
island_image = self.get_island_labelmap(segment_id, minimum_size)
34+
island_image = self.get_island_labelmap(minimum_size)
35+
self.modifier.apply_labelmap(island_image)
3636

37+
def _keep_n_biggest_islands(self, number_of_islands: int):
38+
if not self.is_active:
39+
return
40+
41+
modifier_image = self.modifier.create_modifier_labelmap()
42+
if number_of_islands <= 0:
43+
self.modifier.apply_labelmap(modifier_image)
44+
return
45+
46+
island_image = self.get_island_labelmap()
3747
label_values = vtkIntArray()
3848
vtkSlicerSegmentationsModuleLogic.GetAllLabelValues(label_values, island_image)
3949
max_label_value = label_values.GetNumberOfTuples()
40-
if max_number_of_segments is not None:
41-
max_label_value = min(max_label_value, max_number_of_segments)
42-
if not split and max_number_of_segments is None:
43-
# Overwrite selected segment's labelmap with filtered labelmap
44-
self.modifier.apply_labelmap(island_image)
45-
else:
46-
modifier_image = self.modifier.create_modifier_labelmap()
47-
threshold = vtkImageThreshold()
48-
threshold.SetInputData(island_image)
49-
50-
kept_labels = [int(label_values.GetTuple1(i)) for i in range(1, max_label_value)]
51-
52-
for i in range(max_label_value):
53-
label_value = int(label_values.GetTuple1(i))
54-
if i == 0:
55-
# Replace selected segment's labelmap by first island
56-
threshold.ThresholdBetween(label_value, label_value)
57-
threshold.SetInValue(1)
58-
threshold.SetOutValue(0)
59-
threshold.Update()
60-
modifier_image.DeepCopy(threshold.GetOutput())
61-
# Remove non-kept labels
62-
if label_value in kept_labels:
63-
continue
64-
threshold.ReplaceOutOff()
65-
threshold.ThresholdBetween(label_value, label_value)
66-
threshold.SetInValue(0)
67-
threshold.Update()
68-
threshold.SetInputData(threshold.GetOutput())
50+
if number_of_islands >= max_label_value:
51+
return
6952

70-
self.modifier.apply_labelmap(modifier_image)
53+
modifier_image = self.modifier.create_modifier_labelmap()
54+
threshold = vtkImageThreshold()
55+
threshold.SetInputData(island_image)
56+
threshold.ReplaceOutOff()
57+
58+
for i in range(number_of_islands, max_label_value):
59+
label_value = int(label_values.GetTuple1(i))
60+
threshold.ThresholdBetween(label_value, label_value)
61+
threshold.SetInValue(0)
62+
threshold.Update()
63+
threshold.SetInputData(threshold.GetOutput())
64+
65+
modifier_image.DeepCopy(threshold.GetOutput())
66+
self.modifier.apply_labelmap(modifier_image)
67+
68+
def _split_islands_to_segments(self): # TODO: Make it undoable
69+
if not self.is_active:
70+
return
71+
72+
island_image = self.get_island_labelmap()
73+
label_values = vtkIntArray()
74+
vtkSlicerSegmentationsModuleLogic.GetAllLabelValues(label_values, island_image)
75+
76+
modifier_image = self.modifier.create_modifier_labelmap()
77+
78+
threshold = vtkImageThreshold()
79+
threshold.SetInputData(island_image)
7180

72-
if max_label_value > 1:
81+
# Replace selected segment's labelmap by first island
82+
initial_label_value = int(label_values.GetTuple1(0))
83+
threshold.ThresholdBetween(initial_label_value, initial_label_value)
84+
threshold.SetInValue(initial_label_value)
85+
threshold.SetOutValue(0)
86+
threshold.Update()
87+
modifier_image.DeepCopy(threshold.GetOutput())
88+
89+
# Create labelmap without first segment
90+
threshold.SetInValue(0)
91+
threshold.ReplaceOutOff()
92+
threshold.ThresholdBetween(initial_label_value, initial_label_value)
93+
threshold.Update()
94+
95+
with self.modifier.group_undo_commands(f"{__class__} - Split {self.modifier.active_segment_id}"):
96+
self.modifier.apply_labelmap(modifier_image)
97+
with SegmentationLabelMapUndoCommand.push_state_change(self.modifier.segmentation):
7398
modifier_image.DeepCopy(threshold.GetOutput())
7499
vtkSlicerSegmentationsModuleLogic.ImportLabelmapToSegmentationNode(
75100
modifier_image,
76101
self.modifier.segmentation.segmentation_node,
77-
self.modifier.segmentation.get_segment(segment_id).GetName(),
102+
self.modifier.segmentation.get_segment(self.modifier.active_segment_id).GetName(),
78103
)
79-
self.modifier.segmentation.segmentation_modified.emit()
104+
self.modifier.segmentation.segmentation_modified.emit()
80105

81-
def get_island_labelmap(self, segment_id: str, minimum_size: int = 0) -> vtkOrientedImageData:
106+
def get_island_labelmap(self, minimum_size: int = 0) -> vtkOrientedImageData:
82107
source_image_data = self.modifier.get_source_image_data()
83108

84-
segment_labelmap = self.modifier.get_segment_labelmap(segment_id)
109+
segment_labelmap = self.modifier.get_segment_labelmap(self.modifier.active_segment_id)
85110
cast_in = vtkImageCast()
86111
cast_in.SetInputData(segment_labelmap)
87112
cast_in.SetOutputScalarTypeToUnsignedInt()
@@ -100,3 +125,12 @@ def get_island_labelmap(self, segment_id: str, minimum_size: int = 0) -> vtkOrie
100125
island_image.SetImageToWorldMatrix(image_to_world_matrix)
101126

102127
return island_image
128+
129+
def keep_largest_island(self) -> None:
130+
self._keep_n_biggest_islands(1) # TODO: test, pass to 1
131+
132+
def remove_small_islands(self, min_voxel_size: int) -> None:
133+
self._remove_small_islands(min_voxel_size)
134+
135+
def split_islands_to_segments(self) -> None:
136+
self._split_islands_to_segments()

0 commit comments

Comments
 (0)