neuropose/tests/unit/test_analyzer_segment.py

432 lines
16 KiB
Python

"""Tests for :mod:`neuropose.analyzer.segment`.
Three layers of coverage:
- **Layer 1** (:func:`segment_by_peaks`) against synthetic 1D signals
with known peaks and valleys.
- **Layer 2** (:func:`segment_predictions`) against synthetic
:class:`VideoPredictions` fixtures exercising every extractor variant.
- **Slicing** (:func:`slice_predictions`) — per-rep round-trip and the
metadata rewrite.
The extractor factories and the discriminated :class:`ExtractorSpec`
union are covered incidentally through Layer 2 (which is the only
layer that cares about the spec shape) plus a handful of targeted
schema tests for the validators (distinct joint indices etc.).
"""
from __future__ import annotations
import itertools
import math
import numpy as np
import pytest
from neuropose.analyzer.segment import (
JOINT_INDEX,
JOINT_NAMES,
extract_signal,
joint_angle,
joint_axis,
joint_index,
joint_pair_distance,
joint_speed,
segment_by_peaks,
segment_predictions,
slice_predictions,
)
from neuropose.io import (
JointAngleExtractor,
JointAxisExtractor,
JointPairDistanceExtractor,
JointSpeedExtractor,
Segment,
Segmentation,
VideoPredictions,
)
NUM_JOINTS = 43
def _triple_hump_signal(num_frames: int = 300) -> np.ndarray:
"""Three non-negative sine humps separated by clear zero-valleys."""
t = np.linspace(0.0, 6.0 * math.pi, num_frames)
return np.maximum(0.0, np.sin(t)) ** 2
def _make_predictions(
signal: np.ndarray,
joint: int,
*,
axis: int = 1,
fps: float = 30.0,
) -> VideoPredictions:
"""Build a VideoPredictions whose ``joint``'s ``axis`` follows ``signal``."""
frames = {}
for i, value in enumerate(signal):
poses = [[[0.0, 0.0, 0.0] for _ in range(NUM_JOINTS)]]
poses[0][joint][axis] = float(value)
frames[f"frame_{i:06d}"] = {
"boxes": [[0.0, 0.0, 1.0, 1.0, 0.9]],
"poses3d": poses,
"poses2d": [[[0.0, 0.0]] * NUM_JOINTS],
}
return VideoPredictions.model_validate(
{
"metadata": {
"frame_count": len(signal),
"fps": fps,
"width": 640,
"height": 480,
},
"frames": frames,
}
)
# ---------------------------------------------------------------------------
# JOINT_NAMES / JOINT_INDEX / joint_index()
# ---------------------------------------------------------------------------
class TestJointNames:
def test_tuple_length_is_43(self) -> None:
assert len(JOINT_NAMES) == 43
def test_index_matches_position(self) -> None:
for idx, name in enumerate(JOINT_NAMES):
assert JOINT_INDEX[name] == idx
def test_joint_index_by_name(self) -> None:
assert joint_index("lwri") == JOINT_NAMES.index("lwri")
assert joint_index("rwri") == JOINT_NAMES.index("rwri")
def test_joint_index_unknown_name(self) -> None:
with pytest.raises(KeyError, match="unknown joint name"):
joint_index("elbow") # deliberately wrong spelling
# ---------------------------------------------------------------------------
# Factory shortcuts
# ---------------------------------------------------------------------------
class TestFactories:
def test_joint_axis_factory(self) -> None:
spec = joint_axis(JOINT_INDEX["lwri"], 1, invert=True)
assert isinstance(spec, JointAxisExtractor)
assert spec.kind == "joint_axis"
assert spec.joint == JOINT_INDEX["lwri"]
assert spec.axis == 1
assert spec.invert is True
def test_joint_pair_distance_factory(self) -> None:
spec = joint_pair_distance(JOINT_INDEX["lwri"], JOINT_INDEX["rwri"])
assert isinstance(spec, JointPairDistanceExtractor)
assert spec.joints == (JOINT_INDEX["lwri"], JOINT_INDEX["rwri"])
def test_joint_pair_distance_rejects_same_joint(self) -> None:
from pydantic import ValidationError
with pytest.raises(ValidationError, match="distinct"):
joint_pair_distance(5, 5)
def test_joint_speed_factory(self) -> None:
spec = joint_speed(JOINT_INDEX["rwri"])
assert isinstance(spec, JointSpeedExtractor)
assert spec.joint == JOINT_INDEX["rwri"]
def test_joint_angle_factory(self) -> None:
spec = joint_angle(
JOINT_INDEX["larm"],
JOINT_INDEX["lelb"],
JOINT_INDEX["lwri"],
)
assert isinstance(spec, JointAngleExtractor)
assert spec.triplet == (
JOINT_INDEX["larm"],
JOINT_INDEX["lelb"],
JOINT_INDEX["lwri"],
)
# ---------------------------------------------------------------------------
# extract_signal: one test per extractor variant
# ---------------------------------------------------------------------------
class TestExtractSignal:
def test_joint_axis_selects_axis(self) -> None:
seq = np.zeros((4, NUM_JOINTS, 3))
seq[:, 10, 1] = [1.0, 2.0, 3.0, 4.0]
signal = extract_signal(seq, joint_axis(10, 1))
np.testing.assert_array_equal(signal, [1.0, 2.0, 3.0, 4.0])
def test_joint_axis_invert(self) -> None:
seq = np.zeros((3, NUM_JOINTS, 3))
seq[:, 0, 0] = [1.0, 2.0, 3.0]
signal = extract_signal(seq, joint_axis(0, 0, invert=True))
np.testing.assert_array_equal(signal, [-1.0, -2.0, -3.0])
def test_joint_pair_distance(self) -> None:
seq = np.zeros((3, NUM_JOINTS, 3))
seq[:, 0, 0] = [0.0, 0.0, 0.0]
seq[:, 1, 0] = [3.0, 6.0, 9.0] # distances 3, 6, 9 along x
signal = extract_signal(seq, joint_pair_distance(0, 1))
np.testing.assert_allclose(signal, [3.0, 6.0, 9.0])
def test_joint_speed_pads_first_frame_with_zero(self) -> None:
seq = np.zeros((4, NUM_JOINTS, 3))
seq[:, 5, 0] = [0.0, 1.0, 3.0, 6.0] # speeds: 1, 2, 3
signal = extract_signal(seq, joint_speed(5))
np.testing.assert_allclose(signal, [0.0, 1.0, 2.0, 3.0])
def test_joint_speed_single_frame_rejected(self) -> None:
seq = np.zeros((1, NUM_JOINTS, 3))
with pytest.raises(ValueError, match="at least two frames"):
extract_signal(seq, joint_speed(0))
def test_joint_angle_straight(self) -> None:
seq = np.zeros((2, NUM_JOINTS, 3))
# Straight line: a=(-1,0,0), b=(0,0,0), c=(1,0,0). Angle = pi.
seq[:, 0, 0] = [-1.0, -1.0]
seq[:, 1, 0] = [0.0, 0.0]
seq[:, 2, 0] = [1.0, 1.0]
signal = extract_signal(seq, joint_angle(0, 1, 2))
np.testing.assert_allclose(signal, [math.pi, math.pi])
def test_joint_angle_right(self) -> None:
seq = np.zeros((1, NUM_JOINTS, 3))
seq[0, 0] = [1.0, 0.0, 0.0]
seq[0, 1] = [0.0, 0.0, 0.0]
seq[0, 2] = [0.0, 1.0, 0.0]
signal = extract_signal(seq, joint_angle(0, 1, 2))
np.testing.assert_allclose(signal, [math.pi / 2])
def test_out_of_range_joint_index(self) -> None:
seq = np.zeros((3, NUM_JOINTS, 3))
with pytest.raises(ValueError, match="out of range"):
extract_signal(seq, joint_axis(999, 0))
def test_bad_sequence_shape(self) -> None:
with pytest.raises(ValueError, match="frames, joints, 3"):
extract_signal(np.zeros((5, 10)), joint_axis(0, 0))
# ---------------------------------------------------------------------------
# Layer 1: segment_by_peaks
# ---------------------------------------------------------------------------
class TestSegmentByPeaks:
def test_three_humps_three_segments(self) -> None:
signal = _triple_hump_signal()
segs = segment_by_peaks(signal, min_prominence=0.1)
assert len(segs) == 3
# Segments should not overlap (in this synthetic case). Adjacent
# segments are allowed to share the valley frame on their boundary,
# so we use a ``>=`` comparison with one frame of slack.
for prev, curr in itertools.pairwise(segs):
assert curr.start >= prev.end - 1
def test_first_segment_starts_at_zero_without_leading_valley(self) -> None:
signal = _triple_hump_signal()
segs = segment_by_peaks(signal, min_prominence=0.1)
assert segs[0].start == 0
def test_last_segment_ends_at_signal_length(self) -> None:
signal = _triple_hump_signal()
segs = segment_by_peaks(signal, min_prominence=0.1)
assert segs[-1].end == len(signal)
def test_peaks_lie_inside_segment(self) -> None:
signal = _triple_hump_signal()
segs = segment_by_peaks(signal, min_prominence=0.1)
for seg in segs:
assert seg.start <= seg.peak < seg.end
def test_no_peaks_returns_empty(self) -> None:
flat = np.zeros(50)
segs = segment_by_peaks(flat)
assert segs == []
def test_min_distance_suppresses_close_peaks(self) -> None:
# A signal with two very close peaks should give only one segment
# when min_distance is large.
signal = np.zeros(100)
signal[20] = 1.0
signal[25] = 1.0
signal[70] = 1.0
segs = segment_by_peaks(signal, min_distance=30)
# Exact count depends on scipy's tie-breaking; main assertion is
# "fewer than if no distance constraint".
segs_unconstrained = segment_by_peaks(signal)
assert len(segs) < len(segs_unconstrained)
def test_pad_extends_segment(self) -> None:
signal = _triple_hump_signal()
base = segment_by_peaks(signal, min_prominence=0.1)
padded = segment_by_peaks(signal, min_prominence=0.1, pad=5)
assert len(base) == len(padded)
for b, p in zip(base, padded, strict=True):
assert p.start <= b.start
assert p.end >= b.end
def test_pad_is_clamped_to_bounds(self) -> None:
signal = _triple_hump_signal()
padded = segment_by_peaks(signal, min_prominence=0.1, pad=10_000)
for seg in padded:
assert seg.start >= 0
assert seg.end <= len(signal)
def test_negative_pad_rejected(self) -> None:
with pytest.raises(ValueError, match="pad"):
segment_by_peaks(np.zeros(10), pad=-1)
def test_rejects_non_1d(self) -> None:
with pytest.raises(ValueError, match="1D"):
segment_by_peaks(np.zeros((5, 5)))
# ---------------------------------------------------------------------------
# Layer 2: segment_predictions
# ---------------------------------------------------------------------------
class TestSegmentPredictions:
def test_returns_segmentation_with_config(self) -> None:
signal = _triple_hump_signal() * 1000.0 # mm scale
preds = _make_predictions(signal, joint=JOINT_INDEX["lwri"])
result = segment_predictions(
preds,
joint_axis(JOINT_INDEX["lwri"], 1),
min_prominence=50.0,
)
assert isinstance(result, Segmentation)
assert len(result.segments) == 3
assert result.config.extractor.kind == "joint_axis"
assert result.config.min_prominence == 50.0
assert result.config.method == "valley_to_valley_v1"
def test_min_distance_seconds_converts_via_fps(self) -> None:
# 300 frames at 30 fps = 10 seconds; humps are ~3.3 s apart.
signal = _triple_hump_signal() * 1000.0
preds = _make_predictions(signal, joint=JOINT_INDEX["lwri"], fps=30.0)
# A 5-second minimum distance should collapse the three humps
# into at most two segments.
result = segment_predictions(
preds,
joint_axis(JOINT_INDEX["lwri"], 1),
min_prominence=50.0,
min_distance_seconds=5.0,
)
assert len(result.segments) <= 2
def test_pad_seconds_extends_segments(self) -> None:
signal = _triple_hump_signal() * 1000.0
preds = _make_predictions(signal, joint=JOINT_INDEX["lwri"], fps=30.0)
plain = segment_predictions(preds, joint_axis(JOINT_INDEX["lwri"], 1), min_prominence=50.0)
padded = segment_predictions(
preds,
joint_axis(JOINT_INDEX["lwri"], 1),
min_prominence=50.0,
pad_seconds=0.2, # ~6 frames at 30 fps
)
# At least one segment should have moved outward.
assert any(
p.start < b.start or p.end > b.end
for p, b in zip(padded.segments, plain.segments, strict=True)
)
def test_requires_fps_when_time_params_used(self) -> None:
signal = _triple_hump_signal() * 1000.0
preds = _make_predictions(signal, joint=JOINT_INDEX["lwri"], fps=0.0)
with pytest.raises(ValueError, match="fps"):
segment_predictions(
preds,
joint_axis(JOINT_INDEX["lwri"], 1),
min_prominence=50.0,
min_distance_seconds=1.0,
)
def test_no_fps_is_fine_without_time_params(self) -> None:
signal = _triple_hump_signal() * 1000.0
preds = _make_predictions(signal, joint=JOINT_INDEX["lwri"], fps=0.0)
# Without any time-based parameters we never need to multiply by
# fps, so fps=0 is tolerated.
result = segment_predictions(
preds,
joint_axis(JOINT_INDEX["lwri"], 1),
min_prominence=50.0,
)
assert len(result.segments) == 3
def test_config_roundtrips_through_json(self) -> None:
signal = _triple_hump_signal() * 1000.0
preds = _make_predictions(signal, joint=JOINT_INDEX["lwri"])
result = segment_predictions(
preds,
joint_pair_distance(JOINT_INDEX["lwri"], JOINT_INDEX["rwri"]),
min_prominence=10.0,
)
serialized = result.model_dump(mode="json")
rehydrated = Segmentation.model_validate(serialized)
assert rehydrated == result
# ---------------------------------------------------------------------------
# slice_predictions
# ---------------------------------------------------------------------------
class TestSlicePredictions:
def test_one_output_per_segment(self) -> None:
signal = _triple_hump_signal() * 1000.0
preds = _make_predictions(signal, joint=JOINT_INDEX["lwri"])
result = segment_predictions(preds, joint_axis(JOINT_INDEX["lwri"], 1), min_prominence=50.0)
slices = slice_predictions(preds, result.segments)
assert len(slices) == len(result.segments)
def test_metadata_frame_count_matches_segment_length(self) -> None:
signal = _triple_hump_signal() * 1000.0
preds = _make_predictions(signal, joint=JOINT_INDEX["lwri"])
segments = [Segment(start=10, end=30, peak=20), Segment(start=50, end=90, peak=75)]
slices = slice_predictions(preds, segments)
assert slices[0].metadata.frame_count == 20
assert slices[1].metadata.frame_count == 40
def test_frames_are_rekeyed_from_zero(self) -> None:
signal = _triple_hump_signal() * 1000.0
preds = _make_predictions(signal, joint=JOINT_INDEX["lwri"])
segments = [Segment(start=100, end=110, peak=105)]
sliced = slice_predictions(preds, segments)[0]
assert sliced.frame_names()[0] == "frame_000000"
assert sliced.frame_names()[-1] == "frame_000009"
def test_sliced_segmentations_field_is_empty(self) -> None:
# Parent has a segmentation attached; sliced copies intentionally
# drop it because segment indices are only meaningful in the
# parent's timeline.
signal = _triple_hump_signal() * 1000.0
preds = _make_predictions(signal, joint=JOINT_INDEX["lwri"])
result = segment_predictions(preds, joint_axis(JOINT_INDEX["lwri"], 1), min_prominence=50.0)
parent = preds.model_copy(update={"segmentations": {"cup_lift": result}})
slices = slice_predictions(parent, result.segments)
assert all(s.segmentations == {} for s in slices)
def test_out_of_bounds_segment_rejected(self) -> None:
signal = _triple_hump_signal(num_frames=50) * 1000.0
preds = _make_predictions(signal, joint=JOINT_INDEX["lwri"])
bad = [Segment(start=0, end=100, peak=25)] # end > 50
with pytest.raises(ValueError, match="exceeds frame count"):
slice_predictions(preds, bad)
def test_slice_preserves_pose_values(self) -> None:
signal = _triple_hump_signal() * 1000.0
preds = _make_predictions(signal, joint=JOINT_INDEX["lwri"])
segments = [Segment(start=50, end=55, peak=52)]
sliced = slice_predictions(preds, segments)[0]
# frame_000000 of the slice must equal frame_000050 of the source
assert sliced["frame_000000"].poses3d == preds["frame_000050"].poses3d