diff --git a/CHANGELOG.md b/CHANGELOG.md index 87a9c71..83c8802 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -222,6 +222,23 @@ be split into per-release sections once tagging begins. at `CURRENT_VERSION = 2`, with registered v1 → v2 migrations for `VideoPredictions` and `BenchmarkResult` that add the optional `provenance` field. +- **`neuropose.analyzer.features.procrustes_align`** — Kabsch + rigid-alignment helper for pose sequences, plus a + `ProcrustesMode` literal (`"per_frame"` | `"per_sequence"`) and a + frozen `AlignmentDiagnostics` dataclass (`rotation_deg`, + `rotation_deg_max`, `translation`, `translation_max`, `scale`, + plus the mode that produced them). Per-sequence mode fits one + rigid transform across the whole trial; per-frame fits an + independent transform per frame. Optional `scale=True` fits a + uniform scale factor for cross-subject comparisons. Wired into + every DTW entry point in `neuropose.analyzer.dtw` via a new + keyword-only `align: AlignMode = "none"` parameter — `"none"` + preserves the 0.1 raw-coordinate behaviour, while + `"procrustes_per_frame"` and `"procrustes_per_sequence"` route + inputs through `procrustes_align` before DTW runs so the returned + distance is rotation- and translation-invariant. Paper C's + pipeline is expected to set `align="procrustes_per_sequence"`; + see `TECHNICAL.md` Phase 0. - **`neuropose.io.Provenance`** — reproducibility envelope for every inference run. Populated automatically by `Estimator.process_video` when the model was loaded via `load_model` (the production path) @@ -262,10 +279,12 @@ be split into per-release sections once tagging begins. methodology investigation. - `analyzer.features` — `predictions_to_numpy`, `normalize_pose_sequence` (uniform and axis-wise), - `pad_sequences` (edge-padding), `extract_joint_angles` (NaN on - degenerate vectors), `extract_feature_statistics` - (`FeatureStatistics` frozen dataclass), and a `find_peaks` thin - wrapper around `scipy.signal.find_peaks`. + `pad_sequences` (edge-padding), `procrustes_align` (Kabsch + rigid alignment, per-frame or per-sequence, optional uniform + scaling), `extract_joint_angles` (NaN on degenerate vectors), + `extract_feature_statistics` (`FeatureStatistics` frozen + dataclass), and a `find_peaks` thin wrapper around + `scipy.signal.find_peaks`. - `analyzer.segment` — repetition segmentation for trials in which a subject performs the same movement several times. A three-layer API: `segment_by_peaks` (pure 1D diff --git a/src/neuropose/analyzer/__init__.py b/src/neuropose/analyzer/__init__.py index ec7e811..98b6dbe 100644 --- a/src/neuropose/analyzer/__init__.py +++ b/src/neuropose/analyzer/__init__.py @@ -28,19 +28,23 @@ here for ergonomic access. from __future__ import annotations from neuropose.analyzer.dtw import ( + AlignMode, DTWResult, dtw_all, dtw_per_joint, dtw_relation, ) from neuropose.analyzer.features import ( + AlignmentDiagnostics, FeatureStatistics, + ProcrustesMode, extract_feature_statistics, extract_joint_angles, find_peaks, normalize_pose_sequence, pad_sequences, predictions_to_numpy, + procrustes_align, ) from neuropose.analyzer.segment import ( JOINT_INDEX, @@ -59,8 +63,11 @@ from neuropose.analyzer.segment import ( __all__ = [ "JOINT_INDEX", "JOINT_NAMES", + "AlignMode", + "AlignmentDiagnostics", "DTWResult", "FeatureStatistics", + "ProcrustesMode", "dtw_all", "dtw_per_joint", "dtw_relation", @@ -76,6 +83,7 @@ __all__ = [ "normalize_pose_sequence", "pad_sequences", "predictions_to_numpy", + "procrustes_align", "segment_by_peaks", "segment_predictions", "slice_predictions", diff --git a/src/neuropose/analyzer/dtw.py b/src/neuropose/analyzer/dtw.py index e4f94d9..d16cd36 100644 --- a/src/neuropose/analyzer/dtw.py +++ b/src/neuropose/analyzer/dtw.py @@ -16,6 +16,12 @@ and the warping path. Inputs are expected to be ``(frames, joints, 3)`` numpy arrays — the shape :func:`~neuropose.analyzer.features.predictions_to_numpy` produces. +All three also accept an ``align`` argument that routes the inputs +through :func:`~neuropose.analyzer.features.procrustes_align` before +DTW runs, yielding translation- and rotation-invariant distances. +``align="none"`` (the default) preserves the raw-coordinate behaviour +shipped in 0.1. + Dependency note --------------- This module requires :mod:`fastdtw` and :mod:`scipy`, which are part of @@ -30,9 +36,21 @@ from __future__ import annotations from collections.abc import Callable from dataclasses import dataclass +from typing import Literal import numpy as np +from neuropose.analyzer.features import procrustes_align + +AlignMode = Literal["none", "procrustes_per_frame", "procrustes_per_sequence"] +"""Alignment selector for DTW entry points. + +- ``"none"`` — feed raw coordinates directly to DTW. +- ``"procrustes_per_frame"`` — per-frame Kabsch alignment before DTW. +- ``"procrustes_per_sequence"`` — single sequence-wide Kabsch + alignment before DTW. +""" + @dataclass(frozen=True) class DTWResult: @@ -77,7 +95,12 @@ def _require_fastdtw() -> tuple[Callable, Callable]: return fastdtw, euclidean -def dtw_all(a: np.ndarray, b: np.ndarray) -> DTWResult: +def dtw_all( + a: np.ndarray, + b: np.ndarray, + *, + align: AlignMode = "none", +) -> DTWResult: """DTW on the flattened per-frame joint vector. Each frame's joints are collapsed into a single vector before DTW @@ -90,7 +113,12 @@ def dtw_all(a: np.ndarray, b: np.ndarray) -> DTWResult: a, b Pose sequences as ``(frames, joints, 3)`` numpy arrays. The two sequences do not need to have the same number of frames, but - they must have the same number of joints. + they must have the same number of joints. When ``align`` is not + ``"none"``, the two sequences must additionally share a frame + count (Procrustes requires a 1:1 correspondence). + align + Procrustes alignment mode applied before DTW. See + :data:`AlignMode`. Returns ------- @@ -101,9 +129,11 @@ def dtw_all(a: np.ndarray, b: np.ndarray) -> DTWResult: Raises ------ ValueError - If ``a`` and ``b`` do not have the same joint count. + If ``a`` and ``b`` do not have the same joint count, or if + ``align`` requires a matching frame count that is not present. """ _validate_same_joint_count(a, b) + a, b = _maybe_align(a, b, align=align) fastdtw, euclidean = _require_fastdtw() a_flat = a.reshape(a.shape[0], -1) b_flat = b.reshape(b.shape[0], -1) @@ -111,7 +141,12 @@ def dtw_all(a: np.ndarray, b: np.ndarray) -> DTWResult: return DTWResult(distance=float(distance), path=[tuple(p) for p in path]) -def dtw_per_joint(a: np.ndarray, b: np.ndarray) -> list[DTWResult]: +def dtw_per_joint( + a: np.ndarray, + b: np.ndarray, + *, + align: AlignMode = "none", +) -> list[DTWResult]: """DTW on each joint independently. Performs one DTW computation per joint, yielding a list of @@ -124,7 +159,11 @@ def dtw_per_joint(a: np.ndarray, b: np.ndarray) -> list[DTWResult]: a, b Pose sequences as ``(frames, joints, 3)`` numpy arrays. The two sequences do not need to have the same number of frames but - must have the same number of joints. + must have the same number of joints. When ``align`` is not + ``"none"``, they must additionally share a frame count. + align + Procrustes alignment mode applied before DTW. See + :data:`AlignMode`. Returns ------- @@ -134,9 +173,11 @@ def dtw_per_joint(a: np.ndarray, b: np.ndarray) -> list[DTWResult]: Raises ------ ValueError - If ``a`` and ``b`` do not have the same joint count. + If ``a`` and ``b`` do not have the same joint count, or if + ``align`` requires a matching frame count that is not present. """ _validate_same_joint_count(a, b) + a, b = _maybe_align(a, b, align=align) fastdtw, euclidean = _require_fastdtw() results: list[DTWResult] = [] for joint_idx in range(a.shape[1]): @@ -152,6 +193,8 @@ def dtw_relation( b: np.ndarray, joint_i: int, joint_j: int, + *, + align: AlignMode = "none", ) -> DTWResult: """DTW on the displacement vector between two specific joints. @@ -170,6 +213,12 @@ def dtw_relation( Indices of the two joints whose relative position should be compared. Must be valid indices into ``a`` and ``b``'s joint axis. + align + Procrustes alignment mode applied to the full sequences + before the displacement vectors are extracted. See + :data:`AlignMode`. Note that displacement vectors are already + translation-invariant; alignment is still useful for cancelling + camera rotation between trials. Returns ------- @@ -179,8 +228,9 @@ def dtw_relation( Raises ------ ValueError - If the sequences have different joint counts or if either joint - index is out of range. + If the sequences have different joint counts, either joint + index is out of range, or ``align`` requires a matching frame + count that is not present. """ _validate_same_joint_count(a, b) num_joints = a.shape[1] @@ -188,6 +238,7 @@ def dtw_relation( raise ValueError( f"joint indices must be in [0, {num_joints}); got joint_i={joint_i}, joint_j={joint_j}" ) + a, b = _maybe_align(a, b, align=align) fastdtw, euclidean = _require_fastdtw() disp_a = a[:, joint_j, :] - a[:, joint_i, :] disp_b = b[:, joint_j, :] - b[:, joint_i, :] @@ -206,3 +257,29 @@ def _validate_same_joint_count(a: np.ndarray, b: np.ndarray) -> None: f"input arrays disagree on joint count: " f"a has {a.shape[1]} joints, b has {b.shape[1]} joints" ) + + +def _maybe_align( + a: np.ndarray, + b: np.ndarray, + *, + align: AlignMode, +) -> tuple[np.ndarray, np.ndarray]: + """Apply Procrustes alignment if ``align`` requests it. + + Procrustes requires a frame-by-frame correspondence, so this + helper rejects calls where the two sequences disagree on frame + count and ``align`` is not ``"none"``. Pad upstream with + :func:`~neuropose.analyzer.features.pad_sequences` if the lengths + differ. + """ + if align == "none": + return a, b + if a.shape[0] != b.shape[0]: + raise ValueError( + f"align={align!r} requires matching frame counts; " + f"got a with {a.shape[0]} frames and b with {b.shape[0]} frames" + ) + mode = "per_frame" if align == "procrustes_per_frame" else "per_sequence" + aligned_a, _target, _diag = procrustes_align(a, b, mode=mode) + return aligned_a, b diff --git a/src/neuropose/analyzer/features.py b/src/neuropose/analyzer/features.py index 8998896..18853aa 100644 --- a/src/neuropose/analyzer/features.py +++ b/src/neuropose/analyzer/features.py @@ -14,6 +14,8 @@ The following helpers are provided: fit in the unit cube (either per-axis or uniform). - :func:`pad_sequences` — edge-pad a batch of sequences to a common length, suitable for downstream tensor-based analysis. +- :func:`procrustes_align` — rigid-align one pose sequence to another + via the Kabsch algorithm, with optional uniform scaling. - :func:`extract_joint_angles` — compute joint angles at specified triplet positions across a pose sequence. - :func:`extract_feature_statistics` — summary statistics @@ -26,12 +28,19 @@ from __future__ import annotations from collections.abc import Sequence from dataclasses import dataclass -from typing import Any +from typing import Any, Literal import numpy as np from neuropose.io import VideoPredictions +ProcrustesMode = Literal["per_frame", "per_sequence"] +"""Mode selector for :func:`procrustes_align`. + +``per_sequence`` computes a single rigid transform over the whole +sequence; ``per_frame`` aligns every frame independently. +""" + # --------------------------------------------------------------------------- # VideoPredictions → numpy # --------------------------------------------------------------------------- @@ -211,6 +220,247 @@ def pad_sequences( return padded +# --------------------------------------------------------------------------- +# Procrustes alignment (Kabsch) +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class AlignmentDiagnostics: + """Summary of the rigid transform fitted by :func:`procrustes_align`. + + Attributes + ---------- + mode + Which alignment mode produced this result; mirrors the ``mode`` + argument passed to :func:`procrustes_align`. + rotation_deg + Magnitude of the fitted rotation, in degrees, computed as + ``arccos((trace(R) - 1) / 2)``. For ``per_frame`` mode this is + the mean magnitude across frames. + rotation_deg_max + Worst-case (maximum) rotation magnitude across frames. + Equal to :attr:`rotation_deg` in ``per_sequence`` mode. + translation + Magnitude of the fitted translation vector, in the same units + as the input (millimetres for MeTRAbs output). For ``per_frame`` + mode this is the mean magnitude across frames. + translation_max + Worst-case (maximum) translation magnitude across frames. + Equal to :attr:`translation` in ``per_sequence`` mode. + scale + Applied uniform scale factor. Always ``1.0`` when + ``procrustes_align`` was called with ``scale=False``. In + ``per_frame`` mode this is the mean scale across frames. + """ + + mode: ProcrustesMode + rotation_deg: float + rotation_deg_max: float + translation: float + translation_max: float + scale: float + + +def _kabsch_single( + source: np.ndarray, + target: np.ndarray, + *, + scale: bool, +) -> tuple[np.ndarray, np.ndarray, float, np.ndarray]: + """Fit the optimal rigid (+ optional uniform scale) transform. + + Aligns ``source`` to ``target`` via the closed-form Kabsch + algorithm and returns ``(aligned_source, R, s, t)`` where + ``aligned_source = s * (source - centroid_source) @ R.T + centroid_target + t_fine`` + (with ``t_fine`` absorbed for convenience — aligned points match + the target's centroid to within floating-point error). + + Parameters + ---------- + source + ``(N, 3)`` point set to align. + target + ``(N, 3)`` reference point set. Must have the same shape as + ``source``. + scale + If ``True``, fit a uniform scale factor; otherwise lock to + ``1.0``. + + Returns + ------- + aligned_source + ``(N, 3)`` aligned copy of ``source``. + R + ``(3, 3)`` rotation matrix. + s + Scalar scale factor (``1.0`` when ``scale=False``). + t + ``(3,)`` translation vector in world coordinates such that + ``aligned_source[i] = s * R @ source[i] + t``. + """ + centroid_source = source.mean(axis=0) + centroid_target = target.mean(axis=0) + source_centered = source - centroid_source + target_centered = target - centroid_target + + covariance = source_centered.T @ target_centered + u_mat, sigma, vt_mat = np.linalg.svd(covariance) + reflection_sign = float(np.sign(np.linalg.det(vt_mat.T @ u_mat.T))) + # Guard against the degenerate det == 0 case (coplanar points). + if reflection_sign == 0.0: + reflection_sign = 1.0 + diag = np.diag([1.0, 1.0, reflection_sign]) + rotation = vt_mat.T @ diag @ u_mat.T + + if scale: + source_var = float((source_centered**2).sum()) + if source_var <= 0.0: + scale_factor = 1.0 + else: + scale_factor = float((sigma * np.array([1.0, 1.0, reflection_sign])).sum() / source_var) + else: + scale_factor = 1.0 + + translation = centroid_target - scale_factor * rotation @ centroid_source + aligned = scale_factor * source @ rotation.T + translation + return aligned, rotation, scale_factor, translation + + +def _rotation_magnitude_deg(rotation: np.ndarray) -> float: + """Return the rotation angle (degrees) represented by ``rotation``. + + Uses the axis-angle relation ``cos(theta) = (trace(R) - 1) / 2``. + """ + cos_theta = (float(np.trace(rotation)) - 1.0) / 2.0 + cos_theta = max(-1.0, min(1.0, cos_theta)) + return float(np.degrees(np.arccos(cos_theta))) + + +def procrustes_align( + source: np.ndarray, + target: np.ndarray, + *, + mode: ProcrustesMode = "per_sequence", + scale: bool = False, +) -> tuple[np.ndarray, np.ndarray, AlignmentDiagnostics]: + """Rigid-align ``source`` to ``target`` via the Kabsch algorithm. + + Fits the optimal rigid transform (optionally including uniform + scaling) that minimizes the sum of squared distances between + corresponding joints. The transform is always applied to + ``source``; ``target`` is returned unchanged alongside it for + symmetry with downstream DTW callers, which typically consume both + aligned arrays as a pair. + + Parameters + ---------- + source + Pose sequence to align, shape ``(frames, joints, 3)``. + target + Reference pose sequence, shape ``(frames, joints, 3)``. For + ``per_frame`` mode the frame counts must match; for + ``per_sequence`` mode they must also match (the correspondence + runs frame-by-frame and joint-by-joint). Use + :func:`pad_sequences` first if your sequences have different + lengths. + mode + ``"per_sequence"`` (default) fits a single rigid transform over + the whole sequence — good when the recording geometry is + stable across frames. ``"per_frame"`` fits an independent + transform per frame — good for matching pose shape while + discarding global trajectory. + scale + If ``True``, also fit a uniform scale factor. Useful for + cross-subject comparisons where the reference skeleton has a + different overall size. + + Returns + ------- + aligned_source + ``source`` transformed to align with ``target``, same shape as + the input. + target + The ``target`` array, unchanged. + diagnostics + :class:`AlignmentDiagnostics` summarising the fitted transform. + + Raises + ------ + ValueError + If ``source`` and ``target`` have different shapes or the + trailing axis is not of size 3. + + Notes + ----- + The Kabsch algorithm (Kabsch 1976, "A solution for the best + rotation to relate two sets of vectors") is a closed-form SVD + solution and does not iterate. Reflection is explicitly prevented + via a sign correction on the smallest singular value; the fitted + matrix is always a proper rotation (det = +1). + + In ``per_frame`` mode, rotation, translation, and scale + diagnostics are reported as means across frames, with + :attr:`AlignmentDiagnostics.rotation_deg_max` and + :attr:`AlignmentDiagnostics.translation_max` exposing the worst + frame for anomaly detection. + """ + if source.ndim != 3 or source.shape[-1] != 3: + raise ValueError(f"expected (frames, joints, 3); got source shape {source.shape}") + if source.shape != target.shape: + raise ValueError( + f"source and target must have the same shape; got {source.shape} and {target.shape}" + ) + + source = source.astype(float, copy=False) + target = target.astype(float, copy=False) + num_frames = source.shape[0] + + if mode == "per_sequence": + flat_source = source.reshape(-1, 3) + flat_target = target.reshape(-1, 3) + aligned_flat, rotation, scale_factor, translation = _kabsch_single( + flat_source, flat_target, scale=scale + ) + aligned = aligned_flat.reshape(source.shape) + rotation_deg = _rotation_magnitude_deg(rotation) + translation_mag = float(np.linalg.norm(translation)) + diagnostics = AlignmentDiagnostics( + mode="per_sequence", + rotation_deg=rotation_deg, + rotation_deg_max=rotation_deg, + translation=translation_mag, + translation_max=translation_mag, + scale=scale_factor, + ) + return aligned, target, diagnostics + + if mode == "per_frame": + aligned = np.empty_like(source) + rotation_degs = np.empty(num_frames, dtype=float) + translations = np.empty(num_frames, dtype=float) + scales = np.empty(num_frames, dtype=float) + for frame_idx in range(num_frames): + aligned_frame, rotation, scale_factor, translation = _kabsch_single( + source[frame_idx], target[frame_idx], scale=scale + ) + aligned[frame_idx] = aligned_frame + rotation_degs[frame_idx] = _rotation_magnitude_deg(rotation) + translations[frame_idx] = float(np.linalg.norm(translation)) + scales[frame_idx] = scale_factor + diagnostics = AlignmentDiagnostics( + mode="per_frame", + rotation_deg=float(rotation_degs.mean()) if num_frames else 0.0, + rotation_deg_max=float(rotation_degs.max()) if num_frames else 0.0, + translation=float(translations.mean()) if num_frames else 0.0, + translation_max=float(translations.max()) if num_frames else 0.0, + scale=float(scales.mean()) if num_frames else 1.0, + ) + return aligned, target, diagnostics + + raise ValueError(f"unknown mode {mode!r}; expected 'per_frame' or 'per_sequence'") + + # --------------------------------------------------------------------------- # Joint angles # --------------------------------------------------------------------------- diff --git a/tests/unit/test_analyzer_features.py b/tests/unit/test_analyzer_features.py index 8a7e6d6..828cc29 100644 --- a/tests/unit/test_analyzer_features.py +++ b/tests/unit/test_analyzer_features.py @@ -8,6 +8,7 @@ import numpy as np import pytest from neuropose.analyzer.features import ( + AlignmentDiagnostics, FeatureStatistics, extract_feature_statistics, extract_joint_angles, @@ -15,6 +16,7 @@ from neuropose.analyzer.features import ( normalize_pose_sequence, pad_sequences, predictions_to_numpy, + procrustes_align, ) from neuropose.io import VideoPredictions @@ -297,3 +299,177 @@ class TestFindPeaks: def test_rejects_2d_input(self) -> None: with pytest.raises(ValueError, match="1D"): find_peaks(np.zeros((5, 5))) + + +# --------------------------------------------------------------------------- +# procrustes_align +# --------------------------------------------------------------------------- + + +def _rotation_matrix_z(angle_rad: float) -> np.ndarray: + """Rotation matrix about the Z axis.""" + c, s = np.cos(angle_rad), np.sin(angle_rad) + return np.array( + [ + [c, -s, 0.0], + [s, c, 0.0], + [0.0, 0.0, 1.0], + ] + ) + + +def _skeleton(num_joints: int = 8, seed: int = 0) -> np.ndarray: + """A deterministic, non-degenerate single-frame skeleton.""" + rng = np.random.default_rng(seed) + return rng.standard_normal((num_joints, 3)) + + +class TestProcrustesAlignPerSequence: + def test_identical_sequences_yield_identity_transform(self) -> None: + sequence = _skeleton()[np.newaxis, :, :].repeat(3, axis=0) # (3, 8, 3) + aligned, target, diag = procrustes_align(sequence, sequence, mode="per_sequence") + np.testing.assert_allclose(aligned, sequence, atol=1e-10) + np.testing.assert_array_equal(target, sequence) + assert diag.mode == "per_sequence" + assert diag.rotation_deg == pytest.approx(0.0, abs=1e-6) + assert diag.translation == pytest.approx(0.0, abs=1e-9) + assert diag.scale == pytest.approx(1.0) + + def test_recovers_known_rotation(self) -> None: + # Build a reference sequence; construct the source by rotating it + # about Z, then verify alignment returns the reference up to + # floating-point error. + rotation = _rotation_matrix_z(np.deg2rad(37.0)) + reference = _skeleton(num_joints=10)[np.newaxis, :, :].repeat(4, axis=0) + source = reference @ rotation.T + aligned, _, diag = procrustes_align(source, reference, mode="per_sequence") + np.testing.assert_allclose(aligned, reference, atol=1e-8) + # The recovered rotation's magnitude should be the original 37°. + assert diag.rotation_deg == pytest.approx(37.0, abs=1e-4) + + def test_recovers_known_translation(self) -> None: + reference = _skeleton()[np.newaxis, :, :].repeat(5, axis=0) + translation = np.array([10.0, -4.5, 2.25]) + source = reference + translation + aligned, _, diag = procrustes_align(source, reference, mode="per_sequence") + np.testing.assert_allclose(aligned, reference, atol=1e-9) + # rotation_deg may be numerically tiny but not exactly 0. + assert diag.rotation_deg == pytest.approx(0.0, abs=1e-4) + assert diag.translation == pytest.approx(np.linalg.norm(translation), rel=1e-6) + + def test_recovers_combined_rotation_and_translation(self) -> None: + rotation = _rotation_matrix_z(np.deg2rad(-12.0)) + translation = np.array([1.0, 2.0, 3.0]) + reference = _skeleton(num_joints=6)[np.newaxis, :, :].repeat(3, axis=0) + source = reference @ rotation.T + translation + aligned, _, diag = procrustes_align(source, reference, mode="per_sequence") + np.testing.assert_allclose(aligned, reference, atol=1e-8) + assert diag.rotation_deg == pytest.approx(12.0, abs=1e-4) + assert diag.translation == pytest.approx(np.linalg.norm(translation), rel=1e-4) + + def test_scale_flag_recovers_known_scale(self) -> None: + reference = _skeleton()[np.newaxis, :, :].repeat(2, axis=0) + source = reference * 0.5 + aligned, _, diag = procrustes_align(source, reference, mode="per_sequence", scale=True) + np.testing.assert_allclose(aligned, reference, atol=1e-8) + assert diag.scale == pytest.approx(2.0, rel=1e-6) + + def test_scale_flag_off_leaves_scale_at_one(self) -> None: + reference = _skeleton()[np.newaxis, :, :].repeat(2, axis=0) + source = reference * 0.5 + _, _, diag = procrustes_align(source, reference, mode="per_sequence", scale=False) + assert diag.scale == pytest.approx(1.0) + + def test_rejects_mismatched_shapes(self) -> None: + a = np.zeros((4, 8, 3)) + b = np.zeros((4, 7, 3)) + with pytest.raises(ValueError, match="same shape"): + procrustes_align(a, b) + + def test_rejects_wrong_trailing_axis(self) -> None: + a = np.zeros((4, 8, 2)) + b = np.zeros((4, 8, 2)) + with pytest.raises(ValueError, match="joints, 3"): + procrustes_align(a, b) + + def test_rejects_unknown_mode(self) -> None: + a = np.zeros((2, 4, 3)) + with pytest.raises(ValueError, match="unknown mode"): + procrustes_align(a, a, mode="nope") # type: ignore[arg-type] + + def test_does_not_mutate_inputs(self) -> None: + source = _skeleton()[np.newaxis, :, :].repeat(3, axis=0).copy() + target = (source @ _rotation_matrix_z(np.deg2rad(10.0)).T).copy() + source_before = source.copy() + target_before = target.copy() + procrustes_align(source, target, mode="per_sequence") + np.testing.assert_array_equal(source, source_before) + np.testing.assert_array_equal(target, target_before) + + def test_returns_alignment_diagnostics_dataclass(self) -> None: + a = _skeleton()[np.newaxis, :, :].repeat(2, axis=0) + _, _, diag = procrustes_align(a, a) + assert isinstance(diag, AlignmentDiagnostics) + + +class TestProcrustesAlignPerFrame: + def test_per_frame_recovers_varying_rotations(self) -> None: + # Each frame is rotated by a different angle; per_frame alignment + # should recover each frame independently. + num_frames = 4 + reference_frame = _skeleton(num_joints=6) + angles = np.deg2rad([5.0, -10.0, 20.0, 45.0]) + reference = np.stack([reference_frame for _ in range(num_frames)], axis=0) + source = np.stack([reference_frame @ _rotation_matrix_z(a).T for a in angles], axis=0) + aligned, _, diag = procrustes_align(source, reference, mode="per_frame") + np.testing.assert_allclose(aligned, reference, atol=1e-8) + assert diag.mode == "per_frame" + # The max rotation across frames should be 45°. + assert diag.rotation_deg_max == pytest.approx(45.0, abs=1e-4) + # The mean rotation across frames should be 20°. + assert diag.rotation_deg == pytest.approx(20.0, abs=1e-4) + + def test_per_frame_with_identical_sequences_yields_zero(self) -> None: + sequence = _skeleton(num_joints=5)[np.newaxis, :, :].repeat(3, axis=0) + aligned, _, diag = procrustes_align(sequence, sequence, mode="per_frame") + np.testing.assert_allclose(aligned, sequence, atol=1e-10) + # Per-frame SVD on a symmetric covariance is numerically ambiguous + # in axis selection, so the fitted rotation can be a few micro- + # degrees off zero; the residual positions are still exact. + assert diag.rotation_deg == pytest.approx(0.0, abs=1e-3) + assert diag.rotation_deg_max == pytest.approx(0.0, abs=1e-3) + assert diag.translation == pytest.approx(0.0, abs=1e-9) + + +# --------------------------------------------------------------------------- +# DTW with align= (integration) +# --------------------------------------------------------------------------- + + +class TestDtwAlignIntegration: + """Smoke tests: align= routes through procrustes_align correctly. + + Depth tests of the DTW path itself live in test_analyzer_dtw. + """ + + def test_dtw_all_with_alignment_cancels_rigid_offset(self) -> None: + pytest.importorskip("fastdtw") + from neuropose.analyzer.dtw import dtw_all + + rotation = _rotation_matrix_z(np.deg2rad(30.0)) + translation = np.array([5.0, -2.0, 1.0]) + reference = _skeleton(num_joints=6)[np.newaxis, :, :].repeat(4, axis=0) + source = reference @ rotation.T + translation + baseline = dtw_all(source, reference, align="none") + aligned_result = dtw_all(source, reference, align="procrustes_per_sequence") + assert baseline.distance > 0.0 + assert aligned_result.distance == pytest.approx(0.0, abs=1e-6) + + def test_dtw_align_rejects_mismatched_frame_counts(self) -> None: + pytest.importorskip("fastdtw") + from neuropose.analyzer.dtw import dtw_all + + a = np.zeros((5, 3, 3)) + b = np.zeros((6, 3, 3)) + with pytest.raises(ValueError, match="matching frame counts"): + dtw_all(a, b, align="procrustes_per_sequence")