estimator

This commit is contained in:
Levi Neuwirth 2026-04-13 12:40:50 -04:00
parent 3d2b2fc68d
commit 9bbbbd0d52
8 changed files with 1141 additions and 37 deletions

51
src/neuropose/_model.py Normal file
View File

@ -0,0 +1,51 @@
"""MeTRAbs model loading — stub pending commit 11.
This module exists so that :mod:`neuropose.estimator` can import a single,
well-typed loader function even before the upstream MeTRAbs URL is pinned
and the TensorFlow version is settled.
Commit 11 will replace :func:`load_metrabs_model` with an implementation
that:
1. Pins the canonical MeTRAbs tfhub / Kaggle Models handle (replacing the
``bit.ly/metrabs_1`` shortener from the previous prototype).
2. Caches the downloaded model under ``Settings.model_cache_dir`` so the
first run downloads it and subsequent runs are offline.
3. Returns a typed handle that the estimator can invoke without hitting the
network on each instantiation.
"""
from __future__ import annotations
from pathlib import Path
from typing import Any
def load_metrabs_model(cache_dir: Path | None = None) -> Any: # noqa: ARG001
"""Load the MeTRAbs model, downloading and caching on first use.
Parameters
----------
cache_dir
Directory where the model should be cached. Typically
``Settings.model_cache_dir``. If ``None``, the implementation picks
a default location.
Returns
-------
object
An opaque model handle that exposes ``detect_poses`` and the
per-skeleton joint metadata attributes used by
:class:`neuropose.estimator.Estimator`.
Raises
------
NotImplementedError
Always, at this commit. Commit 11 provides the real implementation
once the upstream MeTRAbs URL is pinned.
"""
raise NotImplementedError(
"load_metrabs_model is stubbed pending commit 11. "
"Inject a model directly via Estimator(model=...) for now, "
"or wait for the upstream MeTRAbs URL and TensorFlow version pin."
)

308
src/neuropose/estimator.py Normal file
View File

