add joint-angle representation and nan_policy to DTW

representation: Literal["coords", "angles"] on dtw_all and
dtw_per_joint. "coords" preserves the 0.1 behaviour; "angles" runs
extract_joint_angles on caller-supplied angle_triplets before DTW,
giving translation-, rotation-, and scale-invariant distances that
are directly interpretable as clinical joint-range comparisons. Under
dtw_per_joint the "unit" becomes one angle column per triplet.

nan_policy: Literal["propagate", "interpolate", "drop"] on all three
entry points. "propagate" (default) lets NaN hit fastdtw, which raises
ValueError via numpy.asarray_chkfinite — the safest default because it
surfaces degenerate-vector problems rather than silently corrupting a
distance. "interpolate" runs 1D linear interpolation per feature
column; "drop" removes NaN frames before DTW.

dtw_relation stays a standalone convenience entry point. Paper C's
typical call becomes dtw_all(representation="angles",
align="procrustes_per_sequence"); see TECHNICAL.md Phase 0.
This commit is contained in:
Levi Neuwirth 2026-04-18 18:02:13 -04:00
parent a1c495b2fd
commit 87461a17d0
4 changed files with 479 additions and 35 deletions

View File

@ -239,6 +239,24 @@ be split into per-release sections once tagging begins.
distance is rotation- and translation-invariant. Paper C's distance is rotation- and translation-invariant. Paper C's
pipeline is expected to set `align="procrustes_per_sequence"`; pipeline is expected to set `align="procrustes_per_sequence"`;
see `TECHNICAL.md` Phase 0. see `TECHNICAL.md` Phase 0.
- **`neuropose.analyzer.dtw.Representation`** and
**`neuropose.analyzer.dtw.NanPolicy`** — two new Literal types
exposing orthogonal DTW preprocessing knobs on every entry point.
`representation` (on `dtw_all` and `dtw_per_joint`) switches the
per-frame feature vector between `"coords"` (the 0.1 default) and
`"angles"`, which runs `extract_joint_angles` on the supplied
`angle_triplets` first — yielding distances that are translation-,
rotation-, and scale-invariant by construction, and directly
interpretable in clinical terms. `nan_policy` (on all three entry
points) selects `"propagate"` (surface fastdtw's ValueError on
NaN — the default), `"interpolate"` (linear fill per feature
column), or `"drop"` (remove NaN frames before DTW); the
policy is applied consistently whether NaN originated from the
angles pipeline or from corrupted upstream coordinates.
`dtw_relation` stays a standalone convenience entry point for
two-joint displacement DTW; users who prefer a unified API can
express the same computation via `dtw_all` with an appropriate
pair of angle triplets or run `dtw_relation` directly.
- **`neuropose.analyzer.segment.segment_gait_cycles`** and - **`neuropose.analyzer.segment.segment_gait_cycles`** and
**`segment_gait_cycles_bilateral`** — clinical convenience **`segment_gait_cycles_bilateral`** — clinical convenience
wrappers over `segment_predictions` that pre-fill a `joint_axis` wrappers over `segment_predictions` that pre-fill a `joint_axis`
@ -313,7 +331,9 @@ be split into per-release sections once tagging begins.
imports for the heavy dependencies: imports for the heavy dependencies:
- `analyzer.dtw` — three DTW entry points (`dtw_all`, - `analyzer.dtw` — three DTW entry points (`dtw_all`,
`dtw_per_joint`, `dtw_relation`) over fastdtw, with a frozen `dtw_per_joint`, `dtw_relation`) over fastdtw, with a frozen
`DTWResult` dataclass. See `RESEARCH.md` for the ongoing `DTWResult` dataclass and three orthogonal preprocessing knobs
(`align`, `representation`, `nan_policy`). See `RESEARCH.md`
for the ongoing
methodology investigation. methodology investigation.
- `analyzer.features``predictions_to_numpy`, - `analyzer.features``predictions_to_numpy`,
`normalize_pose_sequence` (uniform and axis-wise), `normalize_pose_sequence` (uniform and axis-wise),

View File

