Skip to content

Commit c60db7e

Browse files
authored
Merge pull request #108 from pariterre/main
Same again
2 parents 71e3da0 + 70052da commit c60db7e

11 files changed

Lines changed: 905 additions & 24 deletions

biobuddy/components/generic/rigidbody/segment_coordinate_system.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
from ...real.biomechanical_model_real import BiomechanicalModelReal
1010
from ...real.rigidbody.axis_real import AxisReal
1111
from ...real.rigidbody.segment_coordinate_system_real import SegmentCoordinateSystemReal
12-
from ....utils.marker_data import MarkerData, DictData
12+
from ....utils.aliases import Points, Point
13+
from ....utils.marker_data import MarkerData
1314
from ....utils.linear_algebra import RotoTransMatrixTimeSeries, RotoTransMatrix
1415
from ....model_modifiers.joint_center_tool import Score, Sara
1516

@@ -327,23 +328,54 @@ def collapse(static_markers: MarkerData, _: BiomechanicalModelReal, visualize: b
327328

328329
return partial(collapse, visualize=visualize)
329330

331+
@staticmethod
332+
def _original_rotation_axis(original_axis_global: Axis, static_markers: MarkerData) -> np.ndarray:
333+
"""
334+
Estimate the original axis of rotation to make sure that the axis is in the right direction.
335+
"""
336+
if not isinstance(original_axis_global.start, Marker) or not isinstance(original_axis_global.end, Marker):
337+
raise NotImplementedError("The original_axis_global should be an Axis with start and end as Markers.")
338+
if (
339+
original_axis_global.start.name not in static_markers.marker_names
340+
or original_axis_global.end.name not in static_markers.marker_names
341+
):
342+
raise NotImplementedError(
343+
f"The markers defining the original_axis_global should be present in the static markers, got start: {original_axis_global.start.name} and end: {original_axis_global.end.name}"
344+
)
345+
346+
start_marker_global = static_markers.mean_marker_position(original_axis_global.start.name)
347+
end_marker_global = static_markers.mean_marker_position(original_axis_global.end.name)
348+
global_axis = end_marker_global - start_marker_global
349+
350+
return global_axis[:3]
351+
330352
@staticmethod
331353
def sara(
332354
name: int,
333355
functional_data: MarkerData,
334356
parent_marker_names: tuple[str, ...] | list[str],
335357
child_marker_names: tuple[str, ...] | list[str],
358+
expected_rotation_axis_orientation: Axis | None = None,
359+
origin_positions_global: Callable | None = None,
336360
visualize: bool = False,
337361
) -> Axis:
338362
"""
339363
Compute the SARA (Symmetrical Axis of Rotation Approach) between two sets of markers
364+
# TODO: SARA should also change the origin of the axis, not only the direction.
340365
341366
Parameters
342367
----------
343368
parent_marker_names
344369
The names of the markers on the parent segment to compute the SARA axis from
345370
child_marker_names
346371
The names of the markers on the child segment to compute the SARA axis from
372+
expected_rotation_axis_orientation: Axis | None
373+
The original axis in the global reference frame, of shape (3,). It is used to reorient the SARA axis if
374+
it points in the opposite direction of the original axis. If None, the original axis is not used and the SARA axis is not reoriented.
375+
origin_positions_global: Callable | None
376+
The function defining the positions in the global reference frame used as a reference for the origin of the axis (3 x FunctionalTrialFrameCount).
377+
The origin_positions_global points are projected onto the computed axis to determine the final origin of the axis; effectively
378+
replacing the computed COR value.
347379
visualize
348380
If True, a 3D visualization of the SARA axis computation will be shown. Plotly is required for this.
349381
@@ -355,7 +387,7 @@ def sara(
355387
sara_cache = {} # We only need to perform SARA once. So we store the result here.
356388

357389
def collapse(
358-
static_markers: MarkerData, _: BiomechanicalModelReal, visualize: bool
390+
static_markers: MarkerData, bio_model: BiomechanicalModelReal, visualize: bool
359391
) -> tuple[np.ndarray, np.ndarray]:
360392
static_markers_hash = _markers_fingerprint(static_markers)
361393

@@ -381,7 +413,21 @@ def collapse(
381413
)
382414

383415
# Compute the SARA axis
384-
_, aor_parent, _, _, cor_parent, _, _, _ = Sara.perform_algorithm(rt_parent_func, rt_child_func)
416+
if expected_rotation_axis_orientation is not None:
417+
original_axis_global = SegmentCoordinateSystemUtils._original_rotation_axis(
418+
expected_rotation_axis_orientation, static_markers
419+
)
420+
else:
421+
original_axis_global = None
422+
origin_positions_global_evaluated = (
423+
origin_positions_global(functional_data, bio_model) if origin_positions_global is not None else None
424+
)
425+
_, aor_parent, _, _, cor_parent, _, _, _ = Sara.perform_algorithm(
426+
rt_parent=rt_parent_func,
427+
rt_child=rt_child_func,
428+
original_axis_global=original_axis_global,
429+
origin_positions_global=origin_positions_global_evaluated,
430+
)
385431
sara_cache[static_markers_hash] = [
386432
rt_parent_static,
387433
rt_parent_func,

biobuddy/components/real/model_dynamics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -905,7 +905,7 @@ def animate(
905905
_logger.error("pyorerun is not installed. Cannot animate the model.")
906906
return
907907

908-
animation = pyorerun.LiveModelAnimation(model_path, with_q_charts=True)
908+
animation = pyorerun.LiveModelAnimation.from_file(model_path, with_q_charts=True)
909909
animation.options.set_all_labels(False)
910910
animation.rerun()
911911
return

biobuddy/components/real/rigidbody/axis_real.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ def axis(self) -> np.ndarray:
5050
"""
5151
Returns the axis vector
5252
"""
53+
axis = np.ones((4, 1))
5354
start = self.start_point.position
5455
end = self.end_point.position
55-
return end - start
56+
axis[:3, :] = end[:3] - start[:3]
57+
return axis

biobuddy/model_modifiers/joint_center_tool.py

Lines changed: 96 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,10 @@
2323
point_from_local_to_global,
2424
get_vector_from_sequence,
2525
get_sequence_from_rotation_vector,
26-
rot2eul,
26+
project_points_on_axes,
2727
)
2828
from ..utils.named_list import NamedList
29+
from ..utils.aliases import Points, Point, points_to_array, point_to_array
2930

3031
_logger = logging.getLogger(__name__)
3132

@@ -297,6 +298,8 @@ def _check_marker_functional_trial_file(self):
297298
Check that the file format is appropriate and that there is a functional movement in the trial (aka the markers really move).
298299
"""
299300
self.marker_names = self._data.marker_names
301+
if self._data.nb_frames == 0:
302+
raise RuntimeError("The functional trial file does not contain any frame. Please check the trial again.")
300303
self.marker_positions = self._data.all_marker_positions[:3, :, :]
301304

302305
# Check that the markers move
@@ -801,6 +804,36 @@ def __init__(
801804
initialize_whole_trial_reconstruction: bool = False,
802805
animate_rt: bool = False,
803806
):
807+
"""
808+
Initialize the SARA (Symmetrical Axis of Rotation Approach) algorithm.
809+
810+
Parameters
811+
----------
812+
functional_trial: MarkerData
813+
The MarkerData containing the functional trial.
814+
parent_name: str
815+
The name of the joint's parent segment.
816+
child_name: str
817+
The name of the joint's child segment.
818+
parent_marker_names: list[str]
819+
The name of the markers in the parent segment to consider during the SARA algorithm.
820+
child_marker_names: list[str]
821+
The name of the markers in the child segment to consider during the SARA algorithm.
822+
joint_center_markers: list[str]
823+
The name of the markers to consider as joint center markers (i.e., markers close to the joint center).
824+
# TODO: should be uniformized with origin_positions_global: Points from SegmentCoordinateSystemUtils.sara
825+
distal_markers: list[str]
826+
The name of the markers to consider as distal markers (i.e., markers close to the distal end of the child segment).
827+
is_longitudinal_axis_from_jcs_to_distal_markers: bool
828+
If True, the longitudinal axis of the child segment is defined from the joint center markers to the distal markers.
829+
If False, the longitudinal axis is defined from the distal markers to the joint center markers.
830+
expected_rotation_axis_orientation: Axis
831+
The expected orientation of the rotation axis (e.g., Axis.X, Axis.Y, or Axis.Z). This is used to make sure the computed axis is in the expected direction.
832+
initialize_whole_trial_reconstruction: bool
833+
If True, the whole trial is reconstructed using whole body inverse kinematics to initialize the segments' rt in the global reference frame.
834+
animate_rt: bool
835+
If True, it animates the segment rt reconstruction using pyorerun.
836+
"""
804837

805838
super(Sara, self).__init__(
806839
functional_trial=functional_trial,
@@ -821,7 +854,8 @@ def __init__(
821854
def perform_algorithm(
822855
rt_parent: RotoTransMatrixTimeSeries,
823856
rt_child: RotoTransMatrixTimeSeries,
824-
original_axis_global: np.ndarray | None = None,
857+
original_axis_global: Point | None = None,
858+
origin_positions_global: Points | None = None,
825859
recursive_outlier_removal: bool = True,
826860
) -> Tuple[
827861
np.ndarray,
@@ -843,8 +877,12 @@ def perform_algorithm(
843877
Homogeneous transformation matrices from the global frame to the parent segment.
844878
rt_child : RotoTransMatrixTimeSeries
845879
Homogeneous transformation matrices from the global frame to the child segment.
846-
original_axis_global: np.ndarray | None
880+
original_axis_global: Points | None
847881
The original rotation axis direction. This axis is used to make sure the new axis is in a similar direction.
882+
origin_positions_global: Points | None
883+
The positions in the global reference frame used as a reference for the origin of the axis (3 x FunctionalTrialFrameCount).
884+
The origin_positions_global points are projected onto the computed axis to determine the final origin of the axis; effectively
885+
replacing the computed COR value.
848886
recursive_outlier_removal : bool
849887
If True, performs 95th percentile residual filtering and recomputes the axis of rotation.
850888
@@ -867,6 +905,7 @@ def perform_algorithm(
867905
rt_child : RotoTransMatrixTimeSeries
868906
Homogeneous transformations of the child segment after outlier removal.
869907
"""
908+
870909
nb_frames = len(rt_parent)
871910
U, S, V, b_valid = get_svd(rt_parent, rt_child)
872911

@@ -878,6 +917,7 @@ def perform_algorithm(
878917
aor_child_local /= np.linalg.norm(aor_child_local)
879918

880919
# Compute pseudo-inverse solution
920+
# cor = V[:, :5] @ np.diag(1.0 / S[:5]) @ U[:, :5].T @ b_valid # TODO: make a breaking PR for this change !!!
881921
cor = V @ np.diag(1.0 / S) @ U.T @ b_valid
882922
cor_parent_local = cor[3:]
883923
cor_child_local = cor[:3]
@@ -895,30 +935,70 @@ def perform_algorithm(
895935
np.dot(aor_parent_global[:, i_frame], aor_child_global[:, i_frame])
896936
/ (np.linalg.norm(aor_parent_global[:, i_frame]) * np.linalg.norm(aor_child_global[:, i_frame]))
897937
)
938+
898939
cor_parent_global[:, i_frame] = (rt_parent[i_frame] @ np.hstack((cor_parent_local, 1)))[:, 0]
899940
cor_child_global[:, i_frame] = (rt_child[i_frame] @ np.hstack((cor_child_local, 1)))[:, 0]
900941

942+
if origin_positions_global is not None:
943+
origins_global = points_to_array(origin_positions_global)
944+
if origins_global.shape[1] == 1:
945+
# If only one is defined (like the mean), repeat for all frames
946+
origins_global = np.repeat(origins_global, nb_frames, axis=1)
947+
if origins_global.shape[1] != nb_frames:
948+
raise RuntimeError(
949+
f"The number of origin positions {len(origin_positions_global)} does not match the number of frames {nb_frames}."
950+
)
951+
901952
if recursive_outlier_removal:
902953
valid = Sara.get_good_frames(residuals, nb_frames)
903954
if not np.all(valid):
904955
rt_parent = RotoTransMatrixTimeSeries.from_closest_rt_matrix(rt_parent.to_numpy()[:, :, valid])
905956
rt_child = RotoTransMatrixTimeSeries.from_closest_rt_matrix(rt_child.to_numpy()[:, :, valid])
957+
if origin_positions_global is not None:
958+
origin_positions_global = origins_global[:, valid]
959+
906960
return Sara.perform_algorithm(
907-
rt_parent, rt_child, original_axis_global, recursive_outlier_removal=False
961+
rt_parent=rt_parent,
962+
rt_child=rt_child,
963+
original_axis_global=original_axis_global,
964+
origin_positions_global=origin_positions_global,
965+
recursive_outlier_removal=False,
908966
)
909967

968+
if origin_positions_global is not None:
969+
origins_parent = np.zeros((4, nb_frames))
970+
origins_child = np.zeros((4, nb_frames))
971+
for i_frame in range(nb_frames):
972+
origins_parent[:, i_frame] = (rt_parent[i_frame].inverse @ origins_global[:, i_frame])[:, 0]
973+
origins_child[:, i_frame] = (rt_child[i_frame].inverse @ origins_global[:, i_frame])[:, 0]
974+
cor_parent_local = project_points_on_axes(
975+
origins_parent.mean(axis=1)[:3], start=cor_parent_local, end=cor_parent_local + aor_parent_local
976+
)
977+
cor_child_local = project_points_on_axes(
978+
origins_child.mean(axis=1)[:3], start=cor_child_local, end=cor_child_local + aor_child_local
979+
)
980+
cor_parent_global = project_points_on_axes(
981+
origins_global, start=cor_parent_global, end=aor_parent_global + cor_parent_global
982+
)
983+
cor_child_global = project_points_on_axes(
984+
origins_global, start=cor_child_global, end=aor_child_global + cor_child_global
985+
)
986+
910987
# Final output
911988
aor_mean_global = 0.5 * (np.mean(aor_parent_global[:3, :], axis=1) + np.mean(aor_child_global[:3, :], axis=1))
912989
cor_mean_global = 0.5 * (np.mean(cor_parent_global[:3, :], axis=1) + np.mean(cor_child_global[:3, :], axis=1))
913990

914-
if original_axis_global is not None and np.dot(aor_mean_global, original_axis_global) < 0:
915-
# The axis is in the wrong direction
916-
aor_mean_global *= -1
917-
aor_parent_local *= -1
918-
aor_child_local *= -1
991+
if original_axis_global is not None:
992+
original_axis_global = point_to_array(original_axis_global)[:3, 0]
993+
if np.dot(aor_mean_global, original_axis_global) < 0:
994+
# The axis is in the wrong direction
995+
aor_mean_global *= -1
996+
aor_parent_local *= -1
997+
aor_child_local *= -1
919998

920999
_logger.info(
921-
f"\nThere is a residual angle between the parent's and the child's AoR of : {np.nanmean(residuals)*180/np.pi} +- {np.nanstd(residuals)*180/np.pi} degrees."
1000+
f"\nThere is a residual angle between the parent's and the child's AoR of : "
1001+
f"{np.nanmean(residuals)*180/np.pi} +- {np.nanstd(residuals)*180/np.pi} degrees."
9221002
)
9231003

9241004
return (
@@ -1047,6 +1127,7 @@ def perform_task(
10471127
new_model: BiomechanicalModelReal,
10481128
parent_rt_init: RotoTransMatrixTimeSeries,
10491129
child_rt_init: RotoTransMatrixTimeSeries,
1130+
origin_positions_global: Points = None,
10501131
):
10511132

10521133
# Reconstruct the trial to identify the orientation of the segments
@@ -1065,7 +1146,11 @@ def perform_task(
10651146
# Identify axis of rotation
10661147
original_axis_global, _ = self._original_rotation_axis(new_model)
10671148
aor_global, _, aor_local_child, _, _, _, rt_parent_valid_frames, _ = self.perform_algorithm(
1068-
rt_parent_functional, rt_child_functional, original_axis_global, recursive_outlier_removal=True
1149+
rt_parent_functional,
1150+
rt_child_functional,
1151+
original_axis_global,
1152+
origin_positions_global=origin_positions_global,
1153+
recursive_outlier_removal=True,
10691154
)
10701155

10711156
# Extract the joint coordinate system

biobuddy/utils/linear_algebra.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -732,3 +732,53 @@ def local_rt_between_global_rts(
732732

733733
local_rt = parent_rt_in_global.inverse @ child_rt_in_global
734734
return local_rt
735+
736+
737+
def project_points_on_axes(points: Points, start: Points, end: Points) -> Points:
738+
"""
739+
Projects points on axes defined by two points at each column.
740+
741+
Parameters
742+
----------
743+
points
744+
The points to project, of shape (3, N) or (4, N).
745+
start
746+
The start points of the axes, of shape (3, N).
747+
end
748+
The end points of the axes, of shape (3, N).
749+
750+
Returns
751+
-------
752+
projected_point
753+
The projected points, of shape (3, N).
754+
"""
755+
if len(points.shape) == 1:
756+
points = points[:, None]
757+
if len(start.shape) == 1:
758+
start = start[:, None]
759+
if len(end.shape) == 1:
760+
end = end[:, None]
761+
762+
if points.shape[0] != 3 and points.shape[0] != 4:
763+
raise ValueError(f"Expected points of shape (3, N) or (4, N), got shape {points.shape}")
764+
if start.shape[0] != 3 and start.shape[0] != 4:
765+
raise ValueError(f"Expected start points of shape (3, N) or (4, N), got shape {start.shape}")
766+
if end.shape[0] != 3 and end.shape[0] != 4:
767+
raise ValueError(f"Expected end points of shape (3, N) or (4, N), got shape {end.shape}")
768+
if points.shape[0] != start.shape[0] or points.shape[0] != end.shape[0]:
769+
raise ValueError(
770+
f"Expected points, start and end to have the same number of rows. Got {points.shape[0]}, {start.shape[0]}, {end.shape[0]}"
771+
)
772+
if points.shape[1] != start.shape[1] or points.shape[1] != end.shape[1]:
773+
raise ValueError(
774+
f"Expected points, start and end to have the same number of columns. Got {points.shape[1]}, {start.shape[1]}, {end.shape[1]}"
775+
)
776+
777+
start_to_end = end - start
778+
start_to_point = points - start
779+
return (
780+
start
781+
+ np.einsum("ij,ij->j", start_to_point, start_to_end)
782+
/ np.einsum("ij,ij->j", start_to_end, start_to_end)
783+
* start_to_end
784+
)

examples/create_rt_from_functional_trials.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,15 @@ def generate_lower_body_model(visualize: bool = True) -> BiomechanicalModelReal:
115115
functional_data=right_knee_data,
116116
parent_marker_names=["RTHI1", "RTHI2", "RTHI3"],
117117
child_marker_names=["RLEG1", "RLEG2", "RLEG3", "RATT", "RLM", "RSPH"],
118+
expected_rotation_axis_orientation=Axis(
119+
name=Axis.Name.X,
120+
start=Marker("RMFE", is_technical=False, is_anatomical=True),
121+
end=Marker("RLFE", is_technical=False, is_anatomical=True),
122+
),
123+
origin_positions_global=SegmentCoordinateSystemUtils.mean_markers(["RMFE", "RLFE"]),
118124
visualize=visualize,
119125
)
126+
120127
right_ankle_mid = SegmentCoordinateSystemUtils.mean_markers(["RLM", "RSPH"])
121128
model.add_segment(
122129
Segment(

0 commit comments

Comments
 (0)