Skip to content

Commit e9e3414

Browse files
speed
1 parent 2206c85 commit e9e3414

3 files changed

Lines changed: 119 additions & 68 deletions

File tree

neurotask/neurotask/test/test_segmentation.py

Lines changed: 56 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
import pytest
22

3-
from neurotask.tmt.metrics.speed_metrics import InvalidSpeedError, NonMonotonicTimeError
43
from neurotask.tmt.segmentation.segmentation import classify_cursor_positions_with_hesitation
54
from neurotask.tmt.model.tmt_model import (
65
Coordinate,
76
CursorInfo,
8-
TMTSubject,
97
TMTTarget,
108
TMTTrial,
119
TrialType,
@@ -53,48 +51,6 @@ def _build_trial(cursor_trail: list[CursorInfo], targets: list[TMTTarget] = None
5351
class TestClassifyCursorPositionsWithHesitation:
5452
"""Tests for classify_cursor_positions_with_hesitation function."""
5553

56-
def test_raises_invalid_speed_error_when_raise_on_error_true(self):
57-
"""
58-
Con raise_on_error=True, lanza InvalidSpeedError si velocidad > 8.0 px/ms.
59-
"""
60-
# Movimiento de 100px en 1ms = 100 px/ms (inválido, > 8.0)
61-
cursor_trail = _build_cursor_trail([
62-
(0.0, 0.0, 0.0),
63-
(100.0, 0.0, 1.0), # speed = 100 px/ms (invalid)
64-
(102.0, 0.0, 2.0), # speed = 2 px/ms (valid)
65-
])
66-
trial = _build_trial(cursor_trail)
67-
68-
with pytest.raises(InvalidSpeedError, match="exceeds INVALID_SPEED_THRESHOLD"):
69-
classify_cursor_positions_with_hesitation(
70-
tmt_trial=trial,
71-
target_radius=10.0,
72-
speed_threshold=2.0,
73-
consecutive_points=2,
74-
raise_on_error=True
75-
)
76-
77-
def test_raises_non_monotonic_error_when_raise_on_error_true(self):
78-
"""
79-
Con raise_on_error=True, lanza NonMonotonicTimeError si tiempo retrocede.
80-
"""
81-
# Time: 0 -> 2 -> 1 (retrocede)
82-
cursor_trail = _build_cursor_trail([
83-
(0.0, 0.0, 0.0),
84-
(2.0, 0.0, 2.0),
85-
(4.0, 0.0, 1.0), # time goes backwards
86-
])
87-
trial = _build_trial(cursor_trail)
88-
89-
with pytest.raises(NonMonotonicTimeError, match="current_cursor.time must be greater than previous_cursor.time"):
90-
classify_cursor_positions_with_hesitation(
91-
tmt_trial=trial,
92-
target_radius=10.0,
93-
speed_threshold=2.0,
94-
consecutive_points=2,
95-
raise_on_error=True
96-
)
97-
9854
def test_returns_correct_states_with_valid_data(self):
9955
"""
10056
Verifica que la clasificación de estados sea correcta con datos válidos.
@@ -113,8 +69,7 @@ def test_returns_correct_states_with_valid_data(self):
11369
tmt_trial=trial,
11470
target_radius=10.0,
11571
speed_threshold=1.5,
116-
consecutive_points=2,
117-
raise_on_error=True
72+
consecutive_points=2
11873
)
11974

12075
# Verifica estructura correcta
@@ -125,14 +80,10 @@ def test_returns_correct_states_with_valid_data(self):
12580
assert state in ['Search', 'Travel', 'Hesitation']
12681
assert isinstance(cursor_info, CursorInfo)
12782

128-
@pytest.mark.xfail(reason="Bug conocido: con raise_on_error=False y datos inválidos, IndexError por len(speeds) < len(cursor_trail) - 1")
129-
def test_with_invalid_data_and_raise_on_error_false(self):
83+
def test_handles_invalid_speed_gracefully(self):
13084
"""
131-
Con raise_on_error=False, no lanza excepción y devuelve clasificación.
132-
133-
NOTA: Este test documenta un bug conocido. Cuando hay velocidades inválidas
134-
y raise_on_error=False, la lista speeds tiene menos elementos que cursor_trail - 1,
135-
causando IndexError en speed_increases_over_consecutive_points().
85+
Con velocidades inválidas, no lanza excepción y devuelve clasificación.
86+
Las velocidades inválidas son marcadas como is_valid=False en SpeedResult.
13687
"""
13788
# Trial con velocidad inválida (100 px/ms > 8.0)
13889
cursor_trail = _build_cursor_trail([
@@ -148,8 +99,7 @@ def test_with_invalid_data_and_raise_on_error_false(self):
14899
tmt_trial=trial,
149100
target_radius=10.0,
150101
speed_threshold=1.5,
151-
consecutive_points=2,
152-
raise_on_error=False
102+
consecutive_points=2
153103
)
154104

155105
# Verifica estructura correcta
@@ -160,3 +110,54 @@ def test_with_invalid_data_and_raise_on_error_false(self):
160110
assert state in ['Search', 'Travel', 'Hesitation']
161111
assert isinstance(cursor_info, CursorInfo)
162112

113+
def test_handles_non_monotonic_time_gracefully(self):
114+
"""
115+
Con tiempos no monótonos, no lanza excepción y devuelve clasificación.
116+
Los puntos con tiempo no monótono son marcados como is_valid=False en SpeedResult.
117+
"""
118+
# Time: 0 -> 2 -> 1 (retrocede)
119+
cursor_trail = _build_cursor_trail([
120+
(0.0, 0.0, 0.0),
121+
(2.0, 0.0, 2.0),
122+
(4.0, 0.0, 1.0), # time goes backwards
123+
(6.0, 0.0, 3.0), # valid again
124+
])
125+
trial = _build_trial(cursor_trail)
126+
127+
# No debe lanzar excepción
128+
result = classify_cursor_positions_with_hesitation(
129+
tmt_trial=trial,
130+
target_radius=10.0,
131+
speed_threshold=1.5,
132+
consecutive_points=2
133+
)
134+
135+
# Verifica estructura correcta
136+
assert len(result) == len(cursor_trail)
137+
138+
# Cada elemento es (estado, CursorInfo)
139+
for state, cursor_info in result:
140+
assert state in ['Search', 'Travel', 'Hesitation']
141+
assert isinstance(cursor_info, CursorInfo)
142+
143+
def test_first_point_is_search_on_target(self):
144+
"""
145+
Verifica que el primer punto sea Search cuando está sobre el target.
146+
"""
147+
# Cursor empieza sobre el target (0,0)
148+
cursor_trail = _build_cursor_trail([
149+
(0.0, 0.0, 0.0), # sobre target 1
150+
(2.0, 0.0, 1.0),
151+
(4.0, 0.0, 2.0),
152+
])
153+
trial = _build_trial(cursor_trail)
154+
155+
result = classify_cursor_positions_with_hesitation(
156+
tmt_trial=trial,
157+
target_radius=10.0,
158+
speed_threshold=1.5,
159+
consecutive_points=2
160+
)
161+
162+
# El primer punto debe ser Search
163+
assert result[0][0] == 'Search'

neurotask/neurotask/tmt/metrics/speed_metrics.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
import logging
2-
from typing import Dict, Any, Tuple, List
2+
from typing import Dict, Any, Tuple, List, NamedTuple
33

44
import numpy as np
5+
6+
7+
class SpeedResult(NamedTuple):
8+
"""Result of a speed calculation with validity flag."""
9+
is_valid: bool
10+
value: float
511
from neurotask.tmt.metrics.base_metric import BaseMetricCalculator
612
from neurotask.tmt.metrics.distance_calculation import calculate_distance
713
from neurotask.tmt.model.tmt_model import TMTTrial, CursorInfo, TMTSubject, TMTTarget
@@ -104,6 +110,45 @@ def calculate_speeds_between_cursor_positions(trial: TMTTrial, raise_on_error: b
104110
return calculate_speeds(cursor_trail_from_first_click, raise_on_error)
105111

106112

113+
def calculate_speeds_with_validity(cursor_trail: List[CursorInfo]) -> List[SpeedResult]:
114+
"""
115+
Calculate speeds between consecutive cursor positions with validity flags.
116+
117+
Args:
118+
cursor_trail: List of cursor positions.
119+
120+
Returns:
121+
List of SpeedResult. Always len(result) == len(cursor_trail) - 1.
122+
If is_valid=False, value=0.0.
123+
"""
124+
if len(cursor_trail) < 2:
125+
raise ValueError("At least two points are required to calculate velocity")
126+
127+
results = []
128+
for i in range(1, len(cursor_trail)):
129+
try:
130+
speed = calculate_speed(cursor_trail[i], cursor_trail[i - 1])
131+
results.append(SpeedResult(is_valid=True, value=speed))
132+
except (InvalidSpeedError, NonMonotonicTimeError):
133+
results.append(SpeedResult(is_valid=False, value=0.0))
134+
135+
return results
136+
137+
138+
def calculate_speeds_between_cursor_positions_with_validity(trial: TMTTrial) -> List[SpeedResult]:
139+
"""
140+
Calculate speeds between cursor positions with validity flags for a trial.
141+
142+
Args:
143+
trial: The TMT trial.
144+
145+
Returns:
146+
List of SpeedResult. Always len(result) == len(cursor_trail) - 1.
147+
"""
148+
cursor_trail = trial.get_cursor_trail_from_start()
149+
return calculate_speeds_with_validity(cursor_trail)
150+
151+
107152
def calculate_speeds(cursor_trail: List[CursorInfo], raise_on_error: bool = False) -> List[float]:
108153
if len(cursor_trail) < 2:
109154
raise ValueError("At least two points are required to calculate velocity")

neurotask/neurotask/tmt/segmentation/segmentation.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,19 @@
66

77
from neurotask.tmt.metrics.speed_metrics import (
88
calculate_speeds_between_cursor_positions,
9+
calculate_speeds_between_cursor_positions_with_validity,
910
calculate_speeds,
1011
calculate_speed,
1112
InvalidSpeedError,
12-
NonMonotonicTimeError
13+
NonMonotonicTimeError,
14+
SpeedResult
1315
)
1416
from ..metrics.distance_calculation import calculate_distance
1517
from ..model.tmt_model import CursorInfo, TMTTrial, TMTExperiment, Coordinate, TMTSubject, TrialType
1618

1719

1820
def speed_increases_over_consecutive_points(
19-
speeds: List[float],
21+
speeds: List[SpeedResult],
2022
cursor_index: int,
2123
speed_threshold: float,
2224
consecutive_points: int
@@ -25,7 +27,7 @@ def speed_increases_over_consecutive_points(
2527
Determines if the speed has increased over a specified number of consecutive points beyond a given speed threshold.
2628
2729
Parameters:
28-
- speeds: List of speed values between cursor positions.
30+
- speeds: List of SpeedResult values between cursor positions.
2931
- cursor_index: The current index in the cursor trail (starting from 0).
3032
- speed_threshold: The minimum increase in speed between consecutive points to consider.
3133
- consecutive_points: Number of consecutive points over which the speed must increase.
@@ -42,14 +44,16 @@ def speed_increases_over_consecutive_points(
4244
for i in range(speed_index - consecutive_points + 1, speed_index + 1):
4345
if i <= 0:
4446
return False # Not enough data
45-
current_speed = speeds[i]
46-
if current_speed <= speed_threshold:
47+
speed_result = speeds[i]
48+
if not speed_result.is_valid:
49+
return False # Invalid speed, cannot evaluate
50+
if speed_result.value <= speed_threshold:
4751
return False # Speed did not increase sufficiently
4852
return True
4953

5054

5155
def speed_decreases_over_consecutive_points(
52-
speeds: List[float],
56+
speeds: List[SpeedResult],
5357
cursor_index: int,
5458
speed_threshold: float,
5559
consecutive_points: int
@@ -58,7 +62,7 @@ def speed_decreases_over_consecutive_points(
5862
Determines if the speed has decreased over a specified number of consecutive points beyond a given speed threshold.
5963
6064
Parameters:
61-
- speeds: List of speed values between cursor positions.
65+
- speeds: List of SpeedResult values between cursor positions.
6266
- cursor_index: The current index in the cursor trail (starting from 0).
6367
- speed_threshold: The minimum decrease in speed between consecutive points to consider.
6468
- consecutive_points: Number of consecutive points over which the speed must decrease.
@@ -75,8 +79,10 @@ def speed_decreases_over_consecutive_points(
7579
for i in range(speed_index - consecutive_points + 1, speed_index + 1):
7680
if i <= 0:
7781
return False # Not enough data
78-
current_speed = speeds[i]
79-
if current_speed > speed_threshold:
82+
speed_result = speeds[i]
83+
if not speed_result.is_valid:
84+
return False # Invalid speed, cannot evaluate
85+
if speed_result.value > speed_threshold:
8086
return False # Speed did not decrease sufficiently
8187
return True
8288

@@ -85,12 +91,11 @@ def classify_cursor_positions_with_hesitation(
8591
tmt_trial: TMTTrial,
8692
target_radius: float,
8793
speed_threshold,
88-
consecutive_points=5,
89-
raise_on_error: bool = False
94+
consecutive_points=5
9095
) -> List[Tuple[str, CursorInfo]]:
9196
classified_positions = []
9297
cursor_trail = tmt_trial.get_cursor_trail_from_start()
93-
speeds = calculate_speeds_between_cursor_positions(tmt_trial, raise_on_error)
98+
speeds = calculate_speeds_between_cursor_positions_with_validity(tmt_trial)
9499
over_target_flags = calculate_over_targets(cursor_trail, target_radius, tmt_trial.stimuli)
95100

96101
current_state = 'Search'

0 commit comments

Comments
 (0)