neuropose/tests/unit/test_io.py

612 lines
21 KiB
Python

"""Tests for :mod:`neuropose.io` schema and helpers."""
from __future__ import annotations
import json
from datetime import UTC, datetime
from pathlib import Path
import pytest
from pydantic import ValidationError
from neuropose.io import (
BenchmarkAggregate,
BenchmarkResult,
CpuComparisonResult,
FramePrediction,
JobResults,
JobStatus,
JobStatusEntry,
JointAngleExtractor,
JointAxisExtractor,
JointPairDistanceExtractor,
JointSpeedExtractor,
PerformanceMetrics,
Segment,
Segmentation,
SegmentationConfig,
StatusFile,
VideoMetadata,
VideoPredictions,
load_benchmark_result,
load_job_results,
load_status,
load_video_predictions,
save_benchmark_result,
save_job_results,
save_status,
save_video_predictions,
)
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def one_frame() -> dict:
"""A minimal valid FramePrediction payload (one person, two joints)."""
return {
"boxes": [[10.0, 20.0, 100.0, 200.0, 0.95]],
"poses3d": [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]],
"poses2d": [[[10.0, 20.0], [30.0, 40.0]]],
}
@pytest.fixture
def video_metadata_payload() -> dict:
return {"frame_count": 2, "fps": 30.0, "width": 640, "height": 480}
@pytest.fixture
def video_predictions_payload(one_frame: dict, video_metadata_payload: dict) -> dict:
return {
"metadata": video_metadata_payload,
"frames": {
"frame_000000": one_frame,
"frame_000001": one_frame,
},
}
# ---------------------------------------------------------------------------
# FramePrediction
# ---------------------------------------------------------------------------
class TestFramePrediction:
def test_roundtrip(self, one_frame: dict) -> None:
frame = FramePrediction.model_validate(one_frame)
assert frame.boxes == one_frame["boxes"]
assert frame.poses3d == one_frame["poses3d"]
assert frame.poses2d == one_frame["poses2d"]
def test_rejects_extra_fields(self, one_frame: dict) -> None:
one_frame["extra"] = "bogus"
with pytest.raises(ValidationError):
FramePrediction.model_validate(one_frame)
def test_is_frozen(self, one_frame: dict) -> None:
frame = FramePrediction.model_validate(one_frame)
with pytest.raises(ValidationError):
frame.boxes = []
# ---------------------------------------------------------------------------
# VideoMetadata
# ---------------------------------------------------------------------------
class TestVideoMetadata:
def test_valid(self) -> None:
meta = VideoMetadata(frame_count=10, fps=29.97, width=1920, height=1080)
assert meta.frame_count == 10
assert meta.fps == pytest.approx(29.97)
def test_zero_frame_count_allowed(self) -> None:
# Broken or empty videos still produce a valid metadata object so
# the caller can see frame_count == 0 rather than receiving an
# exception.
VideoMetadata(frame_count=0, fps=0.0, width=0, height=0)
def test_rejects_negative(self) -> None:
with pytest.raises(ValidationError):
VideoMetadata(frame_count=-1, fps=30.0, width=640, height=480)
def test_rejects_extra_fields(self) -> None:
with pytest.raises(ValidationError):
VideoMetadata(
frame_count=1,
fps=30.0,
width=640,
height=480,
source_path="/leak/me", # type: ignore[call-arg]
)
def test_is_frozen(self) -> None:
meta = VideoMetadata(frame_count=10, fps=30.0, width=640, height=480)
with pytest.raises(ValidationError):
meta.fps = 60.0
# ---------------------------------------------------------------------------
# VideoPredictions
# ---------------------------------------------------------------------------
class TestVideoPredictions:
def test_from_dict(self, video_predictions_payload: dict) -> None:
vp = VideoPredictions.model_validate(video_predictions_payload)
assert len(vp) == 2
assert vp.frame_names() == ["frame_000000", "frame_000001"]
assert vp["frame_000000"].boxes[0][4] == pytest.approx(0.95)
assert vp.metadata.fps == pytest.approx(30.0)
def test_iteration(self, video_predictions_payload: dict) -> None:
vp = VideoPredictions.model_validate(video_predictions_payload)
assert list(vp) == ["frame_000000", "frame_000001"]
def test_rejects_missing_metadata(self, video_predictions_payload: dict) -> None:
del video_predictions_payload["metadata"]
with pytest.raises(ValidationError):
VideoPredictions.model_validate(video_predictions_payload)
def test_save_and_load_roundtrip(
self,
tmp_path: Path,
video_predictions_payload: dict,
) -> None:
vp = VideoPredictions.model_validate(video_predictions_payload)
path = tmp_path / "preds" / "video.json"
save_video_predictions(path, vp)
assert path.exists()
loaded = load_video_predictions(path)
assert loaded.frame_names() == vp.frame_names()
assert loaded.metadata == vp.metadata
assert loaded["frame_000000"].poses3d == vp["frame_000000"].poses3d
def test_save_is_atomic(
self,
tmp_path: Path,
video_predictions_payload: dict,
) -> None:
vp = VideoPredictions.model_validate(video_predictions_payload)
path = tmp_path / "video.json"
save_video_predictions(path, vp)
assert path.exists()
tmps = list(tmp_path.glob("video.json.tmp"))
assert tmps == []
# ---------------------------------------------------------------------------
# JobResults
# ---------------------------------------------------------------------------
class TestJobResults:
def test_save_and_load_roundtrip(
self,
tmp_path: Path,
video_predictions_payload: dict,
) -> None:
jr = JobResults.model_validate(
{
"video_a.mp4": video_predictions_payload,
"video_b.mp4": video_predictions_payload,
}
)
path = tmp_path / "results.json"
save_job_results(path, jr)
loaded = load_job_results(path)
assert loaded.videos() == ["video_a.mp4", "video_b.mp4"]
assert len(loaded["video_a.mp4"]) == 2
# ---------------------------------------------------------------------------
# Performance / benchmark schemas
# ---------------------------------------------------------------------------
def _make_metrics(
*,
total_seconds: float = 1.0,
latencies: list[float] | None = None,
peak_rss_mb: float = 512.0,
active_device: str = "/CPU:0",
metal_active: bool = False,
model_load_seconds: float | None = None,
) -> PerformanceMetrics:
return PerformanceMetrics(
model_load_seconds=model_load_seconds,
total_seconds=total_seconds,
per_frame_latencies_ms=latencies if latencies is not None else [10.0, 11.0, 9.5],
peak_rss_mb=peak_rss_mb,
active_device=active_device,
tensorflow_metal_active=metal_active,
tensorflow_version="2.18.0",
)
def _make_aggregate() -> BenchmarkAggregate:
return BenchmarkAggregate(
repeats_measured=4,
warmup_frames_per_pass=3,
mean_frame_latency_ms=10.0,
p50_frame_latency_ms=9.8,
p95_frame_latency_ms=12.5,
p99_frame_latency_ms=13.0,
stddev_frame_latency_ms=0.7,
mean_throughput_fps=100.0,
peak_rss_mb_max=512.0,
active_device="/CPU:0",
tensorflow_metal_active=False,
tensorflow_version="2.18.0",
)
class TestPerformanceMetricsModel:
def test_roundtrip(self) -> None:
m = _make_metrics()
rehydrated = PerformanceMetrics.model_validate(m.model_dump(mode="json"))
assert rehydrated == m
def test_rejects_negative_total_seconds(self) -> None:
with pytest.raises(ValidationError):
PerformanceMetrics(
total_seconds=-1.0,
peak_rss_mb=0.0,
active_device="/CPU:0",
tensorflow_version="2.18.0",
)
def test_rejects_negative_peak_rss(self) -> None:
with pytest.raises(ValidationError):
PerformanceMetrics(
total_seconds=1.0,
peak_rss_mb=-5.0,
active_device="/CPU:0",
tensorflow_version="2.18.0",
)
def test_model_load_seconds_optional(self) -> None:
m = _make_metrics(model_load_seconds=None)
assert m.model_load_seconds is None
def test_is_frozen(self) -> None:
m = _make_metrics()
with pytest.raises(ValidationError):
m.total_seconds = 2.0
class TestBenchmarkResultPersistence:
def test_roundtrip_to_disk(self, tmp_path: Path) -> None:
result = BenchmarkResult(
video_name="trial.mp4",
repeats=5,
warmup_frames=3,
warmup_pass=_make_metrics(total_seconds=20.0),
measured_passes=[_make_metrics(total_seconds=1.5) for _ in range(4)],
aggregate=_make_aggregate(),
)
path = tmp_path / "bench.json"
save_benchmark_result(path, result)
assert path.exists()
loaded = load_benchmark_result(path)
assert loaded == result
def test_rejects_repeats_below_one(self) -> None:
with pytest.raises(ValidationError):
BenchmarkResult(
video_name="x.mp4",
repeats=0,
warmup_frames=0,
warmup_pass=_make_metrics(),
measured_passes=[],
aggregate=_make_aggregate(),
)
def test_cpu_comparison_nested(self, tmp_path: Path) -> None:
comparison = CpuComparisonResult(
primary_aggregate=_make_aggregate(),
cpu_aggregate=_make_aggregate(),
speedup=2.5,
max_poses3d_divergence_mm=0.002,
frame_count_compared=30,
)
result = BenchmarkResult(
video_name="trial.mp4",
repeats=5,
warmup_frames=3,
warmup_pass=_make_metrics(),
measured_passes=[_make_metrics() for _ in range(4)],
aggregate=_make_aggregate(),
cpu_comparison=comparison,
)
path = tmp_path / "bench_with_cmp.json"
save_benchmark_result(path, result)
loaded = load_benchmark_result(path)
assert loaded.cpu_comparison is not None
assert loaded.cpu_comparison.speedup == pytest.approx(2.5)
assert loaded.cpu_comparison.max_poses3d_divergence_mm == pytest.approx(0.002)
# ---------------------------------------------------------------------------
# Segmentation schema
# ---------------------------------------------------------------------------
class TestSegmentModel:
def test_valid(self) -> None:
seg = Segment(start=0, end=30, peak=15)
assert seg.start == 0
assert seg.peak == 15
assert seg.end == 30
def test_rejects_end_not_greater_than_start(self) -> None:
with pytest.raises(ValidationError, match="end"):
Segment(start=10, end=10, peak=10)
def test_peak_must_be_inside_window(self) -> None:
with pytest.raises(ValidationError, match="peak"):
Segment(start=0, end=30, peak=30) # peak == end is out of range
def test_is_frozen(self) -> None:
seg = Segment(start=0, end=10, peak=5)
with pytest.raises(ValidationError):
seg.start = 1
class TestExtractorSpecs:
def test_joint_pair_distance_rejects_identical_joints(self) -> None:
with pytest.raises(ValidationError, match="distinct"):
JointPairDistanceExtractor(joints=(7, 7))
def test_joint_pair_distance_rejects_negative(self) -> None:
with pytest.raises(ValidationError, match="non-negative"):
JointPairDistanceExtractor(joints=(-1, 5))
def test_joint_angle_rejects_negative(self) -> None:
with pytest.raises(ValidationError, match="non-negative"):
JointAngleExtractor(triplet=(0, -1, 2))
def test_joint_axis_rejects_bad_axis(self) -> None:
with pytest.raises(ValidationError):
JointAxisExtractor(joint=0, axis=3)
def test_discriminator_dispatches_to_correct_variant(self) -> None:
# Round-trip each extractor variant through a SegmentationConfig
# dict to confirm the discriminator selects the right class.
for payload, cls in [
({"kind": "joint_axis", "joint": 1, "axis": 2}, JointAxisExtractor),
({"kind": "joint_pair_distance", "joints": [1, 2]}, JointPairDistanceExtractor),
({"kind": "joint_speed", "joint": 3}, JointSpeedExtractor),
({"kind": "joint_angle", "triplet": [1, 2, 3]}, JointAngleExtractor),
]:
cfg = SegmentationConfig.model_validate({"extractor": payload})
assert isinstance(cfg.extractor, cls)
class TestSegmentationPersistence:
def test_roundtrip_through_video_predictions(
self,
tmp_path: Path,
video_predictions_payload: dict,
) -> None:
cfg = SegmentationConfig(
extractor=JointAxisExtractor(joint=15, axis=1),
min_prominence=20.0,
pad_seconds=0.1,
)
segmentation = Segmentation(
config=cfg,
segments=[Segment(start=0, end=1, peak=0), Segment(start=1, end=2, peak=1)],
)
video_predictions_payload["segmentations"] = {
"cup_lift": segmentation.model_dump(mode="json")
}
vp = VideoPredictions.model_validate(video_predictions_payload)
path = tmp_path / "video.json"
save_video_predictions(path, vp)
loaded = load_video_predictions(path)
assert "cup_lift" in loaded.segmentations
cup = loaded.segmentations["cup_lift"]
assert cup.config.extractor.kind == "joint_axis"
assert len(cup.segments) == 2
assert cup.config.method == "valley_to_valley_v1"
def test_default_empty_segmentations_on_new_instance(
self, video_predictions_payload: dict
) -> None:
vp = VideoPredictions.model_validate(video_predictions_payload)
assert vp.segmentations == {}
def test_legacy_file_without_segmentations_loads_clean(
self,
tmp_path: Path,
video_predictions_payload: dict,
) -> None:
# Older predictions files never wrote the segmentations field;
# make sure they still validate and deserialize as if they had
# an empty mapping.
assert "segmentations" not in video_predictions_payload
path = tmp_path / "legacy.json"
path.write_text(json.dumps(video_predictions_payload))
vp = load_video_predictions(path)
assert vp.segmentations == {}
# ---------------------------------------------------------------------------
# Status file
# ---------------------------------------------------------------------------
class TestStatusFile:
def test_load_missing_returns_empty(self, tmp_path: Path) -> None:
status = load_status(tmp_path / "nope.json")
assert status.is_empty()
def test_load_corrupt_json_returns_empty(self, tmp_path: Path) -> None:
path = tmp_path / "bad.json"
path.write_text("{ not valid json")
status = load_status(path)
assert status.is_empty()
def test_load_non_mapping_returns_empty(self, tmp_path: Path) -> None:
path = tmp_path / "list.json"
path.write_text(json.dumps([1, 2, 3]))
status = load_status(path)
assert status.is_empty()
def test_save_and_load_completed_entry(self, tmp_path: Path) -> None:
started = datetime(2026, 4, 13, 12, 0, 0, tzinfo=UTC)
completed = datetime(2026, 4, 13, 12, 5, 0, tzinfo=UTC)
status = StatusFile.model_validate(
{
"job_001": {
"status": "completed",
"started_at": started.isoformat(),
"completed_at": completed.isoformat(),
"results_path": "/tmp/results.json",
"error": None,
}
}
)
path = tmp_path / "status.json"
save_status(path, status)
loaded = load_status(path)
entry = loaded.root["job_001"]
assert entry.status == JobStatus.COMPLETED
assert entry.started_at == started
assert entry.completed_at == completed
assert entry.error is None
def test_save_is_atomic(self, tmp_path: Path) -> None:
"""``save_status`` leaves no orphan ``.tmp`` file on success."""
started = datetime(2026, 4, 13, tzinfo=UTC)
status = StatusFile.model_validate(
{
"job_001": {
"status": "processing",
"started_at": started.isoformat(),
}
}
)
path = tmp_path / "status.json"
save_status(path, status)
assert path.exists()
tmps = list(tmp_path.glob("status.json.tmp"))
assert tmps == []
def test_failed_entry_carries_error_message(self, tmp_path: Path) -> None:
started = datetime(2026, 4, 13, tzinfo=UTC)
status = StatusFile.model_validate(
{
"job_001": {
"status": "failed",
"started_at": started.isoformat(),
"error": "ffmpeg decode failed: codec not supported",
}
}
)
path = tmp_path / "status.json"
save_status(path, status)
loaded = load_status(path)
entry = loaded.root["job_001"]
assert entry.status == JobStatus.FAILED
assert entry.error is not None
assert "ffmpeg" in entry.error
def test_rejects_unknown_status(self, tmp_path: Path) -> None:
with pytest.raises(ValidationError):
StatusFile.model_validate(
{
"job_001": {
"status": "some-unknown-state",
"started_at": datetime(2026, 4, 13, tzinfo=UTC).isoformat(),
}
}
)
class TestJobStatusEntryProgressFields:
def test_progress_fields_default_to_none(self) -> None:
entry = JobStatusEntry(
status=JobStatus.PROCESSING,
started_at=datetime(2026, 4, 13, tzinfo=UTC),
)
assert entry.current_video is None
assert entry.frames_processed is None
assert entry.frames_total is None
assert entry.videos_completed is None
assert entry.videos_total is None
assert entry.percent_complete is None
assert entry.last_update is None
def test_legacy_status_file_without_progress_loads(self, tmp_path: Path) -> None:
"""Files written before the progress fields existed must still load."""
started = datetime(2026, 4, 13, 12, 0, 0, tzinfo=UTC)
path = tmp_path / "legacy.json"
path.write_text(
json.dumps(
{
"job_001": {
"status": "completed",
"started_at": started.isoformat(),
"completed_at": started.isoformat(),
"results_path": "/tmp/results.json",
"error": None,
}
}
)
)
loaded = load_status(path)
entry = loaded.root["job_001"]
assert entry.percent_complete is None
def test_progress_roundtrips_through_json(self, tmp_path: Path) -> None:
now = datetime(2026, 4, 13, 12, 0, 0, tzinfo=UTC)
status = StatusFile(
root={
"job_001": JobStatusEntry(
status=JobStatus.PROCESSING,
started_at=now,
current_video="trial_01.mp4",
frames_processed=450,
frames_total=1200,
videos_completed=0,
videos_total=3,
percent_complete=12.5,
last_update=now,
),
}
)
path = tmp_path / "status.json"
save_status(path, status)
loaded = load_status(path)
entry = loaded.root["job_001"]
assert entry.current_video == "trial_01.mp4"
assert entry.frames_processed == 450
assert entry.percent_complete == 12.5
def test_percent_complete_rejects_out_of_range(self) -> None:
with pytest.raises(ValidationError):
JobStatusEntry(
status=JobStatus.PROCESSING,
started_at=datetime(2026, 4, 13, tzinfo=UTC),
percent_complete=150.0,
)
with pytest.raises(ValidationError):
JobStatusEntry(
status=JobStatus.PROCESSING,
started_at=datetime(2026, 4, 13, tzinfo=UTC),
percent_complete=-5.0,
)
def test_frames_processed_rejects_negative(self) -> None:
with pytest.raises(ValidationError):
JobStatusEntry(
status=JobStatus.PROCESSING,
started_at=datetime(2026, 4, 13, tzinfo=UTC),
frames_processed=-1,
)