@ -0,0 +1,308 @@
"""3D human pose estimator — MeTRAbs wrapper.
The :class:`Estimator` class is the core of NeuroPose's inference path. It
takes a video file, runs the MeTRAbs 3D pose-estimation model on each frame,
and returns a validated :class:`~neuropose.io.VideoPredictions` object with
the per-frame predictions and video metadata.
Design
------
The estimator is a **library**, not a daemon: it knows nothing about job
directories, status files, or polling. Those concerns live in
:mod:`neuropose.interfacer`. An ``Estimator`` can be constructed directly
from a Python script and called on a single video, which is the path taken
by the documentation's quick-start.
Frames are streamed from the source video in memory the previous
prototype wrote every frame to disk as a PNG and then re-read each one with
``tf.io.decode_png``. That round-trip is gone; frames are read once and
passed to the model directly.
Model injection
---------------
The MeTRAbs model is supplied either:
1. Through :meth:`Estimator.load_model`, which delegates to
:func:`neuropose._model.load_metrabs_model` (stubbed pending commit 11).
2. Directly via ``Estimator(model=...)``, which is the path used by the
test suite with a fake model to exercise the code without TensorFlow.
Either way, attempting to call :meth:`Estimator.process_video` before a
model is present raises :class:`ModelNotLoadedError`.
"""
from __future__ import annotations
import logging
from collections.abc import Callable
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import cv2
from neuropose._model import load_metrabs_model
from neuropose.io import FramePrediction, VideoMetadata, VideoPredictions
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Exceptions
# ---------------------------------------------------------------------------
class EstimatorError(Exception):
"""Base class for errors raised by :class:`Estimator`."""
class ModelNotLoadedError(EstimatorError):
"""Raised when an inference method is called before the model is loaded."""
class VideoDecodeError(EstimatorError):
"""Raised when a video file cannot be opened or decoded."""
# ---------------------------------------------------------------------------
# Result container
# ---------------------------------------------------------------------------
@dataclass(frozen=True)
class ProcessVideoResult:
"""Result of :meth:`Estimator.process_video`.
Attributes
----------
predictions
The validated :class:`VideoPredictions` object, containing both the
per-frame predictions and the ``VideoMetadata`` envelope.
"""
predictions: VideoPredictions
@property
def frame_count(self) -> int:
"""Convenience accessor for ``predictions.metadata.frame_count``."""
return self.predictions.metadata.frame_count
# ---------------------------------------------------------------------------
# Estimator
# ---------------------------------------------------------------------------
ProgressCallback = Callable[[int, int], None]
"""Type alias for progress callbacks: ``(frames_processed, frame_count_hint)``."""
class Estimator:
"""3D pose estimator built on MeTRAbs.
Parameters
----------
device
TensorFlow device string (e.g. ``"/CPU:0"`` or ``"/GPU:0"``). Passed
through to the model at inference time.
skeleton
Skeleton identifier understood by MeTRAbs. Defaults to
``"berkeley_mhad_43"``, the 43-joint skeleton used by the previous
prototype.
default_fov_degrees
Horizontal field of view assumed when a video does not supply
intrinsics. Overridable per call via
:meth:`process_video`'s ``fov_degrees`` argument.
model
Optional pre-loaded MeTRAbs model. If supplied, the estimator uses
it directly and :meth:`load_model` need not be called. This path is
used by tests to inject a fake model, and by callers that want to
share a single model across many :class:`Estimator` instances.
"""
def __init__(
self,
*,
device: str = "/CPU:0",
skeleton: str = "berkeley_mhad_43",
default_fov_degrees: float = 55.0,
model: Any | None = None,
) -> None:
self.device = device
self.skeleton = skeleton
self.default_fov_degrees = default_fov_degrees
self._model: Any | None = model
# -- model lifecycle ----------------------------------------------------
@property
def model(self) -> Any:
"""Return the loaded model, or raise :class:`ModelNotLoadedError`."""
if self._model is None:
raise ModelNotLoadedError(
"Estimator model has not been loaded. "
"Call Estimator.load_model() or pass model=... to the constructor."
)
return self._model
@property
def is_model_loaded(self) -> bool:
"""Return ``True`` if a model has been supplied or loaded."""
return self._model is not None
def load_model(self, cache_dir: Path | None = None) -> None:
"""Load the MeTRAbs model via :func:`neuropose._model.load_metrabs_model`.
Parameters
----------
cache_dir
Directory where the downloaded model should be cached. Typically
``Settings.model_cache_dir``.
Notes
-----
This is idempotent: calling it again after a successful load is a
no-op. Callers that want to reload the model should construct a new
:class:`Estimator` instance.
"""
if self._model is not None:
logger.debug("Model already loaded; skipping reload.")
return
logger.info("Loading MeTRAbs model (cache_dir=%s)", cache_dir)
self._model = load_metrabs_model(cache_dir=cache_dir)
logger.info("MeTRAbs model loaded.")
# -- inference ----------------------------------------------------------
def process_video(
self,
video_path: Path,
*,
fov_degrees: float | None = None,
progress: ProgressCallback | None = None,
) -> ProcessVideoResult:
"""Run pose estimation on every frame of a video.
Parameters
----------
video_path
Path to the input video. Must exist and be openable by OpenCV.
fov_degrees
Per-call override for the horizontal field of view. If ``None``,
the estimator's ``default_fov_degrees`` is used.
progress
Optional callback invoked after each processed frame as
``progress(processed, total_hint)``. ``total_hint`` is the
approximate frame count reported by OpenCV's
``CAP_PROP_FRAME_COUNT``, which is unreliable for variable-rate
videos; the authoritative count is available after the call on
``result.frame_count``.
Returns
-------
ProcessVideoResult
A typed result containing the validated :class:`VideoPredictions`.
Raises
------
FileNotFoundError
If ``video_path`` does not exist.
VideoDecodeError
If OpenCV cannot open the video.
ModelNotLoadedError
If the model has not been loaded or injected.
"""
if not video_path.exists():
raise FileNotFoundError(f"video file not found: {video_path}")
# Access the model eagerly so we fail fast before opening the video.
model = self.model
fov = fov_degrees if fov_degrees is not None else self.default_fov_degrees
cap = cv2.VideoCapture(str(video_path))
if not cap.isOpened():
raise VideoDecodeError(f"OpenCV could not open video: {video_path}")
try:
fps = float(cap.get(cv2.CAP_PROP_FPS))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
total_hint = max(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), 0)
logger.info(
"Processing video %s (%dx%d @ %.2f fps, ~%d frames)",
video_path,
width,
height,
fps,
total_hint,
)
frames: dict[str, FramePrediction] = {}
frame_index = 0
while True:
ok, bgr_frame = cap.read()
if not ok:
break
# MeTRAbs was trained on RGB images; OpenCV gives us BGR.
rgb_frame = cv2.cvtColor(bgr_frame, cv2.COLOR_BGR2RGB)
prediction = self._infer_frame(model, rgb_frame, fov)
frames[f"frame_{frame_index:06d}"] = prediction
frame_index += 1
if progress is not None:
progress(frame_index, total_hint)
metadata = VideoMetadata(
frame_count=frame_index,
fps=fps if fps > 0 else 0.0,
width=width,
height=height,
)
finally:
cap.release()
if frame_index == 0:
logger.warning("Video %s contained no decodable frames.", video_path)
predictions = VideoPredictions(metadata=metadata, frames=frames)
return ProcessVideoResult(predictions=predictions)
# -- internals ----------------------------------------------------------
def _infer_frame(
self,
model: Any,
rgb_frame: Any,
fov_degrees: float,
) -> FramePrediction:
"""Run a single frame through the model and validate the output."""
pred = model.detect_poses(
rgb_frame,
default_fov_degrees=fov_degrees,
skeleton=self.skeleton,
)
return FramePrediction(
boxes=_to_nested_list(pred["boxes"]),
poses3d=_to_nested_list(pred["poses3d"]),
poses2d=_to_nested_list(pred["poses2d"]),
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _to_nested_list(value: Any) -> Any:
"""Normalise a TF tensor, numpy array, or nested list to Python lists.
The real MeTRAbs model returns ``tf.Tensor`` objects which expose
``.numpy()`` returning a ``numpy.ndarray``. Tests inject fake models
that return plain numpy arrays. Both paths flow through this helper so
the rest of the code is agnostic to which is in use.
"""
if hasattr(value, "numpy"):
value = value.numpy()
if hasattr(value, "tolist"):
return value.tolist()
return list(value)

View File

@ -1,14 +1,15 @@
"""I/O helpers and schema definitions for NeuroPose prediction data.
Defines pydantic models for per-frame predictions, per-video aggregated
predictions, job-level aggregated results, and the persistent status file.
All models are validated on load, so malformed files are caught at the
boundary rather than at some downstream call site.
Defines pydantic models for per-frame predictions, per-video predictions
(with metadata envelope), job-level aggregated results, and the persistent
status file. All models are validated on load, so malformed files are caught
at the boundary rather than at some downstream call site.
Atomicity: :func:`save_status` and :func:`save_job_results` write to a sibling
temp file and then atomically rename, so a crash mid-write will not leave a
partially-written file behind. This matches the crash-resilience guarantee
the interfacer daemon makes to callers.
Atomicity: :func:`save_status`, :func:`save_job_results`, and
:func:`save_video_predictions` write to a sibling temp file and then
atomically rename, so a crash mid-write will not leave a partially-written
file behind. This matches the crash-resilience guarantee the interfacer
daemon makes to callers.
"""
from __future__ import annotations
@ -55,26 +56,51 @@ class FramePrediction(BaseModel):
)
class VideoPredictions(RootModel[dict[str, FramePrediction]]):
"""Per-frame predictions for a single video, keyed by frame filename.
class VideoMetadata(BaseModel):
"""Metadata about the source video for a set of predictions.
Frame names are expected to follow the ``frame_<index>.png`` convention
written by the estimator, but no constraint is enforced at the schema
level so downstream consumers can key by any naming scheme.
Essential for reproducibility: the frame count lets downstream analysis
verify completeness, and the fps lets it convert frame indices to real
time without needing access to the original video file.
Intentionally does NOT include the source file path or filename, which
may encode subject-identifying information. Callers that need provenance
should store it out-of-band in accordance with the data-handling policy.
"""
def frames(self) -> list[str]:
"""Return the frame names in insertion order."""
return list(self.root.keys())
model_config = ConfigDict(extra="forbid", frozen=True)
frame_count: int = Field(ge=0, description="Number of frames actually processed.")
fps: float = Field(ge=0.0, description="Source video frame rate (frames per second).")
width: int = Field(ge=0, description="Source video frame width in pixels.")
height: int = Field(ge=0, description="Source video frame height in pixels.")
class VideoPredictions(BaseModel):
"""Per-frame predictions for a single video, paired with video metadata.
The ``frames`` mapping is keyed by frame identifier (``frame_<index>`` by
convention, zero-padded to 6 digits). The identifier is a stable string,
not a filesystem path no PNG file is implied.
"""
model_config = ConfigDict(extra="forbid", frozen=True)
metadata: VideoMetadata
frames: dict[str, FramePrediction]
def frame_names(self) -> list[str]:
"""Return frame identifiers in insertion order."""
return list(self.frames.keys())
def __len__(self) -> int:
return len(self.root)
return len(self.frames)
def __iter__(self) -> Iterator[str]: # type: ignore[override]
return iter(self.root)
return iter(self.frames)
def __getitem__(self, key: str) -> FramePrediction:
return self.root[key]
return self.frames[key]
class JobResults(RootModel[dict[str, VideoPredictions]]):
@ -85,7 +111,7 @@ class JobResults(RootModel[dict[str, VideoPredictions]]):
"""
def videos(self) -> list[str]:
"""Return the video names in insertion order."""
"""Return video names in insertion order."""
return list(self.root.keys())
def __len__(self) -> int:
@ -143,7 +169,7 @@ def load_video_predictions(path: Path) -> VideoPredictions:
def save_video_predictions(path: Path, predictions: VideoPredictions) -> None:
"""Serialize per-video predictions to a JSON file."""
"""Serialize per-video predictions to a JSON file atomically."""
path.parent.mkdir(parents=True, exist_ok=True)
_write_json_atomic(path, predictions.model_dump(mode="json"))

