neuropose/tests/unit/test_estimator.py

224 lines
7.5 KiB
Python

"""Tests for :class:`neuropose.estimator.Estimator`.
These tests exercise the non-model code paths (video decoding, frame loop,
metadata extraction, result construction, progress reporting, and error
handling) using an injected fake MeTRAbs model. The TensorFlow / real-model
integration smoke test lives in ``tests/integration/`` and lands with
commit 11.
"""
from __future__ import annotations
from pathlib import Path
import pytest
from neuropose.estimator import (
Estimator,
ModelNotLoadedError,
ProcessVideoResult,
VideoDecodeError,
)
from neuropose.io import FramePrediction, VideoPredictions
class TestConstruction:
def test_no_model_by_default(self) -> None:
estimator = Estimator()
assert not estimator.is_model_loaded
def test_injected_model_is_loaded(self, fake_metrabs_model) -> None:
estimator = Estimator(model=fake_metrabs_model)
assert estimator.is_model_loaded
def test_defaults(self) -> None:
estimator = Estimator()
assert estimator.device == "/CPU:0"
assert estimator.skeleton == "berkeley_mhad_43"
assert estimator.default_fov_degrees == pytest.approx(55.0)
def test_overrides(self, fake_metrabs_model) -> None:
estimator = Estimator(
device="/GPU:0",
skeleton="smpl_24",
default_fov_degrees=40.0,
model=fake_metrabs_model,
)
assert estimator.device == "/GPU:0"
assert estimator.skeleton == "smpl_24"
assert estimator.default_fov_degrees == pytest.approx(40.0)
class TestModelGuard:
def test_model_property_raises_when_missing(self) -> None:
estimator = Estimator()
with pytest.raises(ModelNotLoadedError):
_ = estimator.model
def test_process_video_raises_when_missing(self, synthetic_video: Path) -> None:
estimator = Estimator()
with pytest.raises(ModelNotLoadedError):
estimator.process_video(synthetic_video)
def test_load_model_delegates_to_loader(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""``Estimator.load_model`` should delegate to ``load_metrabs_model``.
We verify the delegation without actually invoking TensorFlow or the
network: the loader is monkeypatched to return a sentinel, and we
assert it ends up as the estimator's model.
"""
sentinel = object()
called_with: list[Path | None] = []
def fake_loader(cache_dir: Path | None = None) -> object:
called_with.append(cache_dir)
return sentinel
monkeypatch.setattr("neuropose.estimator.load_metrabs_model", fake_loader)
estimator = Estimator()
estimator.load_model(cache_dir=Path("/tmp/fake-cache"))
assert estimator.model is sentinel
assert called_with == [Path("/tmp/fake-cache")]
def test_load_model_is_idempotent_when_already_loaded(
self,
fake_metrabs_model,
) -> None:
estimator = Estimator(model=fake_metrabs_model)
# Should not raise, and should not clobber the injected model.
estimator.load_model()
assert estimator.model is fake_metrabs_model
class TestProcessVideo:
def test_returns_typed_result(
self,
synthetic_video: Path,
fake_metrabs_model,
) -> None:
estimator = Estimator(model=fake_metrabs_model)
result = estimator.process_video(synthetic_video)
assert isinstance(result, ProcessVideoResult)
assert isinstance(result.predictions, VideoPredictions)
def test_frame_count_matches_source(
self,
synthetic_video: Path,
fake_metrabs_model,
) -> None:
estimator = Estimator(model=fake_metrabs_model)
result = estimator.process_video(synthetic_video)
assert result.frame_count == 5
assert fake_metrabs_model.call_count == 5
def test_frame_naming_is_zero_padded(
self,
synthetic_video: Path,
fake_metrabs_model,
) -> None:
estimator = Estimator(model=fake_metrabs_model)
result = estimator.process_video(synthetic_video)
names = result.predictions.frame_names()
assert names == [
"frame_000000",
"frame_000001",
"frame_000002",
"frame_000003",
"frame_000004",
]
def test_metadata_populated(
self,
synthetic_video: Path,
fake_metrabs_model,
) -> None:
estimator = Estimator(model=fake_metrabs_model)
result = estimator.process_video(synthetic_video)
metadata = result.predictions.metadata
assert metadata.frame_count == 5
assert metadata.width == 32
assert metadata.height == 32
assert metadata.fps > 0.0
def test_each_frame_validates_as_frame_prediction(
self,
synthetic_video: Path,
fake_metrabs_model,
) -> None:
estimator = Estimator(model=fake_metrabs_model)
result = estimator.process_video(synthetic_video)
for name in result.predictions.frame_names():
frame = result.predictions[name]
assert isinstance(frame, FramePrediction)
assert len(frame.boxes) == 1
assert len(frame.poses3d) == 1
assert len(frame.poses3d[0]) == 2 # Two joints per the fake model.
def test_progress_callback_invoked_per_frame(
self,
synthetic_video: Path,
fake_metrabs_model,
) -> None:
estimator = Estimator(model=fake_metrabs_model)
calls: list[tuple[int, int]] = []
estimator.process_video(
synthetic_video,
progress=lambda processed, total: calls.append((processed, total)),
)
assert len(calls) == 5
# Processed counts should be strictly increasing.
assert [c[0] for c in calls] == [1, 2, 3, 4, 5]
def test_fov_override_is_passed_through(self, synthetic_video: Path) -> None:
fov_seen: list[float] = []
class RecordingModel:
def detect_poses(
self,
image,
*,
default_fov_degrees: float,
skeleton: str,
):
del image, skeleton
fov_seen.append(default_fov_degrees)
import numpy as np
return {
"boxes": np.array([[0.0, 0.0, 32.0, 32.0, 0.9]]),
"poses3d": np.array([[[0.0, 0.0, 0.0]]]),
"poses2d": np.array([[[0.0, 0.0]]]),
}
estimator = Estimator(model=RecordingModel(), default_fov_degrees=55.0)
estimator.process_video(synthetic_video, fov_degrees=40.0)
assert all(fov == pytest.approx(40.0) for fov in fov_seen)
assert len(fov_seen) == 5
class TestErrors:
def test_missing_video(
self,
tmp_path: Path,
fake_metrabs_model,
) -> None:
estimator = Estimator(model=fake_metrabs_model)
with pytest.raises(FileNotFoundError):
estimator.process_video(tmp_path / "does_not_exist.mp4")
def test_unreadable_video_raises_decode_error(
self,
tmp_path: Path,
fake_metrabs_model,
) -> None:
# A file that exists but is not a valid video. cv2.VideoCapture
# returns isOpened() == False for non-video content.
path = tmp_path / "not_a_video.avi"
path.write_bytes(b"this is definitely not a video file")
estimator = Estimator(model=fake_metrabs_model)
with pytest.raises(VideoDecodeError):
estimator.process_video(path)