diff --git a/src/neuropose/_model.py b/src/neuropose/_model.py new file mode 100644 index 0000000..bdbcf9c --- /dev/null +++ b/src/neuropose/_model.py @@ -0,0 +1,51 @@ +"""MeTRAbs model loading — stub pending commit 11. + +This module exists so that :mod:`neuropose.estimator` can import a single, +well-typed loader function even before the upstream MeTRAbs URL is pinned +and the TensorFlow version is settled. + +Commit 11 will replace :func:`load_metrabs_model` with an implementation +that: + +1. Pins the canonical MeTRAbs tfhub / Kaggle Models handle (replacing the + ``bit.ly/metrabs_1`` shortener from the previous prototype). +2. Caches the downloaded model under ``Settings.model_cache_dir`` so the + first run downloads it and subsequent runs are offline. +3. Returns a typed handle that the estimator can invoke without hitting the + network on each instantiation. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + + +def load_metrabs_model(cache_dir: Path | None = None) -> Any: # noqa: ARG001 + """Load the MeTRAbs model, downloading and caching on first use. + + Parameters + ---------- + cache_dir + Directory where the model should be cached. Typically + ``Settings.model_cache_dir``. If ``None``, the implementation picks + a default location. + + Returns + ------- + object + An opaque model handle that exposes ``detect_poses`` and the + per-skeleton joint metadata attributes used by + :class:`neuropose.estimator.Estimator`. + + Raises + ------ + NotImplementedError + Always, at this commit. Commit 11 provides the real implementation + once the upstream MeTRAbs URL is pinned. + """ + raise NotImplementedError( + "load_metrabs_model is stubbed pending commit 11. " + "Inject a model directly via Estimator(model=...) for now, " + "or wait for the upstream MeTRAbs URL and TensorFlow version pin." + ) diff --git a/src/neuropose/estimator.py b/src/neuropose/estimator.py new file mode 100644 index 0000000..804544e --- /dev/null +++ b/src/neuropose/estimator.py @@ -0,0 +1,308 @@ +"""3D human pose estimator — MeTRAbs wrapper. + +The :class:`Estimator` class is the core of NeuroPose's inference path. It +takes a video file, runs the MeTRAbs 3D pose-estimation model on each frame, +and returns a validated :class:`~neuropose.io.VideoPredictions` object with +the per-frame predictions and video metadata. + +Design +------ +The estimator is a **library**, not a daemon: it knows nothing about job +directories, status files, or polling. Those concerns live in +:mod:`neuropose.interfacer`. An ``Estimator`` can be constructed directly +from a Python script and called on a single video, which is the path taken +by the documentation's quick-start. + +Frames are streamed from the source video in memory — the previous +prototype wrote every frame to disk as a PNG and then re-read each one with +``tf.io.decode_png``. That round-trip is gone; frames are read once and +passed to the model directly. + +Model injection +--------------- +The MeTRAbs model is supplied either: + +1. Through :meth:`Estimator.load_model`, which delegates to + :func:`neuropose._model.load_metrabs_model` (stubbed pending commit 11). +2. Directly via ``Estimator(model=...)``, which is the path used by the + test suite with a fake model to exercise the code without TensorFlow. + +Either way, attempting to call :meth:`Estimator.process_video` before a +model is present raises :class:`ModelNotLoadedError`. +""" + +from __future__ import annotations + +import logging +from collections.abc import Callable +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import cv2 + +from neuropose._model import load_metrabs_model +from neuropose.io import FramePrediction, VideoMetadata, VideoPredictions + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Exceptions +# --------------------------------------------------------------------------- + + +class EstimatorError(Exception): + """Base class for errors raised by :class:`Estimator`.""" + + +class ModelNotLoadedError(EstimatorError): + """Raised when an inference method is called before the model is loaded.""" + + +class VideoDecodeError(EstimatorError): + """Raised when a video file cannot be opened or decoded.""" + + +# --------------------------------------------------------------------------- +# Result container +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class ProcessVideoResult: + """Result of :meth:`Estimator.process_video`. + + Attributes + ---------- + predictions + The validated :class:`VideoPredictions` object, containing both the + per-frame predictions and the ``VideoMetadata`` envelope. + """ + + predictions: VideoPredictions + + @property + def frame_count(self) -> int: + """Convenience accessor for ``predictions.metadata.frame_count``.""" + return self.predictions.metadata.frame_count + + +# --------------------------------------------------------------------------- +# Estimator +# --------------------------------------------------------------------------- + + +ProgressCallback = Callable[[int, int], None] +"""Type alias for progress callbacks: ``(frames_processed, frame_count_hint)``.""" + + +class Estimator: + """3D pose estimator built on MeTRAbs. + + Parameters + ---------- + device + TensorFlow device string (e.g. ``"/CPU:0"`` or ``"/GPU:0"``). Passed + through to the model at inference time. + skeleton + Skeleton identifier understood by MeTRAbs. Defaults to + ``"berkeley_mhad_43"``, the 43-joint skeleton used by the previous + prototype. + default_fov_degrees + Horizontal field of view assumed when a video does not supply + intrinsics. Overridable per call via + :meth:`process_video`'s ``fov_degrees`` argument. + model + Optional pre-loaded MeTRAbs model. If supplied, the estimator uses + it directly and :meth:`load_model` need not be called. This path is + used by tests to inject a fake model, and by callers that want to + share a single model across many :class:`Estimator` instances. + """ + + def __init__( + self, + *, + device: str = "/CPU:0", + skeleton: str = "berkeley_mhad_43", + default_fov_degrees: float = 55.0, + model: Any | None = None, + ) -> None: + self.device = device + self.skeleton = skeleton + self.default_fov_degrees = default_fov_degrees + self._model: Any | None = model + + # -- model lifecycle ---------------------------------------------------- + + @property + def model(self) -> Any: + """Return the loaded model, or raise :class:`ModelNotLoadedError`.""" + if self._model is None: + raise ModelNotLoadedError( + "Estimator model has not been loaded. " + "Call Estimator.load_model() or pass model=... to the constructor." + ) + return self._model + + @property + def is_model_loaded(self) -> bool: + """Return ``True`` if a model has been supplied or loaded.""" + return self._model is not None + + def load_model(self, cache_dir: Path | None = None) -> None: + """Load the MeTRAbs model via :func:`neuropose._model.load_metrabs_model`. + + Parameters + ---------- + cache_dir + Directory where the downloaded model should be cached. Typically + ``Settings.model_cache_dir``. + + Notes + ----- + This is idempotent: calling it again after a successful load is a + no-op. Callers that want to reload the model should construct a new + :class:`Estimator` instance. + """ + if self._model is not None: + logger.debug("Model already loaded; skipping reload.") + return + logger.info("Loading MeTRAbs model (cache_dir=%s)", cache_dir) + self._model = load_metrabs_model(cache_dir=cache_dir) + logger.info("MeTRAbs model loaded.") + + # -- inference ---------------------------------------------------------- + + def process_video( + self, + video_path: Path, + *, + fov_degrees: float | None = None, + progress: ProgressCallback | None = None, + ) -> ProcessVideoResult: + """Run pose estimation on every frame of a video. + + Parameters + ---------- + video_path + Path to the input video. Must exist and be openable by OpenCV. + fov_degrees + Per-call override for the horizontal field of view. If ``None``, + the estimator's ``default_fov_degrees`` is used. + progress + Optional callback invoked after each processed frame as + ``progress(processed, total_hint)``. ``total_hint`` is the + approximate frame count reported by OpenCV's + ``CAP_PROP_FRAME_COUNT``, which is unreliable for variable-rate + videos; the authoritative count is available after the call on + ``result.frame_count``. + + Returns + ------- + ProcessVideoResult + A typed result containing the validated :class:`VideoPredictions`. + + Raises + ------ + FileNotFoundError + If ``video_path`` does not exist. + VideoDecodeError + If OpenCV cannot open the video. + ModelNotLoadedError + If the model has not been loaded or injected. + """ + if not video_path.exists(): + raise FileNotFoundError(f"video file not found: {video_path}") + + # Access the model eagerly so we fail fast before opening the video. + model = self.model + fov = fov_degrees if fov_degrees is not None else self.default_fov_degrees + + cap = cv2.VideoCapture(str(video_path)) + if not cap.isOpened(): + raise VideoDecodeError(f"OpenCV could not open video: {video_path}") + + try: + fps = float(cap.get(cv2.CAP_PROP_FPS)) + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + total_hint = max(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), 0) + + logger.info( + "Processing video %s (%dx%d @ %.2f fps, ~%d frames)", + video_path, + width, + height, + fps, + total_hint, + ) + + frames: dict[str, FramePrediction] = {} + frame_index = 0 + while True: + ok, bgr_frame = cap.read() + if not ok: + break + # MeTRAbs was trained on RGB images; OpenCV gives us BGR. + rgb_frame = cv2.cvtColor(bgr_frame, cv2.COLOR_BGR2RGB) + prediction = self._infer_frame(model, rgb_frame, fov) + frames[f"frame_{frame_index:06d}"] = prediction + frame_index += 1 + if progress is not None: + progress(frame_index, total_hint) + + metadata = VideoMetadata( + frame_count=frame_index, + fps=fps if fps > 0 else 0.0, + width=width, + height=height, + ) + finally: + cap.release() + + if frame_index == 0: + logger.warning("Video %s contained no decodable frames.", video_path) + + predictions = VideoPredictions(metadata=metadata, frames=frames) + return ProcessVideoResult(predictions=predictions) + + # -- internals ---------------------------------------------------------- + + def _infer_frame( + self, + model: Any, + rgb_frame: Any, + fov_degrees: float, + ) -> FramePrediction: + """Run a single frame through the model and validate the output.""" + pred = model.detect_poses( + rgb_frame, + default_fov_degrees=fov_degrees, + skeleton=self.skeleton, + ) + return FramePrediction( + boxes=_to_nested_list(pred["boxes"]), + poses3d=_to_nested_list(pred["poses3d"]), + poses2d=_to_nested_list(pred["poses2d"]), + ) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _to_nested_list(value: Any) -> Any: + """Normalise a TF tensor, numpy array, or nested list to Python lists. + + The real MeTRAbs model returns ``tf.Tensor`` objects which expose + ``.numpy()`` returning a ``numpy.ndarray``. Tests inject fake models + that return plain numpy arrays. Both paths flow through this helper so + the rest of the code is agnostic to which is in use. + """ + if hasattr(value, "numpy"): + value = value.numpy() + if hasattr(value, "tolist"): + return value.tolist() + return list(value) diff --git a/src/neuropose/io.py b/src/neuropose/io.py index 51c5e33..d4256aa 100644 --- a/src/neuropose/io.py +++ b/src/neuropose/io.py @@ -1,14 +1,15 @@ """I/O helpers and schema definitions for NeuroPose prediction data. -Defines pydantic models for per-frame predictions, per-video aggregated -predictions, job-level aggregated results, and the persistent status file. -All models are validated on load, so malformed files are caught at the -boundary rather than at some downstream call site. +Defines pydantic models for per-frame predictions, per-video predictions +(with metadata envelope), job-level aggregated results, and the persistent +status file. All models are validated on load, so malformed files are caught +at the boundary rather than at some downstream call site. -Atomicity: :func:`save_status` and :func:`save_job_results` write to a sibling -temp file and then atomically rename, so a crash mid-write will not leave a -partially-written file behind. This matches the crash-resilience guarantee -the interfacer daemon makes to callers. +Atomicity: :func:`save_status`, :func:`save_job_results`, and +:func:`save_video_predictions` write to a sibling temp file and then +atomically rename, so a crash mid-write will not leave a partially-written +file behind. This matches the crash-resilience guarantee the interfacer +daemon makes to callers. """ from __future__ import annotations @@ -55,26 +56,51 @@ class FramePrediction(BaseModel): ) -class VideoPredictions(RootModel[dict[str, FramePrediction]]): - """Per-frame predictions for a single video, keyed by frame filename. +class VideoMetadata(BaseModel): + """Metadata about the source video for a set of predictions. - Frame names are expected to follow the ``frame_.png`` convention - written by the estimator, but no constraint is enforced at the schema - level so downstream consumers can key by any naming scheme. + Essential for reproducibility: the frame count lets downstream analysis + verify completeness, and the fps lets it convert frame indices to real + time without needing access to the original video file. + + Intentionally does NOT include the source file path or filename, which + may encode subject-identifying information. Callers that need provenance + should store it out-of-band in accordance with the data-handling policy. """ - def frames(self) -> list[str]: - """Return the frame names in insertion order.""" - return list(self.root.keys()) + model_config = ConfigDict(extra="forbid", frozen=True) + + frame_count: int = Field(ge=0, description="Number of frames actually processed.") + fps: float = Field(ge=0.0, description="Source video frame rate (frames per second).") + width: int = Field(ge=0, description="Source video frame width in pixels.") + height: int = Field(ge=0, description="Source video frame height in pixels.") + + +class VideoPredictions(BaseModel): + """Per-frame predictions for a single video, paired with video metadata. + + The ``frames`` mapping is keyed by frame identifier (``frame_`` by + convention, zero-padded to 6 digits). The identifier is a stable string, + not a filesystem path — no PNG file is implied. + """ + + model_config = ConfigDict(extra="forbid", frozen=True) + + metadata: VideoMetadata + frames: dict[str, FramePrediction] + + def frame_names(self) -> list[str]: + """Return frame identifiers in insertion order.""" + return list(self.frames.keys()) def __len__(self) -> int: - return len(self.root) + return len(self.frames) def __iter__(self) -> Iterator[str]: # type: ignore[override] - return iter(self.root) + return iter(self.frames) def __getitem__(self, key: str) -> FramePrediction: - return self.root[key] + return self.frames[key] class JobResults(RootModel[dict[str, VideoPredictions]]): @@ -85,7 +111,7 @@ class JobResults(RootModel[dict[str, VideoPredictions]]): """ def videos(self) -> list[str]: - """Return the video names in insertion order.""" + """Return video names in insertion order.""" return list(self.root.keys()) def __len__(self) -> int: @@ -143,7 +169,7 @@ def load_video_predictions(path: Path) -> VideoPredictions: def save_video_predictions(path: Path, predictions: VideoPredictions) -> None: - """Serialize per-video predictions to a JSON file.""" + """Serialize per-video predictions to a JSON file atomically.""" path.parent.mkdir(parents=True, exist_ok=True) _write_json_atomic(path, predictions.model_dump(mode="json")) diff --git a/src/neuropose/visualize.py b/src/neuropose/visualize.py new file mode 100644 index 0000000..b18be59 --- /dev/null +++ b/src/neuropose/visualize.py @@ -0,0 +1,210 @@ +"""Matplotlib-based visualization of NeuroPose predictions. + +Separate module so that importing :mod:`neuropose.estimator` does not pull +in matplotlib or incur its global backend side effect. Callers that want +visualization import this module explicitly. + +The old prototype's ``_visualize`` helper had an in-place numpy-view +aliasing bug where ``poses3d[..., 1], poses3d[..., 2] = poses3d[..., 2], +-poses3d[..., 1]`` mutated the caller's prediction array. This rewrite +makes an explicit copy before any axis reordering, so the input +:class:`VideoPredictions` is never touched. +""" + +from __future__ import annotations + +import logging +from collections.abc import Sequence +from pathlib import Path +from typing import Any + +import cv2 +import numpy as np + +from neuropose.estimator import VideoDecodeError +from neuropose.io import VideoPredictions + +logger = logging.getLogger(__name__) + +VALID_VIEWS = frozenset({"normal", "depth"}) + + +def visualize_predictions( + video_path: Path, + predictions: VideoPredictions, + output_dir: Path, + *, + view: str = "normal", + joint_edges: Sequence[tuple[int, int]] | None = None, + frame_indices: Sequence[int] | None = None, +) -> list[Path]: + """Render per-frame 2D + 3D visualizations as PNG files. + + Parameters + ---------- + video_path + Path to the source video, used to recover frame pixels for the 2D + overlay. Must be the same video the predictions were computed from. + predictions + Predictions to overlay. The dict ordering of ``predictions.frames`` + determines which frames are drawn unless ``frame_indices`` is set. + output_dir + Directory to write the rendered PNGs into. Created if absent. + view + ``"normal"`` (default) or ``"depth"``. Controls the 3D subplot's + axis limits. Anything else raises ``ValueError``. + joint_edges + Optional list of ``(i, j)`` index pairs specifying skeleton edges to + draw as lines connecting joints. If ``None``, only scatter points + are drawn. For ``berkeley_mhad_43`` the edges can be obtained from + ``model.per_skeleton_joint_edges[skeleton]``. + frame_indices + Optional subset of frame indices to render. If ``None``, every + frame in ``predictions`` is rendered. Out-of-range indices are + silently skipped. + + Returns + ------- + list[Path] + Paths of the written PNG files, in render order. + + Raises + ------ + FileNotFoundError + If ``video_path`` does not exist. + VideoDecodeError + If OpenCV cannot open the video. + ValueError + If ``view`` is not one of the supported values. + """ + if view not in VALID_VIEWS: + raise ValueError(f"view must be one of {sorted(VALID_VIEWS)}; got {view!r}") + if not video_path.exists(): + raise FileNotFoundError(f"video file not found: {video_path}") + + # Late imports: matplotlib carries a global backend side effect, which + # we want to avoid at module load time. pyplot is also slow to import. + import matplotlib + + matplotlib.use("Agg", force=False) + import matplotlib.pyplot as plt + from matplotlib.patches import Rectangle + + output_dir.mkdir(parents=True, exist_ok=True) + frame_names = predictions.frame_names() + selected = _select_indices(frame_indices, len(frame_names)) + + cap = cv2.VideoCapture(str(video_path)) + if not cap.isOpened(): + raise VideoDecodeError(f"OpenCV could not open video: {video_path}") + + written: list[Path] = [] + try: + next_index_to_render = iter(selected) + target = next(next_index_to_render, None) + frame_index = 0 + while target is not None: + ok, bgr_frame = cap.read() + if not ok: + break + if frame_index == target: + rgb_frame = cv2.cvtColor(bgr_frame, cv2.COLOR_BGR2RGB) + frame_name = frame_names[frame_index] + out_path = output_dir / f"{frame_name}.png" + _render_frame( + rgb_frame, + predictions[frame_name], + out_path, + view=view, + joint_edges=joint_edges, + plt=plt, + Rectangle=Rectangle, + ) + written.append(out_path) + target = next(next_index_to_render, None) + frame_index += 1 + finally: + cap.release() + + logger.info("Wrote %d visualization frame(s) to %s", len(written), output_dir) + return written + + +def _select_indices(frame_indices: Sequence[int] | None, total: int) -> list[int]: + """Normalise the ``frame_indices`` argument to a sorted, in-range list.""" + if frame_indices is None: + return list(range(total)) + return sorted({i for i in frame_indices if 0 <= i < total}) + + +def _render_frame( + rgb_frame: Any, + frame_prediction: Any, + out_path: Path, + *, + view: str, + joint_edges: Sequence[tuple[int, int]] | None, + plt: Any, + Rectangle: Any, +) -> None: + """Render one frame's 2D overlay + 3D scatter to ``out_path``.""" + # Explicit copies. The previous prototype mutated the caller's data via + # numpy-view tuple-assignment; we take a fresh numpy array per person + # so the caller's VideoPredictions is never touched. + boxes = np.asarray(frame_prediction.boxes, dtype=float) + poses3d = np.asarray(frame_prediction.poses3d, dtype=float).copy() + poses2d = np.asarray(frame_prediction.poses2d, dtype=float) + + # Rotate for visualization: swap Y and Z so the ground plane is horizontal. + # Do this on the copy so the original predictions object is untouched. + if poses3d.size > 0: + original_y = poses3d[..., 1].copy() + poses3d[..., 1] = poses3d[..., 2] + poses3d[..., 2] = -original_y + + fig = plt.figure(figsize=(10, 5.2)) + + image_ax = fig.add_subplot(1, 2, 1) + image_ax.imshow(rgb_frame) + for box in boxes: + if len(box) < 4: + continue + x, y, w, h = box[:4] + image_ax.add_patch(Rectangle((x, y), w, h, fill=False)) + + pose_ax = fig.add_subplot(1, 2, 2, projection="3d") + pose_ax.view_init(5, -85) + if view == "depth": + pose_ax.set_xlim3d(200, 17500) + pose_ax.set_zlim3d(-1500, 1500) + pose_ax.set_ylim3d(0, 3000) + else: + pose_ax.set_xlim3d(-1500, 1500) + pose_ax.set_zlim3d(-1500, 1500) + pose_ax.set_ylim3d(0, 3000) + pose_ax.set_box_aspect((1, 1, 1)) + + for pose3d, pose2d in zip(poses3d, poses2d, strict=False): + if joint_edges is not None: + for i_start, i_end in joint_edges: + if 0 <= i_start < len(pose2d) and 0 <= i_end < len(pose2d): + image_ax.plot( + [pose2d[i_start][0], pose2d[i_end][0]], + [pose2d[i_start][1], pose2d[i_end][1]], + marker="o", + markersize=2, + ) + if 0 <= i_start < len(pose3d) and 0 <= i_end < len(pose3d): + pose_ax.plot( + [pose3d[i_start][0], pose3d[i_end][0]], + [pose3d[i_start][1], pose3d[i_end][1]], + [pose3d[i_start][2], pose3d[i_end][2]], + marker="o", + markersize=2, + ) + image_ax.scatter(pose2d[:, 0], pose2d[:, 1], s=2) + pose_ax.scatter(pose3d[:, 0], pose3d[:, 1], pose3d[:, 2], s=2) + + fig.tight_layout() + fig.savefig(out_path) + plt.close(fig) diff --git a/tests/conftest.py b/tests/conftest.py index 6023d16..d9897e9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,10 +5,18 @@ from __future__ import annotations import os from collections.abc import Iterator from pathlib import Path +from typing import Any +import cv2 +import numpy as np import pytest +# --------------------------------------------------------------------------- +# Environment isolation +# --------------------------------------------------------------------------- + + @pytest.fixture(autouse=True) def _isolate_environment( monkeypatch: pytest.MonkeyPatch, @@ -36,3 +44,65 @@ def _isolate_environment( def xdg_home() -> Path: """Return the isolated ``$XDG_DATA_HOME`` set up by ``_isolate_environment``.""" return Path(os.environ["XDG_DATA_HOME"]) + + +# --------------------------------------------------------------------------- +# Synthetic video + fake MeTRAbs model +# --------------------------------------------------------------------------- + + +@pytest.fixture +def synthetic_video(tmp_path: Path) -> Path: + """Generate a tiny synthetic video at test time. + + The fixture writes a 5-frame, 32×32 MJPG-encoded ``.avi`` file. MJPG is + chosen over ``mp4v`` because it ships with ``opencv-python-headless`` on + every platform we target, whereas ``mp4v`` occasionally requires an + ffmpeg binary that may not be present on minimal CI runners. + """ + path = tmp_path / "synthetic.avi" + fourcc = cv2.VideoWriter_fourcc(*"MJPG") + writer = cv2.VideoWriter(str(path), fourcc, 30.0, (32, 32)) + assert writer.isOpened(), "cv2.VideoWriter failed to open; MJPG codec missing?" + for i in range(5): + # Distinct brightness per frame so a downstream check could verify + # the test is actually reading frame-by-frame. + frame = np.full((32, 32, 3), i * 40, dtype=np.uint8) + writer.write(frame) + writer.release() + assert path.exists() and path.stat().st_size > 0, "Synthetic video is empty." + return path + + +class _FakeMetrabsModel: + """Minimal stand-in for the MeTRAbs model used in unit tests. + + Returns deterministic pose data (one person, two joints) per call so + tests can assert on shapes without importing TensorFlow or downloading + the real model. The returned arrays are plain numpy so the estimator's + ``_to_nested_list`` helper exercises its non-``numpy()`` branch. + """ + + def __init__(self) -> None: + self.call_count = 0 + + def detect_poses( + self, + image: Any, + *, + default_fov_degrees: float, + skeleton: str, + ) -> dict[str, np.ndarray]: + del image, default_fov_degrees, skeleton # signature-compatible with MeTRAbs + self.call_count += 1 + return { + "boxes": np.array([[0.0, 0.0, 32.0, 32.0, 0.95]]), + "poses3d": np.array([[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]]), + "poses2d": np.array([[[10.0, 20.0], [30.0, 40.0]]]), + } + + +@pytest.fixture +def fake_metrabs_model() -> _FakeMetrabsModel: + """Return a fresh fake MeTRAbs model instance for a single test.""" + return _FakeMetrabsModel() diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py new file mode 100644 index 0000000..36dbbcb --- /dev/null +++ b/tests/unit/test_estimator.py @@ -0,0 +1,205 @@ +"""Tests for :class:`neuropose.estimator.Estimator`. + +These tests exercise the non-model code paths (video decoding, frame loop, +metadata extraction, result construction, progress reporting, and error +handling) using an injected fake MeTRAbs model. The TensorFlow / real-model +integration smoke test lives in ``tests/integration/`` and lands with +commit 11. +""" + +from __future__ import annotations + +from pathlib import Path + +import pytest + +from neuropose.estimator import ( + Estimator, + ModelNotLoadedError, + ProcessVideoResult, + VideoDecodeError, +) +from neuropose.io import FramePrediction, VideoPredictions + + +class TestConstruction: + def test_no_model_by_default(self) -> None: + estimator = Estimator() + assert not estimator.is_model_loaded + + def test_injected_model_is_loaded(self, fake_metrabs_model) -> None: + estimator = Estimator(model=fake_metrabs_model) + assert estimator.is_model_loaded + + def test_defaults(self) -> None: + estimator = Estimator() + assert estimator.device == "/CPU:0" + assert estimator.skeleton == "berkeley_mhad_43" + assert estimator.default_fov_degrees == pytest.approx(55.0) + + def test_overrides(self, fake_metrabs_model) -> None: + estimator = Estimator( + device="/GPU:0", + skeleton="smpl_24", + default_fov_degrees=40.0, + model=fake_metrabs_model, + ) + assert estimator.device == "/GPU:0" + assert estimator.skeleton == "smpl_24" + assert estimator.default_fov_degrees == pytest.approx(40.0) + + +class TestModelGuard: + def test_model_property_raises_when_missing(self) -> None: + estimator = Estimator() + with pytest.raises(ModelNotLoadedError): + _ = estimator.model + + def test_process_video_raises_when_missing(self, synthetic_video: Path) -> None: + estimator = Estimator() + with pytest.raises(ModelNotLoadedError): + estimator.process_video(synthetic_video) + + def test_load_model_stub_raises_not_implemented(self) -> None: + estimator = Estimator() + with pytest.raises(NotImplementedError, match="commit 11"): + estimator.load_model() + + def test_load_model_is_idempotent_when_already_loaded( + self, + fake_metrabs_model, + ) -> None: + estimator = Estimator(model=fake_metrabs_model) + # Should not raise, and should not clobber the injected model. + estimator.load_model() + assert estimator.model is fake_metrabs_model + + +class TestProcessVideo: + def test_returns_typed_result( + self, + synthetic_video: Path, + fake_metrabs_model, + ) -> None: + estimator = Estimator(model=fake_metrabs_model) + result = estimator.process_video(synthetic_video) + assert isinstance(result, ProcessVideoResult) + assert isinstance(result.predictions, VideoPredictions) + + def test_frame_count_matches_source( + self, + synthetic_video: Path, + fake_metrabs_model, + ) -> None: + estimator = Estimator(model=fake_metrabs_model) + result = estimator.process_video(synthetic_video) + assert result.frame_count == 5 + assert fake_metrabs_model.call_count == 5 + + def test_frame_naming_is_zero_padded( + self, + synthetic_video: Path, + fake_metrabs_model, + ) -> None: + estimator = Estimator(model=fake_metrabs_model) + result = estimator.process_video(synthetic_video) + names = result.predictions.frame_names() + assert names == [ + "frame_000000", + "frame_000001", + "frame_000002", + "frame_000003", + "frame_000004", + ] + + def test_metadata_populated( + self, + synthetic_video: Path, + fake_metrabs_model, + ) -> None: + estimator = Estimator(model=fake_metrabs_model) + result = estimator.process_video(synthetic_video) + metadata = result.predictions.metadata + assert metadata.frame_count == 5 + assert metadata.width == 32 + assert metadata.height == 32 + assert metadata.fps > 0.0 + + def test_each_frame_validates_as_frame_prediction( + self, + synthetic_video: Path, + fake_metrabs_model, + ) -> None: + estimator = Estimator(model=fake_metrabs_model) + result = estimator.process_video(synthetic_video) + for name in result.predictions.frame_names(): + frame = result.predictions[name] + assert isinstance(frame, FramePrediction) + assert len(frame.boxes) == 1 + assert len(frame.poses3d) == 1 + assert len(frame.poses3d[0]) == 2 # Two joints per the fake model. + + def test_progress_callback_invoked_per_frame( + self, + synthetic_video: Path, + fake_metrabs_model, + ) -> None: + estimator = Estimator(model=fake_metrabs_model) + calls: list[tuple[int, int]] = [] + estimator.process_video( + synthetic_video, + progress=lambda processed, total: calls.append((processed, total)), + ) + assert len(calls) == 5 + # Processed counts should be strictly increasing. + assert [c[0] for c in calls] == [1, 2, 3, 4, 5] + + def test_fov_override_is_passed_through(self, synthetic_video: Path) -> None: + fov_seen: list[float] = [] + + class RecordingModel: + def detect_poses( + self, + image, + *, + default_fov_degrees: float, + skeleton: str, + ): + del image, skeleton + fov_seen.append(default_fov_degrees) + import numpy as np + + return { + "boxes": np.array([[0.0, 0.0, 32.0, 32.0, 0.9]]), + "poses3d": np.array([[[0.0, 0.0, 0.0]]]), + "poses2d": np.array([[[0.0, 0.0]]]), + } + + estimator = Estimator(model=RecordingModel(), default_fov_degrees=55.0) + estimator.process_video(synthetic_video, fov_degrees=40.0) + assert all(fov == pytest.approx(40.0) for fov in fov_seen) + assert len(fov_seen) == 5 + + +class TestErrors: + def test_missing_video( + self, + tmp_path: Path, + fake_metrabs_model, + ) -> None: + estimator = Estimator(model=fake_metrabs_model) + with pytest.raises(FileNotFoundError): + estimator.process_video(tmp_path / "does_not_exist.mp4") + + def test_unreadable_video_raises_decode_error( + self, + tmp_path: Path, + fake_metrabs_model, + ) -> None: + # A file that exists but is not a valid video. cv2.VideoCapture + # returns isOpened() == False for non-video content. + path = tmp_path / "not_a_video.avi" + path.write_bytes(b"this is definitely not a video file") + estimator = Estimator(model=fake_metrabs_model) + with pytest.raises(VideoDecodeError): + estimator.process_video(path) diff --git a/tests/unit/test_io.py b/tests/unit/test_io.py index 70f5765..36d6d02 100644 --- a/tests/unit/test_io.py +++ b/tests/unit/test_io.py @@ -14,6 +14,7 @@ from neuropose.io import ( JobResults, JobStatus, StatusFile, + VideoMetadata, VideoPredictions, load_job_results, load_status, @@ -40,10 +41,18 @@ def one_frame() -> dict: @pytest.fixture -def video_payload(one_frame: dict) -> dict: +def video_metadata_payload() -> dict: + return {"frame_count": 2, "fps": 30.0, "width": 640, "height": 480} + + +@pytest.fixture +def video_predictions_payload(one_frame: dict, video_metadata_payload: dict) -> dict: return { - "frame_0000.png": one_frame, - "frame_0001.png": one_frame, + "metadata": video_metadata_payload, + "frames": { + "frame_000000": one_frame, + "frame_000001": one_frame, + }, } @@ -70,30 +79,90 @@ class TestFramePrediction: frame.boxes = [] +# --------------------------------------------------------------------------- +# VideoMetadata +# --------------------------------------------------------------------------- + + +class TestVideoMetadata: + def test_valid(self) -> None: + meta = VideoMetadata(frame_count=10, fps=29.97, width=1920, height=1080) + assert meta.frame_count == 10 + assert meta.fps == pytest.approx(29.97) + + def test_zero_frame_count_allowed(self) -> None: + # Broken or empty videos still produce a valid metadata object so + # the caller can see frame_count == 0 rather than receiving an + # exception. + VideoMetadata(frame_count=0, fps=0.0, width=0, height=0) + + def test_rejects_negative(self) -> None: + with pytest.raises(ValidationError): + VideoMetadata(frame_count=-1, fps=30.0, width=640, height=480) + + def test_rejects_extra_fields(self) -> None: + with pytest.raises(ValidationError): + VideoMetadata( + frame_count=1, + fps=30.0, + width=640, + height=480, + source_path="/leak/me", + ) + + def test_is_frozen(self) -> None: + meta = VideoMetadata(frame_count=10, fps=30.0, width=640, height=480) + with pytest.raises(ValidationError): + meta.fps = 60.0 + + # --------------------------------------------------------------------------- # VideoPredictions # --------------------------------------------------------------------------- class TestVideoPredictions: - def test_from_dict(self, video_payload: dict) -> None: - vp = VideoPredictions.model_validate(video_payload) + def test_from_dict(self, video_predictions_payload: dict) -> None: + vp = VideoPredictions.model_validate(video_predictions_payload) assert len(vp) == 2 - assert vp.frames() == ["frame_0000.png", "frame_0001.png"] - assert vp["frame_0000.png"].boxes[0][4] == pytest.approx(0.95) + assert vp.frame_names() == ["frame_000000", "frame_000001"] + assert vp["frame_000000"].boxes[0][4] == pytest.approx(0.95) + assert vp.metadata.fps == pytest.approx(30.0) - def test_iteration(self, video_payload: dict) -> None: - vp = VideoPredictions.model_validate(video_payload) - assert list(vp) == ["frame_0000.png", "frame_0001.png"] + def test_iteration(self, video_predictions_payload: dict) -> None: + vp = VideoPredictions.model_validate(video_predictions_payload) + assert list(vp) == ["frame_000000", "frame_000001"] - def test_save_and_load_roundtrip(self, tmp_path: Path, video_payload: dict) -> None: - vp = VideoPredictions.model_validate(video_payload) + def test_rejects_missing_metadata(self, video_predictions_payload: dict) -> None: + del video_predictions_payload["metadata"] + with pytest.raises(ValidationError): + VideoPredictions.model_validate(video_predictions_payload) + + def test_save_and_load_roundtrip( + self, + tmp_path: Path, + video_predictions_payload: dict, + ) -> None: + vp = VideoPredictions.model_validate(video_predictions_payload) path = tmp_path / "preds" / "video.json" save_video_predictions(path, vp) assert path.exists() loaded = load_video_predictions(path) - assert loaded.frames() == vp.frames() - assert loaded["frame_0000.png"].poses3d == vp["frame_0000.png"].poses3d + assert loaded.frame_names() == vp.frame_names() + assert loaded.metadata == vp.metadata + assert loaded["frame_000000"].poses3d == vp["frame_000000"].poses3d + + def test_save_is_atomic( + self, + tmp_path: Path, + video_predictions_payload: dict, + ) -> None: + vp = VideoPredictions.model_validate(video_predictions_payload) + path = tmp_path / "video.json" + save_video_predictions(path, vp) + assert path.exists() + tmps = list(tmp_path.glob("video.json.tmp")) + assert tmps == [] # --------------------------------------------------------------------------- @@ -102,9 +171,16 @@ class TestVideoPredictions: class TestJobResults: - def test_save_and_load_roundtrip(self, tmp_path: Path, video_payload: dict) -> None: + def test_save_and_load_roundtrip( + self, + tmp_path: Path, + video_predictions_payload: dict, + ) -> None: jr = JobResults.model_validate( - {"video_a.mp4": video_payload, "video_b.mp4": video_payload} + { + "video_a.mp4": video_predictions_payload, + "video_b.mp4": video_predictions_payload, + } ) path = tmp_path / "results.json" save_job_results(path, jr) diff --git a/tests/unit/test_visualize.py b/tests/unit/test_visualize.py new file mode 100644 index 0000000..bfe1de6 --- /dev/null +++ b/tests/unit/test_visualize.py @@ -0,0 +1,158 @@ +"""Tests for :mod:`neuropose.visualize`. + +Smoke tests only: we verify that the visualize function runs end-to-end +against a synthetic video, writes PNG files, honours ``frame_indices``, and +does NOT mutate the caller's :class:`VideoPredictions`. Actual pixel-level +correctness is out of scope for unit tests. +""" + +from __future__ import annotations + +import copy +from pathlib import Path + +import pytest + +from neuropose.estimator import Estimator, VideoDecodeError +from neuropose.io import VideoPredictions +from neuropose.visualize import visualize_predictions + + +@pytest.fixture +def predictions_for_synthetic( + synthetic_video: Path, + fake_metrabs_model, +) -> VideoPredictions: + """Run the fake estimator over the synthetic video and return predictions.""" + estimator = Estimator(model=fake_metrabs_model) + return estimator.process_video(synthetic_video).predictions + + +class TestVisualizePredictions: + def test_writes_one_png_per_frame( + self, + tmp_path: Path, + synthetic_video: Path, + predictions_for_synthetic: VideoPredictions, + ) -> None: + output_dir = tmp_path / "viz" + written = visualize_predictions( + synthetic_video, predictions_for_synthetic, output_dir + ) + assert len(written) == 5 + for path in written: + assert path.exists() + assert path.suffix == ".png" + assert path.stat().st_size > 0 + + def test_frame_indices_limits_output( + self, + tmp_path: Path, + synthetic_video: Path, + predictions_for_synthetic: VideoPredictions, + ) -> None: + output_dir = tmp_path / "viz" + written = visualize_predictions( + synthetic_video, + predictions_for_synthetic, + output_dir, + frame_indices=[0, 2, 4], + ) + assert len(written) == 3 + names = sorted(p.stem for p in written) + assert names == ["frame_000000", "frame_000002", "frame_000004"] + + def test_out_of_range_indices_silently_skipped( + self, + tmp_path: Path, + synthetic_video: Path, + predictions_for_synthetic: VideoPredictions, + ) -> None: + output_dir = tmp_path / "viz" + written = visualize_predictions( + synthetic_video, + predictions_for_synthetic, + output_dir, + frame_indices=[0, 999, -1], + ) + assert len(written) == 1 + assert written[0].stem == "frame_000000" + + def test_does_not_mutate_input_predictions( + self, + tmp_path: Path, + synthetic_video: Path, + predictions_for_synthetic: VideoPredictions, + ) -> None: + """The aliasing bug from the previous prototype mutated poses3d in place. + + This test guards against its regression by deep-copying poses3d + before visualization and comparing against the live value after. + Without the deepcopy, any in-place mutation would be invisible + because ``before`` and ``after`` would refer to the same list. + """ + frame_name = predictions_for_synthetic.frame_names()[0] + before = copy.deepcopy(predictions_for_synthetic[frame_name].poses3d) + visualize_predictions( + synthetic_video, + predictions_for_synthetic, + tmp_path / "viz", + frame_indices=[0], + ) + after = predictions_for_synthetic[frame_name].poses3d + assert before == after + + def test_rejects_invalid_view( + self, + tmp_path: Path, + synthetic_video: Path, + predictions_for_synthetic: VideoPredictions, + ) -> None: + with pytest.raises(ValueError, match="view must be one of"): + visualize_predictions( + synthetic_video, + predictions_for_synthetic, + tmp_path / "viz", + view="orthographic", + ) + + def test_depth_view_runs( + self, + tmp_path: Path, + synthetic_video: Path, + predictions_for_synthetic: VideoPredictions, + ) -> None: + written = visualize_predictions( + synthetic_video, + predictions_for_synthetic, + tmp_path / "viz", + view="depth", + frame_indices=[0], + ) + assert len(written) == 1 + + def test_missing_video_raises( + self, + tmp_path: Path, + predictions_for_synthetic: VideoPredictions, + ) -> None: + with pytest.raises(FileNotFoundError): + visualize_predictions( + tmp_path / "nope.avi", + predictions_for_synthetic, + tmp_path / "viz", + ) + + def test_unreadable_video_raises( + self, + tmp_path: Path, + predictions_for_synthetic: VideoPredictions, + ) -> None: + fake_video = tmp_path / "fake.avi" + fake_video.write_bytes(b"not a video") + with pytest.raises(VideoDecodeError): + visualize_predictions( + fake_video, + predictions_for_synthetic, + tmp_path / "viz", + )