210
src/neuropose/visualize.py Normal file
View File

@ -0,0 +1,210 @@
"""Matplotlib-based visualization of NeuroPose predictions.
Separate module so that importing :mod:`neuropose.estimator` does not pull
in matplotlib or incur its global backend side effect. Callers that want
visualization import this module explicitly.
The old prototype's ``_visualize`` helper had an in-place numpy-view
aliasing bug where ``poses3d[..., 1], poses3d[..., 2] = poses3d[..., 2],
-poses3d[..., 1]`` mutated the caller's prediction array. This rewrite
makes an explicit copy before any axis reordering, so the input
:class:`VideoPredictions` is never touched.
"""
from __future__ import annotations
import logging
from collections.abc import Sequence
from pathlib import Path
from typing import Any
import cv2
import numpy as np
from neuropose.estimator import VideoDecodeError
from neuropose.io import VideoPredictions
logger = logging.getLogger(__name__)
VALID_VIEWS = frozenset({"normal", "depth"})
def visualize_predictions(
video_path: Path,
predictions: VideoPredictions,
output_dir: Path,
*,
view: str = "normal",
joint_edges: Sequence[tuple[int, int]] | None = None,
frame_indices: Sequence[int] | None = None,
) -> list[Path]:
"""Render per-frame 2D + 3D visualizations as PNG files.
Parameters
----------
video_path
Path to the source video, used to recover frame pixels for the 2D
overlay. Must be the same video the predictions were computed from.
predictions
Predictions to overlay. The dict ordering of ``predictions.frames``
determines which frames are drawn unless ``frame_indices`` is set.
output_dir
Directory to write the rendered PNGs into. Created if absent.
view
``"normal"`` (default) or ``"depth"``. Controls the 3D subplot's
axis limits. Anything else raises ``ValueError``.
joint_edges
Optional list of ``(i, j)`` index pairs specifying skeleton edges to
draw as lines connecting joints. If ``None``, only scatter points
are drawn. For ``berkeley_mhad_43`` the edges can be obtained from
``model.per_skeleton_joint_edges[skeleton]``.
frame_indices
Optional subset of frame indices to render. If ``None``, every
frame in ``predictions`` is rendered. Out-of-range indices are
silently skipped.
Returns
-------
list[Path]
Paths of the written PNG files, in render order.
Raises
------
FileNotFoundError
If ``video_path`` does not exist.
VideoDecodeError
If OpenCV cannot open the video.
ValueError
If ``view`` is not one of the supported values.
"""
if view not in VALID_VIEWS:
raise ValueError(f"view must be one of {sorted(VALID_VIEWS)}; got {view!r}")
if not video_path.exists():
raise FileNotFoundError(f"video file not found: {video_path}")
# Late imports: matplotlib carries a global backend side effect, which
# we want to avoid at module load time. pyplot is also slow to import.
import matplotlib
matplotlib.use("Agg", force=False)
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
output_dir.mkdir(parents=True, exist_ok=True)
frame_names = predictions.frame_names()
selected = _select_indices(frame_indices, len(frame_names))
cap = cv2.VideoCapture(str(video_path))
if not cap.isOpened():
raise VideoDecodeError(f"OpenCV could not open video: {video_path}")
written: list[Path] = []
try:
next_index_to_render = iter(selected)
target = next(next_index_to_render, None)
frame_index = 0
while target is not None:
ok, bgr_frame = cap.read()
if not ok:
break
if frame_index == target:
rgb_frame = cv2.cvtColor(bgr_frame, cv2.COLOR_BGR2RGB)
frame_name = frame_names[frame_index]
out_path = output_dir / f"{frame_name}.png"
_render_frame(
rgb_frame,
predictions[frame_name],
out_path,
view=view,
joint_edges=joint_edges,
plt=plt,
Rectangle=Rectangle,
)
written.append(out_path)
target = next(next_index_to_render, None)
frame_index += 1
finally:
cap.release()
logger.info("Wrote %d visualization frame(s) to %s", len(written), output_dir)
return written
def _select_indices(frame_indices: Sequence[int] | None, total: int) -> list[int]:
"""Normalise the ``frame_indices`` argument to a sorted, in-range list."""
if frame_indices is None:
return list(range(total))
return sorted({i for i in frame_indices if 0 <= i < total})
def _render_frame(
rgb_frame: Any,
frame_prediction: Any,
out_path: Path,
*,
view: str,
joint_edges: Sequence[tuple[int, int]] | None,
plt: Any,
Rectangle: Any,
) -> None:
"""Render one frame's 2D overlay + 3D scatter to ``out_path``."""
# Explicit copies. The previous prototype mutated the caller's data via
# numpy-view tuple-assignment; we take a fresh numpy array per person
# so the caller's VideoPredictions is never touched.
boxes = np.asarray(frame_prediction.boxes, dtype=float)
poses3d = np.asarray(frame_prediction.poses3d, dtype=float).copy()
poses2d = np.asarray(frame_prediction.poses2d, dtype=float)
# Rotate for visualization: swap Y and Z so the ground plane is horizontal.
# Do this on the copy so the original predictions object is untouched.
if poses3d.size > 0:
original_y = poses3d[..., 1].copy()
poses3d[..., 1] = poses3d[..., 2]
poses3d[..., 2] = -original_y
fig = plt.figure(figsize=(10, 5.2))
image_ax = fig.add_subplot(1, 2, 1)
image_ax.imshow(rgb_frame)
for box in boxes:
if len(box) < 4:
continue
x, y, w, h = box[:4]
image_ax.add_patch(Rectangle((x, y), w, h, fill=False))
pose_ax = fig.add_subplot(1, 2, 2, projection="3d")
pose_ax.view_init(5, -85)
if view == "depth":
pose_ax.set_xlim3d(200, 17500)
pose_ax.set_zlim3d(-1500, 1500)
pose_ax.set_ylim3d(0, 3000)
else:
pose_ax.set_xlim3d(-1500, 1500)
pose_ax.set_zlim3d(-1500, 1500)
pose_ax.set_ylim3d(0, 3000)
pose_ax.set_box_aspect((1, 1, 1))
for pose3d, pose2d in zip(poses3d, poses2d, strict=False):
if joint_edges is not None:
for i_start, i_end in joint_edges:
if 0 <= i_start < len(pose2d) and 0 <= i_end < len(pose2d):
image_ax.plot(
[pose2d[i_start][0], pose2d[i_end][0]],
[pose2d[i_start][1], pose2d[i_end][1]],
marker="o",
markersize=2,
)
if 0 <= i_start < len(pose3d) and 0 <= i_end < len(pose3d):
pose_ax.plot(
[pose3d[i_start][0], pose3d[i_end][0]],
[pose3d[i_start][1], pose3d[i_end][1]],
[pose3d[i_start][2], pose3d[i_end][2]],
marker="o",
markersize=2,
)
image_ax.scatter(pose2d[:, 0], pose2d[:, 1], s=2)
pose_ax.scatter(pose3d[:, 0], pose3d[:, 1], pose3d[:, 2], s=2)
fig.tight_layout()
fig.savefig(out_path)
plt.close(fig)

