neuropose/tests/unit/test_analyzer_dtw.py

381 lines
14 KiB
Python

"""Tests for :mod:`neuropose.analyzer.dtw`."""
from __future__ import annotations
import numpy as np
import pytest
from neuropose.analyzer.dtw import (
DTWResult,
dtw_all,
dtw_per_joint,
dtw_relation,
)
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def simple_sequence() -> np.ndarray:
"""A 5-frame, 3-joint sequence of linearly-moving joints."""
rng = np.random.default_rng(seed=42)
return rng.standard_normal((5, 3, 3))
# ---------------------------------------------------------------------------
# dtw_all
# ---------------------------------------------------------------------------
class TestDtwAll:
def test_identical_sequences_distance_zero(self, simple_sequence: np.ndarray) -> None:
result = dtw_all(simple_sequence, simple_sequence)
assert isinstance(result, DTWResult)
assert result.distance == pytest.approx(0.0, abs=1e-9)
# Identical sequences produce a diagonal warping path.
assert all(i == j for i, j in result.path)
def test_shifted_sequences_distance_zero(self, simple_sequence: np.ndarray) -> None:
"""DTW should absorb a pure time shift without penalty."""
# Duplicate the first frame to create a one-frame shift.
shifted = np.concatenate([simple_sequence[:1], simple_sequence], axis=0)
result = dtw_all(simple_sequence, shifted)
assert result.distance == pytest.approx(0.0, abs=1e-9)
def test_different_sequences_positive_distance(self) -> None:
a = np.zeros((5, 3, 3))
b = np.ones((5, 3, 3))
result = dtw_all(a, b)
assert result.distance > 0.0
def test_mismatched_joint_count_rejected(self) -> None:
a = np.zeros((5, 3, 3))
b = np.zeros((5, 4, 3))
with pytest.raises(ValueError, match="joint count"):
dtw_all(a, b)
def test_non_3d_input_rejected(self) -> None:
a = np.zeros((5, 3)) # missing trailing axis
b = np.zeros((5, 3))
with pytest.raises(ValueError, match="expected 3D"):
dtw_all(a, b)
# ---------------------------------------------------------------------------
# dtw_per_joint
# ---------------------------------------------------------------------------
class TestDtwPerJoint:
def test_returns_one_result_per_joint(self, simple_sequence: np.ndarray) -> None:
results = dtw_per_joint(simple_sequence, simple_sequence)
assert len(results) == simple_sequence.shape[1]
for result in results:
assert isinstance(result, DTWResult)
assert result.distance == pytest.approx(0.0, abs=1e-9)
def test_independent_joint_distances(self) -> None:
# Construct two sequences where joint 0 matches exactly but
# joint 1 is offset by a constant. Per-joint DTW should give
# distance 0 for joint 0 and distance > 0 for joint 1.
a = np.array(
[
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
[[1.0, 1.0, 1.0], [0.0, 0.0, 0.0]],
]
)
b = a.copy()
b[:, 1, :] += 10.0
results = dtw_per_joint(a, b)
assert results[0].distance == pytest.approx(0.0, abs=1e-9)
assert results[1].distance > 0.0
def test_mismatched_joint_count_rejected(self) -> None:
a = np.zeros((5, 3, 3))
b = np.zeros((5, 2, 3))
with pytest.raises(ValueError, match="joint count"):
dtw_per_joint(a, b)
# ---------------------------------------------------------------------------
# dtw_relation
# ---------------------------------------------------------------------------
class TestDtwRelation:
def test_identical_sequences_distance_zero(self, simple_sequence: np.ndarray) -> None:
result = dtw_relation(simple_sequence, simple_sequence, joint_i=0, joint_j=1)
assert result.distance == pytest.approx(0.0, abs=1e-9)
def test_same_relative_position_is_zero_even_under_translation(self) -> None:
"""Translating the whole body does not change the
joint-to-joint displacement, so dtw_relation should be 0."""
a = np.zeros((4, 3, 3))
a[:, 0, :] = [0.0, 0.0, 0.0]
a[:, 1, :] = [1.0, 0.0, 0.0]
a[:, 2, :] = [0.0, 1.0, 0.0]
b = a + 50.0 # translate the whole body
result = dtw_relation(a, b, joint_i=0, joint_j=1)
assert result.distance == pytest.approx(0.0, abs=1e-9)
def test_joint_index_out_of_range_rejected(self) -> None:
a = np.zeros((3, 2, 3))
b = np.zeros((3, 2, 3))
with pytest.raises(ValueError, match="joint indices"):
dtw_relation(a, b, joint_i=0, joint_j=5)
def test_mismatched_joint_count_rejected(self) -> None:
a = np.zeros((3, 3, 3))
b = np.zeros((3, 2, 3))
with pytest.raises(ValueError, match="joint count"):
dtw_relation(a, b, joint_i=0, joint_j=1)
# ---------------------------------------------------------------------------
# representation="angles"
# ---------------------------------------------------------------------------
def _rotation_matrix_z(angle_rad: float) -> np.ndarray:
c, s = np.cos(angle_rad), np.sin(angle_rad)
return np.array(
[
[c, -s, 0.0],
[s, c, 0.0],
[0.0, 0.0, 1.0],
]
)
def _three_joint_arm(num_frames: int = 6) -> np.ndarray:
"""A three-joint arm opening from a right angle to straight.
Joints laid out as [shoulder, elbow, wrist], forming an angle at
the elbow that linearly opens from pi/2 to pi across ``num_frames``.
"""
sequence = np.zeros((num_frames, 3, 3))
angles = np.linspace(np.pi / 2, np.pi, num_frames)
for i, theta in enumerate(angles):
sequence[i, 0] = [-1.0, 0.0, 0.0] # shoulder
sequence[i, 1] = [0.0, 0.0, 0.0] # elbow
sequence[i, 2] = [np.cos(theta - np.pi), np.sin(theta - np.pi), 0.0] # wrist
return sequence
class TestDtwAllAngles:
def test_angles_identical_sequences_distance_zero(self) -> None:
seq = _three_joint_arm()
result = dtw_all(
seq,
seq,
representation="angles",
angle_triplets=[(0, 1, 2)],
)
assert result.distance == pytest.approx(0.0, abs=1e-9)
def test_angles_invariant_to_global_rotation(self) -> None:
"""Angle-space DTW must not change under a global rotation."""
seq = _three_joint_arm()
rotated = seq @ _rotation_matrix_z(np.deg2rad(40.0)).T
baseline = dtw_all(seq, seq, representation="angles", angle_triplets=[(0, 1, 2)])
under_rotation = dtw_all(
seq,
rotated,
representation="angles",
angle_triplets=[(0, 1, 2)],
)
assert baseline.distance == pytest.approx(under_rotation.distance, abs=1e-6)
def test_angles_translation_invariant(self) -> None:
seq = _three_joint_arm()
translated = seq + np.array([10.0, -5.0, 2.0])
result = dtw_all(
seq,
translated,
representation="angles",
angle_triplets=[(0, 1, 2)],
)
assert result.distance == pytest.approx(0.0, abs=1e-9)
def test_angles_detects_different_motion(self) -> None:
# A sequence whose angle is constant vs. one that opens.
constant = np.zeros((6, 3, 3))
constant[:, 0] = [-1.0, 0.0, 0.0]
constant[:, 1] = [0.0, 0.0, 0.0]
constant[:, 2] = [0.0, 1.0, 0.0] # right angle throughout
opening = _three_joint_arm()
result = dtw_all(
constant,
opening,
representation="angles",
angle_triplets=[(0, 1, 2)],
)
assert result.distance > 0.0
def test_angles_without_triplets_rejected(self) -> None:
seq = _three_joint_arm()
with pytest.raises(ValueError, match="angle_triplets"):
dtw_all(seq, seq, representation="angles")
class TestDtwPerJointAngles:
def test_returns_one_result_per_triplet(self) -> None:
seq = _three_joint_arm()
triplets = [(0, 1, 2), (0, 1, 2)] # duplicate triplet on purpose
results = dtw_per_joint(
seq,
seq,
representation="angles",
angle_triplets=triplets,
)
assert len(results) == 2
for result in results:
assert result.distance == pytest.approx(0.0, abs=1e-9)
def test_per_triplet_distinct_paths(self) -> None:
# Two triplets covering different angles; with different motion
# per triplet, the per-unit results should differ.
seq_a = np.zeros((5, 4, 3))
seq_b = np.zeros((5, 4, 3))
# joint 0: pivot, joint 1/2/3: arm endpoints
for i in range(5):
seq_a[i, 0] = [0.0, 0.0, 0.0]
seq_a[i, 1] = [1.0, 0.0, 0.0]
seq_a[i, 2] = [0.0, 1.0, 0.0]
seq_a[i, 3] = [0.0, 0.0, 1.0]
seq_b[i, 0] = [0.0, 0.0, 0.0]
seq_b[i, 1] = [1.0, 0.0, 0.0]
seq_b[i, 2] = [np.cos(i * 0.3), np.sin(i * 0.3), 0.0] # rotating
seq_b[i, 3] = [0.0, 0.0, 1.0]
results = dtw_per_joint(
seq_a,
seq_b,
representation="angles",
angle_triplets=[(1, 0, 2), (1, 0, 3)],
)
assert len(results) == 2
# First triplet tracks the rotation, second is stationary.
assert results[0].distance > 0.0
assert results[1].distance == pytest.approx(0.0, abs=1e-9)
# ---------------------------------------------------------------------------
# nan_policy
# ---------------------------------------------------------------------------
def _collinear_sequence(num_frames: int = 4) -> np.ndarray:
"""Three collinear joints — the angle at the middle joint is degenerate."""
seq = np.zeros((num_frames, 3, 3))
seq[:, 0] = [-1.0, 0.0, 0.0]
# Middle joint at (0,0,0); but because the outer joints are collinear
# through the origin, we need one joint overlapping with the middle
# to force a zero-length vector. Place joint 2 AT joint 1 to trigger
# the degenerate case in extract_joint_angles.
seq[:, 1] = [0.0, 0.0, 0.0]
seq[:, 2] = [0.0, 0.0, 0.0]
return seq
class TestNanPolicy:
def test_propagate_surfaces_error(self) -> None:
# Degenerate triplet produces NaN angles for every frame.
# With nan_policy="propagate" the NaN reaches fastdtw, which
# validates via numpy.asarray_chkfinite and raises ValueError —
# the intended behaviour ("make the problem visible").
seq = _collinear_sequence(num_frames=4)
other = _three_joint_arm(num_frames=4)
with pytest.raises(ValueError, match="infs or NaNs"):
dtw_all(
seq,
other,
representation="angles",
angle_triplets=[(0, 1, 2)],
nan_policy="propagate",
)
def test_interpolate_fills_isolated_nan(self) -> None:
# One bad frame in a 5-frame sequence — the other four are
# finite anchors to interpolate between.
good = _three_joint_arm(num_frames=5)
# Inject a degenerate middle frame.
good[2, 2] = good[2, 1] # force zero-length vector → NaN angle
# Reference is the same arm without injection.
reference = _three_joint_arm(num_frames=5)
result = dtw_all(
good,
reference,
representation="angles",
angle_triplets=[(0, 1, 2)],
nan_policy="interpolate",
)
assert not np.isnan(result.distance)
def test_interpolate_all_nan_column_rejected(self) -> None:
seq = _collinear_sequence(num_frames=5)
other = _three_joint_arm(num_frames=5)
with pytest.raises(ValueError, match="all values are NaN"):
dtw_all(
seq,
other,
representation="angles",
angle_triplets=[(0, 1, 2)],
nan_policy="interpolate",
)
def test_drop_removes_nan_frames(self) -> None:
good = _three_joint_arm(num_frames=6)
good[2, 2] = good[2, 1] # inject NaN at frame 2
good[4, 2] = good[4, 1] # inject NaN at frame 4
reference = _three_joint_arm(num_frames=6)
result = dtw_all(
good,
reference,
representation="angles",
angle_triplets=[(0, 1, 2)],
nan_policy="drop",
)
# The 4 remaining finite frames should align cleanly with
# their counterparts in the reference.
assert not np.isnan(result.distance)
def test_drop_empties_sequence_rejected(self) -> None:
seq = _collinear_sequence(num_frames=5)
other = _three_joint_arm(num_frames=5)
with pytest.raises(ValueError, match="every frame"):
dtw_all(
seq,
other,
representation="angles",
angle_triplets=[(0, 1, 2)],
nan_policy="drop",
)
# ---------------------------------------------------------------------------
# align + representation composition
# ---------------------------------------------------------------------------
class TestAlignWithAngles:
def test_procrustes_before_angles_is_no_op_on_invariant_representation(self) -> None:
"""Procrustes on angle-space DTW should be redundant but safe."""
seq = _three_joint_arm()
rotated = seq @ _rotation_matrix_z(np.deg2rad(20.0)).T
with_align = dtw_all(
seq,
rotated,
align="procrustes_per_sequence",
representation="angles",
angle_triplets=[(0, 1, 2)],
)
without_align = dtw_all(
seq,
rotated,
representation="angles",
angle_triplets=[(0, 1, 2)],
)
assert with_align.distance == pytest.approx(without_align.distance, abs=1e-6)