From cc9fcb4adb7f040d2c650365fe0c5a65ebd8ffa6 Mon Sep 17 00:00:00 2001 From: Levi Neuwirth Date: Sat, 18 Apr 2026 17:11:53 -0400 Subject: [PATCH] add Procrustes alignment to analyzer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit procrustes_align in neuropose.analyzer.features — Kabsch closed-form rigid alignment between two pose sequences, with per_frame and per_sequence modes and an optional scale flag for cross-subject comparisons. Returns aligned arrays plus an AlignmentDiagnostics dataclass reporting rotation magnitude (mean and max), translation magnitude (mean and max), and scale factor, so downstream code can flag suspiciously large transforms. Wired into every DTW entry point via a new keyword-only align parameter — "none" (the default) preserves the 0.1 raw-coordinate behaviour, while "procrustes_per_frame" and "procrustes_per_sequence" route inputs through procrustes_align before DTW runs. Rejects mismatched frame counts when alignment is requested (Procrustes requires a 1:1 correspondence). Phase 0 of TECHNICAL.md: closes one of the three methodological gaps Paper C's pipeline is waiting on. --- CHANGELOG.md | 27 ++- src/neuropose/analyzer/__init__.py | 8 + src/neuropose/analyzer/dtw.py | 93 +++++++++- src/neuropose/analyzer/features.py | 252 ++++++++++++++++++++++++++- tests/unit/test_analyzer_features.py | 176 +++++++++++++++++++ 5 files changed, 543 insertions(+), 13 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 87a9c71..83c8802 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -222,6 +222,23 @@ be split into per-release sections once tagging begins. at `CURRENT_VERSION = 2`, with registered v1 → v2 migrations for `VideoPredictions` and `BenchmarkResult` that add the optional `provenance` field. +- **`neuropose.analyzer.features.procrustes_align`** — Kabsch + rigid-alignment helper for pose sequences, plus a + `ProcrustesMode` literal (`"per_frame"` | `"per_sequence"`) and a + frozen `AlignmentDiagnostics` dataclass (`rotation_deg`, + `rotation_deg_max`, `translation`, `translation_max`, `scale`, + plus the mode that produced them). Per-sequence mode fits one + rigid transform across the whole trial; per-frame fits an + independent transform per frame. Optional `scale=True` fits a + uniform scale factor for cross-subject comparisons. Wired into + every DTW entry point in `neuropose.analyzer.dtw` via a new + keyword-only `align: AlignMode = "none"` parameter — `"none"` + preserves the 0.1 raw-coordinate behaviour, while + `"procrustes_per_frame"` and `"procrustes_per_sequence"` route + inputs through `procrustes_align` before DTW runs so the returned + distance is rotation- and translation-invariant. Paper C's + pipeline is expected to set `align="procrustes_per_sequence"`; + see `TECHNICAL.md` Phase 0. - **`neuropose.io.Provenance`** — reproducibility envelope for every inference run. Populated automatically by `Estimator.process_video` when the model was loaded via `load_model` (the production path) @@ -262,10 +279,12 @@ be split into per-release sections once tagging begins. methodology investigation. - `analyzer.features` — `predictions_to_numpy`, `normalize_pose_sequence` (uniform and axis-wise), - `pad_sequences` (edge-padding), `extract_joint_angles` (NaN on - degenerate vectors), `extract_feature_statistics` - (`FeatureStatistics` frozen dataclass), and a `find_peaks` thin - wrapper around `scipy.signal.find_peaks`. + `pad_sequences` (edge-padding), `procrustes_align` (Kabsch + rigid alignment, per-frame or per-sequence, optional uniform + scaling), `extract_joint_angles` (NaN on degenerate vectors), + `extract_feature_statistics` (`FeatureStatistics` frozen + dataclass), and a `find_peaks` thin wrapper around + `scipy.signal.find_peaks`. - `analyzer.segment` — repetition segmentation for trials in which a subject performs the same movement several times. A three-layer API: `segment_by_peaks` (pure 1D diff --git a/src/neuropose/analyzer/__init__.py b/src/neuropose/analyzer/__init__.py index ec7e811..98b6dbe 100644 --- a/src/neuropose/analyzer/__init__.py +++ b/src/neuropose/analyzer/__init__.py @@ -28,19 +28,23 @@ here for ergonomic access. from __future__ import annotations from neuropose.analyzer.dtw import ( + AlignMode, DTWResult, dtw_all, dtw_per_joint, dtw_relation, ) from neuropose.analyzer.features import ( + AlignmentDiagnostics, FeatureStatistics, + ProcrustesMode, extract_feature_statistics, extract_joint_angles, find_peaks, normalize_pose_sequence, pad_sequences, predictions_to_numpy, + procrustes_align, ) from neuropose.analyzer.segment import ( JOINT_INDEX, @@ -59,8 +63,11 @@ from neuropose.analyzer.segment import ( __all__ = [ "JOINT_INDEX", "JOINT_NAMES", + "AlignMode", + "AlignmentDiagnostics", "DTWResult", "FeatureStatistics", + "ProcrustesMode", "dtw_all", "dtw_per_joint", "dtw_relation", @@ -76,6 +83,7 @@ __all__ = [ "normalize_pose_sequence", "pad_sequences", "predictions_to_numpy", + "procrustes_align", "segment_by_peaks", "segment_predictions", "slice_predictions", diff --git a/src/neuropose/analyzer/dtw.py b/src/neuropose/analyzer/dtw.py index e4f94d9..d16cd36 100644 --- a/src/neuropose/analyzer/dtw.py +++ b/src/neuropose/analyzer/dtw.py @@ -16,6 +16,12 @@ and the warping path. Inputs are expected to be ``(frames, joints, 3)`` numpy arrays — the shape :func:`~neuropose.analyzer.features.predictions_to_numpy` produces. +All three also accept an ``align`` argument that routes the inputs +through :func:`~neuropose.analyzer.features.procrustes_align` before +DTW runs, yielding translation- and rotation-invariant distances. +``align="none"`` (the default) preserves the raw-coordinate behaviour +shipped in 0.1. + Dependency note --------------- This module requires :mod:`fastdtw` and :mod:`scipy`, which are part of @@ -30,9 +36,21 @@ from __future__ import annotations from collections.abc import Callable from dataclasses import dataclass +from typing import Literal import numpy as np +from neuropose.analyzer.features import procrustes_align + +AlignMode = Literal["none", "procrustes_per_frame", "procrustes_per_sequence"] +"""Alignment selector for DTW entry points. + +- ``"none"`` — feed raw coordinates directly to DTW. +- ``"procrustes_per_frame"`` — per-frame Kabsch alignment before DTW. +- ``"procrustes_per_sequence"`` — single sequence-wide Kabsch + alignment before DTW. +""" + @dataclass(frozen=True) class DTWResult: @@ -77,7 +95,12 @@ def _require_fastdtw() -> tuple[Callable, Callable]: return fastdtw, euclidean -def dtw_all(a: np.ndarray, b: np.ndarray) -> DTWResult: +def dtw_all( + a: np.ndarray, + b: np.ndarray, + *, + align: AlignMode = "none", +) -> DTWResult: """DTW on the flattened per-frame joint vector. Each frame's joints are collapsed into a single vector before DTW @@ -90,7 +113,12 @@ def dtw_all(a: np.ndarray, b: np.ndarray) -> DTWResult: a, b Pose sequences as ``(frames, joints, 3)`` numpy arrays. The two sequences do not need to have the same number of frames, but - they must have the same number of joints. + they must have the same number of joints. When ``align`` is not + ``"none"``, the two sequences must additionally share a frame + count (Procrustes requires a 1:1 correspondence). + align + Procrustes alignment mode applied before DTW. See + :data:`AlignMode`. Returns ------- @@ -101,9 +129,11 @@ def dtw_all(a: np.ndarray, b: np.ndarray) -> DTWResult: Raises ------ ValueError - If ``a`` and ``b`` do not have the same joint count. + If ``a`` and ``b`` do not have the same joint count, or if + ``align`` requires a matching frame count that is not present. """ _validate_same_joint_count(a, b) + a, b = _maybe_align(a, b, align=align) fastdtw, euclidean = _require_fastdtw() a_flat = a.reshape(a.shape[0], -1) b_flat = b.reshape(b.shape[0], -1) @@ -111,7 +141,12 @@ def dtw_all(a: np.ndarray, b: np.ndarray) -> DTWResult: return DTWResult(distance=float(distance), path=[tuple(p) for p in path]) -def dtw_per_joint(a: np.ndarray, b: np.ndarray) -> list[DTWResult]: +def dtw_per_joint( + a: np.ndarray, + b: np.ndarray, + *, + align: AlignMode = "none", +) -> list[DTWResult]: """DTW on each joint independently. Performs one DTW computation per joint, yielding a list of @@ -124,7 +159,11 @@ def dtw_per_joint(a: np.ndarray, b: np.ndarray) -> list[DTWResult]: a, b Pose sequences as ``(frames, joints, 3)`` numpy arrays. The two sequences do not need to have the same number of frames but - must have the same number of joints. + must have the same number of joints. When ``align`` is not + ``"none"``, they must additionally share a frame count. + align + Procrustes alignment mode applied before DTW. See + :data:`AlignMode`. Returns ------- @@ -134,9 +173,11 @@ def dtw_per_joint(a: np.ndarray, b: np.ndarray) -> list[DTWResult]: Raises ------ ValueError - If ``a`` and ``b`` do not have the same joint count. + If ``a`` and ``b`` do not have the same joint count, or if + ``align`` requires a matching frame count that is not present. """ _validate_same_joint_count(a, b) + a, b = _maybe_align(a, b, align=align) fastdtw, euclidean = _require_fastdtw() results: list[DTWResult] = [] for joint_idx in range(a.shape[1]): @@ -152,6 +193,8 @@ def dtw_relation( b: np.ndarray, joint_i: int, joint_j: int, + *, + align: AlignMode = "none", ) -> DTWResult: """DTW on the displacement vector between two specific joints. @@ -170,6 +213,12 @@ def dtw_relation( Indices of the two joints whose relative position should be compared. Must be valid indices into ``a`` and ``b``'s joint axis. + align + Procrustes alignment mode applied to the full sequences + before the displacement vectors are extracted. See + :data:`AlignMode`. Note that displacement vectors are already + translation-invariant; alignment is still useful for cancelling + camera rotation between trials. Returns ------- @@ -179,8 +228,9 @@ def dtw_relation( Raises ------ ValueError - If the sequences have different joint counts or if either joint - index is out of range. + If the sequences have different joint counts, either joint + index is out of range, or ``align`` requires a matching frame + count that is not present. """ _validate_same_joint_count(a, b) num_joints = a.shape[1] @@ -188,6 +238,7 @@ def dtw_relation( raise ValueError( f"joint indices must be in [0, {num_joints}); got joint_i={joint_i}, joint_j={joint_j}" ) + a, b = _maybe_align(a, b, align=align) fastdtw, euclidean = _require_fastdtw() disp_a = a[:, joint_j, :] - a[:, joint_i, :] disp_b = b[:, joint_j, :] - b[:, joint_i, :] @@ -206,3 +257,29 @@ def _validate_same_joint_count(a: np.ndarray, b: np.ndarray) -> None: f"input arrays disagree on joint count: " f"a has {a.shape[1]} joints, b has {b.shape[1]} joints" ) + + +def _maybe_align( + a: np.ndarray, + b: np.ndarray, + *, + align: AlignMode, +) -> tuple[np.ndarray, np.ndarray]: + """Apply Procrustes alignment if ``align`` requests it. + + Procrustes requires a frame-by-frame correspondence, so this + helper rejects calls where the two sequences disagree on frame + count and ``align`` is not ``"none"``. Pad upstream with + :func:`~neuropose.analyzer.features.pad_sequences` if the lengths + differ. + """ + if align == "none": + return a, b + if a.shape[0] != b.shape[0]: + raise ValueError( + f"align={align!r} requires matching frame counts; " + f"got a with {a.shape[0]} frames and b with {b.shape[0]} frames" + ) + mode = "per_frame" if align == "procrustes_per_frame" else "per_sequence" + aligned_a, _target, _diag = procrustes_align(a, b, mode=mode) + return aligned_a, b diff --git a/src/neuropose/analyzer/features.py b/src/neuropose/analyzer/features.py index 8998896..18853aa 100644 --- a/src/neuropose/analyzer/features.py +++ b/src/neuropose/analyzer/features.py @@ -14,6 +14,8 @@ The following helpers are provided: fit in the unit cube (either per-axis or uniform). - :func:`pad_sequences` — edge-pad a batch of sequences to a common length, suitable for downstream tensor-based analysis. +- :func:`procrustes_align` — rigid-align one pose sequence to another + via the Kabsch algorithm, with optional uniform scaling. - :func:`extract_joint_angles` — compute joint angles at specified triplet positions across a pose sequence. - :func:`extract_feature_statistics` — summary statistics @@ -26,12 +28,19 @@ from __future__ import annotations from collections.abc import Sequence from dataclasses import dataclass -from typing import Any +from typing import Any, Literal import numpy as np from neuropose.io import VideoPredictions +ProcrustesMode = Literal["per_frame", "per_sequence"] +"""Mode selector for :func:`procrustes_align`. + +``per_sequence`` computes a single rigid transform over the whole +sequence; ``per_frame`` aligns every frame independently. +""" + # --------------------------------------------------------------------------- # VideoPredictions → numpy # --------------------------------------------------------------------------- @@ -211,6 +220,247 @@ def pad_sequences( return padded +# --------------------------------------------------------------------------- +# Procrustes alignment (Kabsch) +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class AlignmentDiagnostics: + """Summary of the rigid transform fitted by :func:`procrustes_align`. + + Attributes + ---------- + mode + Which alignment mode produced this result; mirrors the ``mode`` + argument passed to :func:`procrustes_align`. + rotation_deg + Magnitude of the fitted rotation, in degrees, computed as + ``arccos((trace(R) - 1) / 2)``. For ``per_frame`` mode this is + the mean magnitude across frames. + rotation_deg_max + Worst-case (maximum) rotation magnitude across frames. + Equal to :attr:`rotation_deg` in ``per_sequence`` mode. + translation + Magnitude of the fitted translation vector, in the same units + as the input (millimetres for MeTRAbs output). For ``per_frame`` + mode this is the mean magnitude across frames. + translation_max + Worst-case (maximum) translation magnitude across frames. + Equal to :attr:`translation` in ``per_sequence`` mode. + scale + Applied uniform scale factor. Always ``1.0`` when + ``procrustes_align`` was called with ``scale=False``. In + ``per_frame`` mode this is the mean scale across frames. + """ + + mode: ProcrustesMode + rotation_deg: float + rotation_deg_max: float + translation: float + translation_max: float + scale: float + + +def _kabsch_single( + source: np.ndarray, + target: np.ndarray, + *, + scale: bool, +) -> tuple[np.ndarray, np.ndarray, float, np.ndarray]: + """Fit the optimal rigid (+ optional uniform scale) transform. + + Aligns ``source`` to ``target`` via the closed-form Kabsch + algorithm and returns ``(aligned_source, R, s, t)`` where + ``aligned_source = s * (source - centroid_source) @ R.T + centroid_target + t_fine`` + (with ``t_fine`` absorbed for convenience — aligned points match + the target's centroid to within floating-point error). + + Parameters + ---------- + source + ``(N, 3)`` point set to align. + target + ``(N, 3)`` reference point set. Must have the same shape as + ``source``. + scale + If ``True``, fit a uniform scale factor; otherwise lock to + ``1.0``. + + Returns + ------- + aligned_source + ``(N, 3)`` aligned copy of ``source``. + R + ``(3, 3)`` rotation matrix. + s + Scalar scale factor (``1.0`` when ``scale=False``). + t + ``(3,)`` translation vector in world coordinates such that + ``aligned_source[i] = s * R @ source[i] + t``. + """ + centroid_source = source.mean(axis=0) + centroid_target = target.mean(axis=0) + source_centered = source - centroid_source + target_centered = target - centroid_target + + covariance = source_centered.T @ target_centered + u_mat, sigma, vt_mat = np.linalg.svd(covariance) + reflection_sign = float(np.sign(np.linalg.det(vt_mat.T @ u_mat.T))) + # Guard against the degenerate det == 0 case (coplanar points). + if reflection_sign == 0.0: + reflection_sign = 1.0 + diag = np.diag([1.0, 1.0, reflection_sign]) + rotation = vt_mat.T @ diag @ u_mat.T + + if scale: + source_var = float((source_centered**2).sum()) + if source_var <= 0.0: + scale_factor = 1.0 + else: + scale_factor = float((sigma * np.array([1.0, 1.0, reflection_sign])).sum() / source_var) + else: + scale_factor = 1.0 + + translation = centroid_target - scale_factor * rotation @ centroid_source + aligned = scale_factor * source @ rotation.T + translation + return aligned, rotation, scale_factor, translation + + +def _rotation_magnitude_deg(rotation: np.ndarray) -> float: + """Return the rotation angle (degrees) represented by ``rotation``. + + Uses the axis-angle relation ``cos(theta) = (trace(R) - 1) / 2``. + """ + cos_theta = (float(np.trace(rotation)) - 1.0) / 2.0 + cos_theta = max(-1.0, min(1.0, cos_theta)) + return float(np.degrees(np.arccos(cos_theta))) + + +def procrustes_align( + source: np.ndarray, + target: np.ndarray, + *, + mode: ProcrustesMode = "per_sequence", + scale: bool = False, +) -> tuple[np.ndarray, np.ndarray, AlignmentDiagnostics]: + """Rigid-align ``source`` to ``target`` via the Kabsch algorithm. + + Fits the optimal rigid transform (optionally including uniform + scaling) that minimizes the sum of squared distances between + corresponding joints. The transform is always applied to + ``source``; ``target`` is returned unchanged alongside it for + symmetry with downstream DTW callers, which typically consume both + aligned arrays as a pair. + + Parameters + ---------- + source + Pose sequence to align, shape ``(frames, joints, 3)``. + target + Reference pose sequence, shape ``(frames, joints, 3)``. For + ``per_frame`` mode the frame counts must match; for + ``per_sequence`` mode they must also match (the correspondence + runs frame-by-frame and joint-by-joint). Use + :func:`pad_sequences` first if your sequences have different + lengths. + mode + ``"per_sequence"`` (default) fits a single rigid transform over + the whole sequence — good when the recording geometry is + stable across frames. ``"per_frame"`` fits an independent + transform per frame — good for matching pose shape while + discarding global trajectory. + scale + If ``True``, also fit a uniform scale factor. Useful for + cross-subject comparisons where the reference skeleton has a + different overall size. + + Returns + ------- + aligned_source + ``source`` transformed to align with ``target``, same shape as + the input. + target + The ``target`` array, unchanged. + diagnostics + :class:`AlignmentDiagnostics` summarising the fitted transform. + + Raises + ------ + ValueError + If ``source`` and ``target`` have different shapes or the + trailing axis is not of size 3. + + Notes + ----- + The Kabsch algorithm (Kabsch 1976, "A solution for the best + rotation to relate two sets of vectors") is a closed-form SVD + solution and does not iterate. Reflection is explicitly prevented + via a sign correction on the smallest singular value; the fitted + matrix is always a proper rotation (det = +1). + + In ``per_frame`` mode, rotation, translation, and scale + diagnostics are reported as means across frames, with + :attr:`AlignmentDiagnostics.rotation_deg_max` and + :attr:`AlignmentDiagnostics.translation_max` exposing the worst + frame for anomaly detection. + """ + if source.ndim != 3 or source.shape[-1] != 3: + raise ValueError(f"expected (frames, joints, 3); got source shape {source.shape}") + if source.shape != target.shape: + raise ValueError( + f"source and target must have the same shape; got {source.shape} and {target.shape}" + ) + + source = source.astype(float, copy=False) + target = target.astype(float, copy=False) + num_frames = source.shape[0] + + if mode == "per_sequence": + flat_source = source.reshape(-1, 3) + flat_target = target.reshape(-1, 3) + aligned_flat, rotation, scale_factor, translation = _kabsch_single( + flat_source, flat_target, scale=scale + ) + aligned = aligned_flat.reshape(source.shape) + rotation_deg = _rotation_magnitude_deg(rotation) + translation_mag = float(np.linalg.norm(translation)) + diagnostics = AlignmentDiagnostics( + mode="per_sequence", + rotation_deg=rotation_deg, + rotation_deg_max=rotation_deg, + translation=translation_mag, + translation_max=translation_mag, + scale=scale_factor, + ) + return aligned, target, diagnostics + + if mode == "per_frame": + aligned = np.empty_like(source) + rotation_degs = np.empty(num_frames, dtype=float) + translations = np.empty(num_frames, dtype=float) + scales = np.empty(num_frames, dtype=float) + for frame_idx in range(num_frames): + aligned_frame, rotation, scale_factor, translation = _kabsch_single( + source[frame_idx], target[frame_idx], scale=scale + ) + aligned[frame_idx] = aligned_frame + rotation_degs[frame_idx] = _rotation_magnitude_deg(rotation) + translations[frame_idx] = float(np.linalg.norm(translation)) + scales[frame_idx] = scale_factor + diagnostics = AlignmentDiagnostics( + mode="per_frame", + rotation_deg=float(rotation_degs.mean()) if num_frames else 0.0, + rotation_deg_max=float(rotation_degs.max()) if num_frames else 0.0, + translation=float(translations.mean()) if num_frames else 0.0, + translation_max=float(translations.max()) if num_frames else 0.0, + scale=float(scales.mean()) if num_frames else 1.0, + ) + return aligned, target, diagnostics + + raise ValueError(f"unknown mode {mode!r}; expected 'per_frame' or 'per_sequence'") + + # --------------------------------------------------------------------------- # Joint angles # --------------------------------------------------------------------------- diff --git a/tests/unit/test_analyzer_features.py b/tests/unit/test_analyzer_features.py index 8a7e6d6..828cc29 100644 --- a/tests/unit/test_analyzer_features.py +++ b/tests/unit/test_analyzer_features.py @@ -8,6 +8,7 @@ import numpy as np import pytest from neuropose.analyzer.features import ( + AlignmentDiagnostics, FeatureStatistics, extract_feature_statistics, extract_joint_angles, @@ -15,6 +16,7 @@ from neuropose.analyzer.features import ( normalize_pose_sequence, pad_sequences, predictions_to_numpy, + procrustes_align, ) from neuropose.io import VideoPredictions @@ -297,3 +299,177 @@ class TestFindPeaks: def test_rejects_2d_input(self) -> None: with pytest.raises(ValueError, match="1D"): find_peaks(np.zeros((5, 5))) + + +# --------------------------------------------------------------------------- +# procrustes_align +# --------------------------------------------------------------------------- + + +def _rotation_matrix_z(angle_rad: float) -> np.ndarray: + """Rotation matrix about the Z axis.""" + c, s = np.cos(angle_rad), np.sin(angle_rad) + return np.array( + [ + [c, -s, 0.0], + [s, c, 0.0], + [0.0, 0.0, 1.0], + ] + ) + + +def _skeleton(num_joints: int = 8, seed: int = 0) -> np.ndarray: + """A deterministic, non-degenerate single-frame skeleton.""" + rng = np.random.default_rng(seed) + return rng.standard_normal((num_joints, 3)) + + +class TestProcrustesAlignPerSequence: + def test_identical_sequences_yield_identity_transform(self) -> None: + sequence = _skeleton()[np.newaxis, :, :].repeat(3, axis=0) # (3, 8, 3) + aligned, target, diag = procrustes_align(sequence, sequence, mode="per_sequence") + np.testing.assert_allclose(aligned, sequence, atol=1e-10) + np.testing.assert_array_equal(target, sequence) + assert diag.mode == "per_sequence" + assert diag.rotation_deg == pytest.approx(0.0, abs=1e-6) + assert diag.translation == pytest.approx(0.0, abs=1e-9) + assert diag.scale == pytest.approx(1.0) + + def test_recovers_known_rotation(self) -> None: + # Build a reference sequence; construct the source by rotating it + # about Z, then verify alignment returns the reference up to + # floating-point error. + rotation = _rotation_matrix_z(np.deg2rad(37.0)) + reference = _skeleton(num_joints=10)[np.newaxis, :, :].repeat(4, axis=0) + source = reference @ rotation.T + aligned, _, diag = procrustes_align(source, reference, mode="per_sequence") + np.testing.assert_allclose(aligned, reference, atol=1e-8) + # The recovered rotation's magnitude should be the original 37°. + assert diag.rotation_deg == pytest.approx(37.0, abs=1e-4) + + def test_recovers_known_translation(self) -> None: + reference = _skeleton()[np.newaxis, :, :].repeat(5, axis=0) + translation = np.array([10.0, -4.5, 2.25]) + source = reference + translation + aligned, _, diag = procrustes_align(source, reference, mode="per_sequence") + np.testing.assert_allclose(aligned, reference, atol=1e-9) + # rotation_deg may be numerically tiny but not exactly 0. + assert diag.rotation_deg == pytest.approx(0.0, abs=1e-4) + assert diag.translation == pytest.approx(np.linalg.norm(translation), rel=1e-6) + + def test_recovers_combined_rotation_and_translation(self) -> None: + rotation = _rotation_matrix_z(np.deg2rad(-12.0)) + translation = np.array([1.0, 2.0, 3.0]) + reference = _skeleton(num_joints=6)[np.newaxis, :, :].repeat(3, axis=0) + source = reference @ rotation.T + translation + aligned, _, diag = procrustes_align(source, reference, mode="per_sequence") + np.testing.assert_allclose(aligned, reference, atol=1e-8) + assert diag.rotation_deg == pytest.approx(12.0, abs=1e-4) + assert diag.translation == pytest.approx(np.linalg.norm(translation), rel=1e-4) + + def test_scale_flag_recovers_known_scale(self) -> None: + reference = _skeleton()[np.newaxis, :, :].repeat(2, axis=0) + source = reference * 0.5 + aligned, _, diag = procrustes_align(source, reference, mode="per_sequence", scale=True) + np.testing.assert_allclose(aligned, reference, atol=1e-8) + assert diag.scale == pytest.approx(2.0, rel=1e-6) + + def test_scale_flag_off_leaves_scale_at_one(self) -> None: + reference = _skeleton()[np.newaxis, :, :].repeat(2, axis=0) + source = reference * 0.5 + _, _, diag = procrustes_align(source, reference, mode="per_sequence", scale=False) + assert diag.scale == pytest.approx(1.0) + + def test_rejects_mismatched_shapes(self) -> None: + a = np.zeros((4, 8, 3)) + b = np.zeros((4, 7, 3)) + with pytest.raises(ValueError, match="same shape"): + procrustes_align(a, b) + + def test_rejects_wrong_trailing_axis(self) -> None: + a = np.zeros((4, 8, 2)) + b = np.zeros((4, 8, 2)) + with pytest.raises(ValueError, match="joints, 3"): + procrustes_align(a, b) + + def test_rejects_unknown_mode(self) -> None: + a = np.zeros((2, 4, 3)) + with pytest.raises(ValueError, match="unknown mode"): + procrustes_align(a, a, mode="nope") # type: ignore[arg-type] + + def test_does_not_mutate_inputs(self) -> None: + source = _skeleton()[np.newaxis, :, :].repeat(3, axis=0).copy() + target = (source @ _rotation_matrix_z(np.deg2rad(10.0)).T).copy() + source_before = source.copy() + target_before = target.copy() + procrustes_align(source, target, mode="per_sequence") + np.testing.assert_array_equal(source, source_before) + np.testing.assert_array_equal(target, target_before) + + def test_returns_alignment_diagnostics_dataclass(self) -> None: + a = _skeleton()[np.newaxis, :, :].repeat(2, axis=0) + _, _, diag = procrustes_align(a, a) + assert isinstance(diag, AlignmentDiagnostics) + + +class TestProcrustesAlignPerFrame: + def test_per_frame_recovers_varying_rotations(self) -> None: + # Each frame is rotated by a different angle; per_frame alignment + # should recover each frame independently. + num_frames = 4 + reference_frame = _skeleton(num_joints=6) + angles = np.deg2rad([5.0, -10.0, 20.0, 45.0]) + reference = np.stack([reference_frame for _ in range(num_frames)], axis=0) + source = np.stack([reference_frame @ _rotation_matrix_z(a).T for a in angles], axis=0) + aligned, _, diag = procrustes_align(source, reference, mode="per_frame") + np.testing.assert_allclose(aligned, reference, atol=1e-8) + assert diag.mode == "per_frame" + # The max rotation across frames should be 45°. + assert diag.rotation_deg_max == pytest.approx(45.0, abs=1e-4) + # The mean rotation across frames should be 20°. + assert diag.rotation_deg == pytest.approx(20.0, abs=1e-4) + + def test_per_frame_with_identical_sequences_yields_zero(self) -> None: + sequence = _skeleton(num_joints=5)[np.newaxis, :, :].repeat(3, axis=0) + aligned, _, diag = procrustes_align(sequence, sequence, mode="per_frame") + np.testing.assert_allclose(aligned, sequence, atol=1e-10) + # Per-frame SVD on a symmetric covariance is numerically ambiguous + # in axis selection, so the fitted rotation can be a few micro- + # degrees off zero; the residual positions are still exact. + assert diag.rotation_deg == pytest.approx(0.0, abs=1e-3) + assert diag.rotation_deg_max == pytest.approx(0.0, abs=1e-3) + assert diag.translation == pytest.approx(0.0, abs=1e-9) + + +# --------------------------------------------------------------------------- +# DTW with align= (integration) +# --------------------------------------------------------------------------- + + +class TestDtwAlignIntegration: + """Smoke tests: align= routes through procrustes_align correctly. + + Depth tests of the DTW path itself live in test_analyzer_dtw. + """ + + def test_dtw_all_with_alignment_cancels_rigid_offset(self) -> None: + pytest.importorskip("fastdtw") + from neuropose.analyzer.dtw import dtw_all + + rotation = _rotation_matrix_z(np.deg2rad(30.0)) + translation = np.array([5.0, -2.0, 1.0]) + reference = _skeleton(num_joints=6)[np.newaxis, :, :].repeat(4, axis=0) + source = reference @ rotation.T + translation + baseline = dtw_all(source, reference, align="none") + aligned_result = dtw_all(source, reference, align="procrustes_per_sequence") + assert baseline.distance > 0.0 + assert aligned_result.distance == pytest.approx(0.0, abs=1e-6) + + def test_dtw_align_rejects_mismatched_frame_counts(self) -> None: + pytest.importorskip("fastdtw") + from neuropose.analyzer.dtw import dtw_all + + a = np.zeros((5, 3, 3)) + b = np.zeros((6, 3, 3)) + with pytest.raises(ValueError, match="matching frame counts"): + dtw_all(a, b, align="procrustes_per_sequence")