View File

@ -5,10 +5,18 @@ from __future__ import annotations
import os
from collections.abc import Iterator
from pathlib import Path
from typing import Any
import cv2
import numpy as np
import pytest
# ---------------------------------------------------------------------------
# Environment isolation
# ---------------------------------------------------------------------------
@pytest.fixture(autouse=True)
def _isolate_environment(
monkeypatch: pytest.MonkeyPatch,
@ -36,3 +44,65 @@ def _isolate_environment(
def xdg_home() -> Path:
"""Return the isolated ``$XDG_DATA_HOME`` set up by ``_isolate_environment``."""
return Path(os.environ["XDG_DATA_HOME"])
# ---------------------------------------------------------------------------
# Synthetic video + fake MeTRAbs model
# ---------------------------------------------------------------------------
@pytest.fixture
def synthetic_video(tmp_path: Path) -> Path:
"""Generate a tiny synthetic video at test time.
The fixture writes a 5-frame, 32×32 MJPG-encoded ``.avi`` file. MJPG is
chosen over ``mp4v`` because it ships with ``opencv-python-headless`` on
every platform we target, whereas ``mp4v`` occasionally requires an
ffmpeg binary that may not be present on minimal CI runners.
"""
path = tmp_path / "synthetic.avi"
fourcc = cv2.VideoWriter_fourcc(*"MJPG")
writer = cv2.VideoWriter(str(path), fourcc, 30.0, (32, 32))
assert writer.isOpened(), "cv2.VideoWriter failed to open; MJPG codec missing?"
for i in range(5):
# Distinct brightness per frame so a downstream check could verify
# the test is actually reading frame-by-frame.
frame = np.full((32, 32, 3), i * 40, dtype=np.uint8)
writer.write(frame)
writer.release()
assert path.exists() and path.stat().st_size > 0, "Synthetic video is empty."
return path
class _FakeMetrabsModel:
"""Minimal stand-in for the MeTRAbs model used in unit tests.
Returns deterministic pose data (one person, two joints) per call so
tests can assert on shapes without importing TensorFlow or downloading
the real model. The returned arrays are plain numpy so the estimator's
``_to_nested_list`` helper exercises its non-``numpy()`` branch.
"""
def __init__(self) -> None:
self.call_count = 0
def detect_poses(
self,
image: Any,
*,
default_fov_degrees: float,
skeleton: str,
) -> dict[str, np.ndarray]:
del image, default_fov_degrees, skeleton # signature-compatible with MeTRAbs
self.call_count += 1
return {
"boxes": np.array([[0.0, 0.0, 32.0, 32.0, 0.95]]),
"poses3d": np.array([[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]]),
"poses2d": np.array([[[10.0, 20.0], [30.0, 40.0]]]),
}
@pytest.fixture
def fake_metrabs_model() -> _FakeMetrabsModel:
"""Return a fresh fake MeTRAbs model instance for a single test."""
return _FakeMetrabsModel()

View File

@ -0,0 +1,205 @@
"""Tests for :class:`neuropose.estimator.Estimator`.
These tests exercise the non-model code paths (video decoding, frame loop,
metadata extraction, result construction, progress reporting, and error
handling) using an injected fake MeTRAbs model. The TensorFlow / real-model
integration smoke test lives in ``tests/integration/`` and lands with
commit 11.
"""
from __future__ import annotations
from pathlib import Path
import pytest
from neuropose.estimator import (
Estimator,
ModelNotLoadedError,
ProcessVideoResult,
VideoDecodeError,
)
from neuropose.io import FramePrediction, VideoPredictions
class TestConstruction:
def test_no_model_by_default(self) -> None:
estimator = Estimator()
assert not estimator.is_model_loaded
def test_injected_model_is_loaded(self, fake_metrabs_model) -> None:
estimator = Estimator(model=fake_metrabs_model)
assert estimator.is_model_loaded
def test_defaults(self) -> None:
estimator = Estimator()
assert estimator.device == "/CPU:0"
assert estimator.skeleton == "berkeley_mhad_43"
assert estimator.default_fov_degrees == pytest.approx(55.0)
def test_overrides(self, fake_metrabs_model) -> None:
estimator = Estimator(
device="/GPU:0",
skeleton="smpl_24",
default_fov_degrees=40.0,
model=fake_metrabs_model,
)
assert estimator.device == "/GPU:0"
assert estimator.skeleton == "smpl_24"
assert estimator.default_fov_degrees == pytest.approx(40.0)
class TestModelGuard:
def test_model_property_raises_when_missing(self) -> None:
estimator = Estimator()
with pytest.raises(ModelNotLoadedError):
_ = estimator.model
def test_process_video_raises_when_missing(self, synthetic_video: Path) -> None:
estimator = Estimator()
with pytest.raises(ModelNotLoadedError):
estimator.process_video(synthetic_video)
def test_load_model_stub_raises_not_implemented(self) -> None:
estimator = Estimator()
with pytest.raises(NotImplementedError, match="commit 11"):
estimator.load_model()
def test_load_model_is_idempotent_when_already_loaded(
self,
fake_metrabs_model,
) -> None:
estimator = Estimator(model=fake_metrabs_model)
# Should not raise, and should not clobber the injected model.
estimator.load_model()
assert estimator.model is fake_metrabs_model
class TestProcessVideo:
def test_returns_typed_result(
self,
synthetic_video: Path,
fake_metrabs_model,
) -> None:
estimator = Estimator(model=fake_metrabs_model)
result = estimator.process_video(synthetic_video)
assert isinstance(result, ProcessVideoResult)
assert isinstance(result.predictions, VideoPredictions)
def test_frame_count_matches_source(
self,
synthetic_video: Path,
fake_metrabs_model,
) -> None:
estimator = Estimator(model=fake_metrabs_model)
result = estimator.process_video(synthetic_video)
assert result.frame_count == 5
assert fake_metrabs_model.call_count == 5
def test_frame_naming_is_zero_padded(
self,
synthetic_video: Path,
fake_metrabs_model,
) -> None:
estimator = Estimator(model=fake_metrabs_model)
result = estimator.process_video(synthetic_video)
names = result.predictions.frame_names()
assert names == [
"frame_000000",
"frame_000001",
"frame_000002",
"frame_000003",
"frame_000004",
]
def test_metadata_populated(
self,
synthetic_video: Path,
fake_metrabs_model,
) -> None:
estimator = Estimator(model=fake_metrabs_model)
result = estimator.process_video(synthetic_video)
metadata = result.predictions.metadata
assert metadata.frame_count == 5
assert metadata.width == 32
assert metadata.height == 32
assert metadata.fps > 0.0
def test_each_frame_validates_as_frame_prediction(
self,
synthetic_video: Path,
fake_metrabs_model,
) -> None:
estimator = Estimator(model=fake_metrabs_model)
result = estimator.process_video(synthetic_video)
for name in result.predictions.frame_names():
frame = result.predictions[name]
assert isinstance(frame, FramePrediction)
assert len(frame.boxes) == 1
assert len(frame.poses3d) == 1
assert len(frame.poses3d[0]) == 2 # Two joints per the fake model.
def test_progress_callback_invoked_per_frame(
self,
synthetic_video: Path,
fake_metrabs_model,
) -> None:
estimator = Estimator(model=fake_metrabs_model)
calls: list[tuple[int, int]] = []
estimator.process_video(
synthetic_video,
progress=lambda processed, total: calls.append((processed, total)),
)
assert len(calls) == 5
# Processed counts should be strictly increasing.
assert [c[0] for c in calls] == [1, 2, 3, 4, 5]
def test_fov_override_is_passed_through(self, synthetic_video: Path) -> None:
fov_seen: list[float] = []
class RecordingModel:
def detect_poses(
self,
image,
*,
default_fov_degrees: float,
skeleton: str,
):
del image, skeleton
fov_seen.append(default_fov_degrees)
import numpy as np
return {
"boxes": np.array([[0.0, 0.0, 32.0, 32.0, 0.9]]),
"poses3d": np.array([[[0.0, 0.0, 0.0]]]),
"poses2d": np.array([[[0.0, 0.0]]]),
}
estimator = Estimator(model=RecordingModel(), default_fov_degrees=55.0)
estimator.process_video(synthetic_video, fov_degrees=40.0)
assert all(fov == pytest.approx(40.0) for fov in fov_seen)
assert len(fov_seen) == 5
class TestErrors:
def test_missing_video(
self,
tmp_path: Path,
fake_metrabs_model,
) -> None:
estimator = Estimator(model=fake_metrabs_model)
with pytest.raises(FileNotFoundError):
estimator.process_video(tmp_path / "does_not_exist.mp4")
def test_unreadable_video_raises_decode_error(
self,
tmp_path: Path,
fake_metrabs_model,
) -> None:
# A file that exists but is not a valid video. cv2.VideoCapture
# returns isOpened() == False for non-video content.
path = tmp_path / "not_a_video.avi"
path.write_bytes(b"this is definitely not a video file")
estimator = Estimator(model=fake_metrabs_model)
with pytest.raises(VideoDecodeError):
estimator.process_video(path)

View File

@ -14,6 +14,7 @@ from neuropose.io import (
JobResults,
JobStatus,
StatusFile,
VideoMetadata,
VideoPredictions,
load_job_results,
load_status,
@ -40,10 +41,18 @@ def one_frame() -> dict:
@pytest.fixture
def video_payload(one_frame: dict) -> dict:
def video_metadata_payload() -> dict:
return {"frame_count": 2, "fps": 30.0, "width": 640, "height": 480}
@pytest.fixture
def video_predictions_payload(one_frame: dict, video_metadata_payload: dict) -> dict:
return {
"frame_0000.png": one_frame,
"frame_0001.png": one_frame,
"metadata": video_metadata_payload,
"frames": {
"frame_000000": one_frame,
"frame_000001": one_frame,
},
}
@ -70,30 +79,90 @@ class TestFramePrediction:
frame.boxes = []
# ---------------------------------------------------------------------------
# VideoMetadata
# ---------------------------------------------------------------------------
class TestVideoMetadata:
def test_valid(self) -> None:
meta = VideoMetadata(frame_count=10, fps=29.97, width=1920, height=1080)
assert meta.frame_count == 10
assert meta.fps == pytest.approx(29.97)
def test_zero_frame_count_allowed(self) -> None:
# Broken or empty videos still produce a valid metadata object so
# the caller can see frame_count == 0 rather than receiving an
# exception.
VideoMetadata(frame_count=0, fps=0.0, width=0, height=0)
def test_rejects_negative(self) -> None:
with pytest.raises(ValidationError):
VideoMetadata(frame_count=-1, fps=30.0, width=640, height=480)
def test_rejects_extra_fields(self) -> None:
with pytest.raises(ValidationError):
VideoMetadata(
frame_count=1,
fps=30.0,
width=640,
height=480,
source_path="/leak/me",
)
def test_is_frozen(self) -> None:
meta = VideoMetadata(frame_count=10, fps=30.0, width=640, height=480)
with pytest.raises(ValidationError):
meta.fps = 60.0
# ---------------------------------------------------------------------------
# VideoPredictions
# ---------------------------------------------------------------------------
class TestVideoPredictions:
def test_from_dict(self, video_payload: dict) -> None:
vp = VideoPredictions.model_validate(video_payload)
def test_from_dict(self, video_predictions_payload: dict) -> None:
vp = VideoPredictions.model_validate(video_predictions_payload)
assert len(vp) == 2
assert vp.frames() == ["frame_0000.png", "frame_0001.png"]
assert vp["frame_0000.png"].boxes[0][4] == pytest.approx(0.95)
assert vp.frame_names() == ["frame_000000", "frame_000001"]
assert vp["frame_000000"].boxes[0][4] == pytest.approx(0.95)
assert vp.metadata.fps == pytest.approx(30.0)
def test_iteration(self, video_payload: dict) -> None:
vp = VideoPredictions.model_validate(video_payload)
assert list(vp) == ["frame_0000.png", "frame_0001.png"]
def test_iteration(self, video_predictions_payload: dict) -> None:
vp = VideoPredictions.model_validate(video_predictions_payload)
assert list(vp) == ["frame_000000", "frame_000001"]
def test_save_and_load_roundtrip(self, tmp_path: Path, video_payload: dict) -> None:
vp = VideoPredictions.model_validate(video_payload)
def test_rejects_missing_metadata(self, video_predictions_payload: dict) -> None:
del video_predictions_payload["metadata"]
with pytest.raises(ValidationError):
VideoPredictions.model_validate(video_predictions_payload)
def test_save_and_load_roundtrip(
self,
tmp_path: Path,
video_predictions_payload: dict,
) -> None:
vp = VideoPredictions.model_validate(video_predictions_payload)
path = tmp_path / "preds" / "video.json"
save_video_predictions(path, vp)
assert path.exists()
loaded = load_video_predictions(path)
assert loaded.frames() == vp.frames()
assert loaded["frame_0000.png"].poses3d == vp["frame_0000.png"].poses3d
assert loaded.frame_names() == vp.frame_names()
assert loaded.metadata == vp.metadata
assert loaded["frame_000000"].poses3d == vp["frame_000000"].poses3d
def test_save_is_atomic(
self,
tmp_path: Path,
video_predictions_payload: dict,
) -> None:
vp = VideoPredictions.model_validate(video_predictions_payload)
path = tmp_path / "video.json"
save_video_predictions(path, vp)
assert path.exists()
tmps = list(tmp_path.glob("video.json.tmp"))
assert tmps == []
# ---------------------------------------------------------------------------
@ -102,9 +171,16 @@ class TestVideoPredictions:
class TestJobResults:
def test_save_and_load_roundtrip(self, tmp_path: Path, video_payload: dict) -> None:
def test_save_and_load_roundtrip(
self,
tmp_path: Path,
video_predictions_payload: dict,
) -> None:
jr = JobResults.model_validate(
{"video_a.mp4": video_payload, "video_b.mp4": video_payload}
{
"video_a.mp4": video_predictions_payload,
"video_b.mp4": video_predictions_payload,
}
)
path = tmp_path / "results.json"
save_job_results(path, jr)

View File

@ -0,0 +1,158 @@
"""Tests for :mod:`neuropose.visualize`.
Smoke tests only: we verify that the visualize function runs end-to-end
against a synthetic video, writes PNG files, honours ``frame_indices``, and
does NOT mutate the caller's :class:`VideoPredictions`. Actual pixel-level
correctness is out of scope for unit tests.
"""
from __future__ import annotations
import copy
from pathlib import Path
import pytest
from neuropose.estimator import Estimator, VideoDecodeError
from neuropose.io import VideoPredictions
from neuropose.visualize import visualize_predictions
@pytest.fixture
def predictions_for_synthetic(
synthetic_video: Path,
fake_metrabs_model,
) -> VideoPredictions:
"""Run the fake estimator over the synthetic video and return predictions."""
estimator = Estimator(model=fake_metrabs_model)
return estimator.process_video(synthetic_video).predictions
class TestVisualizePredictions:
def test_writes_one_png_per_frame(
self,
tmp_path: Path,
synthetic_video: Path,
predictions_for_synthetic: VideoPredictions,
) -> None:
output_dir = tmp_path / "viz"
written = visualize_predictions(
synthetic_video, predictions_for_synthetic, output_dir
)
assert len(written) == 5
for path in written:
assert path.exists()
assert path.suffix == ".png"
assert path.stat().st_size > 0
def test_frame_indices_limits_output(
self,
tmp_path: Path,
synthetic_video: Path,
predictions_for_synthetic: VideoPredictions,
) -> None:
output_dir = tmp_path / "viz"
written = visualize_predictions(
synthetic_video,
predictions_for_synthetic,
output_dir,
frame_indices=[0, 2, 4],
)
assert len(written) == 3
names = sorted(p.stem for p in written)
assert names == ["frame_000000", "frame_000002", "frame_000004"]
def test_out_of_range_indices_silently_skipped(
self,
tmp_path: Path,
synthetic_video: Path,
predictions_for_synthetic: VideoPredictions,
) -> None:
output_dir = tmp_path / "viz"
written = visualize_predictions(
synthetic_video,
predictions_for_synthetic,
output_dir,
frame_indices=[0, 999, -1],
)
assert len(written) == 1
assert written[0].stem == "frame_000000"
def test_does_not_mutate_input_predictions(
self,
tmp_path: Path,
synthetic_video: Path,
predictions_for_synthetic: VideoPredictions,
) -> None:
"""The aliasing bug from the previous prototype mutated poses3d in place.
This test guards against its regression by deep-copying poses3d
before visualization and comparing against the live value after.
Without the deepcopy, any in-place mutation would be invisible
because ``before`` and ``after`` would refer to the same list.
"""
frame_name = predictions_for_synthetic.frame_names()[0]
before = copy.deepcopy(predictions_for_synthetic[frame_name].poses3d)
visualize_predictions(
synthetic_video,
predictions_for_synthetic,
tmp_path / "viz",
frame_indices=[0],
)
after = predictions_for_synthetic[frame_name].poses3d
assert before == after
def test_rejects_invalid_view(
self,
tmp_path: Path,
synthetic_video: Path,
predictions_for_synthetic: VideoPredictions,
) -> None:
with pytest.raises(ValueError, match="view must be one of"):
visualize_predictions(
synthetic_video,
predictions_for_synthetic,
tmp_path / "viz",
view="orthographic",
)
def test_depth_view_runs(
self,
tmp_path: Path,
synthetic_video: Path,
predictions_for_synthetic: VideoPredictions,
) -> None:
written = visualize_predictions(
synthetic_video,
predictions_for_synthetic,
tmp_path / "viz",
view="depth",
frame_indices=[0],
)
assert len(written) == 1
def test_missing_video_raises(
self,
tmp_path: Path,
predictions_for_synthetic: VideoPredictions,
) -> None:
with pytest.raises(FileNotFoundError):
visualize_predictions(
tmp_path / "nope.avi",
predictions_for_synthetic,
tmp_path / "viz",
)
def test_unreadable_video_raises(
self,
tmp_path: Path,
predictions_for_synthetic: VideoPredictions,
) -> None:
fake_video = tmp_path / "fake.avi"
fake_video.write_bytes(b"not a video")
with pytest.raises(VideoDecodeError):
visualize_predictions(
fake_video,
predictions_for_synthetic,
tmp_path / "viz",
)