estimator
This commit is contained in:
parent
3d2b2fc68d
commit
9bbbbd0d52
|
|
@ -0,0 +1,51 @@
|
||||||
|
"""MeTRAbs model loading — stub pending commit 11.
|
||||||
|
|
||||||
|
This module exists so that :mod:`neuropose.estimator` can import a single,
|
||||||
|
well-typed loader function even before the upstream MeTRAbs URL is pinned
|
||||||
|
and the TensorFlow version is settled.
|
||||||
|
|
||||||
|
Commit 11 will replace :func:`load_metrabs_model` with an implementation
|
||||||
|
that:
|
||||||
|
|
||||||
|
1. Pins the canonical MeTRAbs tfhub / Kaggle Models handle (replacing the
|
||||||
|
``bit.ly/metrabs_1`` shortener from the previous prototype).
|
||||||
|
2. Caches the downloaded model under ``Settings.model_cache_dir`` so the
|
||||||
|
first run downloads it and subsequent runs are offline.
|
||||||
|
3. Returns a typed handle that the estimator can invoke without hitting the
|
||||||
|
network on each instantiation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
def load_metrabs_model(cache_dir: Path | None = None) -> Any: # noqa: ARG001
|
||||||
|
"""Load the MeTRAbs model, downloading and caching on first use.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
cache_dir
|
||||||
|
Directory where the model should be cached. Typically
|
||||||
|
``Settings.model_cache_dir``. If ``None``, the implementation picks
|
||||||
|
a default location.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
object
|
||||||
|
An opaque model handle that exposes ``detect_poses`` and the
|
||||||
|
per-skeleton joint metadata attributes used by
|
||||||
|
:class:`neuropose.estimator.Estimator`.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
NotImplementedError
|
||||||
|
Always, at this commit. Commit 11 provides the real implementation
|
||||||
|
once the upstream MeTRAbs URL is pinned.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(
|
||||||
|
"load_metrabs_model is stubbed pending commit 11. "
|
||||||
|
"Inject a model directly via Estimator(model=...) for now, "
|
||||||
|
"or wait for the upstream MeTRAbs URL and TensorFlow version pin."
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,308 @@
|
||||||
|
"""3D human pose estimator — MeTRAbs wrapper.
|
||||||
|
|
||||||
|
The :class:`Estimator` class is the core of NeuroPose's inference path. It
|
||||||
|
takes a video file, runs the MeTRAbs 3D pose-estimation model on each frame,
|
||||||
|
and returns a validated :class:`~neuropose.io.VideoPredictions` object with
|
||||||
|
the per-frame predictions and video metadata.
|
||||||
|
|
||||||
|
Design
|
||||||
|
------
|
||||||
|
The estimator is a **library**, not a daemon: it knows nothing about job
|
||||||
|
directories, status files, or polling. Those concerns live in
|
||||||
|
:mod:`neuropose.interfacer`. An ``Estimator`` can be constructed directly
|
||||||
|
from a Python script and called on a single video, which is the path taken
|
||||||
|
by the documentation's quick-start.
|
||||||
|
|
||||||
|
Frames are streamed from the source video in memory — the previous
|
||||||
|
prototype wrote every frame to disk as a PNG and then re-read each one with
|
||||||
|
``tf.io.decode_png``. That round-trip is gone; frames are read once and
|
||||||
|
passed to the model directly.
|
||||||
|
|
||||||
|
Model injection
|
||||||
|
---------------
|
||||||
|
The MeTRAbs model is supplied either:
|
||||||
|
|
||||||
|
1. Through :meth:`Estimator.load_model`, which delegates to
|
||||||
|
:func:`neuropose._model.load_metrabs_model` (stubbed pending commit 11).
|
||||||
|
2. Directly via ``Estimator(model=...)``, which is the path used by the
|
||||||
|
test suite with a fake model to exercise the code without TensorFlow.
|
||||||
|
|
||||||
|
Either way, attempting to call :meth:`Estimator.process_video` before a
|
||||||
|
model is present raises :class:`ModelNotLoadedError`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from collections.abc import Callable
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
from neuropose._model import load_metrabs_model
|
||||||
|
from neuropose.io import FramePrediction, VideoMetadata, VideoPredictions
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Exceptions
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class EstimatorError(Exception):
|
||||||
|
"""Base class for errors raised by :class:`Estimator`."""
|
||||||
|
|
||||||
|
|
||||||
|
class ModelNotLoadedError(EstimatorError):
|
||||||
|
"""Raised when an inference method is called before the model is loaded."""
|
||||||
|
|
||||||
|
|
||||||
|
class VideoDecodeError(EstimatorError):
|
||||||
|
"""Raised when a video file cannot be opened or decoded."""
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Result container
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ProcessVideoResult:
|
||||||
|
"""Result of :meth:`Estimator.process_video`.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
predictions
|
||||||
|
The validated :class:`VideoPredictions` object, containing both the
|
||||||
|
per-frame predictions and the ``VideoMetadata`` envelope.
|
||||||
|
"""
|
||||||
|
|
||||||
|
predictions: VideoPredictions
|
||||||
|
|
||||||
|
@property
|
||||||
|
def frame_count(self) -> int:
|
||||||
|
"""Convenience accessor for ``predictions.metadata.frame_count``."""
|
||||||
|
return self.predictions.metadata.frame_count
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Estimator
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
ProgressCallback = Callable[[int, int], None]
|
||||||
|
"""Type alias for progress callbacks: ``(frames_processed, frame_count_hint)``."""
|
||||||
|
|
||||||
|
|
||||||
|
class Estimator:
|
||||||
|
"""3D pose estimator built on MeTRAbs.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
device
|
||||||
|
TensorFlow device string (e.g. ``"/CPU:0"`` or ``"/GPU:0"``). Passed
|
||||||
|
through to the model at inference time.
|
||||||
|
skeleton
|
||||||
|
Skeleton identifier understood by MeTRAbs. Defaults to
|
||||||
|
``"berkeley_mhad_43"``, the 43-joint skeleton used by the previous
|
||||||
|
prototype.
|
||||||
|
default_fov_degrees
|
||||||
|
Horizontal field of view assumed when a video does not supply
|
||||||
|
intrinsics. Overridable per call via
|
||||||
|
:meth:`process_video`'s ``fov_degrees`` argument.
|
||||||
|
model
|
||||||
|
Optional pre-loaded MeTRAbs model. If supplied, the estimator uses
|
||||||
|
it directly and :meth:`load_model` need not be called. This path is
|
||||||
|
used by tests to inject a fake model, and by callers that want to
|
||||||
|
share a single model across many :class:`Estimator` instances.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
device: str = "/CPU:0",
|
||||||
|
skeleton: str = "berkeley_mhad_43",
|
||||||
|
default_fov_degrees: float = 55.0,
|
||||||
|
model: Any | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.device = device
|
||||||
|
self.skeleton = skeleton
|
||||||
|
self.default_fov_degrees = default_fov_degrees
|
||||||
|
self._model: Any | None = model
|
||||||
|
|
||||||
|
# -- model lifecycle ----------------------------------------------------
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model(self) -> Any:
|
||||||
|
"""Return the loaded model, or raise :class:`ModelNotLoadedError`."""
|
||||||
|
if self._model is None:
|
||||||
|
raise ModelNotLoadedError(
|
||||||
|
"Estimator model has not been loaded. "
|
||||||
|
"Call Estimator.load_model() or pass model=... to the constructor."
|
||||||
|
)
|
||||||
|
return self._model
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_model_loaded(self) -> bool:
|
||||||
|
"""Return ``True`` if a model has been supplied or loaded."""
|
||||||
|
return self._model is not None
|
||||||
|
|
||||||
|
def load_model(self, cache_dir: Path | None = None) -> None:
|
||||||
|
"""Load the MeTRAbs model via :func:`neuropose._model.load_metrabs_model`.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
cache_dir
|
||||||
|
Directory where the downloaded model should be cached. Typically
|
||||||
|
``Settings.model_cache_dir``.
|
||||||
|
|
||||||
|
Notes
|
||||||
|
-----
|
||||||
|
This is idempotent: calling it again after a successful load is a
|
||||||
|
no-op. Callers that want to reload the model should construct a new
|
||||||
|
:class:`Estimator` instance.
|
||||||
|
"""
|
||||||
|
if self._model is not None:
|
||||||
|
logger.debug("Model already loaded; skipping reload.")
|
||||||
|
return
|
||||||
|
logger.info("Loading MeTRAbs model (cache_dir=%s)", cache_dir)
|
||||||
|
self._model = load_metrabs_model(cache_dir=cache_dir)
|
||||||
|
logger.info("MeTRAbs model loaded.")
|
||||||
|
|
||||||
|
# -- inference ----------------------------------------------------------
|
||||||
|
|
||||||
|
def process_video(
|
||||||
|
self,
|
||||||
|
video_path: Path,
|
||||||
|
*,
|
||||||
|
fov_degrees: float | None = None,
|
||||||
|
progress: ProgressCallback | None = None,
|
||||||
|
) -> ProcessVideoResult:
|
||||||
|
"""Run pose estimation on every frame of a video.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
video_path
|
||||||
|
Path to the input video. Must exist and be openable by OpenCV.
|
||||||
|
fov_degrees
|
||||||
|
Per-call override for the horizontal field of view. If ``None``,
|
||||||
|
the estimator's ``default_fov_degrees`` is used.
|
||||||
|
progress
|
||||||
|
Optional callback invoked after each processed frame as
|
||||||
|
``progress(processed, total_hint)``. ``total_hint`` is the
|
||||||
|
approximate frame count reported by OpenCV's
|
||||||
|
``CAP_PROP_FRAME_COUNT``, which is unreliable for variable-rate
|
||||||
|
videos; the authoritative count is available after the call on
|
||||||
|
``result.frame_count``.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
ProcessVideoResult
|
||||||
|
A typed result containing the validated :class:`VideoPredictions`.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
FileNotFoundError
|
||||||
|
If ``video_path`` does not exist.
|
||||||
|
VideoDecodeError
|
||||||
|
If OpenCV cannot open the video.
|
||||||
|
ModelNotLoadedError
|
||||||
|
If the model has not been loaded or injected.
|
||||||
|
"""
|
||||||
|
if not video_path.exists():
|
||||||
|
raise FileNotFoundError(f"video file not found: {video_path}")
|
||||||
|
|
||||||
|
# Access the model eagerly so we fail fast before opening the video.
|
||||||
|
model = self.model
|
||||||
|
fov = fov_degrees if fov_degrees is not None else self.default_fov_degrees
|
||||||
|
|
||||||
|
cap = cv2.VideoCapture(str(video_path))
|
||||||
|
if not cap.isOpened():
|
||||||
|
raise VideoDecodeError(f"OpenCV could not open video: {video_path}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
fps = float(cap.get(cv2.CAP_PROP_FPS))
|
||||||
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||||
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||||
|
total_hint = max(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), 0)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Processing video %s (%dx%d @ %.2f fps, ~%d frames)",
|
||||||
|
video_path,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
fps,
|
||||||
|
total_hint,
|
||||||
|
)
|
||||||
|
|
||||||
|
frames: dict[str, FramePrediction] = {}
|
||||||
|
frame_index = 0
|
||||||
|
while True:
|
||||||
|
ok, bgr_frame = cap.read()
|
||||||
|
if not ok:
|
||||||
|
break
|
||||||
|
# MeTRAbs was trained on RGB images; OpenCV gives us BGR.
|
||||||
|
rgb_frame = cv2.cvtColor(bgr_frame, cv2.COLOR_BGR2RGB)
|
||||||
|
prediction = self._infer_frame(model, rgb_frame, fov)
|
||||||
|
frames[f"frame_{frame_index:06d}"] = prediction
|
||||||
|
frame_index += 1
|
||||||
|
if progress is not None:
|
||||||
|
progress(frame_index, total_hint)
|
||||||
|
|
||||||
|
metadata = VideoMetadata(
|
||||||
|
frame_count=frame_index,
|
||||||
|
fps=fps if fps > 0 else 0.0,
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
cap.release()
|
||||||
|
|
||||||
|
if frame_index == 0:
|
||||||
|
logger.warning("Video %s contained no decodable frames.", video_path)
|
||||||
|
|
||||||
|
predictions = VideoPredictions(metadata=metadata, frames=frames)
|
||||||
|
return ProcessVideoResult(predictions=predictions)
|
||||||
|
|
||||||
|
# -- internals ----------------------------------------------------------
|
||||||
|
|
||||||
|
def _infer_frame(
|
||||||
|
self,
|
||||||
|
model: Any,
|
||||||
|
rgb_frame: Any,
|
||||||
|
fov_degrees: float,
|
||||||
|
) -> FramePrediction:
|
||||||
|
"""Run a single frame through the model and validate the output."""
|
||||||
|
pred = model.detect_poses(
|
||||||
|
rgb_frame,
|
||||||
|
default_fov_degrees=fov_degrees,
|
||||||
|
skeleton=self.skeleton,
|
||||||
|
)
|
||||||
|
return FramePrediction(
|
||||||
|
boxes=_to_nested_list(pred["boxes"]),
|
||||||
|
poses3d=_to_nested_list(pred["poses3d"]),
|
||||||
|
poses2d=_to_nested_list(pred["poses2d"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _to_nested_list(value: Any) -> Any:
|
||||||
|
"""Normalise a TF tensor, numpy array, or nested list to Python lists.
|
||||||
|
|
||||||
|
The real MeTRAbs model returns ``tf.Tensor`` objects which expose
|
||||||
|
``.numpy()`` returning a ``numpy.ndarray``. Tests inject fake models
|
||||||
|
that return plain numpy arrays. Both paths flow through this helper so
|
||||||
|
the rest of the code is agnostic to which is in use.
|
||||||
|
"""
|
||||||
|
if hasattr(value, "numpy"):
|
||||||
|
value = value.numpy()
|
||||||
|
if hasattr(value, "tolist"):
|
||||||
|
return value.tolist()
|
||||||
|
return list(value)
|
||||||
|
|
@ -1,14 +1,15 @@
|
||||||
"""I/O helpers and schema definitions for NeuroPose prediction data.
|
"""I/O helpers and schema definitions for NeuroPose prediction data.
|
||||||
|
|
||||||
Defines pydantic models for per-frame predictions, per-video aggregated
|
Defines pydantic models for per-frame predictions, per-video predictions
|
||||||
predictions, job-level aggregated results, and the persistent status file.
|
(with metadata envelope), job-level aggregated results, and the persistent
|
||||||
All models are validated on load, so malformed files are caught at the
|
status file. All models are validated on load, so malformed files are caught
|
||||||
boundary rather than at some downstream call site.
|
at the boundary rather than at some downstream call site.
|
||||||
|
|
||||||
Atomicity: :func:`save_status` and :func:`save_job_results` write to a sibling
|
Atomicity: :func:`save_status`, :func:`save_job_results`, and
|
||||||
temp file and then atomically rename, so a crash mid-write will not leave a
|
:func:`save_video_predictions` write to a sibling temp file and then
|
||||||
partially-written file behind. This matches the crash-resilience guarantee
|
atomically rename, so a crash mid-write will not leave a partially-written
|
||||||
the interfacer daemon makes to callers.
|
file behind. This matches the crash-resilience guarantee the interfacer
|
||||||
|
daemon makes to callers.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
@ -55,26 +56,51 @@ class FramePrediction(BaseModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class VideoPredictions(RootModel[dict[str, FramePrediction]]):
|
class VideoMetadata(BaseModel):
|
||||||
"""Per-frame predictions for a single video, keyed by frame filename.
|
"""Metadata about the source video for a set of predictions.
|
||||||
|
|
||||||
Frame names are expected to follow the ``frame_<index>.png`` convention
|
Essential for reproducibility: the frame count lets downstream analysis
|
||||||
written by the estimator, but no constraint is enforced at the schema
|
verify completeness, and the fps lets it convert frame indices to real
|
||||||
level so downstream consumers can key by any naming scheme.
|
time without needing access to the original video file.
|
||||||
|
|
||||||
|
Intentionally does NOT include the source file path or filename, which
|
||||||
|
may encode subject-identifying information. Callers that need provenance
|
||||||
|
should store it out-of-band in accordance with the data-handling policy.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def frames(self) -> list[str]:
|
model_config = ConfigDict(extra="forbid", frozen=True)
|
||||||
"""Return the frame names in insertion order."""
|
|
||||||
return list(self.root.keys())
|
frame_count: int = Field(ge=0, description="Number of frames actually processed.")
|
||||||
|
fps: float = Field(ge=0.0, description="Source video frame rate (frames per second).")
|
||||||
|
width: int = Field(ge=0, description="Source video frame width in pixels.")
|
||||||
|
height: int = Field(ge=0, description="Source video frame height in pixels.")
|
||||||
|
|
||||||
|
|
||||||
|
class VideoPredictions(BaseModel):
|
||||||
|
"""Per-frame predictions for a single video, paired with video metadata.
|
||||||
|
|
||||||
|
The ``frames`` mapping is keyed by frame identifier (``frame_<index>`` by
|
||||||
|
convention, zero-padded to 6 digits). The identifier is a stable string,
|
||||||
|
not a filesystem path — no PNG file is implied.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="forbid", frozen=True)
|
||||||
|
|
||||||
|
metadata: VideoMetadata
|
||||||
|
frames: dict[str, FramePrediction]
|
||||||
|
|
||||||
|
def frame_names(self) -> list[str]:
|
||||||
|
"""Return frame identifiers in insertion order."""
|
||||||
|
return list(self.frames.keys())
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return len(self.root)
|
return len(self.frames)
|
||||||
|
|
||||||
def __iter__(self) -> Iterator[str]: # type: ignore[override]
|
def __iter__(self) -> Iterator[str]: # type: ignore[override]
|
||||||
return iter(self.root)
|
return iter(self.frames)
|
||||||
|
|
||||||
def __getitem__(self, key: str) -> FramePrediction:
|
def __getitem__(self, key: str) -> FramePrediction:
|
||||||
return self.root[key]
|
return self.frames[key]
|
||||||
|
|
||||||
|
|
||||||
class JobResults(RootModel[dict[str, VideoPredictions]]):
|
class JobResults(RootModel[dict[str, VideoPredictions]]):
|
||||||
|
|
@ -85,7 +111,7 @@ class JobResults(RootModel[dict[str, VideoPredictions]]):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def videos(self) -> list[str]:
|
def videos(self) -> list[str]:
|
||||||
"""Return the video names in insertion order."""
|
"""Return video names in insertion order."""
|
||||||
return list(self.root.keys())
|
return list(self.root.keys())
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
|
|
@ -143,7 +169,7 @@ def load_video_predictions(path: Path) -> VideoPredictions:
|
||||||
|
|
||||||
|
|
||||||
def save_video_predictions(path: Path, predictions: VideoPredictions) -> None:
|
def save_video_predictions(path: Path, predictions: VideoPredictions) -> None:
|
||||||
"""Serialize per-video predictions to a JSON file."""
|
"""Serialize per-video predictions to a JSON file atomically."""
|
||||||
path.parent.mkdir(parents=True, exist_ok=True)
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
_write_json_atomic(path, predictions.model_dump(mode="json"))
|
_write_json_atomic(path, predictions.model_dump(mode="json"))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,210 @@
|
||||||
|
"""Matplotlib-based visualization of NeuroPose predictions.
|
||||||
|
|
||||||
|
Separate module so that importing :mod:`neuropose.estimator` does not pull
|
||||||
|
in matplotlib or incur its global backend side effect. Callers that want
|
||||||
|
visualization import this module explicitly.
|
||||||
|
|
||||||
|
The old prototype's ``_visualize`` helper had an in-place numpy-view
|
||||||
|
aliasing bug where ``poses3d[..., 1], poses3d[..., 2] = poses3d[..., 2],
|
||||||
|
-poses3d[..., 1]`` mutated the caller's prediction array. This rewrite
|
||||||
|
makes an explicit copy before any axis reordering, so the input
|
||||||
|
:class:`VideoPredictions` is never touched.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from neuropose.estimator import VideoDecodeError
|
||||||
|
from neuropose.io import VideoPredictions
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
VALID_VIEWS = frozenset({"normal", "depth"})
|
||||||
|
|
||||||
|
|
||||||
|
def visualize_predictions(
|
||||||
|
video_path: Path,
|
||||||
|
predictions: VideoPredictions,
|
||||||
|
output_dir: Path,
|
||||||
|
*,
|
||||||
|
view: str = "normal",
|
||||||
|
joint_edges: Sequence[tuple[int, int]] | None = None,
|
||||||
|
frame_indices: Sequence[int] | None = None,
|
||||||
|
) -> list[Path]:
|
||||||
|
"""Render per-frame 2D + 3D visualizations as PNG files.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
video_path
|
||||||
|
Path to the source video, used to recover frame pixels for the 2D
|
||||||
|
overlay. Must be the same video the predictions were computed from.
|
||||||
|
predictions
|
||||||
|
Predictions to overlay. The dict ordering of ``predictions.frames``
|
||||||
|
determines which frames are drawn unless ``frame_indices`` is set.
|
||||||
|
output_dir
|
||||||
|
Directory to write the rendered PNGs into. Created if absent.
|
||||||
|
view
|
||||||
|
``"normal"`` (default) or ``"depth"``. Controls the 3D subplot's
|
||||||
|
axis limits. Anything else raises ``ValueError``.
|
||||||
|
joint_edges
|
||||||
|
Optional list of ``(i, j)`` index pairs specifying skeleton edges to
|
||||||
|
draw as lines connecting joints. If ``None``, only scatter points
|
||||||
|
are drawn. For ``berkeley_mhad_43`` the edges can be obtained from
|
||||||
|
``model.per_skeleton_joint_edges[skeleton]``.
|
||||||
|
frame_indices
|
||||||
|
Optional subset of frame indices to render. If ``None``, every
|
||||||
|
frame in ``predictions`` is rendered. Out-of-range indices are
|
||||||
|
silently skipped.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
list[Path]
|
||||||
|
Paths of the written PNG files, in render order.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
FileNotFoundError
|
||||||
|
If ``video_path`` does not exist.
|
||||||
|
VideoDecodeError
|
||||||
|
If OpenCV cannot open the video.
|
||||||
|
ValueError
|
||||||
|
If ``view`` is not one of the supported values.
|
||||||
|
"""
|
||||||
|
if view not in VALID_VIEWS:
|
||||||
|
raise ValueError(f"view must be one of {sorted(VALID_VIEWS)}; got {view!r}")
|
||||||
|
if not video_path.exists():
|
||||||
|
raise FileNotFoundError(f"video file not found: {video_path}")
|
||||||
|
|
||||||
|
# Late imports: matplotlib carries a global backend side effect, which
|
||||||
|
# we want to avoid at module load time. pyplot is also slow to import.
|
||||||
|
import matplotlib
|
||||||
|
|
||||||
|
matplotlib.use("Agg", force=False)
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from matplotlib.patches import Rectangle
|
||||||
|
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
frame_names = predictions.frame_names()
|
||||||
|
selected = _select_indices(frame_indices, len(frame_names))
|
||||||
|
|
||||||
|
cap = cv2.VideoCapture(str(video_path))
|
||||||
|
if not cap.isOpened():
|
||||||
|
raise VideoDecodeError(f"OpenCV could not open video: {video_path}")
|
||||||
|
|
||||||
|
written: list[Path] = []
|
||||||
|
try:
|
||||||
|
next_index_to_render = iter(selected)
|
||||||
|
target = next(next_index_to_render, None)
|
||||||
|
frame_index = 0
|
||||||
|
while target is not None:
|
||||||
|
ok, bgr_frame = cap.read()
|
||||||
|
if not ok:
|
||||||
|
break
|
||||||
|
if frame_index == target:
|
||||||
|
rgb_frame = cv2.cvtColor(bgr_frame, cv2.COLOR_BGR2RGB)
|
||||||
|
frame_name = frame_names[frame_index]
|
||||||
|
out_path = output_dir / f"{frame_name}.png"
|
||||||
|
_render_frame(
|
||||||
|
rgb_frame,
|
||||||
|
predictions[frame_name],
|
||||||
|
out_path,
|
||||||
|
view=view,
|
||||||
|
joint_edges=joint_edges,
|
||||||
|
plt=plt,
|
||||||
|
Rectangle=Rectangle,
|
||||||
|
)
|
||||||
|
written.append(out_path)
|
||||||
|
target = next(next_index_to_render, None)
|
||||||
|
frame_index += 1
|
||||||
|
finally:
|
||||||
|
cap.release()
|
||||||
|
|
||||||
|
logger.info("Wrote %d visualization frame(s) to %s", len(written), output_dir)
|
||||||
|
return written
|
||||||
|
|
||||||
|
|
||||||
|
def _select_indices(frame_indices: Sequence[int] | None, total: int) -> list[int]:
|
||||||
|
"""Normalise the ``frame_indices`` argument to a sorted, in-range list."""
|
||||||
|
if frame_indices is None:
|
||||||
|
return list(range(total))
|
||||||
|
return sorted({i for i in frame_indices if 0 <= i < total})
|
||||||
|
|
||||||
|
|
||||||
|
def _render_frame(
|
||||||
|
rgb_frame: Any,
|
||||||
|
frame_prediction: Any,
|
||||||
|
out_path: Path,
|
||||||
|
*,
|
||||||
|
view: str,
|
||||||
|
joint_edges: Sequence[tuple[int, int]] | None,
|
||||||
|
plt: Any,
|
||||||
|
Rectangle: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Render one frame's 2D overlay + 3D scatter to ``out_path``."""
|
||||||
|
# Explicit copies. The previous prototype mutated the caller's data via
|
||||||
|
# numpy-view tuple-assignment; we take a fresh numpy array per person
|
||||||
|
# so the caller's VideoPredictions is never touched.
|
||||||
|
boxes = np.asarray(frame_prediction.boxes, dtype=float)
|
||||||
|
poses3d = np.asarray(frame_prediction.poses3d, dtype=float).copy()
|
||||||
|
poses2d = np.asarray(frame_prediction.poses2d, dtype=float)
|
||||||
|
|
||||||
|
# Rotate for visualization: swap Y and Z so the ground plane is horizontal.
|
||||||
|
# Do this on the copy so the original predictions object is untouched.
|
||||||
|
if poses3d.size > 0:
|
||||||
|
original_y = poses3d[..., 1].copy()
|
||||||
|
poses3d[..., 1] = poses3d[..., 2]
|
||||||
|
poses3d[..., 2] = -original_y
|
||||||
|
|
||||||
|
fig = plt.figure(figsize=(10, 5.2))
|
||||||
|
|
||||||
|
image_ax = fig.add_subplot(1, 2, 1)
|
||||||
|
image_ax.imshow(rgb_frame)
|
||||||
|
for box in boxes:
|
||||||
|
if len(box) < 4:
|
||||||
|
continue
|
||||||
|
x, y, w, h = box[:4]
|
||||||
|
image_ax.add_patch(Rectangle((x, y), w, h, fill=False))
|
||||||
|
|
||||||
|
pose_ax = fig.add_subplot(1, 2, 2, projection="3d")
|
||||||
|
pose_ax.view_init(5, -85)
|
||||||
|
if view == "depth":
|
||||||
|
pose_ax.set_xlim3d(200, 17500)
|
||||||
|
pose_ax.set_zlim3d(-1500, 1500)
|
||||||
|
pose_ax.set_ylim3d(0, 3000)
|
||||||
|
else:
|
||||||
|
pose_ax.set_xlim3d(-1500, 1500)
|
||||||
|
pose_ax.set_zlim3d(-1500, 1500)
|
||||||
|
pose_ax.set_ylim3d(0, 3000)
|
||||||
|
pose_ax.set_box_aspect((1, 1, 1))
|
||||||
|
|
||||||
|
for pose3d, pose2d in zip(poses3d, poses2d, strict=False):
|
||||||
|
if joint_edges is not None:
|
||||||
|
for i_start, i_end in joint_edges:
|
||||||
|
if 0 <= i_start < len(pose2d) and 0 <= i_end < len(pose2d):
|
||||||
|
image_ax.plot(
|
||||||
|
[pose2d[i_start][0], pose2d[i_end][0]],
|
||||||
|
[pose2d[i_start][1], pose2d[i_end][1]],
|
||||||
|
marker="o",
|
||||||
|
markersize=2,
|
||||||
|
)
|
||||||
|
if 0 <= i_start < len(pose3d) and 0 <= i_end < len(pose3d):
|
||||||
|
pose_ax.plot(
|
||||||
|
[pose3d[i_start][0], pose3d[i_end][0]],
|
||||||
|
[pose3d[i_start][1], pose3d[i_end][1]],
|
||||||
|
[pose3d[i_start][2], pose3d[i_end][2]],
|
||||||
|
marker="o",
|
||||||
|
markersize=2,
|
||||||
|
)
|
||||||
|
image_ax.scatter(pose2d[:, 0], pose2d[:, 1], s=2)
|
||||||
|
pose_ax.scatter(pose3d[:, 0], pose3d[:, 1], pose3d[:, 2], s=2)
|
||||||
|
|
||||||
|
fig.tight_layout()
|
||||||
|
fig.savefig(out_path)
|
||||||
|
plt.close(fig)
|
||||||
|
|
@ -5,10 +5,18 @@ from __future__ import annotations
|
||||||
import os
|
import os
|
||||||
from collections.abc import Iterator
|
from collections.abc import Iterator
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Environment isolation
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def _isolate_environment(
|
def _isolate_environment(
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
|
@ -36,3 +44,65 @@ def _isolate_environment(
|
||||||
def xdg_home() -> Path:
|
def xdg_home() -> Path:
|
||||||
"""Return the isolated ``$XDG_DATA_HOME`` set up by ``_isolate_environment``."""
|
"""Return the isolated ``$XDG_DATA_HOME`` set up by ``_isolate_environment``."""
|
||||||
return Path(os.environ["XDG_DATA_HOME"])
|
return Path(os.environ["XDG_DATA_HOME"])
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Synthetic video + fake MeTRAbs model
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def synthetic_video(tmp_path: Path) -> Path:
|
||||||
|
"""Generate a tiny synthetic video at test time.
|
||||||
|
|
||||||
|
The fixture writes a 5-frame, 32×32 MJPG-encoded ``.avi`` file. MJPG is
|
||||||
|
chosen over ``mp4v`` because it ships with ``opencv-python-headless`` on
|
||||||
|
every platform we target, whereas ``mp4v`` occasionally requires an
|
||||||
|
ffmpeg binary that may not be present on minimal CI runners.
|
||||||
|
"""
|
||||||
|
path = tmp_path / "synthetic.avi"
|
||||||
|
fourcc = cv2.VideoWriter_fourcc(*"MJPG")
|
||||||
|
writer = cv2.VideoWriter(str(path), fourcc, 30.0, (32, 32))
|
||||||
|
assert writer.isOpened(), "cv2.VideoWriter failed to open; MJPG codec missing?"
|
||||||
|
for i in range(5):
|
||||||
|
# Distinct brightness per frame so a downstream check could verify
|
||||||
|
# the test is actually reading frame-by-frame.
|
||||||
|
frame = np.full((32, 32, 3), i * 40, dtype=np.uint8)
|
||||||
|
writer.write(frame)
|
||||||
|
writer.release()
|
||||||
|
assert path.exists() and path.stat().st_size > 0, "Synthetic video is empty."
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeMetrabsModel:
|
||||||
|
"""Minimal stand-in for the MeTRAbs model used in unit tests.
|
||||||
|
|
||||||
|
Returns deterministic pose data (one person, two joints) per call so
|
||||||
|
tests can assert on shapes without importing TensorFlow or downloading
|
||||||
|
the real model. The returned arrays are plain numpy so the estimator's
|
||||||
|
``_to_nested_list`` helper exercises its non-``numpy()`` branch.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.call_count = 0
|
||||||
|
|
||||||
|
def detect_poses(
|
||||||
|
self,
|
||||||
|
image: Any,
|
||||||
|
*,
|
||||||
|
default_fov_degrees: float,
|
||||||
|
skeleton: str,
|
||||||
|
) -> dict[str, np.ndarray]:
|
||||||
|
del image, default_fov_degrees, skeleton # signature-compatible with MeTRAbs
|
||||||
|
self.call_count += 1
|
||||||
|
return {
|
||||||
|
"boxes": np.array([[0.0, 0.0, 32.0, 32.0, 0.95]]),
|
||||||
|
"poses3d": np.array([[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]]),
|
||||||
|
"poses2d": np.array([[[10.0, 20.0], [30.0, 40.0]]]),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def fake_metrabs_model() -> _FakeMetrabsModel:
|
||||||
|
"""Return a fresh fake MeTRAbs model instance for a single test."""
|
||||||
|
return _FakeMetrabsModel()
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,205 @@
|
||||||
|
"""Tests for :class:`neuropose.estimator.Estimator`.
|
||||||
|
|
||||||
|
These tests exercise the non-model code paths (video decoding, frame loop,
|
||||||
|
metadata extraction, result construction, progress reporting, and error
|
||||||
|
handling) using an injected fake MeTRAbs model. The TensorFlow / real-model
|
||||||
|
integration smoke test lives in ``tests/integration/`` and lands with
|
||||||
|
commit 11.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from neuropose.estimator import (
|
||||||
|
Estimator,
|
||||||
|
ModelNotLoadedError,
|
||||||
|
ProcessVideoResult,
|
||||||
|
VideoDecodeError,
|
||||||
|
)
|
||||||
|
from neuropose.io import FramePrediction, VideoPredictions
|
||||||
|
|
||||||
|
|
||||||
|
class TestConstruction:
|
||||||
|
def test_no_model_by_default(self) -> None:
|
||||||
|
estimator = Estimator()
|
||||||
|
assert not estimator.is_model_loaded
|
||||||
|
|
||||||
|
def test_injected_model_is_loaded(self, fake_metrabs_model) -> None:
|
||||||
|
estimator = Estimator(model=fake_metrabs_model)
|
||||||
|
assert estimator.is_model_loaded
|
||||||
|
|
||||||
|
def test_defaults(self) -> None:
|
||||||
|
estimator = Estimator()
|
||||||
|
assert estimator.device == "/CPU:0"
|
||||||
|
assert estimator.skeleton == "berkeley_mhad_43"
|
||||||
|
assert estimator.default_fov_degrees == pytest.approx(55.0)
|
||||||
|
|
||||||
|
def test_overrides(self, fake_metrabs_model) -> None:
|
||||||
|
estimator = Estimator(
|
||||||
|
device="/GPU:0",
|
||||||
|
skeleton="smpl_24",
|
||||||
|
default_fov_degrees=40.0,
|
||||||
|
model=fake_metrabs_model,
|
||||||
|
)
|
||||||
|
assert estimator.device == "/GPU:0"
|
||||||
|
assert estimator.skeleton == "smpl_24"
|
||||||
|
assert estimator.default_fov_degrees == pytest.approx(40.0)
|
||||||
|
|
||||||
|
|
||||||
|
class TestModelGuard:
|
||||||
|
def test_model_property_raises_when_missing(self) -> None:
|
||||||
|
estimator = Estimator()
|
||||||
|
with pytest.raises(ModelNotLoadedError):
|
||||||
|
_ = estimator.model
|
||||||
|
|
||||||
|
def test_process_video_raises_when_missing(self, synthetic_video: Path) -> None:
|
||||||
|
estimator = Estimator()
|
||||||
|
with pytest.raises(ModelNotLoadedError):
|
||||||
|
estimator.process_video(synthetic_video)
|
||||||
|
|
||||||
|
def test_load_model_stub_raises_not_implemented(self) -> None:
|
||||||
|
estimator = Estimator()
|
||||||
|
with pytest.raises(NotImplementedError, match="commit 11"):
|
||||||
|
estimator.load_model()
|
||||||
|
|
||||||
|
def test_load_model_is_idempotent_when_already_loaded(
|
||||||
|
self,
|
||||||
|
fake_metrabs_model,
|
||||||
|
) -> None:
|
||||||
|
estimator = Estimator(model=fake_metrabs_model)
|
||||||
|
# Should not raise, and should not clobber the injected model.
|
||||||
|
estimator.load_model()
|
||||||
|
assert estimator.model is fake_metrabs_model
|
||||||
|
|
||||||
|
|
||||||
|
class TestProcessVideo:
|
||||||
|
def test_returns_typed_result(
|
||||||
|
self,
|
||||||
|
synthetic_video: Path,
|
||||||
|
fake_metrabs_model,
|
||||||
|
) -> None:
|
||||||
|
estimator = Estimator(model=fake_metrabs_model)
|
||||||
|
result = estimator.process_video(synthetic_video)
|
||||||
|
assert isinstance(result, ProcessVideoResult)
|
||||||
|
assert isinstance(result.predictions, VideoPredictions)
|
||||||
|
|
||||||
|
def test_frame_count_matches_source(
|
||||||
|
self,
|
||||||
|
synthetic_video: Path,
|
||||||
|
fake_metrabs_model,
|
||||||
|
) -> None:
|
||||||
|
estimator = Estimator(model=fake_metrabs_model)
|
||||||
|
result = estimator.process_video(synthetic_video)
|
||||||
|
assert result.frame_count == 5
|
||||||
|
assert fake_metrabs_model.call_count == 5
|
||||||
|
|
||||||
|
def test_frame_naming_is_zero_padded(
|
||||||
|
self,
|
||||||
|
synthetic_video: Path,
|
||||||
|
fake_metrabs_model,
|
||||||
|
) -> None:
|
||||||
|
estimator = Estimator(model=fake_metrabs_model)
|
||||||
|
result = estimator.process_video(synthetic_video)
|
||||||
|
names = result.predictions.frame_names()
|
||||||
|
assert names == [
|
||||||
|
"frame_000000",
|
||||||
|
"frame_000001",
|
||||||
|
"frame_000002",
|
||||||
|
"frame_000003",
|
||||||
|
"frame_000004",
|
||||||
|
]
|
||||||
|
|
||||||
|
def test_metadata_populated(
|
||||||
|
self,
|
||||||
|
synthetic_video: Path,
|
||||||
|
fake_metrabs_model,
|
||||||
|
) -> None:
|
||||||
|
estimator = Estimator(model=fake_metrabs_model)
|
||||||
|
result = estimator.process_video(synthetic_video)
|
||||||
|
metadata = result.predictions.metadata
|
||||||
|
assert metadata.frame_count == 5
|
||||||
|
assert metadata.width == 32
|
||||||
|
assert metadata.height == 32
|
||||||
|
assert metadata.fps > 0.0
|
||||||
|
|
||||||
|
def test_each_frame_validates_as_frame_prediction(
|
||||||
|
self,
|
||||||
|
synthetic_video: Path,
|
||||||
|
fake_metrabs_model,
|
||||||
|
) -> None:
|
||||||
|
estimator = Estimator(model=fake_metrabs_model)
|
||||||
|
result = estimator.process_video(synthetic_video)
|
||||||
|
for name in result.predictions.frame_names():
|
||||||
|
frame = result.predictions[name]
|
||||||
|
assert isinstance(frame, FramePrediction)
|
||||||
|
assert len(frame.boxes) == 1
|
||||||
|
assert len(frame.poses3d) == 1
|
||||||
|
assert len(frame.poses3d[0]) == 2 # Two joints per the fake model.
|
||||||
|
|
||||||
|
def test_progress_callback_invoked_per_frame(
|
||||||
|
self,
|
||||||
|
synthetic_video: Path,
|
||||||
|
fake_metrabs_model,
|
||||||
|
) -> None:
|
||||||
|
estimator = Estimator(model=fake_metrabs_model)
|
||||||
|
calls: list[tuple[int, int]] = []
|
||||||
|
estimator.process_video(
|
||||||
|
synthetic_video,
|
||||||
|
progress=lambda processed, total: calls.append((processed, total)),
|
||||||
|
)
|
||||||
|
assert len(calls) == 5
|
||||||
|
# Processed counts should be strictly increasing.
|
||||||
|
assert [c[0] for c in calls] == [1, 2, 3, 4, 5]
|
||||||
|
|
||||||
|
def test_fov_override_is_passed_through(self, synthetic_video: Path) -> None:
|
||||||
|
fov_seen: list[float] = []
|
||||||
|
|
||||||
|
class RecordingModel:
|
||||||
|
def detect_poses(
|
||||||
|
self,
|
||||||
|
image,
|
||||||
|
*,
|
||||||
|
default_fov_degrees: float,
|
||||||
|
skeleton: str,
|
||||||
|
):
|
||||||
|
del image, skeleton
|
||||||
|
fov_seen.append(default_fov_degrees)
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
return {
|
||||||
|
"boxes": np.array([[0.0, 0.0, 32.0, 32.0, 0.9]]),
|
||||||
|
"poses3d": np.array([[[0.0, 0.0, 0.0]]]),
|
||||||
|
"poses2d": np.array([[[0.0, 0.0]]]),
|
||||||
|
}
|
||||||
|
|
||||||
|
estimator = Estimator(model=RecordingModel(), default_fov_degrees=55.0)
|
||||||
|
estimator.process_video(synthetic_video, fov_degrees=40.0)
|
||||||
|
assert all(fov == pytest.approx(40.0) for fov in fov_seen)
|
||||||
|
assert len(fov_seen) == 5
|
||||||
|
|
||||||
|
|
||||||
|
class TestErrors:
|
||||||
|
def test_missing_video(
|
||||||
|
self,
|
||||||
|
tmp_path: Path,
|
||||||
|
fake_metrabs_model,
|
||||||
|
) -> None:
|
||||||
|
estimator = Estimator(model=fake_metrabs_model)
|
||||||
|
with pytest.raises(FileNotFoundError):
|
||||||
|
estimator.process_video(tmp_path / "does_not_exist.mp4")
|
||||||
|
|
||||||
|
def test_unreadable_video_raises_decode_error(
|
||||||
|
self,
|
||||||
|
tmp_path: Path,
|
||||||
|
fake_metrabs_model,
|
||||||
|
) -> None:
|
||||||
|
# A file that exists but is not a valid video. cv2.VideoCapture
|
||||||
|
# returns isOpened() == False for non-video content.
|
||||||
|
path = tmp_path / "not_a_video.avi"
|
||||||
|
path.write_bytes(b"this is definitely not a video file")
|
||||||
|
estimator = Estimator(model=fake_metrabs_model)
|
||||||
|
with pytest.raises(VideoDecodeError):
|
||||||
|
estimator.process_video(path)
|
||||||
|
|
@ -14,6 +14,7 @@ from neuropose.io import (
|
||||||
JobResults,
|
JobResults,
|
||||||
JobStatus,
|
JobStatus,
|
||||||
StatusFile,
|
StatusFile,
|
||||||
|
VideoMetadata,
|
||||||
VideoPredictions,
|
VideoPredictions,
|
||||||
load_job_results,
|
load_job_results,
|
||||||
load_status,
|
load_status,
|
||||||
|
|
@ -40,10 +41,18 @@ def one_frame() -> dict:
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def video_payload(one_frame: dict) -> dict:
|
def video_metadata_payload() -> dict:
|
||||||
|
return {"frame_count": 2, "fps": 30.0, "width": 640, "height": 480}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def video_predictions_payload(one_frame: dict, video_metadata_payload: dict) -> dict:
|
||||||
return {
|
return {
|
||||||
"frame_0000.png": one_frame,
|
"metadata": video_metadata_payload,
|
||||||
"frame_0001.png": one_frame,
|
"frames": {
|
||||||
|
"frame_000000": one_frame,
|
||||||
|
"frame_000001": one_frame,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -70,30 +79,90 @@ class TestFramePrediction:
|
||||||
frame.boxes = []
|
frame.boxes = []
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# VideoMetadata
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestVideoMetadata:
|
||||||
|
def test_valid(self) -> None:
|
||||||
|
meta = VideoMetadata(frame_count=10, fps=29.97, width=1920, height=1080)
|
||||||
|
assert meta.frame_count == 10
|
||||||
|
assert meta.fps == pytest.approx(29.97)
|
||||||
|
|
||||||
|
def test_zero_frame_count_allowed(self) -> None:
|
||||||
|
# Broken or empty videos still produce a valid metadata object so
|
||||||
|
# the caller can see frame_count == 0 rather than receiving an
|
||||||
|
# exception.
|
||||||
|
VideoMetadata(frame_count=0, fps=0.0, width=0, height=0)
|
||||||
|
|
||||||
|
def test_rejects_negative(self) -> None:
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
VideoMetadata(frame_count=-1, fps=30.0, width=640, height=480)
|
||||||
|
|
||||||
|
def test_rejects_extra_fields(self) -> None:
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
VideoMetadata(
|
||||||
|
frame_count=1,
|
||||||
|
fps=30.0,
|
||||||
|
width=640,
|
||||||
|
height=480,
|
||||||
|
source_path="/leak/me",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_is_frozen(self) -> None:
|
||||||
|
meta = VideoMetadata(frame_count=10, fps=30.0, width=640, height=480)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
meta.fps = 60.0
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# VideoPredictions
|
# VideoPredictions
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestVideoPredictions:
|
class TestVideoPredictions:
|
||||||
def test_from_dict(self, video_payload: dict) -> None:
|
def test_from_dict(self, video_predictions_payload: dict) -> None:
|
||||||
vp = VideoPredictions.model_validate(video_payload)
|
vp = VideoPredictions.model_validate(video_predictions_payload)
|
||||||
assert len(vp) == 2
|
assert len(vp) == 2
|
||||||
assert vp.frames() == ["frame_0000.png", "frame_0001.png"]
|
assert vp.frame_names() == ["frame_000000", "frame_000001"]
|
||||||
assert vp["frame_0000.png"].boxes[0][4] == pytest.approx(0.95)
|
assert vp["frame_000000"].boxes[0][4] == pytest.approx(0.95)
|
||||||
|
assert vp.metadata.fps == pytest.approx(30.0)
|
||||||
|
|
||||||
def test_iteration(self, video_payload: dict) -> None:
|
def test_iteration(self, video_predictions_payload: dict) -> None:
|
||||||
vp = VideoPredictions.model_validate(video_payload)
|
vp = VideoPredictions.model_validate(video_predictions_payload)
|
||||||
assert list(vp) == ["frame_0000.png", "frame_0001.png"]
|
assert list(vp) == ["frame_000000", "frame_000001"]
|
||||||
|
|
||||||
def test_save_and_load_roundtrip(self, tmp_path: Path, video_payload: dict) -> None:
|
def test_rejects_missing_metadata(self, video_predictions_payload: dict) -> None:
|
||||||
vp = VideoPredictions.model_validate(video_payload)
|
del video_predictions_payload["metadata"]
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
VideoPredictions.model_validate(video_predictions_payload)
|
||||||
|
|
||||||
|
def test_save_and_load_roundtrip(
|
||||||
|
self,
|
||||||
|
tmp_path: Path,
|
||||||
|
video_predictions_payload: dict,
|
||||||
|
) -> None:
|
||||||
|
vp = VideoPredictions.model_validate(video_predictions_payload)
|
||||||
path = tmp_path / "preds" / "video.json"
|
path = tmp_path / "preds" / "video.json"
|
||||||
save_video_predictions(path, vp)
|
save_video_predictions(path, vp)
|
||||||
assert path.exists()
|
assert path.exists()
|
||||||
loaded = load_video_predictions(path)
|
loaded = load_video_predictions(path)
|
||||||
assert loaded.frames() == vp.frames()
|
assert loaded.frame_names() == vp.frame_names()
|
||||||
assert loaded["frame_0000.png"].poses3d == vp["frame_0000.png"].poses3d
|
assert loaded.metadata == vp.metadata
|
||||||
|
assert loaded["frame_000000"].poses3d == vp["frame_000000"].poses3d
|
||||||
|
|
||||||
|
def test_save_is_atomic(
|
||||||
|
self,
|
||||||
|
tmp_path: Path,
|
||||||
|
video_predictions_payload: dict,
|
||||||
|
) -> None:
|
||||||
|
vp = VideoPredictions.model_validate(video_predictions_payload)
|
||||||
|
path = tmp_path / "video.json"
|
||||||
|
save_video_predictions(path, vp)
|
||||||
|
assert path.exists()
|
||||||
|
tmps = list(tmp_path.glob("video.json.tmp"))
|
||||||
|
assert tmps == []
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -102,9 +171,16 @@ class TestVideoPredictions:
|
||||||
|
|
||||||
|
|
||||||
class TestJobResults:
|
class TestJobResults:
|
||||||
def test_save_and_load_roundtrip(self, tmp_path: Path, video_payload: dict) -> None:
|
def test_save_and_load_roundtrip(
|
||||||
|
self,
|
||||||
|
tmp_path: Path,
|
||||||
|
video_predictions_payload: dict,
|
||||||
|
) -> None:
|
||||||
jr = JobResults.model_validate(
|
jr = JobResults.model_validate(
|
||||||
{"video_a.mp4": video_payload, "video_b.mp4": video_payload}
|
{
|
||||||
|
"video_a.mp4": video_predictions_payload,
|
||||||
|
"video_b.mp4": video_predictions_payload,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
path = tmp_path / "results.json"
|
path = tmp_path / "results.json"
|
||||||
save_job_results(path, jr)
|
save_job_results(path, jr)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,158 @@
|
||||||
|
"""Tests for :mod:`neuropose.visualize`.
|
||||||
|
|
||||||
|
Smoke tests only: we verify that the visualize function runs end-to-end
|
||||||
|
against a synthetic video, writes PNG files, honours ``frame_indices``, and
|
||||||
|
does NOT mutate the caller's :class:`VideoPredictions`. Actual pixel-level
|
||||||
|
correctness is out of scope for unit tests.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import copy
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from neuropose.estimator import Estimator, VideoDecodeError
|
||||||
|
from neuropose.io import VideoPredictions
|
||||||
|
from neuropose.visualize import visualize_predictions
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def predictions_for_synthetic(
|
||||||
|
synthetic_video: Path,
|
||||||
|
fake_metrabs_model,
|
||||||
|
) -> VideoPredictions:
|
||||||
|
"""Run the fake estimator over the synthetic video and return predictions."""
|
||||||
|
estimator = Estimator(model=fake_metrabs_model)
|
||||||
|
return estimator.process_video(synthetic_video).predictions
|
||||||
|
|
||||||
|
|
||||||
|
class TestVisualizePredictions:
|
||||||
|
def test_writes_one_png_per_frame(
|
||||||
|
self,
|
||||||
|
tmp_path: Path,
|
||||||
|
synthetic_video: Path,
|
||||||
|
predictions_for_synthetic: VideoPredictions,
|
||||||
|
) -> None:
|
||||||
|
output_dir = tmp_path / "viz"
|
||||||
|
written = visualize_predictions(
|
||||||
|
synthetic_video, predictions_for_synthetic, output_dir
|
||||||
|
)
|
||||||
|
assert len(written) == 5
|
||||||
|
for path in written:
|
||||||
|
assert path.exists()
|
||||||
|
assert path.suffix == ".png"
|
||||||
|
assert path.stat().st_size > 0
|
||||||
|
|
||||||
|
def test_frame_indices_limits_output(
|
||||||
|
self,
|
||||||
|
tmp_path: Path,
|
||||||
|
synthetic_video: Path,
|
||||||
|
predictions_for_synthetic: VideoPredictions,
|
||||||
|
) -> None:
|
||||||
|
output_dir = tmp_path / "viz"
|
||||||
|
written = visualize_predictions(
|
||||||
|
synthetic_video,
|
||||||
|
predictions_for_synthetic,
|
||||||
|
output_dir,
|
||||||
|
frame_indices=[0, 2, 4],
|
||||||
|
)
|
||||||
|
assert len(written) == 3
|
||||||
|
names = sorted(p.stem for p in written)
|
||||||
|
assert names == ["frame_000000", "frame_000002", "frame_000004"]
|
||||||
|
|
||||||
|
def test_out_of_range_indices_silently_skipped(
|
||||||
|
self,
|
||||||
|
tmp_path: Path,
|
||||||
|
synthetic_video: Path,
|
||||||
|
predictions_for_synthetic: VideoPredictions,
|
||||||
|
) -> None:
|
||||||
|
output_dir = tmp_path / "viz"
|
||||||
|
written = visualize_predictions(
|
||||||
|
synthetic_video,
|
||||||
|
predictions_for_synthetic,
|
||||||
|
output_dir,
|
||||||
|
frame_indices=[0, 999, -1],
|
||||||
|
)
|
||||||
|
assert len(written) == 1
|
||||||
|
assert written[0].stem == "frame_000000"
|
||||||
|
|
||||||
|
def test_does_not_mutate_input_predictions(
|
||||||
|
self,
|
||||||
|
tmp_path: Path,
|
||||||
|
synthetic_video: Path,
|
||||||
|
predictions_for_synthetic: VideoPredictions,
|
||||||
|
) -> None:
|
||||||
|
"""The aliasing bug from the previous prototype mutated poses3d in place.
|
||||||
|
|
||||||
|
This test guards against its regression by deep-copying poses3d
|
||||||
|
before visualization and comparing against the live value after.
|
||||||
|
Without the deepcopy, any in-place mutation would be invisible
|
||||||
|
because ``before`` and ``after`` would refer to the same list.
|
||||||
|
"""
|
||||||
|
frame_name = predictions_for_synthetic.frame_names()[0]
|
||||||
|
before = copy.deepcopy(predictions_for_synthetic[frame_name].poses3d)
|
||||||
|
visualize_predictions(
|
||||||
|
synthetic_video,
|
||||||
|
predictions_for_synthetic,
|
||||||
|
tmp_path / "viz",
|
||||||
|
frame_indices=[0],
|
||||||
|
)
|
||||||
|
after = predictions_for_synthetic[frame_name].poses3d
|
||||||
|
assert before == after
|
||||||
|
|
||||||
|
def test_rejects_invalid_view(
|
||||||
|
self,
|
||||||
|
tmp_path: Path,
|
||||||
|
synthetic_video: Path,
|
||||||
|
predictions_for_synthetic: VideoPredictions,
|
||||||
|
) -> None:
|
||||||
|
with pytest.raises(ValueError, match="view must be one of"):
|
||||||
|
visualize_predictions(
|
||||||
|
synthetic_video,
|
||||||
|
predictions_for_synthetic,
|
||||||
|
tmp_path / "viz",
|
||||||
|
view="orthographic",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_depth_view_runs(
|
||||||
|
self,
|
||||||
|
tmp_path: Path,
|
||||||
|
synthetic_video: Path,
|
||||||
|
predictions_for_synthetic: VideoPredictions,
|
||||||
|
) -> None:
|
||||||
|
written = visualize_predictions(
|
||||||
|
synthetic_video,
|
||||||
|
predictions_for_synthetic,
|
||||||
|
tmp_path / "viz",
|
||||||
|
view="depth",
|
||||||
|
frame_indices=[0],
|
||||||
|
)
|
||||||
|
assert len(written) == 1
|
||||||
|
|
||||||
|
def test_missing_video_raises(
|
||||||
|
self,
|
||||||
|
tmp_path: Path,
|
||||||
|
predictions_for_synthetic: VideoPredictions,
|
||||||
|
) -> None:
|
||||||
|
with pytest.raises(FileNotFoundError):
|
||||||
|
visualize_predictions(
|
||||||
|
tmp_path / "nope.avi",
|
||||||
|
predictions_for_synthetic,
|
||||||
|
tmp_path / "viz",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_unreadable_video_raises(
|
||||||
|
self,
|
||||||
|
tmp_path: Path,
|
||||||
|
predictions_for_synthetic: VideoPredictions,
|
||||||
|
) -> None:
|
||||||
|
fake_video = tmp_path / "fake.avi"
|
||||||
|
fake_video.write_bytes(b"not a video")
|
||||||
|
with pytest.raises(VideoDecodeError):
|
||||||
|
visualize_predictions(
|
||||||
|
fake_video,
|
||||||
|
predictions_for_synthetic,
|
||||||
|
tmp_path / "viz",
|
||||||
|
)
|
||||||
Loading…
Reference in New Issue