"""Tests for :mod:`neuropose.io` schema and helpers.""" from __future__ import annotations import json from datetime import UTC, datetime from pathlib import Path import pytest from pydantic import ValidationError from neuropose.io import ( BenchmarkAggregate, BenchmarkResult, CpuComparisonResult, FramePrediction, JobResults, JobStatus, JobStatusEntry, JointAngleExtractor, JointAxisExtractor, JointPairDistanceExtractor, JointSpeedExtractor, PerformanceMetrics, Provenance, Segment, Segmentation, SegmentationConfig, StatusFile, VideoMetadata, VideoPredictions, load_benchmark_result, load_job_results, load_status, load_video_predictions, save_benchmark_result, save_job_results, save_status, save_video_predictions, ) # --------------------------------------------------------------------------- # Fixtures # --------------------------------------------------------------------------- @pytest.fixture def one_frame() -> dict: """A minimal valid FramePrediction payload (one person, two joints).""" return { "boxes": [[10.0, 20.0, 100.0, 200.0, 0.95]], "poses3d": [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]], "poses2d": [[[10.0, 20.0], [30.0, 40.0]]], } @pytest.fixture def video_metadata_payload() -> dict: return {"frame_count": 2, "fps": 30.0, "width": 640, "height": 480} @pytest.fixture def video_predictions_payload(one_frame: dict, video_metadata_payload: dict) -> dict: return { "metadata": video_metadata_payload, "frames": { "frame_000000": one_frame, "frame_000001": one_frame, }, } # --------------------------------------------------------------------------- # FramePrediction # --------------------------------------------------------------------------- class TestFramePrediction: def test_roundtrip(self, one_frame: dict) -> None: frame = FramePrediction.model_validate(one_frame) assert frame.boxes == one_frame["boxes"] assert frame.poses3d == one_frame["poses3d"] assert frame.poses2d == one_frame["poses2d"] def test_rejects_extra_fields(self, one_frame: dict) -> None: one_frame["extra"] = "bogus" with pytest.raises(ValidationError): FramePrediction.model_validate(one_frame) def test_is_frozen(self, one_frame: dict) -> None: frame = FramePrediction.model_validate(one_frame) with pytest.raises(ValidationError): frame.boxes = [] # --------------------------------------------------------------------------- # VideoMetadata # --------------------------------------------------------------------------- class TestVideoMetadata: def test_valid(self) -> None: meta = VideoMetadata(frame_count=10, fps=29.97, width=1920, height=1080) assert meta.frame_count == 10 assert meta.fps == pytest.approx(29.97) def test_zero_frame_count_allowed(self) -> None: # Broken or empty videos still produce a valid metadata object so # the caller can see frame_count == 0 rather than receiving an # exception. VideoMetadata(frame_count=0, fps=0.0, width=0, height=0) def test_rejects_negative(self) -> None: with pytest.raises(ValidationError): VideoMetadata(frame_count=-1, fps=30.0, width=640, height=480) def test_rejects_extra_fields(self) -> None: with pytest.raises(ValidationError): VideoMetadata( frame_count=1, fps=30.0, width=640, height=480, source_path="/leak/me", # type: ignore[call-arg] ) def test_is_frozen(self) -> None: meta = VideoMetadata(frame_count=10, fps=30.0, width=640, height=480) with pytest.raises(ValidationError): meta.fps = 60.0 # --------------------------------------------------------------------------- # VideoPredictions # --------------------------------------------------------------------------- class TestVideoPredictions: def test_from_dict(self, video_predictions_payload: dict) -> None: vp = VideoPredictions.model_validate(video_predictions_payload) assert len(vp) == 2 assert vp.frame_names() == ["frame_000000", "frame_000001"] assert vp["frame_000000"].boxes[0][4] == pytest.approx(0.95) assert vp.metadata.fps == pytest.approx(30.0) def test_iteration(self, video_predictions_payload: dict) -> None: vp = VideoPredictions.model_validate(video_predictions_payload) assert list(vp) == ["frame_000000", "frame_000001"] def test_rejects_missing_metadata(self, video_predictions_payload: dict) -> None: del video_predictions_payload["metadata"] with pytest.raises(ValidationError): VideoPredictions.model_validate(video_predictions_payload) def test_save_and_load_roundtrip( self, tmp_path: Path, video_predictions_payload: dict, ) -> None: vp = VideoPredictions.model_validate(video_predictions_payload) path = tmp_path / "preds" / "video.json" save_video_predictions(path, vp) assert path.exists() loaded = load_video_predictions(path) assert loaded.frame_names() == vp.frame_names() assert loaded.metadata == vp.metadata assert loaded["frame_000000"].poses3d == vp["frame_000000"].poses3d def test_save_is_atomic( self, tmp_path: Path, video_predictions_payload: dict, ) -> None: vp = VideoPredictions.model_validate(video_predictions_payload) path = tmp_path / "video.json" save_video_predictions(path, vp) assert path.exists() tmps = list(tmp_path.glob("video.json.tmp")) assert tmps == [] # --------------------------------------------------------------------------- # JobResults # --------------------------------------------------------------------------- class TestJobResults: def test_save_and_load_roundtrip( self, tmp_path: Path, video_predictions_payload: dict, ) -> None: jr = JobResults.model_validate( { "video_a.mp4": video_predictions_payload, "video_b.mp4": video_predictions_payload, } ) path = tmp_path / "results.json" save_job_results(path, jr) loaded = load_job_results(path) assert loaded.videos() == ["video_a.mp4", "video_b.mp4"] assert len(loaded["video_a.mp4"]) == 2 # --------------------------------------------------------------------------- # Performance / benchmark schemas # --------------------------------------------------------------------------- def _make_metrics( *, total_seconds: float = 1.0, latencies: list[float] | None = None, peak_rss_mb: float = 512.0, active_device: str = "/CPU:0", metal_active: bool = False, model_load_seconds: float | None = None, ) -> PerformanceMetrics: return PerformanceMetrics( model_load_seconds=model_load_seconds, total_seconds=total_seconds, per_frame_latencies_ms=latencies if latencies is not None else [10.0, 11.0, 9.5], peak_rss_mb=peak_rss_mb, active_device=active_device, tensorflow_metal_active=metal_active, tensorflow_version="2.18.0", ) def _make_aggregate() -> BenchmarkAggregate: return BenchmarkAggregate( repeats_measured=4, warmup_frames_per_pass=3, mean_frame_latency_ms=10.0, p50_frame_latency_ms=9.8, p95_frame_latency_ms=12.5, p99_frame_latency_ms=13.0, stddev_frame_latency_ms=0.7, mean_throughput_fps=100.0, peak_rss_mb_max=512.0, active_device="/CPU:0", tensorflow_metal_active=False, tensorflow_version="2.18.0", ) class TestPerformanceMetricsModel: def test_roundtrip(self) -> None: m = _make_metrics() rehydrated = PerformanceMetrics.model_validate(m.model_dump(mode="json")) assert rehydrated == m def test_rejects_negative_total_seconds(self) -> None: with pytest.raises(ValidationError): PerformanceMetrics( total_seconds=-1.0, peak_rss_mb=0.0, active_device="/CPU:0", tensorflow_version="2.18.0", ) def test_rejects_negative_peak_rss(self) -> None: with pytest.raises(ValidationError): PerformanceMetrics( total_seconds=1.0, peak_rss_mb=-5.0, active_device="/CPU:0", tensorflow_version="2.18.0", ) def test_model_load_seconds_optional(self) -> None: m = _make_metrics(model_load_seconds=None) assert m.model_load_seconds is None def test_is_frozen(self) -> None: m = _make_metrics() with pytest.raises(ValidationError): m.total_seconds = 2.0 def _minimal_provenance() -> Provenance: return Provenance( model_sha256="a" * 64, model_filename="metrabs_fake.tar.gz", tensorflow_version="2.18.1", numpy_version="2.0.2", neuropose_version="0.1.0.dev0", python_version="3.11.14", ) class TestProvenanceModel: """Schema-level behaviour of :class:`neuropose.io.Provenance`.""" def test_roundtrip_through_json(self) -> None: p = Provenance( model_sha256="a" * 64, model_filename="metrabs_fake.tar.gz", tensorflow_version="2.18.1", tensorflow_metal_version="1.2.0", numpy_version="2.0.2", neuropose_version="0.1.0.dev0", python_version="3.11.14", seed=42, deterministic=True, analysis_config={"step": "dtw", "nan_policy": "propagate"}, ) rehydrated = Provenance.model_validate(p.model_dump(mode="json")) assert rehydrated == p def test_optional_fields_default_to_none_and_false(self) -> None: p = _minimal_provenance() assert p.tensorflow_metal_version is None assert p.seed is None assert p.deterministic is False assert p.analysis_config is None def test_is_frozen(self) -> None: p = _minimal_provenance() with pytest.raises(ValidationError): p.model_sha256 = "different" def test_extra_fields_forbidden(self) -> None: # Construct via model_validate so pyright doesn't have to prove the # keyword doesn't exist on the class at static-type time. with pytest.raises(ValidationError): Provenance.model_validate( { "model_sha256": "x" * 64, "model_filename": "x.tar.gz", "tensorflow_version": "2.18", "numpy_version": "2.0", "neuropose_version": "0.1", "python_version": "3.11.14", "unknown_field": "bogus", } ) class TestVideoPredictionsProvenance: """``provenance`` field on :class:`VideoPredictions` round-trips.""" def test_default_is_none(self) -> None: vp = VideoPredictions( metadata=VideoMetadata(frame_count=0, fps=30.0, width=32, height=32), frames={}, ) assert vp.provenance is None def test_roundtrip_with_provenance(self, tmp_path: Path) -> None: prov = Provenance( model_sha256="f" * 64, model_filename="metrabs.tar.gz", tensorflow_version="2.18.1", numpy_version="2.0.2", neuropose_version="0.1.0.dev0", python_version="3.11.14", ) vp = VideoPredictions( metadata=VideoMetadata(frame_count=1, fps=30.0, width=32, height=32), frames={ "frame_000000": FramePrediction( boxes=[[0.0, 0.0, 32.0, 32.0, 0.9]], poses3d=[[[1.0, 2.0, 3.0]]], poses2d=[[[10.0, 20.0]]], ) }, provenance=prov, ) path = tmp_path / "vp.json" save_video_predictions(path, vp) loaded = load_video_predictions(path) assert loaded == vp assert loaded.provenance == prov class TestBenchmarkResultPersistence: def test_roundtrip_to_disk(self, tmp_path: Path) -> None: result = BenchmarkResult( video_name="trial.mp4", repeats=5, warmup_frames=3, warmup_pass=_make_metrics(total_seconds=20.0), measured_passes=[_make_metrics(total_seconds=1.5) for _ in range(4)], aggregate=_make_aggregate(), ) path = tmp_path / "bench.json" save_benchmark_result(path, result) assert path.exists() loaded = load_benchmark_result(path) assert loaded == result def test_rejects_repeats_below_one(self) -> None: with pytest.raises(ValidationError): BenchmarkResult( video_name="x.mp4", repeats=0, warmup_frames=0, warmup_pass=_make_metrics(), measured_passes=[], aggregate=_make_aggregate(), ) def test_cpu_comparison_nested(self, tmp_path: Path) -> None: comparison = CpuComparisonResult( primary_aggregate=_make_aggregate(), cpu_aggregate=_make_aggregate(), speedup=2.5, max_poses3d_divergence_mm=0.002, frame_count_compared=30, ) result = BenchmarkResult( video_name="trial.mp4", repeats=5, warmup_frames=3, warmup_pass=_make_metrics(), measured_passes=[_make_metrics() for _ in range(4)], aggregate=_make_aggregate(), cpu_comparison=comparison, ) path = tmp_path / "bench_with_cmp.json" save_benchmark_result(path, result) loaded = load_benchmark_result(path) assert loaded.cpu_comparison is not None assert loaded.cpu_comparison.speedup == pytest.approx(2.5) assert loaded.cpu_comparison.max_poses3d_divergence_mm == pytest.approx(0.002) # --------------------------------------------------------------------------- # Segmentation schema # --------------------------------------------------------------------------- class TestSegmentModel: def test_valid(self) -> None: seg = Segment(start=0, end=30, peak=15) assert seg.start == 0 assert seg.peak == 15 assert seg.end == 30 def test_rejects_end_not_greater_than_start(self) -> None: with pytest.raises(ValidationError, match="end"): Segment(start=10, end=10, peak=10) def test_peak_must_be_inside_window(self) -> None: with pytest.raises(ValidationError, match="peak"): Segment(start=0, end=30, peak=30) # peak == end is out of range def test_is_frozen(self) -> None: seg = Segment(start=0, end=10, peak=5) with pytest.raises(ValidationError): seg.start = 1 class TestExtractorSpecs: def test_joint_pair_distance_rejects_identical_joints(self) -> None: with pytest.raises(ValidationError, match="distinct"): JointPairDistanceExtractor(joints=(7, 7)) def test_joint_pair_distance_rejects_negative(self) -> None: with pytest.raises(ValidationError, match="non-negative"): JointPairDistanceExtractor(joints=(-1, 5)) def test_joint_angle_rejects_negative(self) -> None: with pytest.raises(ValidationError, match="non-negative"): JointAngleExtractor(triplet=(0, -1, 2)) def test_joint_axis_rejects_bad_axis(self) -> None: with pytest.raises(ValidationError): JointAxisExtractor(joint=0, axis=3) def test_discriminator_dispatches_to_correct_variant(self) -> None: # Round-trip each extractor variant through a SegmentationConfig # dict to confirm the discriminator selects the right class. for payload, cls in [ ({"kind": "joint_axis", "joint": 1, "axis": 2}, JointAxisExtractor), ({"kind": "joint_pair_distance", "joints": [1, 2]}, JointPairDistanceExtractor), ({"kind": "joint_speed", "joint": 3}, JointSpeedExtractor), ({"kind": "joint_angle", "triplet": [1, 2, 3]}, JointAngleExtractor), ]: cfg = SegmentationConfig.model_validate({"extractor": payload}) assert isinstance(cfg.extractor, cls) class TestSegmentationPersistence: def test_roundtrip_through_video_predictions( self, tmp_path: Path, video_predictions_payload: dict, ) -> None: cfg = SegmentationConfig( extractor=JointAxisExtractor(joint=15, axis=1), min_prominence=20.0, pad_seconds=0.1, ) segmentation = Segmentation( config=cfg, segments=[Segment(start=0, end=1, peak=0), Segment(start=1, end=2, peak=1)], ) video_predictions_payload["segmentations"] = { "cup_lift": segmentation.model_dump(mode="json") } vp = VideoPredictions.model_validate(video_predictions_payload) path = tmp_path / "video.json" save_video_predictions(path, vp) loaded = load_video_predictions(path) assert "cup_lift" in loaded.segmentations cup = loaded.segmentations["cup_lift"] assert cup.config.extractor.kind == "joint_axis" assert len(cup.segments) == 2 assert cup.config.method == "valley_to_valley_v1" def test_default_empty_segmentations_on_new_instance( self, video_predictions_payload: dict ) -> None: vp = VideoPredictions.model_validate(video_predictions_payload) assert vp.segmentations == {} def test_legacy_file_without_segmentations_loads_clean( self, tmp_path: Path, video_predictions_payload: dict, ) -> None: # Older predictions files never wrote the segmentations field; # make sure they still validate and deserialize as if they had # an empty mapping. assert "segmentations" not in video_predictions_payload path = tmp_path / "legacy.json" path.write_text(json.dumps(video_predictions_payload)) vp = load_video_predictions(path) assert vp.segmentations == {} # --------------------------------------------------------------------------- # Status file # --------------------------------------------------------------------------- class TestStatusFile: def test_load_missing_returns_empty(self, tmp_path: Path) -> None: status = load_status(tmp_path / "nope.json") assert status.is_empty() def test_load_corrupt_json_returns_empty(self, tmp_path: Path) -> None: path = tmp_path / "bad.json" path.write_text("{ not valid json") status = load_status(path) assert status.is_empty() def test_load_non_mapping_returns_empty(self, tmp_path: Path) -> None: path = tmp_path / "list.json" path.write_text(json.dumps([1, 2, 3])) status = load_status(path) assert status.is_empty() def test_save_and_load_completed_entry(self, tmp_path: Path) -> None: started = datetime(2026, 4, 13, 12, 0, 0, tzinfo=UTC) completed = datetime(2026, 4, 13, 12, 5, 0, tzinfo=UTC) status = StatusFile.model_validate( { "job_001": { "status": "completed", "started_at": started.isoformat(), "completed_at": completed.isoformat(), "results_path": "/tmp/results.json", "error": None, } } ) path = tmp_path / "status.json" save_status(path, status) loaded = load_status(path) entry = loaded.root["job_001"] assert entry.status == JobStatus.COMPLETED assert entry.started_at == started assert entry.completed_at == completed assert entry.error is None def test_save_is_atomic(self, tmp_path: Path) -> None: """``save_status`` leaves no orphan ``.tmp`` file on success.""" started = datetime(2026, 4, 13, tzinfo=UTC) status = StatusFile.model_validate( { "job_001": { "status": "processing", "started_at": started.isoformat(), } } ) path = tmp_path / "status.json" save_status(path, status) assert path.exists() tmps = list(tmp_path.glob("status.json.tmp")) assert tmps == [] def test_failed_entry_carries_error_message(self, tmp_path: Path) -> None: started = datetime(2026, 4, 13, tzinfo=UTC) status = StatusFile.model_validate( { "job_001": { "status": "failed", "started_at": started.isoformat(), "error": "ffmpeg decode failed: codec not supported", } } ) path = tmp_path / "status.json" save_status(path, status) loaded = load_status(path) entry = loaded.root["job_001"] assert entry.status == JobStatus.FAILED assert entry.error is not None assert "ffmpeg" in entry.error def test_rejects_unknown_status(self, tmp_path: Path) -> None: with pytest.raises(ValidationError): StatusFile.model_validate( { "job_001": { "status": "some-unknown-state", "started_at": datetime(2026, 4, 13, tzinfo=UTC).isoformat(), } } ) class TestJobStatusEntryProgressFields: def test_progress_fields_default_to_none(self) -> None: entry = JobStatusEntry( status=JobStatus.PROCESSING, started_at=datetime(2026, 4, 13, tzinfo=UTC), ) assert entry.current_video is None assert entry.frames_processed is None assert entry.frames_total is None assert entry.videos_completed is None assert entry.videos_total is None assert entry.percent_complete is None assert entry.last_update is None def test_legacy_status_file_without_progress_loads(self, tmp_path: Path) -> None: """Files written before the progress fields existed must still load.""" started = datetime(2026, 4, 13, 12, 0, 0, tzinfo=UTC) path = tmp_path / "legacy.json" path.write_text( json.dumps( { "job_001": { "status": "completed", "started_at": started.isoformat(), "completed_at": started.isoformat(), "results_path": "/tmp/results.json", "error": None, } } ) ) loaded = load_status(path) entry = loaded.root["job_001"] assert entry.percent_complete is None def test_progress_roundtrips_through_json(self, tmp_path: Path) -> None: now = datetime(2026, 4, 13, 12, 0, 0, tzinfo=UTC) status = StatusFile( root={ "job_001": JobStatusEntry( status=JobStatus.PROCESSING, started_at=now, current_video="trial_01.mp4", frames_processed=450, frames_total=1200, videos_completed=0, videos_total=3, percent_complete=12.5, last_update=now, ), } ) path = tmp_path / "status.json" save_status(path, status) loaded = load_status(path) entry = loaded.root["job_001"] assert entry.current_video == "trial_01.mp4" assert entry.frames_processed == 450 assert entry.percent_complete == 12.5 def test_percent_complete_rejects_out_of_range(self) -> None: with pytest.raises(ValidationError): JobStatusEntry( status=JobStatus.PROCESSING, started_at=datetime(2026, 4, 13, tzinfo=UTC), percent_complete=150.0, ) with pytest.raises(ValidationError): JobStatusEntry( status=JobStatus.PROCESSING, started_at=datetime(2026, 4, 13, tzinfo=UTC), percent_complete=-5.0, ) def test_frames_processed_rejects_negative(self) -> None: with pytest.raises(ValidationError): JobStatusEntry( status=JobStatus.PROCESSING, started_at=datetime(2026, 4, 13, tzinfo=UTC), frames_processed=-1, )