11import pytest
22
3- from neurotask .tmt .metrics .speed_metrics import InvalidSpeedError , NonMonotonicTimeError
43from neurotask .tmt .segmentation .segmentation import classify_cursor_positions_with_hesitation
54from neurotask .tmt .model .tmt_model import (
65 Coordinate ,
76 CursorInfo ,
8- TMTSubject ,
97 TMTTarget ,
108 TMTTrial ,
119 TrialType ,
@@ -53,48 +51,6 @@ def _build_trial(cursor_trail: list[CursorInfo], targets: list[TMTTarget] = None
5351class TestClassifyCursorPositionsWithHesitation :
5452 """Tests for classify_cursor_positions_with_hesitation function."""
5553
56- def test_raises_invalid_speed_error_when_raise_on_error_true (self ):
57- """
58- Con raise_on_error=True, lanza InvalidSpeedError si velocidad > 8.0 px/ms.
59- """
60- # Movimiento de 100px en 1ms = 100 px/ms (inválido, > 8.0)
61- cursor_trail = _build_cursor_trail ([
62- (0.0 , 0.0 , 0.0 ),
63- (100.0 , 0.0 , 1.0 ), # speed = 100 px/ms (invalid)
64- (102.0 , 0.0 , 2.0 ), # speed = 2 px/ms (valid)
65- ])
66- trial = _build_trial (cursor_trail )
67-
68- with pytest .raises (InvalidSpeedError , match = "exceeds INVALID_SPEED_THRESHOLD" ):
69- classify_cursor_positions_with_hesitation (
70- tmt_trial = trial ,
71- target_radius = 10.0 ,
72- speed_threshold = 2.0 ,
73- consecutive_points = 2 ,
74- raise_on_error = True
75- )
76-
77- def test_raises_non_monotonic_error_when_raise_on_error_true (self ):
78- """
79- Con raise_on_error=True, lanza NonMonotonicTimeError si tiempo retrocede.
80- """
81- # Time: 0 -> 2 -> 1 (retrocede)
82- cursor_trail = _build_cursor_trail ([
83- (0.0 , 0.0 , 0.0 ),
84- (2.0 , 0.0 , 2.0 ),
85- (4.0 , 0.0 , 1.0 ), # time goes backwards
86- ])
87- trial = _build_trial (cursor_trail )
88-
89- with pytest .raises (NonMonotonicTimeError , match = "current_cursor.time must be greater than previous_cursor.time" ):
90- classify_cursor_positions_with_hesitation (
91- tmt_trial = trial ,
92- target_radius = 10.0 ,
93- speed_threshold = 2.0 ,
94- consecutive_points = 2 ,
95- raise_on_error = True
96- )
97-
9854 def test_returns_correct_states_with_valid_data (self ):
9955 """
10056 Verifica que la clasificación de estados sea correcta con datos válidos.
@@ -113,8 +69,7 @@ def test_returns_correct_states_with_valid_data(self):
11369 tmt_trial = trial ,
11470 target_radius = 10.0 ,
11571 speed_threshold = 1.5 ,
116- consecutive_points = 2 ,
117- raise_on_error = True
72+ consecutive_points = 2
11873 )
11974
12075 # Verifica estructura correcta
@@ -125,14 +80,10 @@ def test_returns_correct_states_with_valid_data(self):
12580 assert state in ['Search' , 'Travel' , 'Hesitation' ]
12681 assert isinstance (cursor_info , CursorInfo )
12782
128- @pytest .mark .xfail (reason = "Bug conocido: con raise_on_error=False y datos inválidos, IndexError por len(speeds) < len(cursor_trail) - 1" )
129- def test_with_invalid_data_and_raise_on_error_false (self ):
83+ def test_handles_invalid_speed_gracefully (self ):
13084 """
131- Con raise_on_error=False, no lanza excepción y devuelve clasificación.
132-
133- NOTA: Este test documenta un bug conocido. Cuando hay velocidades inválidas
134- y raise_on_error=False, la lista speeds tiene menos elementos que cursor_trail - 1,
135- causando IndexError en speed_increases_over_consecutive_points().
85+ Con velocidades inválidas, no lanza excepción y devuelve clasificación.
86+ Las velocidades inválidas son marcadas como is_valid=False en SpeedResult.
13687 """
13788 # Trial con velocidad inválida (100 px/ms > 8.0)
13889 cursor_trail = _build_cursor_trail ([
@@ -148,8 +99,7 @@ def test_with_invalid_data_and_raise_on_error_false(self):
14899 tmt_trial = trial ,
149100 target_radius = 10.0 ,
150101 speed_threshold = 1.5 ,
151- consecutive_points = 2 ,
152- raise_on_error = False
102+ consecutive_points = 2
153103 )
154104
155105 # Verifica estructura correcta
@@ -160,3 +110,54 @@ def test_with_invalid_data_and_raise_on_error_false(self):
160110 assert state in ['Search' , 'Travel' , 'Hesitation' ]
161111 assert isinstance (cursor_info , CursorInfo )
162112
113+ def test_handles_non_monotonic_time_gracefully (self ):
114+ """
115+ Con tiempos no monótonos, no lanza excepción y devuelve clasificación.
116+ Los puntos con tiempo no monótono son marcados como is_valid=False en SpeedResult.
117+ """
118+ # Time: 0 -> 2 -> 1 (retrocede)
119+ cursor_trail = _build_cursor_trail ([
120+ (0.0 , 0.0 , 0.0 ),
121+ (2.0 , 0.0 , 2.0 ),
122+ (4.0 , 0.0 , 1.0 ), # time goes backwards
123+ (6.0 , 0.0 , 3.0 ), # valid again
124+ ])
125+ trial = _build_trial (cursor_trail )
126+
127+ # No debe lanzar excepción
128+ result = classify_cursor_positions_with_hesitation (
129+ tmt_trial = trial ,
130+ target_radius = 10.0 ,
131+ speed_threshold = 1.5 ,
132+ consecutive_points = 2
133+ )
134+
135+ # Verifica estructura correcta
136+ assert len (result ) == len (cursor_trail )
137+
138+ # Cada elemento es (estado, CursorInfo)
139+ for state , cursor_info in result :
140+ assert state in ['Search' , 'Travel' , 'Hesitation' ]
141+ assert isinstance (cursor_info , CursorInfo )
142+
143+ def test_first_point_is_search_on_target (self ):
144+ """
145+ Verifica que el primer punto sea Search cuando está sobre el target.
146+ """
147+ # Cursor empieza sobre el target (0,0)
148+ cursor_trail = _build_cursor_trail ([
149+ (0.0 , 0.0 , 0.0 ), # sobre target 1
150+ (2.0 , 0.0 , 1.0 ),
151+ (4.0 , 0.0 , 2.0 ),
152+ ])
153+ trial = _build_trial (cursor_trail )
154+
155+ result = classify_cursor_positions_with_hesitation (
156+ tmt_trial = trial ,
157+ target_radius = 10.0 ,
158+ speed_threshold = 1.5 ,
159+ consecutive_points = 2
160+ )
161+
162+ # El primer punto debe ser Search
163+ assert result [0 ][0 ] == 'Search'
0 commit comments