estimator
This commit is contained in:
parent
3d2b2fc68d
commit
9bbbbd0d52
|
|
@ -0,0 +1,51 @@
|
|||
"""MeTRAbs model loading — stub pending commit 11.
|
||||
|
||||
This module exists so that :mod:`neuropose.estimator` can import a single,
|
||||
well-typed loader function even before the upstream MeTRAbs URL is pinned
|
||||
and the TensorFlow version is settled.
|
||||
|
||||
Commit 11 will replace :func:`load_metrabs_model` with an implementation
|
||||
that:
|
||||
|
||||
1. Pins the canonical MeTRAbs tfhub / Kaggle Models handle (replacing the
|
||||
``bit.ly/metrabs_1`` shortener from the previous prototype).
|
||||
2. Caches the downloaded model under ``Settings.model_cache_dir`` so the
|
||||
first run downloads it and subsequent runs are offline.
|
||||
3. Returns a typed handle that the estimator can invoke without hitting the
|
||||
network on each instantiation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
def load_metrabs_model(cache_dir: Path | None = None) -> Any: # noqa: ARG001
|
||||
"""Load the MeTRAbs model, downloading and caching on first use.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
cache_dir
|
||||
Directory where the model should be cached. Typically
|
||||
``Settings.model_cache_dir``. If ``None``, the implementation picks
|
||||
a default location.
|
||||
|
||||
Returns
|
||||
-------
|
||||
object
|
||||
An opaque model handle that exposes ``detect_poses`` and the
|
||||
per-skeleton joint metadata attributes used by
|
||||
:class:`neuropose.estimator.Estimator`.
|
||||
|
||||
Raises
|
||||
------
|
||||
NotImplementedError
|
||||
Always, at this commit. Commit 11 provides the real implementation
|
||||
once the upstream MeTRAbs URL is pinned.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"load_metrabs_model is stubbed pending commit 11. "
|
||||
"Inject a model directly via Estimator(model=...) for now, "
|
||||
"or wait for the upstream MeTRAbs URL and TensorFlow version pin."
|
||||
)
|
||||
|
|
@ -0,0 +1,308 @@
|
|||
"""3D human pose estimator — MeTRAbs wrapper.
|
||||
|
||||
The :class:`Estimator` class is the core of NeuroPose's inference path. It
|
||||
takes a video file, runs the MeTRAbs 3D pose-estimation model on each frame,
|
||||
and returns a validated :class:`~neuropose.io.VideoPredictions` object with
|
||||
the per-frame predictions and video metadata.
|
||||
|
||||
Design
|
||||
------
|
||||
The estimator is a **library**, not a daemon: it knows nothing about job
|
||||
directories, status files, or polling. Those concerns live in
|
||||
:mod:`neuropose.interfacer`. An ``Estimator`` can be constructed directly
|
||||
from a Python script and called on a single video, which is the path taken
|
||||
by the documentation's quick-start.
|
||||
|
||||
Frames are streamed from the source video in memory — the previous
|
||||
prototype wrote every frame to disk as a PNG and then re-read each one with
|
||||
``tf.io.decode_png``. That round-trip is gone; frames are read once and
|
||||
passed to the model directly.
|
||||
|
||||
Model injection
|
||||
---------------
|
||||
The MeTRAbs model is supplied either:
|
||||
|
||||
1. Through :meth:`Estimator.load_model`, which delegates to
|
||||
:func:`neuropose._model.load_metrabs_model` (stubbed pending commit 11).
|
||||
2. Directly via ``Estimator(model=...)``, which is the path used by the
|
||||
test suite with a fake model to exercise the code without TensorFlow.
|
||||
|
||||
Either way, attempting to call :meth:`Estimator.process_video` before a
|
||||
model is present raises :class:`ModelNotLoadedError`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import cv2
|
||||
|
||||
from neuropose._model import load_metrabs_model
|
||||
from neuropose.io import FramePrediction, VideoMetadata, VideoPredictions
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Exceptions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class EstimatorError(Exception):
|
||||
"""Base class for errors raised by :class:`Estimator`."""
|
||||
|
||||
|
||||
class ModelNotLoadedError(EstimatorError):
|
||||
"""Raised when an inference method is called before the model is loaded."""
|
||||
|
||||
|
||||
class VideoDecodeError(EstimatorError):
|
||||
"""Raised when a video file cannot be opened or decoded."""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Result container
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProcessVideoResult:
|
||||
"""Result of :meth:`Estimator.process_video`.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
predictions
|
||||
The validated :class:`VideoPredictions` object, containing both the
|
||||
per-frame predictions and the ``VideoMetadata`` envelope.
|
||||
"""
|
||||
|
||||
predictions: VideoPredictions
|
||||
|
||||
@property
|
||||
def frame_count(self) -> int:
|
||||
"""Convenience accessor for ``predictions.metadata.frame_count``."""
|
||||
return self.predictions.metadata.frame_count
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Estimator
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
ProgressCallback = Callable[[int, int], None]
|
||||
"""Type alias for progress callbacks: ``(frames_processed, frame_count_hint)``."""
|
||||
|
||||
|
||||
class Estimator:
|
||||
"""3D pose estimator built on MeTRAbs.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
device
|
||||
TensorFlow device string (e.g. ``"/CPU:0"`` or ``"/GPU:0"``). Passed
|
||||
through to the model at inference time.
|
||||
skeleton
|
||||
Skeleton identifier understood by MeTRAbs. Defaults to
|
||||
``"berkeley_mhad_43"``, the 43-joint skeleton used by the previous
|
||||
prototype.
|
||||
default_fov_degrees
|
||||
Horizontal field of view assumed when a video does not supply
|
||||
intrinsics. Overridable per call via
|
||||
:meth:`process_video`'s ``fov_degrees`` argument.
|
||||
model
|
||||
Optional pre-loaded MeTRAbs model. If supplied, the estimator uses
|
||||
it directly and :meth:`load_model` need not be called. This path is
|
||||
used by tests to inject a fake model, and by callers that want to
|
||||
share a single model across many :class:`Estimator` instances.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
device: str = "/CPU:0",
|
||||
skeleton: str = "berkeley_mhad_43",
|
||||
default_fov_degrees: float = 55.0,
|
||||
model: Any | None = None,
|
||||
) -> None:
|
||||
self.device = device
|
||||
self.skeleton = skeleton
|
||||
self.default_fov_degrees = default_fov_degrees
|
||||
self._model: Any | None = model
|
||||
|
||||
# -- model lifecycle ----------------------------------------------------
|
||||
|
||||
@property
|
||||
def model(self) -> Any:
|
||||
"""Return the loaded model, or raise :class:`ModelNotLoadedError`."""
|
||||
if self._model is None:
|
||||
raise ModelNotLoadedError(
|
||||
"Estimator model has not been loaded. "
|
||||
"Call Estimator.load_model() or pass model=... to the constructor."
|
||||
)
|
||||
return self._model
|
||||
|
||||
@property
|
||||
def is_model_loaded(self) -> bool:
|
||||
"""Return ``True`` if a model has been supplied or loaded."""
|
||||
return self._model is not None
|
||||
|
||||
def load_model(self, cache_dir: Path | None = None) -> None:
|
||||
"""Load the MeTRAbs model via :func:`neuropose._model.load_metrabs_model`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
cache_dir
|
||||
Directory where the downloaded model should be cached. Typically
|
||||
``Settings.model_cache_dir``.
|
||||
|
||||
Notes
|
||||
-----
|
||||
This is idempotent: calling it again after a successful load is a
|
||||
no-op. Callers that want to reload the model should construct a new
|
||||
:class:`Estimator` instance.
|
||||
"""
|
||||
if self._model is not None:
|
||||
logger.debug("Model already loaded; skipping reload.")
|
||||
return
|
||||
logger.info("Loading MeTRAbs model (cache_dir=%s)", cache_dir)
|
||||
self._model = load_metrabs_model(cache_dir=cache_dir)
|
||||
logger.info("MeTRAbs model loaded.")
|
||||
|
||||
# -- inference ----------------------------------------------------------
|
||||
|
||||
def process_video(
|
||||
self,
|
||||
video_path: Path,
|
||||
*,
|
||||
fov_degrees: float | None = None,
|
||||
progress: ProgressCallback | None = None,
|
||||
) -> ProcessVideoResult:
|
||||
"""Run pose estimation on every frame of a video.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
video_path
|
||||
Path to the input video. Must exist and be openable by OpenCV.
|
||||
fov_degrees
|
||||
Per-call override for the horizontal field of view. If ``None``,
|
||||
the estimator's ``default_fov_degrees`` is used.
|
||||
progress
|
||||
Optional callback invoked after each processed frame as
|
||||
``progress(processed, total_hint)``. ``total_hint`` is the
|
||||
approximate frame count reported by OpenCV's
|
||||
``CAP_PROP_FRAME_COUNT``, which is unreliable for variable-rate
|
||||
videos; the authoritative count is available after the call on
|
||||
``result.frame_count``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ProcessVideoResult
|
||||
A typed result containing the validated :class:`VideoPredictions`.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If ``video_path`` does not exist.
|
||||
VideoDecodeError
|
||||
If OpenCV cannot open the video.
|
||||
ModelNotLoadedError
|
||||
If the model has not been loaded or injected.
|
||||
"""
|
||||
if not video_path.exists():
|
||||
raise FileNotFoundError(f"video file not found: {video_path}")
|
||||
|
||||
# Access the model eagerly so we fail fast before opening the video.
|
||||
model = self.model
|
||||
fov = fov_degrees if fov_degrees is not None else self.default_fov_degrees
|
||||
|
||||
cap = cv2.VideoCapture(str(video_path))
|
||||
if not cap.isOpened():
|
||||
raise VideoDecodeError(f"OpenCV could not open video: {video_path}")
|
||||
|
||||
try:
|
||||
fps = float(cap.get(cv2.CAP_PROP_FPS))
|
||||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
total_hint = max(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), 0)
|
||||
|
||||
logger.info(
|
||||
"Processing video %s (%dx%d @ %.2f fps, ~%d frames)",
|
||||
video_path,
|
||||
width,
|
||||
height,
|
||||
fps,
|
||||
total_hint,
|
||||
)
|
||||
|
||||
frames: dict[str, FramePrediction] = {}
|
||||
frame_index = 0
|
||||
while True:
|
||||
ok, bgr_frame = cap.read()
|
||||
if not ok:
|
||||
break
|
||||
# MeTRAbs was trained on RGB images; OpenCV gives us BGR.
|
||||
rgb_frame = cv2.cvtColor(bgr_frame, cv2.COLOR_BGR2RGB)
|
||||
prediction = self._infer_frame(model, rgb_frame, fov)
|
||||
frames[f"frame_{frame_index:06d}"] = prediction
|
||||
frame_index += 1
|
||||
if progress is not None:
|
||||
progress(frame_index, total_hint)
|
||||
|
||||
metadata = VideoMetadata(
|
||||
frame_count=frame_index,
|
||||
fps=fps if fps > 0 else 0.0,
|
||||
width=width,
|
||||
height=height,
|
||||
)
|
||||
finally:
|
||||
cap.release()
|
||||
|
||||
if frame_index == 0:
|
||||
logger.warning("Video %s contained no decodable frames.", video_path)
|
||||
|
||||
predictions = VideoPredictions(metadata=metadata, frames=frames)
|
||||
return ProcessVideoResult(predictions=predictions)
|
||||
|
||||
# -- internals ----------------------------------------------------------
|
||||
|
||||
def _infer_frame(
|
||||
self,
|
||||
model: Any,
|
||||
rgb_frame: Any,
|
||||
fov_degrees: float,
|
||||
) -> FramePrediction:
|
||||
"""Run a single frame through the model and validate the output."""
|
||||
pred = model.detect_poses(
|
||||
rgb_frame,
|
||||
default_fov_degrees=fov_degrees,
|
||||
skeleton=self.skeleton,
|
||||
)
|
||||
return FramePrediction(
|
||||
boxes=_to_nested_list(pred["boxes"]),
|
||||
poses3d=_to_nested_list(pred["poses3d"]),
|
||||
poses2d=_to_nested_list(pred["poses2d"]),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _to_nested_list(value: Any) -> Any:
|
||||
"""Normalise a TF tensor, numpy array, or nested list to Python lists.
|
||||
|
||||
The real MeTRAbs model returns ``tf.Tensor`` objects which expose
|
||||
``.numpy()`` returning a ``numpy.ndarray``. Tests inject fake models
|
||||
that return plain numpy arrays. Both paths flow through this helper so
|
||||
the rest of the code is agnostic to which is in use.
|
||||
"""
|
||||
if hasattr(value, "numpy"):
|
||||
value = value.numpy()
|
||||
if hasattr(value, "tolist"):
|
||||
return value.tolist()
|
||||
return list(value)
|
||||
|
|
@ -1,14 +1,15 @@
|
|||
"""I/O helpers and schema definitions for NeuroPose prediction data.
|
||||
|
||||
Defines pydantic models for per-frame predictions, per-video aggregated
|
||||
predictions, job-level aggregated results, and the persistent status file.
|
||||
All models are validated on load, so malformed files are caught at the
|
||||
boundary rather than at some downstream call site.
|
||||
Defines pydantic models for per-frame predictions, per-video predictions
|
||||
(with metadata envelope), job-level aggregated results, and the persistent
|
||||
status file. All models are validated on load, so malformed files are caught
|
||||
at the boundary rather than at some downstream call site.
|
||||
|
||||
Atomicity: :func:`save_status` and :func:`save_job_results` write to a sibling
|
||||
temp file and then atomically rename, so a crash mid-write will not leave a
|
||||
partially-written file behind. This matches the crash-resilience guarantee
|
||||
the interfacer daemon makes to callers.
|
||||
Atomicity: :func:`save_status`, :func:`save_job_results`, and
|
||||
:func:`save_video_predictions` write to a sibling temp file and then
|
||||
atomically rename, so a crash mid-write will not leave a partially-written
|
||||
file behind. This matches the crash-resilience guarantee the interfacer
|
||||
daemon makes to callers.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -55,26 +56,51 @@ class FramePrediction(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class VideoPredictions(RootModel[dict[str, FramePrediction]]):
|
||||
"""Per-frame predictions for a single video, keyed by frame filename.
|
||||
class VideoMetadata(BaseModel):
|
||||
"""Metadata about the source video for a set of predictions.
|
||||
|
||||
Frame names are expected to follow the ``frame_<index>.png`` convention
|
||||
written by the estimator, but no constraint is enforced at the schema
|
||||
level so downstream consumers can key by any naming scheme.
|
||||
Essential for reproducibility: the frame count lets downstream analysis
|
||||
verify completeness, and the fps lets it convert frame indices to real
|
||||
time without needing access to the original video file.
|
||||
|
||||
Intentionally does NOT include the source file path or filename, which
|
||||
may encode subject-identifying information. Callers that need provenance
|
||||
should store it out-of-band in accordance with the data-handling policy.
|
||||
"""
|
||||
|
||||
def frames(self) -> list[str]:
|
||||
"""Return the frame names in insertion order."""
|
||||
return list(self.root.keys())
|
||||
model_config = ConfigDict(extra="forbid", frozen=True)
|
||||
|
||||
frame_count: int = Field(ge=0, description="Number of frames actually processed.")
|
||||
fps: float = Field(ge=0.0, description="Source video frame rate (frames per second).")
|
||||
width: int = Field(ge=0, description="Source video frame width in pixels.")
|
||||
height: int = Field(ge=0, description="Source video frame height in pixels.")
|
||||
|
||||
|
||||
class VideoPredictions(BaseModel):
|
||||
"""Per-frame predictions for a single video, paired with video metadata.
|
||||
|
||||
The ``frames`` mapping is keyed by frame identifier (``frame_<index>`` by
|
||||
convention, zero-padded to 6 digits). The identifier is a stable string,
|
||||
not a filesystem path — no PNG file is implied.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid", frozen=True)
|
||||
|
||||
metadata: VideoMetadata
|
||||
frames: dict[str, FramePrediction]
|
||||
|
||||
def frame_names(self) -> list[str]:
|
||||
"""Return frame identifiers in insertion order."""
|
||||
return list(self.frames.keys())
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.root)
|
||||
return len(self.frames)
|
||||
|
||||
def __iter__(self) -> Iterator[str]: # type: ignore[override]
|
||||
return iter(self.root)
|
||||
return iter(self.frames)
|
||||
|
||||
def __getitem__(self, key: str) -> FramePrediction:
|
||||
return self.root[key]
|
||||
return self.frames[key]
|
||||
|
||||
|
||||
class JobResults(RootModel[dict[str, VideoPredictions]]):
|
||||
|
|
@ -85,7 +111,7 @@ class JobResults(RootModel[dict[str, VideoPredictions]]):
|
|||
"""
|
||||
|
||||
def videos(self) -> list[str]:
|
||||
"""Return the video names in insertion order."""
|
||||
"""Return video names in insertion order."""
|
||||
return list(self.root.keys())
|
||||
|
||||
def __len__(self) -> int:
|
||||
|
|
@ -143,7 +169,7 @@ def load_video_predictions(path: Path) -> VideoPredictions:
|
|||
|
||||
|
||||
def save_video_predictions(path: Path, predictions: VideoPredictions) -> None:
|
||||
"""Serialize per-video predictions to a JSON file."""
|
||||
"""Serialize per-video predictions to a JSON file atomically."""
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
_write_json_atomic(path, predictions.model_dump(mode="json"))
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,210 @@
|
|||
"""Matplotlib-based visualization of NeuroPose predictions.
|
||||
|
||||
Separate module so that importing :mod:`neuropose.estimator` does not pull
|
||||
in matplotlib or incur its global backend side effect. Callers that want
|
||||
visualization import this module explicitly.
|
||||
|
||||
The old prototype's ``_visualize`` helper had an in-place numpy-view
|
||||
aliasing bug where ``poses3d[..., 1], poses3d[..., 2] = poses3d[..., 2],
|
||||
-poses3d[..., 1]`` mutated the caller's prediction array. This rewrite
|
||||
makes an explicit copy before any axis reordering, so the input
|
||||
:class:`VideoPredictions` is never touched.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from neuropose.estimator import VideoDecodeError
|
||||
from neuropose.io import VideoPredictions
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
VALID_VIEWS = frozenset({"normal", "depth"})
|
||||
|
||||
|
||||
def visualize_predictions(
|
||||
video_path: Path,
|
||||
predictions: VideoPredictions,
|
||||
output_dir: Path,
|
||||
*,
|
||||
view: str = "normal",
|
||||
joint_edges: Sequence[tuple[int, int]] | None = None,
|
||||
frame_indices: Sequence[int] | None = None,
|
||||
) -> list[Path]:
|
||||
"""Render per-frame 2D + 3D visualizations as PNG files.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
video_path
|
||||
Path to the source video, used to recover frame pixels for the 2D
|
||||
overlay. Must be the same video the predictions were computed from.
|
||||
predictions
|
||||
Predictions to overlay. The dict ordering of ``predictions.frames``
|
||||
determines which frames are drawn unless ``frame_indices`` is set.
|
||||
output_dir
|
||||
Directory to write the rendered PNGs into. Created if absent.
|
||||
view
|
||||
``"normal"`` (default) or ``"depth"``. Controls the 3D subplot's
|
||||
axis limits. Anything else raises ``ValueError``.
|
||||
joint_edges
|
||||
Optional list of ``(i, j)`` index pairs specifying skeleton edges to
|
||||
draw as lines connecting joints. If ``None``, only scatter points
|
||||
are drawn. For ``berkeley_mhad_43`` the edges can be obtained from
|
||||
``model.per_skeleton_joint_edges[skeleton]``.
|
||||
frame_indices
|
||||
Optional subset of frame indices to render. If ``None``, every
|
||||
frame in ``predictions`` is rendered. Out-of-range indices are
|
||||
silently skipped.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[Path]
|
||||
Paths of the written PNG files, in render order.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If ``video_path`` does not exist.
|
||||
VideoDecodeError
|
||||
If OpenCV cannot open the video.
|
||||
ValueError
|
||||
If ``view`` is not one of the supported values.
|
||||
"""
|
||||
if view not in VALID_VIEWS:
|
||||
raise ValueError(f"view must be one of {sorted(VALID_VIEWS)}; got {view!r}")
|
||||
if not video_path.exists():
|
||||
raise FileNotFoundError(f"video file not found: {video_path}")
|
||||
|
||||
# Late imports: matplotlib carries a global backend side effect, which
|
||||
# we want to avoid at module load time. pyplot is also slow to import.
|
||||
import matplotlib
|
||||
|
||||
matplotlib.use("Agg", force=False)
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.patches import Rectangle
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
frame_names = predictions.frame_names()
|
||||
selected = _select_indices(frame_indices, len(frame_names))
|
||||
|
||||
cap = cv2.VideoCapture(str(video_path))
|
||||
if not cap.isOpened():
|
||||
raise VideoDecodeError(f"OpenCV could not open video: {video_path}")
|
||||
|
||||
written: list[Path] = []
|
||||
try:
|
||||
next_index_to_render = iter(selected)
|
||||
target = next(next_index_to_render, None)
|
||||
frame_index = 0
|
||||
while target is not None:
|
||||
ok, bgr_frame = cap.read()
|
||||
if not ok:
|
||||
break
|
||||
if frame_index == target:
|
||||
rgb_frame = cv2.cvtColor(bgr_frame, cv2.COLOR_BGR2RGB)
|
||||
frame_name = frame_names[frame_index]
|
||||
out_path = output_dir / f"{frame_name}.png"
|
||||
_render_frame(
|
||||
rgb_frame,
|
||||
predictions[frame_name],
|
||||
out_path,
|
||||
view=view,
|
||||
joint_edges=joint_edges,
|
||||
plt=plt,
|
||||
Rectangle=Rectangle,
|
||||
)
|
||||
written.append(out_path)
|
||||
target = next(next_index_to_render, None)
|
||||
frame_index += 1
|
||||
finally:
|
||||
cap.release()
|
||||
|
||||
logger.info("Wrote %d visualization frame(s) to %s", len(written), output_dir)
|
||||
return written
|
||||
|
||||
|
||||
def _select_indices(frame_indices: Sequence[int] | None, total: int) -> list[int]:
|
||||
"""Normalise the ``frame_indices`` argument to a sorted, in-range list."""
|
||||
if frame_indices is None:
|
||||
return list(range(total))
|
||||
return sorted({i for i in frame_indices if 0 <= i < total})
|
||||
|
||||
|
||||
def _render_frame(
|
||||
rgb_frame: Any,
|
||||
frame_prediction: Any,
|
||||
out_path: Path,
|
||||
*,
|
||||
view: str,
|
||||
joint_edges: Sequence[tuple[int, int]] | None,
|
||||
plt: Any,
|
||||
Rectangle: Any,
|
||||
) -> None:
|
||||
"""Render one frame's 2D overlay + 3D scatter to ``out_path``."""
|
||||
# Explicit copies. The previous prototype mutated the caller's data via
|
||||
# numpy-view tuple-assignment; we take a fresh numpy array per person
|
||||
# so the caller's VideoPredictions is never touched.
|
||||
boxes = np.asarray(frame_prediction.boxes, dtype=float)
|
||||
poses3d = np.asarray(frame_prediction.poses3d, dtype=float).copy()
|
||||
poses2d = np.asarray(frame_prediction.poses2d, dtype=float)
|
||||
|
||||
# Rotate for visualization: swap Y and Z so the ground plane is horizontal.
|
||||
# Do this on the copy so the original predictions object is untouched.
|
||||
if poses3d.size > 0:
|
||||
original_y = poses3d[..., 1].copy()
|
||||
poses3d[..., 1] = poses3d[..., 2]
|
||||
poses3d[..., 2] = -original_y
|
||||
|
||||
fig = plt.figure(figsize=(10, 5.2))
|
||||
|
||||
image_ax = fig.add_subplot(1, 2, 1)
|
||||
image_ax.imshow(rgb_frame)
|
||||
for box in boxes:
|
||||
if len(box) < 4:
|
||||
continue
|
||||
x, y, w, h = box[:4]
|
||||
image_ax.add_patch(Rectangle((x, y), w, h, fill=False))
|
||||
|
||||
pose_ax = fig.add_subplot(1, 2, 2, projection="3d")
|
||||
pose_ax.view_init(5, -85)
|
||||
if view == "depth":
|
||||
pose_ax.set_xlim3d(200, 17500)
|
||||
pose_ax.set_zlim3d(-1500, 1500)
|
||||
pose_ax.set_ylim3d(0, 3000)
|
||||
else:
|
||||
pose_ax.set_xlim3d(-1500, 1500)
|
||||
pose_ax.set_zlim3d(-1500, 1500)
|
||||
pose_ax.set_ylim3d(0, 3000)
|
||||
pose_ax.set_box_aspect((1, 1, 1))
|
||||
|
||||
for pose3d, pose2d in zip(poses3d, poses2d, strict=False):
|
||||
if joint_edges is not None:
|
||||
for i_start, i_end in joint_edges:
|
||||
if 0 <= i_start < len(pose2d) and 0 <= i_end < len(pose2d):
|
||||
image_ax.plot(
|
||||
[pose2d[i_start][0], pose2d[i_end][0]],
|
||||
[pose2d[i_start][1], pose2d[i_end][1]],
|
||||
marker="o",
|
||||
markersize=2,
|
||||
)
|
||||
if 0 <= i_start < len(pose3d) and 0 <= i_end < len(pose3d):
|
||||
pose_ax.plot(
|
||||
[pose3d[i_start][0], pose3d[i_end][0]],
|
||||
[pose3d[i_start][1], pose3d[i_end][1]],
|
||||
[pose3d[i_start][2], pose3d[i_end][2]],
|
||||
marker="o",
|
||||
markersize=2,
|
||||
)
|
||||
image_ax.scatter(pose2d[:, 0], pose2d[:, 1], s=2)
|
||||
pose_ax.scatter(pose3d[:, 0], pose3d[:, 1], pose3d[:, 2], s=2)
|
||||
|
||||
fig.tight_layout()
|
||||
fig.savefig(out_path)
|
||||
plt.close(fig)
|
||||
|
|
@ -5,10 +5,18 @@ from __future__ import annotations
|
|||
import os
|
||||
from collections.abc import Iterator
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Environment isolation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _isolate_environment(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
|
|
@ -36,3 +44,65 @@ def _isolate_environment(
|
|||
def xdg_home() -> Path:
|
||||
"""Return the isolated ``$XDG_DATA_HOME`` set up by ``_isolate_environment``."""
|
||||
return Path(os.environ["XDG_DATA_HOME"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Synthetic video + fake MeTRAbs model
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def synthetic_video(tmp_path: Path) -> Path:
|
||||
"""Generate a tiny synthetic video at test time.
|
||||
|
||||
The fixture writes a 5-frame, 32×32 MJPG-encoded ``.avi`` file. MJPG is
|
||||
chosen over ``mp4v`` because it ships with ``opencv-python-headless`` on
|
||||
every platform we target, whereas ``mp4v`` occasionally requires an
|
||||
ffmpeg binary that may not be present on minimal CI runners.
|
||||
"""
|
||||
path = tmp_path / "synthetic.avi"
|
||||
fourcc = cv2.VideoWriter_fourcc(*"MJPG")
|
||||
writer = cv2.VideoWriter(str(path), fourcc, 30.0, (32, 32))
|
||||
assert writer.isOpened(), "cv2.VideoWriter failed to open; MJPG codec missing?"
|
||||
for i in range(5):
|
||||
# Distinct brightness per frame so a downstream check could verify
|
||||
# the test is actually reading frame-by-frame.
|
||||
frame = np.full((32, 32, 3), i * 40, dtype=np.uint8)
|
||||
writer.write(frame)
|
||||
writer.release()
|
||||
assert path.exists() and path.stat().st_size > 0, "Synthetic video is empty."
|
||||
return path
|
||||
|
||||
|
||||
class _FakeMetrabsModel:
|
||||
"""Minimal stand-in for the MeTRAbs model used in unit tests.
|
||||
|
||||
Returns deterministic pose data (one person, two joints) per call so
|
||||
tests can assert on shapes without importing TensorFlow or downloading
|
||||
the real model. The returned arrays are plain numpy so the estimator's
|
||||
``_to_nested_list`` helper exercises its non-``numpy()`` branch.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.call_count = 0
|
||||
|
||||
def detect_poses(
|
||||
self,
|
||||
image: Any,
|
||||
*,
|
||||
default_fov_degrees: float,
|
||||
skeleton: str,
|
||||
) -> dict[str, np.ndarray]:
|
||||
del image, default_fov_degrees, skeleton # signature-compatible with MeTRAbs
|
||||
self.call_count += 1
|
||||
return {
|
||||
"boxes": np.array([[0.0, 0.0, 32.0, 32.0, 0.95]]),
|
||||
"poses3d": np.array([[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]]),
|
||||
"poses2d": np.array([[[10.0, 20.0], [30.0, 40.0]]]),
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_metrabs_model() -> _FakeMetrabsModel:
|
||||
"""Return a fresh fake MeTRAbs model instance for a single test."""
|
||||
return _FakeMetrabsModel()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,205 @@
|
|||
"""Tests for :class:`neuropose.estimator.Estimator`.
|
||||
|
||||
These tests exercise the non-model code paths (video decoding, frame loop,
|
||||
metadata extraction, result construction, progress reporting, and error
|
||||
handling) using an injected fake MeTRAbs model. The TensorFlow / real-model
|
||||
integration smoke test lives in ``tests/integration/`` and lands with
|
||||
commit 11.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from neuropose.estimator import (
|
||||
Estimator,
|
||||
ModelNotLoadedError,
|
||||
ProcessVideoResult,
|
||||
VideoDecodeError,
|
||||
)
|
||||
from neuropose.io import FramePrediction, VideoPredictions
|
||||
|
||||
|
||||
class TestConstruction:
|
||||
def test_no_model_by_default(self) -> None:
|
||||
estimator = Estimator()
|
||||
assert not estimator.is_model_loaded
|
||||
|
||||
def test_injected_model_is_loaded(self, fake_metrabs_model) -> None:
|
||||
estimator = Estimator(model=fake_metrabs_model)
|
||||
assert estimator.is_model_loaded
|
||||
|
||||
def test_defaults(self) -> None:
|
||||
estimator = Estimator()
|
||||
assert estimator.device == "/CPU:0"
|
||||
assert estimator.skeleton == "berkeley_mhad_43"
|
||||
assert estimator.default_fov_degrees == pytest.approx(55.0)
|
||||
|
||||
def test_overrides(self, fake_metrabs_model) -> None:
|
||||
estimator = Estimator(
|
||||
device="/GPU:0",
|
||||
skeleton="smpl_24",
|
||||
default_fov_degrees=40.0,
|
||||
model=fake_metrabs_model,
|
||||
)
|
||||
assert estimator.device == "/GPU:0"
|
||||
assert estimator.skeleton == "smpl_24"
|
||||
assert estimator.default_fov_degrees == pytest.approx(40.0)
|
||||
|
||||
|
||||
class TestModelGuard:
|
||||
def test_model_property_raises_when_missing(self) -> None:
|
||||
estimator = Estimator()
|
||||
with pytest.raises(ModelNotLoadedError):
|
||||
_ = estimator.model
|
||||
|
||||
def test_process_video_raises_when_missing(self, synthetic_video: Path) -> None:
|
||||
estimator = Estimator()
|
||||
with pytest.raises(ModelNotLoadedError):
|
||||
estimator.process_video(synthetic_video)
|
||||
|
||||
def test_load_model_stub_raises_not_implemented(self) -> None:
|
||||
estimator = Estimator()
|
||||
with pytest.raises(NotImplementedError, match="commit 11"):
|
||||
estimator.load_model()
|
||||
|
||||
def test_load_model_is_idempotent_when_already_loaded(
|
||||
self,
|
||||
fake_metrabs_model,
|
||||
) -> None:
|
||||
estimator = Estimator(model=fake_metrabs_model)
|
||||
# Should not raise, and should not clobber the injected model.
|
||||
estimator.load_model()
|
||||
assert estimator.model is fake_metrabs_model
|
||||
|
||||
|
||||
class TestProcessVideo:
|
||||
def test_returns_typed_result(
|
||||
self,
|
||||
synthetic_video: Path,
|
||||
fake_metrabs_model,
|
||||
) -> None:
|
||||
estimator = Estimator(model=fake_metrabs_model)
|
||||
result = estimator.process_video(synthetic_video)
|
||||
assert isinstance(result, ProcessVideoResult)
|
||||
assert isinstance(result.predictions, VideoPredictions)
|
||||
|
||||
def test_frame_count_matches_source(
|
||||
self,
|
||||
synthetic_video: Path,
|
||||
fake_metrabs_model,
|
||||
) -> None:
|
||||
estimator = Estimator(model=fake_metrabs_model)
|
||||
result = estimator.process_video(synthetic_video)
|
||||
assert result.frame_count == 5
|
||||
assert fake_metrabs_model.call_count == 5
|
||||
|
||||
def test_frame_naming_is_zero_padded(
|
||||
self,
|
||||
synthetic_video: Path,
|
||||
fake_metrabs_model,
|
||||
) -> None:
|
||||
estimator = Estimator(model=fake_metrabs_model)
|
||||
result = estimator.process_video(synthetic_video)
|
||||
names = result.predictions.frame_names()
|
||||
assert names == [
|
||||
"frame_000000",
|
||||
"frame_000001",
|
||||
"frame_000002",
|
||||
"frame_000003",
|
||||
"frame_000004",
|
||||
]
|
||||
|
||||
def test_metadata_populated(
|
||||
self,
|
||||
synthetic_video: Path,
|
||||
fake_metrabs_model,
|
||||
) -> None:
|
||||
estimator = Estimator(model=fake_metrabs_model)
|
||||
result = estimator.process_video(synthetic_video)
|
||||
metadata = result.predictions.metadata
|
||||
assert metadata.frame_count == 5
|
||||
assert metadata.width == 32
|
||||
assert metadata.height == 32
|
||||
assert metadata.fps > 0.0
|
||||
|
||||
def test_each_frame_validates_as_frame_prediction(
|
||||
self,
|
||||
synthetic_video: Path,
|
||||
fake_metrabs_model,
|
||||
) -> None:
|
||||
estimator = Estimator(model=fake_metrabs_model)
|
||||
result = estimator.process_video(synthetic_video)
|
||||
for name in result.predictions.frame_names():
|
||||
frame = result.predictions[name]
|
||||
assert isinstance(frame, FramePrediction)
|
||||
assert len(frame.boxes) == 1
|
||||
assert len(frame.poses3d) == 1
|
||||
assert len(frame.poses3d[0]) == 2 # Two joints per the fake model.
|
||||
|
||||
def test_progress_callback_invoked_per_frame(
|
||||
self,
|
||||
synthetic_video: Path,
|
||||
fake_metrabs_model,
|
||||
) -> None:
|
||||
estimator = Estimator(model=fake_metrabs_model)
|
||||
calls: list[tuple[int, int]] = []
|
||||
estimator.process_video(
|
||||
synthetic_video,
|
||||
progress=lambda processed, total: calls.append((processed, total)),
|
||||
)
|
||||
assert len(calls) == 5
|
||||
# Processed counts should be strictly increasing.
|
||||
assert [c[0] for c in calls] == [1, 2, 3, 4, 5]
|
||||
|
||||
def test_fov_override_is_passed_through(self, synthetic_video: Path) -> None:
|
||||
fov_seen: list[float] = []
|
||||
|
||||
class RecordingModel:
|
||||
def detect_poses(
|
||||
self,
|
||||
image,
|
||||
*,
|
||||
default_fov_degrees: float,
|
||||
skeleton: str,
|
||||
):
|
||||
del image, skeleton
|
||||
fov_seen.append(default_fov_degrees)
|
||||
import numpy as np
|
||||
|
||||
return {
|
||||
"boxes": np.array([[0.0, 0.0, 32.0, 32.0, 0.9]]),
|
||||
"poses3d": np.array([[[0.0, 0.0, 0.0]]]),
|
||||
"poses2d": np.array([[[0.0, 0.0]]]),
|
||||
}
|
||||
|
||||
estimator = Estimator(model=RecordingModel(), default_fov_degrees=55.0)
|
||||
estimator.process_video(synthetic_video, fov_degrees=40.0)
|
||||
assert all(fov == pytest.approx(40.0) for fov in fov_seen)
|
||||
assert len(fov_seen) == 5
|
||||
|
||||
|
||||
class TestErrors:
|
||||
def test_missing_video(
|
||||
self,
|
||||
tmp_path: Path,
|
||||
fake_metrabs_model,
|
||||
) -> None:
|
||||
estimator = Estimator(model=fake_metrabs_model)
|
||||
with pytest.raises(FileNotFoundError):
|
||||
estimator.process_video(tmp_path / "does_not_exist.mp4")
|
||||
|
||||
def test_unreadable_video_raises_decode_error(
|
||||
self,
|
||||
tmp_path: Path,
|
||||
fake_metrabs_model,
|
||||
) -> None:
|
||||
# A file that exists but is not a valid video. cv2.VideoCapture
|
||||
# returns isOpened() == False for non-video content.
|
||||
path = tmp_path / "not_a_video.avi"
|
||||
path.write_bytes(b"this is definitely not a video file")
|
||||
estimator = Estimator(model=fake_metrabs_model)
|
||||
with pytest.raises(VideoDecodeError):
|
||||
estimator.process_video(path)
|
||||
|
|
@ -14,6 +14,7 @@ from neuropose.io import (
|
|||
JobResults,
|
||||
JobStatus,
|
||||
StatusFile,
|
||||
VideoMetadata,
|
||||
VideoPredictions,
|
||||
load_job_results,
|
||||
load_status,
|
||||
|
|
@ -40,10 +41,18 @@ def one_frame() -> dict:
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def video_payload(one_frame: dict) -> dict:
|
||||
def video_metadata_payload() -> dict:
|
||||
return {"frame_count": 2, "fps": 30.0, "width": 640, "height": 480}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def video_predictions_payload(one_frame: dict, video_metadata_payload: dict) -> dict:
|
||||
return {
|
||||
"frame_0000.png": one_frame,
|
||||
"frame_0001.png": one_frame,
|
||||
"metadata": video_metadata_payload,
|
||||
"frames": {
|
||||
"frame_000000": one_frame,
|
||||
"frame_000001": one_frame,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -70,30 +79,90 @@ class TestFramePrediction:
|
|||
frame.boxes = []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# VideoMetadata
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestVideoMetadata:
|
||||
def test_valid(self) -> None:
|
||||
meta = VideoMetadata(frame_count=10, fps=29.97, width=1920, height=1080)
|
||||
assert meta.frame_count == 10
|
||||
assert meta.fps == pytest.approx(29.97)
|
||||
|
||||
def test_zero_frame_count_allowed(self) -> None:
|
||||
# Broken or empty videos still produce a valid metadata object so
|
||||
# the caller can see frame_count == 0 rather than receiving an
|
||||
# exception.
|
||||
VideoMetadata(frame_count=0, fps=0.0, width=0, height=0)
|
||||
|
||||
def test_rejects_negative(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
VideoMetadata(frame_count=-1, fps=30.0, width=640, height=480)
|
||||
|
||||
def test_rejects_extra_fields(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
VideoMetadata(
|
||||
frame_count=1,
|
||||
fps=30.0,
|
||||
width=640,
|
||||
height=480,
|
||||
source_path="/leak/me",
|
||||
)
|
||||
|
||||
def test_is_frozen(self) -> None:
|
||||
meta = VideoMetadata(frame_count=10, fps=30.0, width=640, height=480)
|
||||
with pytest.raises(ValidationError):
|
||||
meta.fps = 60.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# VideoPredictions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestVideoPredictions:
|
||||
def test_from_dict(self, video_payload: dict) -> None:
|
||||
vp = VideoPredictions.model_validate(video_payload)
|
||||
def test_from_dict(self, video_predictions_payload: dict) -> None:
|
||||
vp = VideoPredictions.model_validate(video_predictions_payload)
|
||||
assert len(vp) == 2
|
||||
assert vp.frames() == ["frame_0000.png", "frame_0001.png"]
|
||||
assert vp["frame_0000.png"].boxes[0][4] == pytest.approx(0.95)
|
||||
assert vp.frame_names() == ["frame_000000", "frame_000001"]
|
||||
assert vp["frame_000000"].boxes[0][4] == pytest.approx(0.95)
|
||||
assert vp.metadata.fps == pytest.approx(30.0)
|
||||
|
||||
def test_iteration(self, video_payload: dict) -> None:
|
||||
vp = VideoPredictions.model_validate(video_payload)
|
||||
assert list(vp) == ["frame_0000.png", "frame_0001.png"]
|
||||
def test_iteration(self, video_predictions_payload: dict) -> None:
|
||||
vp = VideoPredictions.model_validate(video_predictions_payload)
|
||||
assert list(vp) == ["frame_000000", "frame_000001"]
|
||||
|
||||
def test_save_and_load_roundtrip(self, tmp_path: Path, video_payload: dict) -> None:
|
||||
vp = VideoPredictions.model_validate(video_payload)
|
||||
def test_rejects_missing_metadata(self, video_predictions_payload: dict) -> None:
|
||||
del video_predictions_payload["metadata"]
|
||||
with pytest.raises(ValidationError):
|
||||
VideoPredictions.model_validate(video_predictions_payload)
|
||||
|
||||
def test_save_and_load_roundtrip(
|
||||
self,
|
||||
tmp_path: Path,
|
||||
video_predictions_payload: dict,
|
||||
) -> None:
|
||||
vp = VideoPredictions.model_validate(video_predictions_payload)
|
||||
path = tmp_path / "preds" / "video.json"
|
||||
save_video_predictions(path, vp)
|
||||
assert path.exists()
|
||||
loaded = load_video_predictions(path)
|
||||
assert loaded.frames() == vp.frames()
|
||||
assert loaded["frame_0000.png"].poses3d == vp["frame_0000.png"].poses3d
|
||||
assert loaded.frame_names() == vp.frame_names()
|
||||
assert loaded.metadata == vp.metadata
|
||||
assert loaded["frame_000000"].poses3d == vp["frame_000000"].poses3d
|
||||
|
||||
def test_save_is_atomic(
|
||||
self,
|
||||
tmp_path: Path,
|
||||
video_predictions_payload: dict,
|
||||
) -> None:
|
||||
vp = VideoPredictions.model_validate(video_predictions_payload)
|
||||
path = tmp_path / "video.json"
|
||||
save_video_predictions(path, vp)
|
||||
assert path.exists()
|
||||
tmps = list(tmp_path.glob("video.json.tmp"))
|
||||
assert tmps == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -102,9 +171,16 @@ class TestVideoPredictions:
|
|||
|
||||
|
||||
class TestJobResults:
|
||||
def test_save_and_load_roundtrip(self, tmp_path: Path, video_payload: dict) -> None:
|
||||
def test_save_and_load_roundtrip(
|
||||
self,
|
||||
tmp_path: Path,
|
||||
video_predictions_payload: dict,
|
||||
) -> None:
|
||||
jr = JobResults.model_validate(
|
||||
{"video_a.mp4": video_payload, "video_b.mp4": video_payload}
|
||||
{
|
||||
"video_a.mp4": video_predictions_payload,
|
||||
"video_b.mp4": video_predictions_payload,
|
||||
}
|
||||
)
|
||||
path = tmp_path / "results.json"
|
||||
save_job_results(path, jr)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,158 @@
|
|||
"""Tests for :mod:`neuropose.visualize`.
|
||||
|
||||
Smoke tests only: we verify that the visualize function runs end-to-end
|
||||
against a synthetic video, writes PNG files, honours ``frame_indices``, and
|
||||
does NOT mutate the caller's :class:`VideoPredictions`. Actual pixel-level
|
||||
correctness is out of scope for unit tests.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from neuropose.estimator import Estimator, VideoDecodeError
|
||||
from neuropose.io import VideoPredictions
|
||||
from neuropose.visualize import visualize_predictions
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def predictions_for_synthetic(
|
||||
synthetic_video: Path,
|
||||
fake_metrabs_model,
|
||||
) -> VideoPredictions:
|
||||
"""Run the fake estimator over the synthetic video and return predictions."""
|
||||
estimator = Estimator(model=fake_metrabs_model)
|
||||
return estimator.process_video(synthetic_video).predictions
|
||||
|
||||
|
||||
class TestVisualizePredictions:
|
||||
def test_writes_one_png_per_frame(
|
||||
self,
|
||||
tmp_path: Path,
|
||||
synthetic_video: Path,
|
||||
predictions_for_synthetic: VideoPredictions,
|
||||
) -> None:
|
||||
output_dir = tmp_path / "viz"
|
||||
written = visualize_predictions(
|
||||
synthetic_video, predictions_for_synthetic, output_dir
|
||||
)
|
||||
assert len(written) == 5
|
||||
for path in written:
|
||||
assert path.exists()
|
||||
assert path.suffix == ".png"
|
||||
assert path.stat().st_size > 0
|
||||
|
||||
def test_frame_indices_limits_output(
|
||||
self,
|
||||
tmp_path: Path,
|
||||
synthetic_video: Path,
|
||||
predictions_for_synthetic: VideoPredictions,
|
||||
) -> None:
|
||||
output_dir = tmp_path / "viz"
|
||||
written = visualize_predictions(
|
||||
synthetic_video,
|
||||
predictions_for_synthetic,
|
||||
output_dir,
|
||||
frame_indices=[0, 2, 4],
|
||||
)
|
||||
assert len(written) == 3
|
||||
names = sorted(p.stem for p in written)
|
||||
assert names == ["frame_000000", "frame_000002", "frame_000004"]
|
||||
|
||||
def test_out_of_range_indices_silently_skipped(
|
||||
self,
|
||||
tmp_path: Path,
|
||||
synthetic_video: Path,
|
||||
predictions_for_synthetic: VideoPredictions,
|
||||
) -> None:
|
||||
output_dir = tmp_path / "viz"
|
||||
written = visualize_predictions(
|
||||
synthetic_video,
|
||||
predictions_for_synthetic,
|
||||
output_dir,
|
||||
frame_indices=[0, 999, -1],
|
||||
)
|
||||
assert len(written) == 1
|
||||
assert written[0].stem == "frame_000000"
|
||||
|
||||
def test_does_not_mutate_input_predictions(
|
||||
self,
|
||||
tmp_path: Path,
|
||||
synthetic_video: Path,
|
||||
predictions_for_synthetic: VideoPredictions,
|
||||
) -> None:
|
||||
"""The aliasing bug from the previous prototype mutated poses3d in place.
|
||||
|
||||
This test guards against its regression by deep-copying poses3d
|
||||
before visualization and comparing against the live value after.
|
||||
Without the deepcopy, any in-place mutation would be invisible
|
||||
because ``before`` and ``after`` would refer to the same list.
|
||||
"""
|
||||
frame_name = predictions_for_synthetic.frame_names()[0]
|
||||
before = copy.deepcopy(predictions_for_synthetic[frame_name].poses3d)
|
||||
visualize_predictions(
|
||||
synthetic_video,
|
||||
predictions_for_synthetic,
|
||||
tmp_path / "viz",
|
||||
frame_indices=[0],
|
||||
)
|
||||
after = predictions_for_synthetic[frame_name].poses3d
|
||||
assert before == after
|
||||
|
||||
def test_rejects_invalid_view(
|
||||
self,
|
||||
tmp_path: Path,
|
||||
synthetic_video: Path,
|
||||
predictions_for_synthetic: VideoPredictions,
|
||||
) -> None:
|
||||
with pytest.raises(ValueError, match="view must be one of"):
|
||||
visualize_predictions(
|
||||
synthetic_video,
|
||||
predictions_for_synthetic,
|
||||
tmp_path / "viz",
|
||||
view="orthographic",
|
||||
)
|
||||
|
||||
def test_depth_view_runs(
|
||||
self,
|
||||
tmp_path: Path,
|
||||
synthetic_video: Path,
|
||||
predictions_for_synthetic: VideoPredictions,
|
||||
) -> None:
|
||||
written = visualize_predictions(
|
||||
synthetic_video,
|
||||
predictions_for_synthetic,
|
||||
tmp_path / "viz",
|
||||
view="depth",
|
||||
frame_indices=[0],
|
||||
)
|
||||
assert len(written) == 1
|
||||
|
||||
def test_missing_video_raises(
|
||||
self,
|
||||
tmp_path: Path,
|
||||
predictions_for_synthetic: VideoPredictions,
|
||||
) -> None:
|
||||
with pytest.raises(FileNotFoundError):
|
||||
visualize_predictions(
|
||||
tmp_path / "nope.avi",
|
||||
predictions_for_synthetic,
|
||||
tmp_path / "viz",
|
||||
)
|
||||
|
||||
def test_unreadable_video_raises(
|
||||
self,
|
||||
tmp_path: Path,
|
||||
predictions_for_synthetic: VideoPredictions,
|
||||
) -> None:
|
||||
fake_video = tmp_path / "fake.avi"
|
||||
fake_video.write_bytes(b"not a video")
|
||||
with pytest.raises(VideoDecodeError):
|
||||
visualize_predictions(
|
||||
fake_video,
|
||||
predictions_for_synthetic,
|
||||
tmp_path / "viz",
|
||||
)
|
||||
Loading…
Reference in New Issue