neuropose/tests/unit/test_analyzer_dtw.py

143 lines
5.0 KiB
Python

"""Tests for :mod:`neuropose.analyzer.dtw`."""
from __future__ import annotations
import numpy as np
import pytest
from neuropose.analyzer.dtw import (
DTWResult,
dtw_all,
dtw_per_joint,
dtw_relation,
)
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def simple_sequence() -> np.ndarray:
"""A 5-frame, 3-joint sequence of linearly-moving joints."""
rng = np.random.default_rng(seed=42)
return rng.standard_normal((5, 3, 3))
# ---------------------------------------------------------------------------
# dtw_all
# ---------------------------------------------------------------------------
class TestDtwAll:
def test_identical_sequences_distance_zero(
self, simple_sequence: np.ndarray
) -> None:
result = dtw_all(simple_sequence, simple_sequence)
assert isinstance(result, DTWResult)
assert result.distance == pytest.approx(0.0, abs=1e-9)
# Identical sequences produce a diagonal warping path.
assert all(i == j for i, j in result.path)
def test_shifted_sequences_distance_zero(
self, simple_sequence: np.ndarray
) -> None:
"""DTW should absorb a pure time shift without penalty."""
# Duplicate the first frame to create a one-frame shift.
shifted = np.concatenate([simple_sequence[:1], simple_sequence], axis=0)
result = dtw_all(simple_sequence, shifted)
assert result.distance == pytest.approx(0.0, abs=1e-9)
def test_different_sequences_positive_distance(self) -> None:
a = np.zeros((5, 3, 3))
b = np.ones((5, 3, 3))
result = dtw_all(a, b)
assert result.distance > 0.0
def test_mismatched_joint_count_rejected(self) -> None:
a = np.zeros((5, 3, 3))
b = np.zeros((5, 4, 3))
with pytest.raises(ValueError, match="joint count"):
dtw_all(a, b)
def test_non_3d_input_rejected(self) -> None:
a = np.zeros((5, 3)) # missing trailing axis
b = np.zeros((5, 3))
with pytest.raises(ValueError, match="expected 3D"):
dtw_all(a, b)
# ---------------------------------------------------------------------------
# dtw_per_joint
# ---------------------------------------------------------------------------
class TestDtwPerJoint:
def test_returns_one_result_per_joint(
self, simple_sequence: np.ndarray
) -> None:
results = dtw_per_joint(simple_sequence, simple_sequence)
assert len(results) == simple_sequence.shape[1]
for result in results:
assert isinstance(result, DTWResult)
assert result.distance == pytest.approx(0.0, abs=1e-9)
def test_independent_joint_distances(self) -> None:
# Construct two sequences where joint 0 matches exactly but
# joint 1 is offset by a constant. Per-joint DTW should give
# distance 0 for joint 0 and distance > 0 for joint 1.
a = np.array(
[
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
[[1.0, 1.0, 1.0], [0.0, 0.0, 0.0]],
]
)
b = a.copy()
b[:, 1, :] += 10.0
results = dtw_per_joint(a, b)
assert results[0].distance == pytest.approx(0.0, abs=1e-9)
assert results[1].distance > 0.0
def test_mismatched_joint_count_rejected(self) -> None:
a = np.zeros((5, 3, 3))
b = np.zeros((5, 2, 3))
with pytest.raises(ValueError, match="joint count"):
dtw_per_joint(a, b)
# ---------------------------------------------------------------------------
# dtw_relation
# ---------------------------------------------------------------------------
class TestDtwRelation:
def test_identical_sequences_distance_zero(
self, simple_sequence: np.ndarray
) -> None:
result = dtw_relation(simple_sequence, simple_sequence, joint_i=0, joint_j=1)
assert result.distance == pytest.approx(0.0, abs=1e-9)
def test_same_relative_position_is_zero_even_under_translation(self) -> None:
"""Translating the whole body does not change the
joint-to-joint displacement, so dtw_relation should be 0."""
a = np.zeros((4, 3, 3))
a[:, 0, :] = [0.0, 0.0, 0.0]
a[:, 1, :] = [1.0, 0.0, 0.0]
a[:, 2, :] = [0.0, 1.0, 0.0]
b = a + 50.0 # translate the whole body
result = dtw_relation(a, b, joint_i=0, joint_j=1)
assert result.distance == pytest.approx(0.0, abs=1e-9)
def test_joint_index_out_of_range_rejected(self) -> None:
a = np.zeros((3, 2, 3))
b = np.zeros((3, 2, 3))
with pytest.raises(ValueError, match="joint indices"):
dtw_relation(a, b, joint_i=0, joint_j=5)
def test_mismatched_joint_count_rejected(self) -> None:
a = np.zeros((3, 3, 3))
b = np.zeros((3, 2, 3))
with pytest.raises(ValueError, match="joint count"):
dtw_relation(a, b, joint_i=0, joint_j=1)