diff --git a/CHANGELOG.md b/CHANGELOG.md index a887ff7..adde901 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -239,6 +239,24 @@ be split into per-release sections once tagging begins. distance is rotation- and translation-invariant. Paper C's pipeline is expected to set `align="procrustes_per_sequence"`; see `TECHNICAL.md` Phase 0. +- **`neuropose.analyzer.dtw.Representation`** and + **`neuropose.analyzer.dtw.NanPolicy`** — two new Literal types + exposing orthogonal DTW preprocessing knobs on every entry point. + `representation` (on `dtw_all` and `dtw_per_joint`) switches the + per-frame feature vector between `"coords"` (the 0.1 default) and + `"angles"`, which runs `extract_joint_angles` on the supplied + `angle_triplets` first — yielding distances that are translation-, + rotation-, and scale-invariant by construction, and directly + interpretable in clinical terms. `nan_policy` (on all three entry + points) selects `"propagate"` (surface fastdtw's ValueError on + NaN — the default), `"interpolate"` (linear fill per feature + column), or `"drop"` (remove NaN frames before DTW); the + policy is applied consistently whether NaN originated from the + angles pipeline or from corrupted upstream coordinates. + `dtw_relation` stays a standalone convenience entry point for + two-joint displacement DTW; users who prefer a unified API can + express the same computation via `dtw_all` with an appropriate + pair of angle triplets or run `dtw_relation` directly. - **`neuropose.analyzer.segment.segment_gait_cycles`** and **`segment_gait_cycles_bilateral`** — clinical convenience wrappers over `segment_predictions` that pre-fill a `joint_axis` @@ -313,7 +331,9 @@ be split into per-release sections once tagging begins. imports for the heavy dependencies: - `analyzer.dtw` — three DTW entry points (`dtw_all`, `dtw_per_joint`, `dtw_relation`) over fastdtw, with a frozen - `DTWResult` dataclass. See `RESEARCH.md` for the ongoing + `DTWResult` dataclass and three orthogonal preprocessing knobs + (`align`, `representation`, `nan_policy`). See `RESEARCH.md` + for the ongoing methodology investigation. - `analyzer.features` — `predictions_to_numpy`, `normalize_pose_sequence` (uniform and axis-wise), diff --git a/src/neuropose/analyzer/__init__.py b/src/neuropose/analyzer/__init__.py index fb0c97d..2d04464 100644 --- a/src/neuropose/analyzer/__init__.py +++ b/src/neuropose/analyzer/__init__.py @@ -30,6 +30,8 @@ from __future__ import annotations from neuropose.analyzer.dtw import ( AlignMode, DTWResult, + NanPolicy, + Representation, dtw_all, dtw_per_joint, dtw_relation, @@ -71,7 +73,9 @@ __all__ = [ "AxisLetter", "DTWResult", "FeatureStatistics", + "NanPolicy", "ProcrustesMode", + "Representation", "dtw_all", "dtw_per_joint", "dtw_relation", diff --git a/src/neuropose/analyzer/dtw.py b/src/neuropose/analyzer/dtw.py index d16cd36..9375549 100644 --- a/src/neuropose/analyzer/dtw.py +++ b/src/neuropose/analyzer/dtw.py @@ -2,10 +2,12 @@ Three entry points, ordered by increasing precision (and increasing cost): -- :func:`dtw_all` — DTW on the flattened per-frame joint vector. Fast but - coarse; collapses every joint axis into a single per-frame vector. -- :func:`dtw_per_joint` — DTW on each joint independently. Preserves - per-joint temporal alignment at the cost of one DTW call per joint. +- :func:`dtw_all` — DTW on the flattened per-frame feature vector. Fast + but coarse; collapses every joint axis (or every angle triplet) into + a single per-frame vector. +- :func:`dtw_per_joint` — DTW on each joint (or angle triplet) + independently. Preserves per-unit temporal alignment at the cost of + one DTW call per unit. - :func:`dtw_relation` — DTW on the displacement vector between two specific joints. This is the right tool when the research question is about the *relative* motion of a specific pair of joints (e.g. the @@ -16,11 +18,23 @@ and the warping path. Inputs are expected to be ``(frames, joints, 3)`` numpy arrays — the shape :func:`~neuropose.analyzer.features.predictions_to_numpy` produces. -All three also accept an ``align`` argument that routes the inputs -through :func:`~neuropose.analyzer.features.procrustes_align` before -DTW runs, yielding translation- and rotation-invariant distances. -``align="none"`` (the default) preserves the raw-coordinate behaviour -shipped in 0.1. +Three orthogonal preprocessing knobs are available on the entry points: + +- **``align``** routes the inputs through + :func:`~neuropose.analyzer.features.procrustes_align` before DTW runs, + yielding translation- and rotation-invariant distances. + ``align="none"`` (the default) preserves the raw-coordinate behaviour + shipped in 0.1. +- **``representation``** (on :func:`dtw_all` and :func:`dtw_per_joint`) + selects what each frame is reduced to before DTW. ``"coords"`` uses + the raw joint coordinates; ``"angles"`` replaces them with joint + angles computed at caller-supplied triplets via + :func:`~neuropose.analyzer.features.extract_joint_angles`, giving + DTW distances that are directly interpretable as clinical joint-range + comparisons. +- **``nan_policy``** decides how the DTW path handles non-finite values + in its input — typically a concern only for the angle representation, + where degenerate (zero-length) vectors produce NaN. See :data:`NanPolicy`. Dependency note --------------- @@ -34,13 +48,13 @@ called. from __future__ import annotations -from collections.abc import Callable +from collections.abc import Callable, Sequence from dataclasses import dataclass from typing import Literal import numpy as np -from neuropose.analyzer.features import procrustes_align +from neuropose.analyzer.features import extract_joint_angles, procrustes_align AlignMode = Literal["none", "procrustes_per_frame", "procrustes_per_sequence"] """Alignment selector for DTW entry points. @@ -51,6 +65,41 @@ AlignMode = Literal["none", "procrustes_per_frame", "procrustes_per_sequence"] alignment before DTW. """ +Representation = Literal["coords", "angles"] +"""Per-frame feature representation for :func:`dtw_all` and :func:`dtw_per_joint`. + +- ``"coords"`` — use the raw joint coordinates (the input's last two + axes). Preserves the 0.1 behaviour. +- ``"angles"`` — replace joints with joint angles at caller-supplied + triplets. Translation- and rotation-invariant by construction, + scale-invariant modulo the upstream normalization, and directly + interpretable in clinical terms ("knee flexion during swing phase"). + The ``angle_triplets`` keyword becomes mandatory in this mode. +""" + +NanPolicy = Literal["propagate", "interpolate", "drop"] +"""Per-feature NaN handling for the DTW input. + +NaN typically appears when ``representation="angles"`` encounters a +degenerate (zero-length) vector — the angle is undefined and +:func:`extract_joint_angles` propagates NaN rather than quietly returning +a stand-in value. + +- ``"propagate"`` (default) — pass NaN straight through to the DTW + engine. fastdtw validates its input via + :func:`numpy.asarray_chkfinite` and raises :class:`ValueError` + the moment a NaN appears, which is the safest default because it + makes the problem visible instead of quietly corrupting a + distance. +- ``"interpolate"`` — linearly interpolate NaN frames along each + feature column using neighbouring finite values. Reasonable when a + small number of frames are corrupted and the surrounding motion is + smooth; inappropriate when long stretches are missing. +- ``"drop"`` — remove any frame where *any* feature is NaN before DTW + runs. Simple, but compresses the time axis, so warping-path indices + refer to the *compacted* sequence rather than the original. +""" + @dataclass(frozen=True) class DTWResult: @@ -100,13 +149,18 @@ def dtw_all( b: np.ndarray, *, align: AlignMode = "none", + representation: Representation = "coords", + angle_triplets: Sequence[tuple[int, int, int]] | None = None, + nan_policy: NanPolicy = "propagate", ) -> DTWResult: - """DTW on the flattened per-frame joint vector. + """DTW on the flattened per-frame feature vector. - Each frame's joints are collapsed into a single vector before DTW - is applied. This is fast — one DTW call regardless of the joint - count — but loses per-joint temporal structure, so a small + Under the default ``representation="coords"`` each frame's joints + are collapsed into a single vector before DTW is applied — fast + (one DTW call regardless of joint count) but coarse, since a small timing mismatch on one joint can dominate the distance metric. + Switching to ``representation="angles"`` computes joint angles at + the supplied triplets first and flattens those instead. Parameters ---------- @@ -119,6 +173,15 @@ def dtw_all( align Procrustes alignment mode applied before DTW. See :data:`AlignMode`. + representation + Per-frame feature representation. See :data:`Representation`. + angle_triplets + Required when ``representation="angles"``. Sequence of + ``(a, b, c)`` joint-index triplets passed through to + :func:`~neuropose.analyzer.features.extract_joint_angles`. + Ignored otherwise. + nan_policy + How to handle NaN values in the DTW input. See :data:`NanPolicy`. Returns ------- @@ -129,15 +192,20 @@ def dtw_all( Raises ------ ValueError - If ``a`` and ``b`` do not have the same joint count, or if - ``align`` requires a matching frame count that is not present. + If ``a`` and ``b`` do not have the same joint count, if + ``align`` requires a matching frame count that is not present, + if ``representation="angles"`` is requested without + ``angle_triplets``, or if ``nan_policy="interpolate"`` + encounters an all-NaN column. """ _validate_same_joint_count(a, b) a, b = _maybe_align(a, b, align=align) + feat_a = _apply_representation(a, representation, angle_triplets=angle_triplets) + feat_b = _apply_representation(b, representation, angle_triplets=angle_triplets) + feat_a = _apply_nan_policy(feat_a, nan_policy) + feat_b = _apply_nan_policy(feat_b, nan_policy) fastdtw, euclidean = _require_fastdtw() - a_flat = a.reshape(a.shape[0], -1) - b_flat = b.reshape(b.shape[0], -1) - distance, path = fastdtw(a_flat, b_flat, dist=euclidean) + distance, path = fastdtw(feat_a, feat_b, dist=euclidean) return DTWResult(distance=float(distance), path=[tuple(p) for p in path]) @@ -146,13 +214,21 @@ def dtw_per_joint( b: np.ndarray, *, align: AlignMode = "none", + representation: Representation = "coords", + angle_triplets: Sequence[tuple[int, int, int]] | None = None, + nan_policy: NanPolicy = "propagate", ) -> list[DTWResult]: - """DTW on each joint independently. + """DTW on each joint (or angle triplet) independently. - Performs one DTW computation per joint, yielding a list of - :class:`DTWResult` objects in joint-index order. More precise than - :func:`dtw_all` because each joint's temporal alignment is optimised - separately, at the cost of J times more DTW calls for J joints. + Performs one DTW computation per unit, yielding a list of + :class:`DTWResult` objects in input order. More precise than + :func:`dtw_all` because each unit's temporal alignment is optimised + separately, at the cost of J times more DTW calls for J units. + + Under the default ``representation="coords"`` a "unit" is one of + the input's joints (xyz treated jointly). Under + ``representation="angles"`` a "unit" is one scalar angle column + computed from one ``angle_triplets`` entry. Parameters ---------- @@ -164,26 +240,54 @@ def dtw_per_joint( align Procrustes alignment mode applied before DTW. See :data:`AlignMode`. + representation + Per-frame feature representation. See :data:`Representation`. + angle_triplets + Required when ``representation="angles"``; see + :func:`dtw_all` for details. + nan_policy + How to handle NaN values in the DTW input. See :data:`NanPolicy`. Returns ------- list[DTWResult] - One DTW result per joint, in index order. + One DTW result per joint or per angle triplet, in input order. Raises ------ ValueError - If ``a`` and ``b`` do not have the same joint count, or if - ``align`` requires a matching frame count that is not present. + Same conditions as :func:`dtw_all`. """ _validate_same_joint_count(a, b) a, b = _maybe_align(a, b, align=align) + + if representation == "coords": + feat_a = a + feat_b = b + # (frames, joints, 3) — one DTW per joint over its (frames, 3) slice. + num_units = feat_a.shape[1] + slicers: list[Callable[[np.ndarray], np.ndarray]] = [ + (lambda arr, idx=i: arr[:, idx, :]) for i in range(num_units) + ] + else: # "angles" + if angle_triplets is None: + raise ValueError("representation='angles' requires angle_triplets") + feat_a = extract_joint_angles(a, angle_triplets) # (frames, num_triplets) + feat_b = extract_joint_angles(b, angle_triplets) + num_units = feat_a.shape[1] + slicers = [ + # Scalar columns become 2D for DTW (fastdtw expects a + # sequence of vectors, not a sequence of scalars). + (lambda arr, idx=i: arr[:, idx : idx + 1]) + for i in range(num_units) + ] + fastdtw, euclidean = _require_fastdtw() results: list[DTWResult] = [] - for joint_idx in range(a.shape[1]): - a_joint = a[:, joint_idx, :] - b_joint = b[:, joint_idx, :] - distance, path = fastdtw(a_joint, b_joint, dist=euclidean) + for slicer in slicers: + unit_a = _apply_nan_policy(slicer(feat_a), nan_policy) + unit_b = _apply_nan_policy(slicer(feat_b), nan_policy) + distance, path = fastdtw(unit_a, unit_b, dist=euclidean) results.append(DTWResult(distance=float(distance), path=[tuple(p) for p in path])) return results @@ -195,6 +299,7 @@ def dtw_relation( joint_j: int, *, align: AlignMode = "none", + nan_policy: NanPolicy = "propagate", ) -> DTWResult: """DTW on the displacement vector between two specific joints. @@ -219,6 +324,8 @@ def dtw_relation( :data:`AlignMode`. Note that displacement vectors are already translation-invariant; alignment is still useful for cancelling camera rotation between trials. + nan_policy + How to handle NaN values in the DTW input. See :data:`NanPolicy`. Returns ------- @@ -239,9 +346,9 @@ def dtw_relation( f"joint indices must be in [0, {num_joints}); got joint_i={joint_i}, joint_j={joint_j}" ) a, b = _maybe_align(a, b, align=align) + disp_a = _apply_nan_policy(a[:, joint_j, :] - a[:, joint_i, :], nan_policy) + disp_b = _apply_nan_policy(b[:, joint_j, :] - b[:, joint_i, :], nan_policy) fastdtw, euclidean = _require_fastdtw() - disp_a = a[:, joint_j, :] - a[:, joint_i, :] - disp_b = b[:, joint_j, :] - b[:, joint_i, :] distance, path = fastdtw(disp_a, disp_b, dist=euclidean) return DTWResult(distance=float(distance), path=[tuple(p) for p in path]) @@ -283,3 +390,69 @@ def _maybe_align( mode = "per_frame" if align == "procrustes_per_frame" else "per_sequence" aligned_a, _target, _diag = procrustes_align(a, b, mode=mode) return aligned_a, b + + +def _apply_representation( + sequence: np.ndarray, + representation: Representation, + *, + angle_triplets: Sequence[tuple[int, int, int]] | None, +) -> np.ndarray: + """Reduce a ``(frames, joints, 3)`` sequence to DTW-ready 2D features. + + ``"coords"`` reshapes to ``(frames, joints * 3)``; ``"angles"`` + runs :func:`extract_joint_angles` to produce + ``(frames, len(angle_triplets))``. + """ + if representation == "coords": + return sequence.reshape(sequence.shape[0], -1) + if representation == "angles": + if angle_triplets is None: + raise ValueError("representation='angles' requires angle_triplets") + return extract_joint_angles(sequence, angle_triplets) + raise ValueError(f"unknown representation {representation!r}") + + +def _apply_nan_policy(features: np.ndarray, policy: NanPolicy) -> np.ndarray: + """Handle NaN values in a ``(frames, features)`` array per ``policy``. + + ``"propagate"`` is a no-op. ``"interpolate"`` runs 1D linear + interpolation along the frame axis within each feature column, + leaving finite data untouched. ``"drop"`` removes any frame where + *any* feature is NaN. + + Raises + ------ + ValueError + If ``"interpolate"`` encounters a column that is entirely NaN + (no finite anchors to interpolate between), or if ``"drop"`` + leaves an empty sequence. + """ + if policy == "propagate": + return features + if features.ndim == 1: + features = features.reshape(-1, 1) + if policy == "drop": + keep = np.isfinite(features).all(axis=1) + dropped = features[keep] + if dropped.shape[0] == 0: + raise ValueError( + "nan_policy='drop' removed every frame; DTW needs a non-empty sequence" + ) + return dropped + if policy == "interpolate": + out = features.astype(float, copy=True) + num_frames = out.shape[0] + indices = np.arange(num_frames, dtype=float) + for col in range(out.shape[1]): + column = out[:, col] + finite = np.isfinite(column) + if finite.all(): + continue + if not finite.any(): + raise ValueError( + f"nan_policy='interpolate' cannot fill column {col}: all values are NaN" + ) + out[:, col] = np.interp(indices, indices[finite], column[finite]) + return out + raise ValueError(f"unknown nan_policy {policy!r}") diff --git a/tests/unit/test_analyzer_dtw.py b/tests/unit/test_analyzer_dtw.py index d312e6b..24b14bf 100644 --- a/tests/unit/test_analyzer_dtw.py +++ b/tests/unit/test_analyzer_dtw.py @@ -131,3 +131,250 @@ class TestDtwRelation: b = np.zeros((3, 2, 3)) with pytest.raises(ValueError, match="joint count"): dtw_relation(a, b, joint_i=0, joint_j=1) + + +# --------------------------------------------------------------------------- +# representation="angles" +# --------------------------------------------------------------------------- + + +def _rotation_matrix_z(angle_rad: float) -> np.ndarray: + c, s = np.cos(angle_rad), np.sin(angle_rad) + return np.array( + [ + [c, -s, 0.0], + [s, c, 0.0], + [0.0, 0.0, 1.0], + ] + ) + + +def _three_joint_arm(num_frames: int = 6) -> np.ndarray: + """A three-joint arm opening from a right angle to straight. + + Joints laid out as [shoulder, elbow, wrist], forming an angle at + the elbow that linearly opens from pi/2 to pi across ``num_frames``. + """ + sequence = np.zeros((num_frames, 3, 3)) + angles = np.linspace(np.pi / 2, np.pi, num_frames) + for i, theta in enumerate(angles): + sequence[i, 0] = [-1.0, 0.0, 0.0] # shoulder + sequence[i, 1] = [0.0, 0.0, 0.0] # elbow + sequence[i, 2] = [np.cos(theta - np.pi), np.sin(theta - np.pi), 0.0] # wrist + return sequence + + +class TestDtwAllAngles: + def test_angles_identical_sequences_distance_zero(self) -> None: + seq = _three_joint_arm() + result = dtw_all( + seq, + seq, + representation="angles", + angle_triplets=[(0, 1, 2)], + ) + assert result.distance == pytest.approx(0.0, abs=1e-9) + + def test_angles_invariant_to_global_rotation(self) -> None: + """Angle-space DTW must not change under a global rotation.""" + seq = _three_joint_arm() + rotated = seq @ _rotation_matrix_z(np.deg2rad(40.0)).T + baseline = dtw_all(seq, seq, representation="angles", angle_triplets=[(0, 1, 2)]) + under_rotation = dtw_all( + seq, + rotated, + representation="angles", + angle_triplets=[(0, 1, 2)], + ) + assert baseline.distance == pytest.approx(under_rotation.distance, abs=1e-6) + + def test_angles_translation_invariant(self) -> None: + seq = _three_joint_arm() + translated = seq + np.array([10.0, -5.0, 2.0]) + result = dtw_all( + seq, + translated, + representation="angles", + angle_triplets=[(0, 1, 2)], + ) + assert result.distance == pytest.approx(0.0, abs=1e-9) + + def test_angles_detects_different_motion(self) -> None: + # A sequence whose angle is constant vs. one that opens. + constant = np.zeros((6, 3, 3)) + constant[:, 0] = [-1.0, 0.0, 0.0] + constant[:, 1] = [0.0, 0.0, 0.0] + constant[:, 2] = [0.0, 1.0, 0.0] # right angle throughout + opening = _three_joint_arm() + result = dtw_all( + constant, + opening, + representation="angles", + angle_triplets=[(0, 1, 2)], + ) + assert result.distance > 0.0 + + def test_angles_without_triplets_rejected(self) -> None: + seq = _three_joint_arm() + with pytest.raises(ValueError, match="angle_triplets"): + dtw_all(seq, seq, representation="angles") + + +class TestDtwPerJointAngles: + def test_returns_one_result_per_triplet(self) -> None: + seq = _three_joint_arm() + triplets = [(0, 1, 2), (0, 1, 2)] # duplicate triplet on purpose + results = dtw_per_joint( + seq, + seq, + representation="angles", + angle_triplets=triplets, + ) + assert len(results) == 2 + for result in results: + assert result.distance == pytest.approx(0.0, abs=1e-9) + + def test_per_triplet_distinct_paths(self) -> None: + # Two triplets covering different angles; with different motion + # per triplet, the per-unit results should differ. + seq_a = np.zeros((5, 4, 3)) + seq_b = np.zeros((5, 4, 3)) + # joint 0: pivot, joint 1/2/3: arm endpoints + for i in range(5): + seq_a[i, 0] = [0.0, 0.0, 0.0] + seq_a[i, 1] = [1.0, 0.0, 0.0] + seq_a[i, 2] = [0.0, 1.0, 0.0] + seq_a[i, 3] = [0.0, 0.0, 1.0] + seq_b[i, 0] = [0.0, 0.0, 0.0] + seq_b[i, 1] = [1.0, 0.0, 0.0] + seq_b[i, 2] = [np.cos(i * 0.3), np.sin(i * 0.3), 0.0] # rotating + seq_b[i, 3] = [0.0, 0.0, 1.0] + results = dtw_per_joint( + seq_a, + seq_b, + representation="angles", + angle_triplets=[(1, 0, 2), (1, 0, 3)], + ) + assert len(results) == 2 + # First triplet tracks the rotation, second is stationary. + assert results[0].distance > 0.0 + assert results[1].distance == pytest.approx(0.0, abs=1e-9) + + +# --------------------------------------------------------------------------- +# nan_policy +# --------------------------------------------------------------------------- + + +def _collinear_sequence(num_frames: int = 4) -> np.ndarray: + """Three collinear joints — the angle at the middle joint is degenerate.""" + seq = np.zeros((num_frames, 3, 3)) + seq[:, 0] = [-1.0, 0.0, 0.0] + # Middle joint at (0,0,0); but because the outer joints are collinear + # through the origin, we need one joint overlapping with the middle + # to force a zero-length vector. Place joint 2 AT joint 1 to trigger + # the degenerate case in extract_joint_angles. + seq[:, 1] = [0.0, 0.0, 0.0] + seq[:, 2] = [0.0, 0.0, 0.0] + return seq + + +class TestNanPolicy: + def test_propagate_surfaces_error(self) -> None: + # Degenerate triplet produces NaN angles for every frame. + # With nan_policy="propagate" the NaN reaches fastdtw, which + # validates via numpy.asarray_chkfinite and raises ValueError — + # the intended behaviour ("make the problem visible"). + seq = _collinear_sequence(num_frames=4) + other = _three_joint_arm(num_frames=4) + with pytest.raises(ValueError, match="infs or NaNs"): + dtw_all( + seq, + other, + representation="angles", + angle_triplets=[(0, 1, 2)], + nan_policy="propagate", + ) + + def test_interpolate_fills_isolated_nan(self) -> None: + # One bad frame in a 5-frame sequence — the other four are + # finite anchors to interpolate between. + good = _three_joint_arm(num_frames=5) + # Inject a degenerate middle frame. + good[2, 2] = good[2, 1] # force zero-length vector → NaN angle + # Reference is the same arm without injection. + reference = _three_joint_arm(num_frames=5) + result = dtw_all( + good, + reference, + representation="angles", + angle_triplets=[(0, 1, 2)], + nan_policy="interpolate", + ) + assert not np.isnan(result.distance) + + def test_interpolate_all_nan_column_rejected(self) -> None: + seq = _collinear_sequence(num_frames=5) + other = _three_joint_arm(num_frames=5) + with pytest.raises(ValueError, match="all values are NaN"): + dtw_all( + seq, + other, + representation="angles", + angle_triplets=[(0, 1, 2)], + nan_policy="interpolate", + ) + + def test_drop_removes_nan_frames(self) -> None: + good = _three_joint_arm(num_frames=6) + good[2, 2] = good[2, 1] # inject NaN at frame 2 + good[4, 2] = good[4, 1] # inject NaN at frame 4 + reference = _three_joint_arm(num_frames=6) + result = dtw_all( + good, + reference, + representation="angles", + angle_triplets=[(0, 1, 2)], + nan_policy="drop", + ) + # The 4 remaining finite frames should align cleanly with + # their counterparts in the reference. + assert not np.isnan(result.distance) + + def test_drop_empties_sequence_rejected(self) -> None: + seq = _collinear_sequence(num_frames=5) + other = _three_joint_arm(num_frames=5) + with pytest.raises(ValueError, match="every frame"): + dtw_all( + seq, + other, + representation="angles", + angle_triplets=[(0, 1, 2)], + nan_policy="drop", + ) + + +# --------------------------------------------------------------------------- +# align + representation composition +# --------------------------------------------------------------------------- + + +class TestAlignWithAngles: + def test_procrustes_before_angles_is_no_op_on_invariant_representation(self) -> None: + """Procrustes on angle-space DTW should be redundant but safe.""" + seq = _three_joint_arm() + rotated = seq @ _rotation_matrix_z(np.deg2rad(20.0)).T + with_align = dtw_all( + seq, + rotated, + align="procrustes_per_sequence", + representation="angles", + angle_triplets=[(0, 1, 2)], + ) + without_align = dtw_all( + seq, + rotated, + representation="angles", + angle_triplets=[(0, 1, 2)], + ) + assert with_align.distance == pytest.approx(without_align.distance, abs=1e-6)