@ -30,6 +30,8 @@ from __future__ import annotations
from neuropose.analyzer.dtw import ( from neuropose.analyzer.dtw import (
AlignMode, AlignMode,
DTWResult, DTWResult,
NanPolicy,
Representation,
dtw_all, dtw_all,
dtw_per_joint, dtw_per_joint,
dtw_relation, dtw_relation,
@ -71,7 +73,9 @@ __all__ = [
"AxisLetter", "AxisLetter",
"DTWResult", "DTWResult",
"FeatureStatistics", "FeatureStatistics",
"NanPolicy",
"ProcrustesMode", "ProcrustesMode",
"Representation",
"dtw_all", "dtw_all",
"dtw_per_joint", "dtw_per_joint",
"dtw_relation", "dtw_relation",

View File

@ -2,10 +2,12 @@
Three entry points, ordered by increasing precision (and increasing cost): Three entry points, ordered by increasing precision (and increasing cost):
- :func:`dtw_all` DTW on the flattened per-frame joint vector. Fast but - :func:`dtw_all` DTW on the flattened per-frame feature vector. Fast
coarse; collapses every joint axis into a single per-frame vector. but coarse; collapses every joint axis (or every angle triplet) into
- :func:`dtw_per_joint` DTW on each joint independently. Preserves a single per-frame vector.
per-joint temporal alignment at the cost of one DTW call per joint. - :func:`dtw_per_joint` DTW on each joint (or angle triplet)
independently. Preserves per-unit temporal alignment at the cost of
one DTW call per unit.
- :func:`dtw_relation` DTW on the displacement vector between two - :func:`dtw_relation` DTW on the displacement vector between two
specific joints. This is the right tool when the research question is specific joints. This is the right tool when the research question is
about the *relative* motion of a specific pair of joints (e.g. the about the *relative* motion of a specific pair of joints (e.g. the
@ -16,11 +18,23 @@ and the warping path. Inputs are expected to be ``(frames, joints, 3)``
numpy arrays the shape :func:`~neuropose.analyzer.features.predictions_to_numpy` numpy arrays the shape :func:`~neuropose.analyzer.features.predictions_to_numpy`
produces. produces.
All three also accept an ``align`` argument that routes the inputs Three orthogonal preprocessing knobs are available on the entry points:
through :func:`~neuropose.analyzer.features.procrustes_align` before
DTW runs, yielding translation- and rotation-invariant distances. - **``align``** routes the inputs through
``align="none"`` (the default) preserves the raw-coordinate behaviour :func:`~neuropose.analyzer.features.procrustes_align` before DTW runs,
shipped in 0.1. yielding translation- and rotation-invariant distances.
``align="none"`` (the default) preserves the raw-coordinate behaviour
shipped in 0.1.
- **``representation``** (on :func:`dtw_all` and :func:`dtw_per_joint`)
selects what each frame is reduced to before DTW. ``"coords"`` uses
the raw joint coordinates; ``"angles"`` replaces them with joint
angles computed at caller-supplied triplets via
:func:`~neuropose.analyzer.features.extract_joint_angles`, giving
DTW distances that are directly interpretable as clinical joint-range
comparisons.
- **``nan_policy``** decides how the DTW path handles non-finite values
in its input typically a concern only for the angle representation,
where degenerate (zero-length) vectors produce NaN. See :data:`NanPolicy`.
Dependency note Dependency note
--------------- ---------------
@ -34,13 +48,13 @@ called.
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable, Sequence
from dataclasses import dataclass from dataclasses import dataclass
from typing import Literal from typing import Literal
import numpy as np import numpy as np
from neuropose.analyzer.features import procrustes_align from neuropose.analyzer.features import extract_joint_angles, procrustes_align
AlignMode = Literal["none", "procrustes_per_frame", "procrustes_per_sequence"] AlignMode = Literal["none", "procrustes_per_frame", "procrustes_per_sequence"]
"""Alignment selector for DTW entry points. """Alignment selector for DTW entry points.
@ -51,6 +65,41 @@ AlignMode = Literal["none", "procrustes_per_frame", "procrustes_per_sequence"]
alignment before DTW. alignment before DTW.
""" """
Representation = Literal["coords", "angles"]
"""Per-frame feature representation for :func:`dtw_all` and :func:`dtw_per_joint`.
- ``"coords"`` use the raw joint coordinates (the input's last two
axes). Preserves the 0.1 behaviour.
- ``"angles"`` replace joints with joint angles at caller-supplied
triplets. Translation- and rotation-invariant by construction,
scale-invariant modulo the upstream normalization, and directly
interpretable in clinical terms ("knee flexion during swing phase").
The ``angle_triplets`` keyword becomes mandatory in this mode.
"""
NanPolicy = Literal["propagate", "interpolate", "drop"]
"""Per-feature NaN handling for the DTW input.
NaN typically appears when ``representation="angles"`` encounters a
degenerate (zero-length) vector the angle is undefined and
:func:`extract_joint_angles` propagates NaN rather than quietly returning
a stand-in value.
- ``"propagate"`` (default) pass NaN straight through to the DTW
engine. fastdtw validates its input via
:func:`numpy.asarray_chkfinite` and raises :class:`ValueError`
the moment a NaN appears, which is the safest default because it
makes the problem visible instead of quietly corrupting a
distance.
- ``"interpolate"`` linearly interpolate NaN frames along each
feature column using neighbouring finite values. Reasonable when a
small number of frames are corrupted and the surrounding motion is
smooth; inappropriate when long stretches are missing.
- ``"drop"`` remove any frame where *any* feature is NaN before DTW
runs. Simple, but compresses the time axis, so warping-path indices
refer to the *compacted* sequence rather than the original.
"""
@dataclass(frozen=True) @dataclass(frozen=True)
class DTWResult: class DTWResult:
@ -100,13 +149,18 @@ def dtw_all(
b: np.ndarray, b: np.ndarray,
*, *,
align: AlignMode = "none", align: AlignMode = "none",
representation: Representation = "coords",
angle_triplets: Sequence[tuple[int, int, int]] | None = None,
nan_policy: NanPolicy = "propagate",
) -> DTWResult: ) -> DTWResult:
"""DTW on the flattened per-frame joint vector. """DTW on the flattened per-frame feature vector.
Each frame's joints are collapsed into a single vector before DTW Under the default ``representation="coords"`` each frame's joints
is applied. This is fast one DTW call regardless of the joint are collapsed into a single vector before DTW is applied fast
count but loses per-joint temporal structure, so a small (one DTW call regardless of joint count) but coarse, since a small
timing mismatch on one joint can dominate the distance metric. timing mismatch on one joint can dominate the distance metric.
Switching to ``representation="angles"`` computes joint angles at
the supplied triplets first and flattens those instead.
Parameters Parameters
---------- ----------
@ -119,6 +173,15 @@ def dtw_all(
align align
Procrustes alignment mode applied before DTW. See Procrustes alignment mode applied before DTW. See
:data:`AlignMode`. :data:`AlignMode`.
representation
Per-frame feature representation. See :data:`Representation`.
angle_triplets
Required when ``representation="angles"``. Sequence of
``(a, b, c)`` joint-index triplets passed through to
:func:`~neuropose.analyzer.features.extract_joint_angles`.
Ignored otherwise.
nan_policy
How to handle NaN values in the DTW input. See :data:`NanPolicy`.
Returns Returns
------- -------
@ -129,15 +192,20 @@ def dtw_all(
Raises Raises
------ ------
ValueError ValueError
If ``a`` and ``b`` do not have the same joint count, or if If ``a`` and ``b`` do not have the same joint count, if
``align`` requires a matching frame count that is not present. ``align`` requires a matching frame count that is not present,
if ``representation="angles"`` is requested without
``angle_triplets``, or if ``nan_policy="interpolate"``
encounters an all-NaN column.
""" """
_validate_same_joint_count(a, b) _validate_same_joint_count(a, b)
a, b = _maybe_align(a, b, align=align) a, b = _maybe_align(a, b, align=align)
feat_a = _apply_representation(a, representation, angle_triplets=angle_triplets)
feat_b = _apply_representation(b, representation, angle_triplets=angle_triplets)
feat_a = _apply_nan_policy(feat_a, nan_policy)
feat_b = _apply_nan_policy(feat_b, nan_policy)
fastdtw, euclidean = _require_fastdtw() fastdtw, euclidean = _require_fastdtw()
a_flat = a.reshape(a.shape[0], -1) distance, path = fastdtw(feat_a, feat_b, dist=euclidean)
b_flat = b.reshape(b.shape[0], -1)
distance, path = fastdtw(a_flat, b_flat, dist=euclidean)
return DTWResult(distance=float(distance), path=[tuple(p) for p in path]) return DTWResult(distance=float(distance), path=[tuple(p) for p in path])
@ -146,13 +214,21 @@ def dtw_per_joint(
b: np.ndarray, b: np.ndarray,
*, *,
align: AlignMode = "none", align: AlignMode = "none",
representation: Representation = "coords",
angle_triplets: Sequence[tuple[int, int, int]] | None = None,
nan_policy: NanPolicy = "propagate",
) -> list[DTWResult]: ) -> list[DTWResult]:
"""DTW on each joint independently. """DTW on each joint (or angle triplet) independently.
Performs one DTW computation per joint, yielding a list of Performs one DTW computation per unit, yielding a list of
:class:`DTWResult` objects in joint-index order. More precise than :class:`DTWResult` objects in input order. More precise than
:func:`dtw_all` because each joint's temporal alignment is optimised :func:`dtw_all` because each unit's temporal alignment is optimised
separately, at the cost of J times more DTW calls for J joints. separately, at the cost of J times more DTW calls for J units.
Under the default ``representation="coords"`` a "unit" is one of
the input's joints (xyz treated jointly). Under
``representation="angles"`` a "unit" is one scalar angle column
computed from one ``angle_triplets`` entry.
Parameters Parameters
---------- ----------
@ -164,26 +240,54 @@ def dtw_per_joint(
align align
Procrustes alignment mode applied before DTW. See Procrustes alignment mode applied before DTW. See
:data:`AlignMode`. :data:`AlignMode`.
representation
Per-frame feature representation. See :data:`Representation`.
angle_triplets
Required when ``representation="angles"``; see
:func:`dtw_all` for details.
nan_policy
How to handle NaN values in the DTW input. See :data:`NanPolicy`.
Returns Returns
------- -------
list[DTWResult] list[DTWResult]
One DTW result per joint, in index order. One DTW result per joint or per angle triplet, in input order.
Raises Raises
------ ------
ValueError ValueError
If ``a`` and ``b`` do not have the same joint count, or if Same conditions as :func:`dtw_all`.
``align`` requires a matching frame count that is not present.
""" """
_validate_same_joint_count(a, b) _validate_same_joint_count(a, b)
a, b = _maybe_align(a, b, align=align) a, b = _maybe_align(a, b, align=align)
if representation == "coords":
feat_a = a
feat_b = b
# (frames, joints, 3) — one DTW per joint over its (frames, 3) slice.
num_units = feat_a.shape[1]
slicers: list[Callable[[np.ndarray], np.ndarray]] = [
(lambda arr, idx=i: arr[:, idx, :]) for i in range(num_units)
]
else: # "angles"
if angle_triplets is None:
raise ValueError("representation='angles' requires angle_triplets")
feat_a = extract_joint_angles(a, angle_triplets) # (frames, num_triplets)
feat_b = extract_joint_angles(b, angle_triplets)
num_units = feat_a.shape[1]
slicers = [
# Scalar columns become 2D for DTW (fastdtw expects a
# sequence of vectors, not a sequence of scalars).
(lambda arr, idx=i: arr[:, idx : idx + 1])
for i in range(num_units)
]
fastdtw, euclidean = _require_fastdtw() fastdtw, euclidean = _require_fastdtw()
results: list[DTWResult] = [] results: list[DTWResult] = []
for joint_idx in range(a.shape[1]): for slicer in slicers:
a_joint = a[:, joint_idx, :] unit_a = _apply_nan_policy(slicer(feat_a), nan_policy)
b_joint = b[:, joint_idx, :] unit_b = _apply_nan_policy(slicer(feat_b), nan_policy)
distance, path = fastdtw(a_joint, b_joint, dist=euclidean) distance, path = fastdtw(unit_a, unit_b, dist=euclidean)
results.append(DTWResult(distance=float(distance), path=[tuple(p) for p in path])) results.append(DTWResult(distance=float(distance), path=[tuple(p) for p in path]))
return results return results
@ -195,6 +299,7 @@ def dtw_relation(
joint_j: int, joint_j: int,
*, *,
align: AlignMode = "none", align: AlignMode = "none",
nan_policy: NanPolicy = "propagate",
) -> DTWResult: ) -> DTWResult:
"""DTW on the displacement vector between two specific joints. """DTW on the displacement vector between two specific joints.
@ -219,6 +324,8 @@ def dtw_relation(
:data:`AlignMode`. Note that displacement vectors are already :data:`AlignMode`. Note that displacement vectors are already
translation-invariant; alignment is still useful for cancelling translation-invariant; alignment is still useful for cancelling
camera rotation between trials. camera rotation between trials.
nan_policy
How to handle NaN values in the DTW input. See :data:`NanPolicy`.
Returns Returns
------- -------
@ -239,9 +346,9 @@ def dtw_relation(
f"joint indices must be in [0, {num_joints}); got joint_i={joint_i}, joint_j={joint_j}" f"joint indices must be in [0, {num_joints}); got joint_i={joint_i}, joint_j={joint_j}"
) )
a, b = _maybe_align(a, b, align=align) a, b = _maybe_align(a, b, align=align)
disp_a = _apply_nan_policy(a[:, joint_j, :] - a[:, joint_i, :], nan_policy)
disp_b = _apply_nan_policy(b[:, joint_j, :] - b[:, joint_i, :], nan_policy)
fastdtw, euclidean = _require_fastdtw() fastdtw, euclidean = _require_fastdtw()
disp_a = a[:, joint_j, :] - a[:, joint_i, :]
disp_b = b[:, joint_j, :] - b[:, joint_i, :]
distance, path = fastdtw(disp_a, disp_b, dist=euclidean) distance, path = fastdtw(disp_a, disp_b, dist=euclidean)
return DTWResult(distance=float(distance), path=[tuple(p) for p in path]) return DTWResult(distance=float(distance), path=[tuple(p) for p in path])
@ -283,3 +390,69 @@ def _maybe_align(
mode = "per_frame" if align == "procrustes_per_frame" else "per_sequence" mode = "per_frame" if align == "procrustes_per_frame" else "per_sequence"
aligned_a, _target, _diag = procrustes_align(a, b, mode=mode) aligned_a, _target, _diag = procrustes_align(a, b, mode=mode)
return aligned_a, b return aligned_a, b
def _apply_representation(
sequence: np.ndarray,
representation: Representation,
*,
angle_triplets: Sequence[tuple[int, int, int]] | None,
) -> np.ndarray:
"""Reduce a ``(frames, joints, 3)`` sequence to DTW-ready 2D features.
``"coords"`` reshapes to ``(frames, joints * 3)``; ``"angles"``
runs :func:`extract_joint_angles` to produce
``(frames, len(angle_triplets))``.
"""
if representation == "coords":
return sequence.reshape(sequence.shape[0], -1)
if representation == "angles":
if angle_triplets is None:
raise ValueError("representation='angles' requires angle_triplets")
return extract_joint_angles(sequence, angle_triplets)
raise ValueError(f"unknown representation {representation!r}")
def _apply_nan_policy(features: np.ndarray, policy: NanPolicy) -> np.ndarray:
"""Handle NaN values in a ``(frames, features)`` array per ``policy``.
``"propagate"`` is a no-op. ``"interpolate"`` runs 1D linear
interpolation along the frame axis within each feature column,
leaving finite data untouched. ``"drop"`` removes any frame where
*any* feature is NaN.
Raises
------
ValueError
If ``"interpolate"`` encounters a column that is entirely NaN
(no finite anchors to interpolate between), or if ``"drop"``
leaves an empty sequence.
"""
if policy == "propagate":
return features
if features.ndim == 1:
features = features.reshape(-1, 1)
if policy == "drop":
keep = np.isfinite(features).all(axis=1)
dropped = features[keep]
if dropped.shape[0] == 0:
raise ValueError(
"nan_policy='drop' removed every frame; DTW needs a non-empty sequence"
)
return dropped
if policy == "interpolate":
out = features.astype(float, copy=True)
num_frames = out.shape[0]
indices = np.arange(num_frames, dtype=float)
for col in range(out.shape[1]):
column = out[:, col]
finite = np.isfinite(column)
if finite.all():
continue
if not finite.any():
raise ValueError(
f"nan_policy='interpolate' cannot fill column {col}: all values are NaN"
)
out[:, col] = np.interp(indices, indices[finite], column[finite])
return out
raise ValueError(f"unknown nan_policy {policy!r}")

View File

@ -131,3 +131,250 @@ class TestDtwRelation:
b = np.zeros((3, 2, 3)) b = np.zeros((3, 2, 3))
with pytest.raises(ValueError, match="joint count"): with pytest.raises(ValueError, match="joint count"):
dtw_relation(a, b, joint_i=0, joint_j=1) dtw_relation(a, b, joint_i=0, joint_j=1)
# ---------------------------------------------------------------------------
# representation="angles"
# ---------------------------------------------------------------------------
def _rotation_matrix_z(angle_rad: float) -> np.ndarray:
c, s = np.cos(angle_rad), np.sin(angle_rad)
return np.array(
[
[c, -s, 0.0],
[s, c, 0.0],
[0.0, 0.0, 1.0],
]
)
def _three_joint_arm(num_frames: int = 6) -> np.ndarray:
"""A three-joint arm opening from a right angle to straight.
Joints laid out as [shoulder, elbow, wrist], forming an angle at
the elbow that linearly opens from pi/2 to pi across ``num_frames``.
"""
sequence = np.zeros((num_frames, 3, 3))
angles = np.linspace(np.pi / 2, np.pi, num_frames)
for i, theta in enumerate(angles):
sequence[i, 0] = [-1.0, 0.0, 0.0] # shoulder
sequence[i, 1] = [0.0, 0.0, 0.0] # elbow
sequence[i, 2] = [np.cos(theta - np.pi), np.sin(theta - np.pi), 0.0] # wrist
return sequence
class TestDtwAllAngles:
def test_angles_identical_sequences_distance_zero(self) -> None:
seq = _three_joint_arm()
result = dtw_all(
seq,
seq,
representation="angles",
angle_triplets=[(0, 1, 2)],
)
assert result.distance == pytest.approx(0.0, abs=1e-9)
def test_angles_invariant_to_global_rotation(self) -> None:
"""Angle-space DTW must not change under a global rotation."""
seq = _three_joint_arm()
rotated = seq @ _rotation_matrix_z(np.deg2rad(40.0)).T
baseline = dtw_all(seq, seq, representation="angles", angle_triplets=[(0, 1, 2)])
under_rotation = dtw_all(
seq,
rotated,
representation="angles",
angle_triplets=[(0, 1, 2)],
)
assert baseline.distance == pytest.approx(under_rotation.distance, abs=1e-6)
def test_angles_translation_invariant(self) -> None:
seq = _three_joint_arm()
translated = seq + np.array([10.0, -5.0, 2.0])
result = dtw_all(
seq,
translated,
representation="angles",
angle_triplets=[(0, 1, 2)],
)
assert result.distance == pytest.approx(0.0, abs=1e-9)
def test_angles_detects_different_motion(self) -> None:
# A sequence whose angle is constant vs. one that opens.
constant = np.zeros((6, 3, 3))
constant[:, 0] = [-1.0, 0.0, 0.0]
constant[:, 1] = [0.0, 0.0, 0.0]
constant[:, 2] = [0.0, 1.0, 0.0] # right angle throughout
opening = _three_joint_arm()
result = dtw_all(
constant,
opening,
representation="angles",
angle_triplets=[(0, 1, 2)],
)
assert result.distance > 0.0
def test_angles_without_triplets_rejected(self) -> None:
seq = _three_joint_arm()
with pytest.raises(ValueError, match="angle_triplets"):
dtw_all(seq, seq, representation="angles")
class TestDtwPerJointAngles:
def test_returns_one_result_per_triplet(self) -> None:
seq = _three_joint_arm()
triplets = [(0, 1, 2), (0, 1, 2)] # duplicate triplet on purpose
results = dtw_per_joint(
seq,
seq,
representation="angles",
angle_triplets=triplets,
)
assert len(results) == 2
for result in results:
assert result.distance == pytest.approx(0.0, abs=1e-9)
def test_per_triplet_distinct_paths(self) -> None:
# Two triplets covering different angles; with different motion
# per triplet, the per-unit results should differ.
seq_a = np.zeros((5, 4, 3))
seq_b = np.zeros((5, 4, 3))
# joint 0: pivot, joint 1/2/3: arm endpoints
for i in range(5):
seq_a[i, 0] = [0.0, 0.0, 0.0]
seq_a[i, 1] = [1.0, 0.0, 0.0]
seq_a[i, 2] = [0.0, 1.0, 0.0]
seq_a[i, 3] = [0.0, 0.0, 1.0]
seq_b[i, 0] = [0.0, 0.0, 0.0]
seq_b[i, 1] = [1.0, 0.0, 0.0]
seq_b[i, 2] = [np.cos(i * 0.3), np.sin(i * 0.3), 0.0] # rotating
seq_b[i, 3] = [0.0, 0.0, 1.0]
results = dtw_per_joint(
seq_a,
seq_b,
representation="angles",
angle_triplets=[(1, 0, 2), (1, 0, 3)],
)
assert len(results) == 2
# First triplet tracks the rotation, second is stationary.
assert results[0].distance > 0.0
assert results[1].distance == pytest.approx(0.0, abs=1e-9)
# ---------------------------------------------------------------------------
# nan_policy
# ---------------------------------------------------------------------------
def _collinear_sequence(num_frames: int = 4) -> np.ndarray:
"""Three collinear joints — the angle at the middle joint is degenerate."""
seq = np.zeros((num_frames, 3, 3))
seq[:, 0] = [-1.0, 0.0, 0.0]
# Middle joint at (0,0,0); but because the outer joints are collinear
# through the origin, we need one joint overlapping with the middle
# to force a zero-length vector. Place joint 2 AT joint 1 to trigger
# the degenerate case in extract_joint_angles.
seq[:, 1] = [0.0, 0.0, 0.0]
seq[:, 2] = [0.0, 0.0, 0.0]
return seq
class TestNanPolicy:
def test_propagate_surfaces_error(self) -> None:
# Degenerate triplet produces NaN angles for every frame.
# With nan_policy="propagate" the NaN reaches fastdtw, which
# validates via numpy.asarray_chkfinite and raises ValueError —
# the intended behaviour ("make the problem visible").
seq = _collinear_sequence(num_frames=4)
other = _three_joint_arm(num_frames=4)
with pytest.raises(ValueError, match="infs or NaNs"):
dtw_all(
seq,
other,
representation="angles",
angle_triplets=[(0, 1, 2)],
nan_policy="propagate",
)
def test_interpolate_fills_isolated_nan(self) -> None:
# One bad frame in a 5-frame sequence — the other four are
# finite anchors to interpolate between.
good = _three_joint_arm(num_frames=5)
# Inject a degenerate middle frame.
good[2, 2] = good[2, 1] # force zero-length vector → NaN angle
# Reference is the same arm without injection.
reference = _three_joint_arm(num_frames=5)
result = dtw_all(
good,
reference,
representation="angles",
angle_triplets=[(0, 1, 2)],
nan_policy="interpolate",
)
assert not np.isnan(result.distance)
def test_interpolate_all_nan_column_rejected(self) -> None:
seq = _collinear_sequence(num_frames=5)
other = _three_joint_arm(num_frames=5)
with pytest.raises(ValueError, match="all values are NaN"):
dtw_all(
seq,
other,
representation="angles",
angle_triplets=[(0, 1, 2)],
nan_policy="interpolate",
)
def test_drop_removes_nan_frames(self) -> None:
good = _three_joint_arm(num_frames=6)
good[2, 2] = good[2, 1] # inject NaN at frame 2
good[4, 2] = good[4, 1] # inject NaN at frame 4
reference = _three_joint_arm(num_frames=6)
result = dtw_all(
good,
reference,
representation="angles",
angle_triplets=[(0, 1, 2)],
nan_policy="drop",
)
# The 4 remaining finite frames should align cleanly with
# their counterparts in the reference.
assert not np.isnan(result.distance)
def test_drop_empties_sequence_rejected(self) -> None:
seq = _collinear_sequence(num_frames=5)
other = _three_joint_arm(num_frames=5)
with pytest.raises(ValueError, match="every frame"):
dtw_all(
seq,
other,
representation="angles",
angle_triplets=[(0, 1, 2)],
nan_policy="drop",
)
# ---------------------------------------------------------------------------
# align + representation composition
# ---------------------------------------------------------------------------
class TestAlignWithAngles:
def test_procrustes_before_angles_is_no_op_on_invariant_representation(self) -> None:
"""Procrustes on angle-space DTW should be redundant but safe."""
seq = _three_joint_arm()
rotated = seq @ _rotation_matrix_z(np.deg2rad(20.0)).T
with_align = dtw_all(
seq,
rotated,
align="procrustes_per_sequence",
representation="angles",
angle_triplets=[(0, 1, 2)],
)
without_align = dtw_all(
seq,
rotated,
representation="angles",
angle_triplets=[(0, 1, 2)],
)
assert with_align.distance == pytest.approx(without_align.distance, abs=1e-6)