add joint-angle representation and nan_policy to DTW

representation: Literal["coords", "angles"] on dtw_all and
dtw_per_joint. "coords" preserves the 0.1 behaviour; "angles" runs
extract_joint_angles on caller-supplied angle_triplets before DTW,
giving translation-, rotation-, and scale-invariant distances that
are directly interpretable as clinical joint-range comparisons. Under
dtw_per_joint the "unit" becomes one angle column per triplet.

nan_policy: Literal["propagate", "interpolate", "drop"] on all three
entry points. "propagate" (default) lets NaN hit fastdtw, which raises
ValueError via numpy.asarray_chkfinite — the safest default because it
surfaces degenerate-vector problems rather than silently corrupting a
distance. "interpolate" runs 1D linear interpolation per feature
column; "drop" removes NaN frames before DTW.

dtw_relation stays a standalone convenience entry point. Paper C's
typical call becomes dtw_all(representation="angles",
align="procrustes_per_sequence"); see TECHNICAL.md Phase 0.
This commit is contained in:
Levi Neuwirth 2026-04-18 18:02:13 -04:00
parent a1c495b2fd
commit 87461a17d0
4 changed files with 479 additions and 35 deletions

View File

@ -239,6 +239,24 @@ be split into per-release sections once tagging begins.
distance is rotation- and translation-invariant. Paper C's
pipeline is expected to set `align="procrustes_per_sequence"`;
see `TECHNICAL.md` Phase 0.
- **`neuropose.analyzer.dtw.Representation`** and
**`neuropose.analyzer.dtw.NanPolicy`** — two new Literal types
exposing orthogonal DTW preprocessing knobs on every entry point.
`representation` (on `dtw_all` and `dtw_per_joint`) switches the
per-frame feature vector between `"coords"` (the 0.1 default) and
`"angles"`, which runs `extract_joint_angles` on the supplied
`angle_triplets` first — yielding distances that are translation-,
rotation-, and scale-invariant by construction, and directly
interpretable in clinical terms. `nan_policy` (on all three entry
points) selects `"propagate"` (surface fastdtw's ValueError on
NaN — the default), `"interpolate"` (linear fill per feature
column), or `"drop"` (remove NaN frames before DTW); the
policy is applied consistently whether NaN originated from the
angles pipeline or from corrupted upstream coordinates.
`dtw_relation` stays a standalone convenience entry point for
two-joint displacement DTW; users who prefer a unified API can
express the same computation via `dtw_all` with an appropriate
pair of angle triplets or run `dtw_relation` directly.
- **`neuropose.analyzer.segment.segment_gait_cycles`** and
**`segment_gait_cycles_bilateral`** — clinical convenience
wrappers over `segment_predictions` that pre-fill a `joint_axis`
@ -313,7 +331,9 @@ be split into per-release sections once tagging begins.
imports for the heavy dependencies:
- `analyzer.dtw` — three DTW entry points (`dtw_all`,
`dtw_per_joint`, `dtw_relation`) over fastdtw, with a frozen
`DTWResult` dataclass. See `RESEARCH.md` for the ongoing
`DTWResult` dataclass and three orthogonal preprocessing knobs
(`align`, `representation`, `nan_policy`). See `RESEARCH.md`
for the ongoing
methodology investigation.
- `analyzer.features``predictions_to_numpy`,
`normalize_pose_sequence` (uniform and axis-wise),

View File

