Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 32 additions & 11 deletions examples/viewer_lib/logic/segmentation/islands_effect_logic.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from trame_server import Server

from trame_slicer.core import SlicerApp
from trame_slicer.segmentation import SegmentationEffectIslands
from trame_slicer.segmentation import SegmentationEffectIslands, SegmentationIslandsMode

from ...ui import (
IslandsEffectUI,
IslandsSegmentationMode,
IslandsState,
SegmentEditorUI,
)
Expand All @@ -16,18 +15,40 @@ class IslandsEffectLogic(BaseEffectLogic[IslandsState, SegmentationEffectIslands
def __init__(self, server: Server, slicer_app: SlicerApp):
super().__init__(server, slicer_app, IslandsState, SegmentationEffectIslands)

def set_ui(self, ui: SegmentEditorUI):
self.bind_changes(
{
self.name.mode: self._set_mode,
self.name.minimum_size: self._set_minimum_island_size,
}
)

def set_ui(self, ui: SegmentEditorUI) -> None:
self.set_effect_ui(ui.get_effect_ui(SegmentationEffectIslands))

def set_effect_ui(self, islands_ui: IslandsEffectUI):
def set_effect_ui(self, islands_ui: IslandsEffectUI) -> None:
islands_ui.apply_clicked.connect(self._on_apply_clicked)

def _on_apply_clicked(self):
def _set_mode(self, island_segmentation_mode: SegmentationIslandsMode) -> None:
if not self.is_active():
return
self.effect.set_island_mode(island_segmentation_mode)

def _set_minimum_island_size(self, minimum_island_size: int) -> None:
if not self.is_active():
return
self.effect.set_minimum_island_size(minimum_island_size)

def _on_apply_clicked(self) -> None:
if not self.is_active():
return
self.effect.apply()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a fan of generic apply.
I think it's much better to keep the effect API as specialized to make intent clearer for consuming code (suggestions in the effect.apply section of the PR)


def _on_effect_parameters_changed(self) -> None:
with self.state:
self.data.mode = self.effect.get_island_mode()
self.data.minimum_size = self.effect.get_minimum_island_size()

def _on_effect_changed(self, _effect_name: str) -> None:
if not self.is_active():
return
if self._typed_state.data.mode == IslandsSegmentationMode.KEEP_LARGEST_ISLAND:
self.effect.keep_largest_island()
elif self._typed_state.data.mode == IslandsSegmentationMode.REMOVE_SMALL_ISLANDS:
self.effect.remove_small_islands(self._typed_state.data.minimum_size)
elif self._typed_state.data.mode == IslandsSegmentationMode.SPLIT_TO_SEGMENTS:
self.effect.split_islands_to_segments()
self.effect.parameters_changed.connect(self._on_effect_parameters_changed)
2 changes: 0 additions & 2 deletions examples/viewer_lib/ui/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from .mpr_interaction_button import MprInteractionButton, MprInteractionButtonState
from .segmentation import (
IslandsEffectUI,
IslandsSegmentationMode,
IslandsState,
PaintEffectState,
PaintEffectUI,
Expand Down Expand Up @@ -35,7 +34,6 @@
"ControlButton",
"FlexContainer",
"IslandsEffectUI",
"IslandsSegmentationMode",
"IslandsState",
"LayoutButton",
"LayoutButtonState",
Expand Down
3 changes: 1 addition & 2 deletions examples/viewer_lib/ui/segmentation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .islands_effect_ui import IslandsEffectUI, IslandsSegmentationMode, IslandsState
from .islands_effect_ui import IslandsEffectUI, IslandsState
from .paint_effect_ui import PaintEffectState, PaintEffectUI
from .segment_display_ui import SegmentDisplayState, SegmentDisplayUI
from .segment_edit_ui import (
Expand All @@ -17,7 +17,6 @@

__all__ = [
"IslandsEffectUI",
"IslandsSegmentationMode",
"IslandsState",
"PaintEffectState",
"PaintEffectUI",
Expand Down
45 changes: 28 additions & 17 deletions examples/viewer_lib/ui/segmentation/islands_effect_ui.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,17 @@
from dataclasses import dataclass
from enum import Enum, auto

from trame_server.utils.typed_state import TypedState
from trame_vuetify.widgets.vuetify3 import VBtn, VBtnToggle, VNumberInput, VSpacer
from undo_stack import Signal

from ..flex_container import FlexContainer

from trame_slicer.segmentation import SegmentationIslandsMode

class IslandsSegmentationMode(Enum):
KEEP_LARGEST_ISLAND = auto()
REMOVE_SMALL_ISLANDS = auto()
SPLIT_TO_SEGMENTS = auto()
from ..flex_container import FlexContainer


@dataclass
class IslandsState:
mode: IslandsSegmentationMode = IslandsSegmentationMode.KEEP_LARGEST_ISLAND
mode: SegmentationIslandsMode = SegmentationIslandsMode.KEEP_LARGEST_ISLAND
minimum_size: int = 1000


Expand All @@ -28,20 +23,27 @@ def __init__(self, **kwargs):
self._typed_state = TypedState(self.state, IslandsState)

self.labels = {
IslandsSegmentationMode.KEEP_LARGEST_ISLAND: "Keep largest",
IslandsSegmentationMode.REMOVE_SMALL_ISLANDS: "Remove small",
IslandsSegmentationMode.SPLIT_TO_SEGMENTS: "Split",
SegmentationIslandsMode.KEEP_LARGEST_ISLAND: "Keep largest",
SegmentationIslandsMode.REMOVE_SMALL_ISLANDS: "Remove small",
SegmentationIslandsMode.SPLIT_TO_SEGMENTS: "Split",
SegmentationIslandsMode.KEEP_SELECTED: "Keep selected",
SegmentationIslandsMode.REMOVE_SELECTED: "Remove selected",
SegmentationIslandsMode.ADD_SELECTED: "Add selected",
}

with self:
with VBtnToggle(v_model=(self._typed_state.name.mode,), mandatory=True, style="align-self: center;"):
for mode in IslandsSegmentationMode:
with VBtnToggle(
v_model=(self._typed_state.name.mode,),
mandatory=True,
style="align-self: center; height: fit-content; display: grid; grid-template-columns: repeat(2, 1fr); gap: 8px;",
):
for mode in SegmentationIslandsMode:
self._create_mode_button(mode)

with FlexContainer(row=True, align="start", classes="mt-2"):
VNumberInput(
v_if=(
f"{self._typed_state.name.mode} === {self._typed_state.encode(IslandsSegmentationMode.REMOVE_SMALL_ISLANDS)}",
f"{self._typed_state.name.mode} === {self._typed_state.encode(SegmentationIslandsMode.REMOVE_SMALL_ISLANDS)}",
),
v_model=self._typed_state.name.minimum_size,
control_variant="stacked",
Expand All @@ -54,7 +56,16 @@ def __init__(self, **kwargs):
density="compact",
)
VSpacer()
VBtn(text="Apply", prepend_icon="mdi-check", variant="tonal", click=self.apply_clicked)
VBtn(
text="Apply",
prepend_icon="mdi-check",
variant="tonal",
click=self.apply_clicked,
disabled=(
f"{self._typed_state.encode(SegmentationIslandsMode.get_interactive_modes())}.includes({self._typed_state.name.mode})",
),
style="margin-top: 16px !important;",
)

def _create_mode_button(self, mode: IslandsSegmentationMode):
VBtn(text=self.labels[mode], value=(self._typed_state.encode(mode),), size="small")
def _create_mode_button(self, mode: SegmentationIslandsMode):
VBtn(text=self.labels[mode], value=(self._typed_state.encode(mode),), size="small", style="min-height: 30px;")
6 changes: 3 additions & 3 deletions tests/examples/test_segment_islands.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from examples.viewer_lib.logic import IslandsEffectLogic
from examples.viewer_lib.ui import (
IslandsEffectUI,
IslandsSegmentationMode,
ViewerLayout,
)
from trame_slicer.segmentation import SegmentationIslandsMode


@pytest.fixture
Expand All @@ -21,7 +21,7 @@ def effect_logic(a_server, a_slicer_app, effect_ui):
return logic


@pytest.mark.parametrize("island_mode", list(IslandsSegmentationMode))
@pytest.mark.parametrize("island_mode", list(SegmentationIslandsMode))
def test_can_apply_island_effect(
effect_logic,
effect_ui,
Expand All @@ -34,5 +34,5 @@ def test_can_apply_island_effect(
segmentation_node = a_slicer_app.io_manager.load_segmentation(a_segmentation_nifti_file_path)
a_segmentation_editor.set_active_segmentation(segmentation_node, a_volume_node)
effect_logic.set_active()
effect_ui._typed_state.data.mode = island_mode
effect_logic.effect.set_island_mode(island_mode)
effect_ui.apply_clicked()
44 changes: 44 additions & 0 deletions tests/test_segmentation_islands_effect.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from trame_slicer.segmentation import SegmentationEffectIslands

SPHERE_CENTER = [0, 40, 3]


@pytest.fixture
def a_segmentation_spheres_file_path(a_data_folder) -> Path:
Expand Down Expand Up @@ -64,3 +66,45 @@ def test_with_max_min_voxel_size_remove_small_islands_removes_all_islands(a_segm
effect.remove_small_islands(int(1e15))
segment_array = get_segment_array(a_segmentation_editor, segment_id)
assert np.array_equal(segment_array, np.zeros_like(segment_array))


def test_keep_selected(a_segmentation_editor, effect, segment_id):
assert effect.is_active
assert len(effect._get_label_values()) == 3
segment_array = get_segment_array(a_segmentation_editor, segment_id)

effect.keep_island_at_position([0, 0, 0])
assert np.sum(get_segment_array(a_segmentation_editor, segment_id)) == np.sum(segment_array)
assert len(effect._get_label_values()) == 3

effect.keep_island_at_position(SPHERE_CENTER)
assert np.sum(get_segment_array(a_segmentation_editor, segment_id)) < np.sum(segment_array)
assert len(effect._get_label_values()) == 1


def test_remove_selected(a_segmentation_editor, effect, segment_id):
assert effect.is_active
assert len(effect._get_label_values()) == 3
segment_array = get_segment_array(a_segmentation_editor, segment_id)

effect.remove_island_at_position([0, 0, 0])
assert np.sum(get_segment_array(a_segmentation_editor, segment_id)) == np.sum(segment_array)
assert len(effect._get_label_values()) == 3

effect.remove_island_at_position(SPHERE_CENTER)
assert np.sum(get_segment_array(a_segmentation_editor, segment_id)) < np.sum(segment_array)
assert len(effect._get_label_values()) == 2


def test_add_selected(a_segmentation_editor, effect):
assert effect.is_active
new_segment_id = a_segmentation_editor.add_empty_segment()
assert len(effect._get_label_values()) == 0

effect.add_island_at_position([0, 0, 0])
assert np.sum(get_segment_array(a_segmentation_editor, new_segment_id)) == 0
assert len(effect._get_label_values()) == 0

effect.add_island_at_position(SPHERE_CENTER)
assert np.sum(get_segment_array(a_segmentation_editor, new_segment_id)) > 0
assert len(effect._get_label_values()) == 1
8 changes: 7 additions & 1 deletion trame_slicer/segmentation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
from .segmentation import Segmentation
from .segmentation_display import SegmentationDisplay, SegmentationOpacityEnum
from .segmentation_effect import SegmentationEffect
from .segmentation_effect_islands import SegmentationEffectIslands
from .segmentation_effect_islands import (
SegmentationEffectIslands,
SegmentationIslandsMode,
)
from .segmentation_effect_no_tool import SegmentationEffectNoTool
from .segmentation_effect_paint_erase import (
SegmentationEffectErase,
Expand All @@ -28,6 +31,7 @@
SegmentationThresholdPipeline2D,
ThresholdParameters,
)
from .segmentation_islands_pipeline import SegmentationIslandsPipeline
from .segmentation_paint_pipeline import (
SegmentationPaintPipeline2D,
SegmentationPaintPipeline3D,
Expand Down Expand Up @@ -59,6 +63,8 @@
"SegmentationEffectPipeline",
"SegmentationEffectScissors",
"SegmentationEffectThreshold",
"SegmentationIslandsMode",
"SegmentationIslandsPipeline",
"SegmentationOpacityEnum",
"SegmentationPaintPipeline2D",
"SegmentationPaintPipeline3D",
Expand Down
4 changes: 2 additions & 2 deletions trame_slicer/segmentation/segment_modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import enum
import logging
import math
from collections.abc import Generator
from contextlib import AbstractContextManager
from enum import auto

from numpy.typing import NDArray
Expand Down Expand Up @@ -382,5 +382,5 @@ def is_source_volume_intensity_mask_enabled(self) -> bool:
return False
return self.segment_editor_node.GetSourceVolumeIntensityMask()

def group_undo_commands(self, text: str = "") -> Generator:
def group_undo_commands(self, text: str = "") -> AbstractContextManager[None]:
return self.segmentation.group_undo_commands(text)
3 changes: 2 additions & 1 deletion trame_slicer/segmentation/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from collections.abc import Generator
from contextlib import contextmanager
from typing import Any

from numpy.typing import NDArray
from slicer import (
Expand Down Expand Up @@ -277,7 +278,7 @@ def get_display(self) -> SegmentationDisplay | None:
return SegmentationDisplay(self._segmentation_node.GetDisplayNode()) if self._segmentation_node else None

@contextmanager
def group_undo_commands(self, text: str = "") -> Generator:
def group_undo_commands(self, text: str = "") -> Generator[None, Any, None]:
if not self.undo_stack:
yield
return
Expand Down
5 changes: 4 additions & 1 deletion trame_slicer/segmentation/segmentation_effect.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
from weakref import ref

from slicer import (
Expand All @@ -13,7 +14,9 @@
from undo_stack import Signal

from .segment_modifier import ModificationMode, SegmentModifier
from .segmentation_effect_pipeline import SegmentationEffectPipeline

if TYPE_CHECKING:
from .segmentation_effect_pipeline import SegmentationEffectPipeline


class SegmentationEffect(ABC):
Expand Down
Loading
Loading