381 lines
14 KiB
Python
381 lines
14 KiB
Python
"""Tests for :mod:`neuropose.analyzer.dtw`."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from neuropose.analyzer.dtw import (
|
|
DTWResult,
|
|
dtw_all,
|
|
dtw_per_joint,
|
|
dtw_relation,
|
|
)
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fixtures
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.fixture
|
|
def simple_sequence() -> np.ndarray:
|
|
"""A 5-frame, 3-joint sequence of linearly-moving joints."""
|
|
rng = np.random.default_rng(seed=42)
|
|
return rng.standard_normal((5, 3, 3))
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# dtw_all
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestDtwAll:
|
|
def test_identical_sequences_distance_zero(self, simple_sequence: np.ndarray) -> None:
|
|
result = dtw_all(simple_sequence, simple_sequence)
|
|
assert isinstance(result, DTWResult)
|
|
assert result.distance == pytest.approx(0.0, abs=1e-9)
|
|
# Identical sequences produce a diagonal warping path.
|
|
assert all(i == j for i, j in result.path)
|
|
|
|
def test_shifted_sequences_distance_zero(self, simple_sequence: np.ndarray) -> None:
|
|
"""DTW should absorb a pure time shift without penalty."""
|
|
# Duplicate the first frame to create a one-frame shift.
|
|
shifted = np.concatenate([simple_sequence[:1], simple_sequence], axis=0)
|
|
result = dtw_all(simple_sequence, shifted)
|
|
assert result.distance == pytest.approx(0.0, abs=1e-9)
|
|
|
|
def test_different_sequences_positive_distance(self) -> None:
|
|
a = np.zeros((5, 3, 3))
|
|
b = np.ones((5, 3, 3))
|
|
result = dtw_all(a, b)
|
|
assert result.distance > 0.0
|
|
|
|
def test_mismatched_joint_count_rejected(self) -> None:
|
|
a = np.zeros((5, 3, 3))
|
|
b = np.zeros((5, 4, 3))
|
|
with pytest.raises(ValueError, match="joint count"):
|
|
dtw_all(a, b)
|
|
|
|
def test_non_3d_input_rejected(self) -> None:
|
|
a = np.zeros((5, 3)) # missing trailing axis
|
|
b = np.zeros((5, 3))
|
|
with pytest.raises(ValueError, match="expected 3D"):
|
|
dtw_all(a, b)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# dtw_per_joint
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestDtwPerJoint:
|
|
def test_returns_one_result_per_joint(self, simple_sequence: np.ndarray) -> None:
|
|
results = dtw_per_joint(simple_sequence, simple_sequence)
|
|
assert len(results) == simple_sequence.shape[1]
|
|
for result in results:
|
|
assert isinstance(result, DTWResult)
|
|
assert result.distance == pytest.approx(0.0, abs=1e-9)
|
|
|
|
def test_independent_joint_distances(self) -> None:
|
|
# Construct two sequences where joint 0 matches exactly but
|
|
# joint 1 is offset by a constant. Per-joint DTW should give
|
|
# distance 0 for joint 0 and distance > 0 for joint 1.
|
|
a = np.array(
|
|
[
|
|
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
|
|
[[1.0, 1.0, 1.0], [0.0, 0.0, 0.0]],
|
|
]
|
|
)
|
|
b = a.copy()
|
|
b[:, 1, :] += 10.0
|
|
results = dtw_per_joint(a, b)
|
|
assert results[0].distance == pytest.approx(0.0, abs=1e-9)
|
|
assert results[1].distance > 0.0
|
|
|
|
def test_mismatched_joint_count_rejected(self) -> None:
|
|
a = np.zeros((5, 3, 3))
|
|
b = np.zeros((5, 2, 3))
|
|
with pytest.raises(ValueError, match="joint count"):
|
|
dtw_per_joint(a, b)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# dtw_relation
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestDtwRelation:
|
|
def test_identical_sequences_distance_zero(self, simple_sequence: np.ndarray) -> None:
|
|
result = dtw_relation(simple_sequence, simple_sequence, joint_i=0, joint_j=1)
|
|
assert result.distance == pytest.approx(0.0, abs=1e-9)
|
|
|
|
def test_same_relative_position_is_zero_even_under_translation(self) -> None:
|
|
"""Translating the whole body does not change the
|
|
joint-to-joint displacement, so dtw_relation should be 0."""
|
|
a = np.zeros((4, 3, 3))
|
|
a[:, 0, :] = [0.0, 0.0, 0.0]
|
|
a[:, 1, :] = [1.0, 0.0, 0.0]
|
|
a[:, 2, :] = [0.0, 1.0, 0.0]
|
|
b = a + 50.0 # translate the whole body
|
|
result = dtw_relation(a, b, joint_i=0, joint_j=1)
|
|
assert result.distance == pytest.approx(0.0, abs=1e-9)
|
|
|
|
def test_joint_index_out_of_range_rejected(self) -> None:
|
|
a = np.zeros((3, 2, 3))
|
|
b = np.zeros((3, 2, 3))
|
|
with pytest.raises(ValueError, match="joint indices"):
|
|
dtw_relation(a, b, joint_i=0, joint_j=5)
|
|
|
|
def test_mismatched_joint_count_rejected(self) -> None:
|
|
a = np.zeros((3, 3, 3))
|
|
b = np.zeros((3, 2, 3))
|
|
with pytest.raises(ValueError, match="joint count"):
|
|
dtw_relation(a, b, joint_i=0, joint_j=1)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# representation="angles"
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _rotation_matrix_z(angle_rad: float) -> np.ndarray:
|
|
c, s = np.cos(angle_rad), np.sin(angle_rad)
|
|
return np.array(
|
|
[
|
|
[c, -s, 0.0],
|
|
[s, c, 0.0],
|
|
[0.0, 0.0, 1.0],
|
|
]
|
|
)
|
|
|
|
|
|
def _three_joint_arm(num_frames: int = 6) -> np.ndarray:
|
|
"""A three-joint arm opening from a right angle to straight.
|
|
|
|
Joints laid out as [shoulder, elbow, wrist], forming an angle at
|
|
the elbow that linearly opens from pi/2 to pi across ``num_frames``.
|
|
"""
|
|
sequence = np.zeros((num_frames, 3, 3))
|
|
angles = np.linspace(np.pi / 2, np.pi, num_frames)
|
|
for i, theta in enumerate(angles):
|
|
sequence[i, 0] = [-1.0, 0.0, 0.0] # shoulder
|
|
sequence[i, 1] = [0.0, 0.0, 0.0] # elbow
|
|
sequence[i, 2] = [np.cos(theta - np.pi), np.sin(theta - np.pi), 0.0] # wrist
|
|
return sequence
|
|
|
|
|
|
class TestDtwAllAngles:
|
|
def test_angles_identical_sequences_distance_zero(self) -> None:
|
|
seq = _three_joint_arm()
|
|
result = dtw_all(
|
|
seq,
|
|
seq,
|
|
representation="angles",
|
|
angle_triplets=[(0, 1, 2)],
|
|
)
|
|
assert result.distance == pytest.approx(0.0, abs=1e-9)
|
|
|
|
def test_angles_invariant_to_global_rotation(self) -> None:
|
|
"""Angle-space DTW must not change under a global rotation."""
|
|
seq = _three_joint_arm()
|
|
rotated = seq @ _rotation_matrix_z(np.deg2rad(40.0)).T
|
|
baseline = dtw_all(seq, seq, representation="angles", angle_triplets=[(0, 1, 2)])
|
|
under_rotation = dtw_all(
|
|
seq,
|
|
rotated,
|
|
representation="angles",
|
|
angle_triplets=[(0, 1, 2)],
|
|
)
|
|
assert baseline.distance == pytest.approx(under_rotation.distance, abs=1e-6)
|
|
|
|
def test_angles_translation_invariant(self) -> None:
|
|
seq = _three_joint_arm()
|
|
translated = seq + np.array([10.0, -5.0, 2.0])
|
|
result = dtw_all(
|
|
seq,
|
|
translated,
|
|
representation="angles",
|
|
angle_triplets=[(0, 1, 2)],
|
|
)
|
|
assert result.distance == pytest.approx(0.0, abs=1e-9)
|
|
|
|
def test_angles_detects_different_motion(self) -> None:
|
|
# A sequence whose angle is constant vs. one that opens.
|
|
constant = np.zeros((6, 3, 3))
|
|
constant[:, 0] = [-1.0, 0.0, 0.0]
|
|
constant[:, 1] = [0.0, 0.0, 0.0]
|
|
constant[:, 2] = [0.0, 1.0, 0.0] # right angle throughout
|
|
opening = _three_joint_arm()
|
|
result = dtw_all(
|
|
constant,
|
|
opening,
|
|
representation="angles",
|
|
angle_triplets=[(0, 1, 2)],
|
|
)
|
|
assert result.distance > 0.0
|
|
|
|
def test_angles_without_triplets_rejected(self) -> None:
|
|
seq = _three_joint_arm()
|
|
with pytest.raises(ValueError, match="angle_triplets"):
|
|
dtw_all(seq, seq, representation="angles")
|
|
|
|
|
|
class TestDtwPerJointAngles:
|
|
def test_returns_one_result_per_triplet(self) -> None:
|
|
seq = _three_joint_arm()
|
|
triplets = [(0, 1, 2), (0, 1, 2)] # duplicate triplet on purpose
|
|
results = dtw_per_joint(
|
|
seq,
|
|
seq,
|
|
representation="angles",
|
|
angle_triplets=triplets,
|
|
)
|
|
assert len(results) == 2
|
|
for result in results:
|
|
assert result.distance == pytest.approx(0.0, abs=1e-9)
|
|
|
|
def test_per_triplet_distinct_paths(self) -> None:
|
|
# Two triplets covering different angles; with different motion
|
|
# per triplet, the per-unit results should differ.
|
|
seq_a = np.zeros((5, 4, 3))
|
|
seq_b = np.zeros((5, 4, 3))
|
|
# joint 0: pivot, joint 1/2/3: arm endpoints
|
|
for i in range(5):
|
|
seq_a[i, 0] = [0.0, 0.0, 0.0]
|
|
seq_a[i, 1] = [1.0, 0.0, 0.0]
|
|
seq_a[i, 2] = [0.0, 1.0, 0.0]
|
|
seq_a[i, 3] = [0.0, 0.0, 1.0]
|
|
seq_b[i, 0] = [0.0, 0.0, 0.0]
|
|
seq_b[i, 1] = [1.0, 0.0, 0.0]
|
|
seq_b[i, 2] = [np.cos(i * 0.3), np.sin(i * 0.3), 0.0] # rotating
|
|
seq_b[i, 3] = [0.0, 0.0, 1.0]
|
|
results = dtw_per_joint(
|
|
seq_a,
|
|
seq_b,
|
|
representation="angles",
|
|
angle_triplets=[(1, 0, 2), (1, 0, 3)],
|
|
)
|
|
assert len(results) == 2
|
|
# First triplet tracks the rotation, second is stationary.
|
|
assert results[0].distance > 0.0
|
|
assert results[1].distance == pytest.approx(0.0, abs=1e-9)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# nan_policy
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _collinear_sequence(num_frames: int = 4) -> np.ndarray:
|
|
"""Three collinear joints — the angle at the middle joint is degenerate."""
|
|
seq = np.zeros((num_frames, 3, 3))
|
|
seq[:, 0] = [-1.0, 0.0, 0.0]
|
|
# Middle joint at (0,0,0); but because the outer joints are collinear
|
|
# through the origin, we need one joint overlapping with the middle
|
|
# to force a zero-length vector. Place joint 2 AT joint 1 to trigger
|
|
# the degenerate case in extract_joint_angles.
|
|
seq[:, 1] = [0.0, 0.0, 0.0]
|
|
seq[:, 2] = [0.0, 0.0, 0.0]
|
|
return seq
|
|
|
|
|
|
class TestNanPolicy:
|
|
def test_propagate_surfaces_error(self) -> None:
|
|
# Degenerate triplet produces NaN angles for every frame.
|
|
# With nan_policy="propagate" the NaN reaches fastdtw, which
|
|
# validates via numpy.asarray_chkfinite and raises ValueError —
|
|
# the intended behaviour ("make the problem visible").
|
|
seq = _collinear_sequence(num_frames=4)
|
|
other = _three_joint_arm(num_frames=4)
|
|
with pytest.raises(ValueError, match="infs or NaNs"):
|
|
dtw_all(
|
|
seq,
|
|
other,
|
|
representation="angles",
|
|
angle_triplets=[(0, 1, 2)],
|
|
nan_policy="propagate",
|
|
)
|
|
|
|
def test_interpolate_fills_isolated_nan(self) -> None:
|
|
# One bad frame in a 5-frame sequence — the other four are
|
|
# finite anchors to interpolate between.
|
|
good = _three_joint_arm(num_frames=5)
|
|
# Inject a degenerate middle frame.
|
|
good[2, 2] = good[2, 1] # force zero-length vector → NaN angle
|
|
# Reference is the same arm without injection.
|
|
reference = _three_joint_arm(num_frames=5)
|
|
result = dtw_all(
|
|
good,
|
|
reference,
|
|
representation="angles",
|
|
angle_triplets=[(0, 1, 2)],
|
|
nan_policy="interpolate",
|
|
)
|
|
assert not np.isnan(result.distance)
|
|
|
|
def test_interpolate_all_nan_column_rejected(self) -> None:
|
|
seq = _collinear_sequence(num_frames=5)
|
|
other = _three_joint_arm(num_frames=5)
|
|
with pytest.raises(ValueError, match="all values are NaN"):
|
|
dtw_all(
|
|
seq,
|
|
other,
|
|
representation="angles",
|
|
angle_triplets=[(0, 1, 2)],
|
|
nan_policy="interpolate",
|
|
)
|
|
|
|
def test_drop_removes_nan_frames(self) -> None:
|
|
good = _three_joint_arm(num_frames=6)
|
|
good[2, 2] = good[2, 1] # inject NaN at frame 2
|
|
good[4, 2] = good[4, 1] # inject NaN at frame 4
|
|
reference = _three_joint_arm(num_frames=6)
|
|
result = dtw_all(
|
|
good,
|
|
reference,
|
|
representation="angles",
|
|
angle_triplets=[(0, 1, 2)],
|
|
nan_policy="drop",
|
|
)
|
|
# The 4 remaining finite frames should align cleanly with
|
|
# their counterparts in the reference.
|
|
assert not np.isnan(result.distance)
|
|
|
|
def test_drop_empties_sequence_rejected(self) -> None:
|
|
seq = _collinear_sequence(num_frames=5)
|
|
other = _three_joint_arm(num_frames=5)
|
|
with pytest.raises(ValueError, match="every frame"):
|
|
dtw_all(
|
|
seq,
|
|
other,
|
|
representation="angles",
|
|
angle_triplets=[(0, 1, 2)],
|
|
nan_policy="drop",
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# align + representation composition
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestAlignWithAngles:
|
|
def test_procrustes_before_angles_is_no_op_on_invariant_representation(self) -> None:
|
|
"""Procrustes on angle-space DTW should be redundant but safe."""
|
|
seq = _three_joint_arm()
|
|
rotated = seq @ _rotation_matrix_z(np.deg2rad(20.0)).T
|
|
with_align = dtw_all(
|
|
seq,
|
|
rotated,
|
|
align="procrustes_per_sequence",
|
|
representation="angles",
|
|
angle_triplets=[(0, 1, 2)],
|
|
)
|
|
without_align = dtw_all(
|
|
seq,
|
|
rotated,
|
|
representation="angles",
|
|
angle_triplets=[(0, 1, 2)],
|
|
)
|
|
assert with_align.distance == pytest.approx(without_align.distance, abs=1e-6)
|