@ -30,6 +30,8 @@ from __future__ import annotations
from neuropose.analyzer.dtw import (
AlignMode,
DTWResult,
NanPolicy,
Representation,
dtw_all,
dtw_per_joint,
dtw_relation,
@ -71,7 +73,9 @@ __all__ = [
"AxisLetter",
"DTWResult",
"FeatureStatistics",
"NanPolicy",
"ProcrustesMode",
"Representation",
"dtw_all",
"dtw_per_joint",
"dtw_relation",

View File

@ -2,10 +2,12 @@
Three entry points, ordered by increasing precision (and increasing cost):
- :func:`dtw_all` DTW on the flattened per-frame joint vector. Fast but
coarse; collapses every joint axis into a single per-frame vector.
- :func:`dtw_per_joint` DTW on each joint independently. Preserves
per-joint temporal alignment at the cost of one DTW call per joint.
- :func:`dtw_all` DTW on the flattened per-frame feature vector. Fast
but coarse; collapses every joint axis (or every angle triplet) into
a single per-frame vector.
- :func:`dtw_per_joint` DTW on each joint (or angle triplet)
independently. Preserves per-unit temporal alignment at the cost of
one DTW call per unit.
- :func:`dtw_relation` DTW on the displacement vector between two
specific joints. This is the right tool when the research question is
about the *relative* motion of a specific pair of joints (e.g. the
@ -16,11 +18,23 @@ and the warping path. Inputs are expected to be ``(frames, joints, 3)``
numpy arrays the shape :func:`~neuropose.analyzer.features.predictions_to_numpy`
produces.
All three also accept an ``align`` argument that routes the inputs
through :func:`~neuropose.analyzer.features.procrustes_align` before
DTW runs, yielding translation- and rotation-invariant distances.
``align="none"`` (the default) preserves the raw-coordinate behaviour
shipped in 0.1.
Three orthogonal preprocessing knobs are available on the entry points:
- **``align``** routes the inputs through
:func:`~neuropose.analyzer.features.procrustes_align` before DTW runs,
yielding translation- and rotation-invariant distances.
``align="none"`` (the default) preserves the raw-coordinate behaviour
shipped in 0.1.
- **``representation``** (on :func:`dtw_all` and :func:`dtw_per_joint`)
selects what each frame is reduced to before DTW. ``"coords"`` uses
the raw joint coordinates; ``"angles"`` replaces them with joint
angles computed at caller-supplied triplets via
:func:`~neuropose.analyzer.features.extract_joint_angles`, giving
DTW distances that are directly interpretable as clinical joint-range
comparisons.
- **``nan_policy``** decides how the DTW path handles non-finite values
in its input typically a concern only for the angle representation,
where degenerate (zero-length) vectors produce NaN. See :data:`NanPolicy`.
Dependency note
---------------
@ -34,13 +48,13 @@ called.
from __future__ import annotations
from collections.abc import Callable
from collections.abc import Callable, Sequence
from dataclasses import dataclass
from typing import Literal
import numpy as np
from neuropose.analyzer.features import procrustes_align
from neuropose.analyzer.features import extract_joint_angles, procrustes_align
AlignMode = Literal["none", "procrustes_per_frame", "procrustes_per_sequence"]
"""Alignment selector for DTW entry points.
@ -51,6 +65,41 @@ AlignMode = Literal["none", "procrustes_per_frame", "procrustes_per_sequence"]
alignment before DTW.
"""
Representation = Literal["coords", "angles"]
"""Per-frame feature representation for :func:`dtw_all` and :func:`dtw_per_joint`.
- ``"coords"`` use the raw joint coordinates (the input's last two
axes). Preserves the 0.1 behaviour.
- ``"angles"`` replace joints with joint angles at caller-supplied
triplets. Translation- and rotation-invariant by construction,
scale-invariant modulo the upstream normalization, and directly
interpretable in clinical terms ("knee flexion during swing phase").
The ``angle_triplets`` keyword becomes mandatory in this mode.
"""
NanPolicy = Literal["propagate", "interpolate", "drop"]
"""Per-feature NaN handling for the DTW input.
NaN typically appears when ``representation="angles"`` encounters a
degenerate (zero-length) vector the angle is undefined and
:func:`extract_joint_angles` propagates NaN rather than quietly returning
a stand-in value.
- ``"propagate"`` (default) pass NaN straight through to the DTW
engine. fastdtw validates its input via
:func:`numpy.asarray_chkfinite` and raises :class:`ValueError`
the moment a NaN appears, which is the safest default because it
makes the problem visible instead of quietly corrupting a
distance.
- ``"interpolate"`` linearly interpolate NaN frames along each
feature column using neighbouring finite values. Reasonable when a
small number of frames are corrupted and the surrounding motion is
smooth; inappropriate when long stretches are missing.
- ``"drop"`` remove any frame where *any* feature is NaN before DTW
runs. Simple, but compresses the time axis, so warping-path indices
refer to the *compacted* sequence rather than the original.
"""
@dataclass(frozen=True)
class DTWResult:
@ -100,13 +149,18 @@ def dtw_all(
b: np.ndarray,
*,
align: AlignMode = "none",
representation: Representation = "coords",
angle_triplets: Sequence[tuple[int, int, int]] | None = None,
nan_policy: NanPolicy = "propagate",
) -> DTWResult:
"""DTW on the flattened per-frame joint vector.
"""DTW on the flattened per-frame feature vector.
Each frame's joints are collapsed into a single vector before DTW
is applied. This is fast one DTW call regardless of the joint
count but loses per-joint temporal structure, so a small
Under the default ``representation="coords"`` each frame's joints
are collapsed into a single vector before DTW is applied fast
(one DTW call regardless of joint count) but coarse, since a small
timing mismatch on one joint can dominate the distance metric.
Switching to ``representation="angles"`` computes joint angles at
the supplied triplets first and flattens those instead.
Parameters
----------
@ -119,6 +173,15 @@ def dtw_all(
align
Procrustes alignment mode applied before DTW. See
:data:`AlignMode`.
representation
Per-frame feature representation. See :data:`Representation`.
angle_triplets
Required when ``representation="angles"``. Sequence of
``(a, b, c)`` joint-index triplets passed through to
:func:`~neuropose.analyzer.features.extract_joint_angles`.
Ignored otherwise.
nan_policy
How to handle NaN values in the DTW input. See :data:`NanPolicy`.
Returns
-------
@ -129,15 +192,20 @@ def dtw_all(
Raises
------
ValueError
If ``a`` and ``b`` do not have the same joint count, or if
``align`` requires a matching frame count that is not present.
If ``a`` and ``b`` do not have the same joint count, if
``align`` requires a matching frame count that is not present,
if ``representation="angles"`` is requested without
``angle_triplets``, or if ``nan_policy="interpolate"``
encounters an all-NaN column.
"""
_validate_same_joint_count(a, b)
a, b = _maybe_align(a, b, align=align)
feat_a = _apply_representation(a, representation, angle_triplets=angle_triplets)
feat_b = _apply_representation(b, representation, angle_triplets=angle_triplets)
feat_a = _apply_nan_policy(feat_a, nan_policy)
feat_b = _apply_nan_policy(feat_b, nan_policy)
fastdtw, euclidean = _require_fastdtw()
a_flat = a.reshape(a.shape[0], -1)
b_flat = b.reshape(b.shape[0], -1)
distance, path = fastdtw(a_flat, b_flat, dist=euclidean)
distance, path = fastdtw(feat_a, feat_b, dist=euclidean)
return DTWResult(distance=float(distance), path=[tuple(p) for p in path])
@ -146,13 +214,21 @@ def dtw_per_joint(
b: np.ndarray,
*,
align: AlignMode = "none",
representation: Representation = "coords",
angle_triplets: Sequence[tuple[int, int, int]] | None = None,
nan_policy: NanPolicy = "propagate",
) -> list[DTWResult]:
"""DTW on each joint independently.
"""DTW on each joint (or angle triplet) independently.
Performs one DTW computation per joint, yielding a list of
:class:`DTWResult` objects in joint-index order. More precise than
:func:`dtw_all` because each joint's temporal alignment is optimised
separately, at the cost of J times more DTW calls for J joints.
Performs one DTW computation per unit, yielding a list of
:class:`DTWResult` objects in input order. More precise than
:func:`dtw_all` because each unit's temporal alignment is optimised
separately, at the cost of J times more DTW calls for J units.
Under the default ``representation="coords"`` a "unit" is one of
the input's joints (xyz treated jointly). Under
``representation="angles"`` a "unit" is one scalar angle column
computed from one ``angle_triplets`` entry.
Parameters
----------
@ -164,26 +240,54 @@ def dtw_per_joint(
align
Procrustes alignment mode applied before DTW. See
:data:`AlignMode`.
representation
Per-frame feature representation. See :data:`Representation`.
angle_triplets
Required when ``representation="angles"``; see
:func:`dtw_all` for details.
nan_policy
How to handle NaN values in the DTW input. See :data:`NanPolicy`.
Returns
-------
list[DTWResult]
One DTW result per joint, in index order.
One DTW result per joint or per angle triplet, in input order.
Raises
------
ValueError
If ``a`` and ``b`` do not have the same joint count, or if
``align`` requires a matching frame count that is not present.
Same conditions as :func:`dtw_all`.
"""
_validate_same_joint_count(a, b)
a, b = _maybe_align(a, b, align=align)
if representation == "coords":
feat_a = a
feat_b = b
# (frames, joints, 3) — one DTW per joint over its (frames, 3) slice.
num_units = feat_a.shape[1]
slicers: list[Callable[[np.ndarray], np.ndarray]] = [
(lambda arr, idx=i: arr[:, idx, :]) for i in range(num_units)
]
else: # "angles"
if angle_triplets is None:
raise ValueError("representation='angles' requires angle_triplets")
feat_a = extract_joint_angles(a, angle_triplets) # (frames, num_triplets)
feat_b = extract_joint_angles(b, angle_triplets)
num_units = feat_a.shape[1]
slicers = [
# Scalar columns become 2D for DTW (fastdtw expects a
# sequence of vectors, not a sequence of scalars).
(lambda arr, idx=i: arr[:, idx : idx + 1])
for i in range(num_units)
]
fastdtw, euclidean = _require_fastdtw()
results: list[DTWResult] = []
for joint_idx in range(a.shape[1]):
a_joint = a[:, joint_idx, :]
b_joint = b[:, joint_idx, :]
distance, path = fastdtw(a_joint, b_joint, dist=euclidean)
for slicer in slicers:
unit_a = _apply_nan_policy(slicer(feat_a), nan_policy)
unit_b = _apply_nan_policy(slicer(feat_b), nan_policy)
distance, path = fastdtw(unit_a, unit_b, dist=euclidean)
results.append(DTWResult(distance=float(distance), path=[tuple(p) for p in path]))
return results
@ -195,6 +299,7 @@ def dtw_relation(
joint_j: int,
*,
align: AlignMode = "none",
nan_policy: NanPolicy = "propagate",
) -> DTWResult:
"""DTW on the displacement vector between two specific joints.
@ -219,6 +324,8 @@ def dtw_relation(
:data:`AlignMode`. Note that displacement vectors are already
translation-invariant; alignment is still useful for cancelling
camera rotation between trials.
nan_policy
How to handle NaN values in the DTW input. See :data:`NanPolicy`.
Returns
-------
@ -239,9 +346,9 @@ def dtw_relation(
f"joint indices must be in [0, {num_joints}); got joint_i={joint_i}, joint_j={joint_j}"
)
a, b = _maybe_align(a, b, align=align)
disp_a = _apply_nan_policy(a[:, joint_j, :] - a[:, joint_i, :], nan_policy)
disp_b = _apply_nan_policy(b[:, joint_j, :] - b[:, joint_i, :], nan_policy)
fastdtw, euclidean = _require_fastdtw()
disp_a = a[:, joint_j, :] - a[:, joint_i, :]
disp_b = b[:, joint_j, :] - b[:, joint_i, :]
distance, path = fastdtw(disp_a, disp_b, dist=euclidean)
return DTWResult(distance=float(distance), path=[tuple(p) for p in path])
@ -283,3 +390,69 @@ def _maybe_align(
mode = "per_frame" if align == "procrustes_per_frame" else "per_sequence"
aligned_a, _target, _diag = procrustes_align(a, b, mode=mode)
return aligned_a, b
def _apply_representation(
sequence: np.ndarray,
representation: Representation,
*,
angle_triplets: Sequence[tuple[int, int, int]] | None,
) -> np.ndarray:
"""Reduce a ``(frames, joints, 3)`` sequence to DTW-ready 2D features.
``"coords"`` reshapes to ``(frames, joints * 3)``; ``"angles"``
runs :func:`extract_joint_angles` to produce
``(frames, len(angle_triplets))``.
"""
if representation == "coords":
return sequence.reshape(sequence.shape[0], -1)
if representation == "angles":
if angle_triplets is None:
raise ValueError("representation='angles' requires angle_triplets")
return extract_joint_angles(sequence, angle_triplets)
raise ValueError(f"unknown representation {representation!r}")
def _apply_nan_policy(features: np.ndarray, policy: NanPolicy) -> np.ndarray:
"""Handle NaN values in a ``(frames, features)`` array per ``policy``.
``"propagate"`` is a no-op. ``"interpolate"`` runs 1D linear
interpolation along the frame axis within each feature column,
leaving finite data untouched. ``"drop"`` removes any frame where
*any* feature is NaN.
Raises
------
ValueError
If ``"interpolate"`` encounters a column that is entirely NaN
(no finite anchors to interpolate between), or if ``"drop"``
leaves an empty sequence.
"""
if policy == "propagate":
return features
if features.ndim == 1:
features = features.reshape(-1, 1)
if policy == "drop":
keep = np.isfinite(features).all(axis=1)
dropped = features[keep]
if dropped.shape[0] == 0:
raise ValueError(
"nan_policy='drop' removed every frame; DTW needs a non-empty sequence"
)
return dropped
if policy == "interpolate":
out = features.astype(float, copy=True)
num_frames = out.shape[0]
indices = np.arange(num_frames, dtype=float)
for col in range(out.shape[1]):
column = out[:, col]
finite = np.isfinite(column)
if finite.all():
continue
if not finite.any():
raise ValueError(
f"nan_policy='interpolate' cannot fill column {col}: all values are NaN"
)
out[:, col] = np.interp(indices, indices[finite], column[finite])
return out
raise ValueError(f"unknown nan_policy {policy!r}")

View File

@ -131,3 +131,250 @@ class TestDtwRelation:
b = np.zeros((3, 2, 3))
with pytest.raises(ValueError, match="joint count"):
dtw_relation(a, b, joint_i=0, joint_j=1)
# ---------------------------------------------------------------------------
# representation="angles"
# ---------------------------------------------------------------------------
def _rotation_matrix_z(angle_rad: float) -> np.ndarray:
c, s = np.cos(angle_rad), np.sin(angle_rad)
return np.array(
[
[c, -s, 0.0],
[s, c, 0.0],
[0.0, 0.0, 1.0],
]
)
def _three_joint_arm(num_frames: int = 6) -> np.ndarray:
"""A three-joint arm opening from a right angle to straight.
Joints laid out as [shoulder, elbow, wrist], forming an angle at
the elbow that linearly opens from pi/2 to pi across ``num_frames``.
"""
sequence = np.zeros((num_frames, 3, 3))
angles = np.linspace(np.pi / 2, np.pi, num_frames)
for i, theta in enumerate(angles):
sequence[i, 0] = [-1.0, 0.0, 0.0] # shoulder
sequence[i, 1] = [0.0, 0.0, 0.0] # elbow
sequence[i, 2] = [np.cos(theta - np.pi), np.sin(theta - np.pi), 0.0] # wrist
return sequence
class TestDtwAllAngles:
def test_angles_identical_sequences_distance_zero(self) -> None:
seq = _three_joint_arm()
result = dtw_all(
seq,
seq,
representation="angles",
angle_triplets=[(0, 1, 2)],
)
assert result.distance == pytest.approx(0.0, abs=1e-9)
def test_angles_invariant_to_global_rotation(self) -> None:
"""Angle-space DTW must not change under a global rotation."""
seq = _three_joint_arm()
rotated = seq @ _rotation_matrix_z(np.deg2rad(40.0)).T
baseline = dtw_all(seq, seq, representation="angles", angle_triplets=[(0, 1, 2)])
under_rotation = dtw_all(
seq,
rotated,
representation="angles",
angle_triplets=[(0, 1, 2)],
)
assert baseline.distance == pytest.approx(under_rotation.distance, abs=1e-6)
def test_angles_translation_invariant(self) -> None:
seq = _three_joint_arm()
translated = seq + np.array([10.0, -5.0, 2.0])
result = dtw_all(
seq,
translated,
representation="angles",
angle_triplets=[(0, 1, 2)],
)
assert result.distance == pytest.approx(0.0, abs=1e-9)
def test_angles_detects_different_motion(self) -> None:
# A sequence whose angle is constant vs. one that opens.
constant = np.zeros((6, 3, 3))
constant[:, 0] = [-1.0, 0.0, 0.0]
constant[:, 1] = [0.0, 0.0, 0.0]
constant[:, 2] = [0.0, 1.0, 0.0] # right angle throughout
opening = _three_joint_arm()
result = dtw_all(
constant,
opening,
representation="angles",
angle_triplets=[(0, 1, 2)],
)
assert result.distance > 0.0
def test_angles_without_triplets_rejected(self) -> None:
seq = _three_joint_arm()
with pytest.raises(ValueError, match="angle_triplets"):
dtw_all(seq, seq, representation="angles")
class TestDtwPerJointAngles:
def test_returns_one_result_per_triplet(self) -> None:
seq = _three_joint_arm()
triplets = [(0, 1, 2), (0, 1, 2)] # duplicate triplet on purpose
results = dtw_per_joint(
seq,
seq,
representation="angles",
angle_triplets=triplets,
)
assert len(results) == 2
for result in results:
assert result.distance == pytest.approx(0.0, abs=1e-9)
def test_per_triplet_distinct_paths(self) -> None:
# Two triplets covering different angles; with different motion
# per triplet, the per-unit results should differ.
seq_a = np.zeros((5, 4, 3))
seq_b = np.zeros((5, 4, 3))
# joint 0: pivot, joint 1/2/3: arm endpoints
for i in range(5):
seq_a[i, 0] = [0.0, 0.0, 0.0]
seq_a[i, 1] = [1.0, 0.0, 0.0]
seq_a[i, 2] = [0.0, 1.0, 0.0]
seq_a[i, 3] = [0.0, 0.0, 1.0]
seq_b[i, 0] = [0.0, 0.0, 0.0]
seq_b[i, 1] = [1.0, 0.0, 0.0]
seq_b[i, 2] = [np.cos(i * 0.3), np.sin(i * 0.3), 0.0] # rotating
seq_b[i, 3] = [0.0, 0.0, 1.0]
results = dtw_per_joint(
seq_a,
seq_b,
representation="angles",
angle_triplets=[(1, 0, 2), (1, 0, 3)],
)
assert len(results) == 2
# First triplet tracks the rotation, second is stationary.
assert results[0].distance > 0.0
assert results[1].distance == pytest.approx(0.0, abs=1e-9)
# ---------------------------------------------------------------------------
# nan_policy
# ---------------------------------------------------------------------------
def _collinear_sequence(num_frames: int = 4) -> np.ndarray:
"""Three collinear joints — the angle at the middle joint is degenerate."""
seq = np.zeros((num_frames, 3, 3))
seq[:, 0] = [-1.0, 0.0, 0.0]
# Middle joint at (0,0,0); but because the outer joints are collinear
# through the origin, we need one joint overlapping with the middle
# to force a zero-length vector. Place joint 2 AT joint 1 to trigger
# the degenerate case in extract_joint_angles.
seq[:, 1] = [0.0, 0.0, 0.0]
seq[:, 2] = [0.0, 0.0, 0.0]
return seq
class TestNanPolicy:
def test_propagate_surfaces_error(self) -> None:
# Degenerate triplet produces NaN angles for every frame.
# With nan_policy="propagate" the NaN reaches fastdtw, which
# validates via numpy.asarray_chkfinite and raises ValueError —
# the intended behaviour ("make the problem visible").
seq = _collinear_sequence(num_frames=4)
other = _three_joint_arm(num_frames=4)
with pytest.raises(ValueError, match="infs or NaNs"):
dtw_all(
seq,
other,
representation="angles",
angle_triplets=[(0, 1, 2)],
nan_policy="propagate",
)
def test_interpolate_fills_isolated_nan(self) -> None:
# One bad frame in a 5-frame sequence — the other four are
# finite anchors to interpolate between.
good = _three_joint_arm(num_frames=5)
# Inject a degenerate middle frame.
good[2, 2] = good[2, 1] # force zero-length vector → NaN angle
# Reference is the same arm without injection.
reference = _three_joint_arm(num_frames=5)
result = dtw_all(
good,
reference,
representation="angles",
angle_triplets=[(0, 1, 2)],
nan_policy="interpolate",
)
assert not np.isnan(result.distance)
def test_interpolate_all_nan_column_rejected(self) -> None:
seq = _collinear_sequence(num_frames=5)
other = _three_joint_arm(num_frames=5)
with pytest.raises(ValueError, match="all values are NaN"):
dtw_all(
seq,
other,
representation="angles",
angle_triplets=[(0, 1, 2)],
nan_policy="interpolate",
)
def test_drop_removes_nan_frames(self) -> None:
good = _three_joint_arm(num_frames=6)
good[2, 2] = good[2, 1] # inject NaN at frame 2
good[4, 2] = good[4, 1] # inject NaN at frame 4
reference = _three_joint_arm(num_frames=6)
result = dtw_all(
good,
reference,
representation="angles",
angle_triplets=[(0, 1, 2)],
nan_policy="drop",
)
# The 4 remaining finite frames should align cleanly with
# their counterparts in the reference.
assert not np.isnan(result.distance)
def test_drop_empties_sequence_rejected(self) -> None:
seq = _collinear_sequence(num_frames=5)
other = _three_joint_arm(num_frames=5)
with pytest.raises(ValueError, match="every frame"):
dtw_all(
seq,
other,
representation="angles",
angle_triplets=[(0, 1, 2)],
nan_policy="drop",
)
# ---------------------------------------------------------------------------
# align + representation composition
# ---------------------------------------------------------------------------
class TestAlignWithAngles:
def test_procrustes_before_angles_is_no_op_on_invariant_representation(self) -> None:
"""Procrustes on angle-space DTW should be redundant but safe."""
seq = _three_joint_arm()
rotated = seq @ _rotation_matrix_z(np.deg2rad(20.0)).T
with_align = dtw_all(
seq,
rotated,
align="procrustes_per_sequence",
representation="angles",
angle_triplets=[(0, 1, 2)],
)
without_align = dtw_all(
seq,
rotated,
representation="angles",
angle_triplets=[(0, 1, 2)],
)
assert with_align.distance == pytest.approx(without_align.distance, abs=1e-6)