Compare commits
No commits in common. "4f3a6241fbb77bb0d4b8ae0175c6c06d15a9b598" and "01b9ed9475652de70eeb81a0210bc1caf05905a9" have entirely different histories.
4f3a6241fb
...
01b9ed9475
197
CHANGELOG.md
197
CHANGELOG.md
|
|
@ -202,170 +202,6 @@ be split into per-release sections once tagging begins.
|
|||
(with a `.collisions` list of offending names). The running
|
||||
daemon needs no changes — ingested job dirs are picked up on the
|
||||
next poll.
|
||||
- **`neuropose.migrations`** — schema-migration infrastructure for
|
||||
the three top-level serialised payloads (`VideoPredictions`,
|
||||
`JobResults`, `BenchmarkResult`). Every payload carries a
|
||||
`schema_version` field defaulting to `CURRENT_VERSION`; on load,
|
||||
the raw JSON dict is passed through `migrate_video_predictions` /
|
||||
`migrate_job_results` / `migrate_benchmark_result` *before*
|
||||
pydantic validation so files written by older NeuroPose versions
|
||||
upgrade transparently. One shared `CURRENT_VERSION` counter;
|
||||
per-schema migration registries populated via
|
||||
`register_video_predictions_migration(from_version)` and
|
||||
`register_benchmark_result_migration(from_version)` decorators.
|
||||
`JobResults` is a `RootModel` with no envelope of its own, so its
|
||||
migration runs per-entry across the root mapping. The driver raises
|
||||
`FutureSchemaError` for payloads newer than the current build
|
||||
(clear upgrade-NeuroPose message), `MigrationNotFoundError` for
|
||||
missing chain links (indicates a `CURRENT_VERSION` bump that forgot
|
||||
its migration), and logs at INFO on each version advance. Currently
|
||||
at `CURRENT_VERSION = 2`, with registered v1 → v2 migrations for
|
||||
`VideoPredictions` and `BenchmarkResult` that add the optional
|
||||
`provenance` field.
|
||||
- **`neuropose.analyzer.features.procrustes_align`** — Kabsch
|
||||
rigid-alignment helper for pose sequences, plus a
|
||||
`ProcrustesMode` literal (`"per_frame"` | `"per_sequence"`) and a
|
||||
frozen `AlignmentDiagnostics` dataclass (`rotation_deg`,
|
||||
`rotation_deg_max`, `translation`, `translation_max`, `scale`,
|
||||
plus the mode that produced them). Per-sequence mode fits one
|
||||
rigid transform across the whole trial; per-frame fits an
|
||||
independent transform per frame. Optional `scale=True` fits a
|
||||
uniform scale factor for cross-subject comparisons. Wired into
|
||||
every DTW entry point in `neuropose.analyzer.dtw` via a new
|
||||
keyword-only `align: AlignMode = "none"` parameter — `"none"`
|
||||
preserves the 0.1 raw-coordinate behaviour, while
|
||||
`"procrustes_per_frame"` and `"procrustes_per_sequence"` route
|
||||
inputs through `procrustes_align` before DTW runs so the returned
|
||||
distance is rotation- and translation-invariant. Paper C's
|
||||
pipeline is expected to set `align="procrustes_per_sequence"`;
|
||||
see `TECHNICAL.md` Phase 0.
|
||||
- **`neuropose.analyzer.dtw.Representation`** and
|
||||
**`neuropose.analyzer.dtw.NanPolicy`** — two new Literal types
|
||||
exposing orthogonal DTW preprocessing knobs on every entry point.
|
||||
`representation` (on `dtw_all` and `dtw_per_joint`) switches the
|
||||
per-frame feature vector between `"coords"` (the 0.1 default) and
|
||||
`"angles"`, which runs `extract_joint_angles` on the supplied
|
||||
`angle_triplets` first — yielding distances that are translation-,
|
||||
rotation-, and scale-invariant by construction, and directly
|
||||
interpretable in clinical terms. `nan_policy` (on all three entry
|
||||
points) selects `"propagate"` (surface fastdtw's ValueError on
|
||||
NaN — the default), `"interpolate"` (linear fill per feature
|
||||
column), or `"drop"` (remove NaN frames before DTW); the
|
||||
policy is applied consistently whether NaN originated from the
|
||||
angles pipeline or from corrupted upstream coordinates.
|
||||
`dtw_relation` stays a standalone convenience entry point for
|
||||
two-joint displacement DTW; users who prefer a unified API can
|
||||
express the same computation via `dtw_all` with an appropriate
|
||||
pair of angle triplets or run `dtw_relation` directly.
|
||||
- **`neuropose.analyzer.pipeline`** (schemas) — declarative
|
||||
analysis-pipeline configuration and output artifact, parseable from
|
||||
YAML or JSON via pydantic. `AnalysisConfig` captures a full
|
||||
experiment: inputs (primary + optional reference predictions
|
||||
files), preprocessing (person index, with room to grow),
|
||||
optional segmentation (`gait_cycles` / `gait_cycles_bilateral` /
|
||||
`extractor` discriminated union), and a required analysis stage
|
||||
(`dtw` / `stats` / `none` discriminated union). `AnalysisReport`
|
||||
is the runtime output: carries the originating config, a
|
||||
`Provenance` envelope with `analysis_config` populated, per-input
|
||||
summaries, produced segmentations, and an analysis-result payload
|
||||
that mirrors the stage choice (`DtwResults`, `StatsResults`, or
|
||||
`NoResults`). Cross-field invariants — `method="dtw_relation"`
|
||||
requires `joint_i`/`joint_j`, `representation="angles"` requires
|
||||
non-empty `angle_triplets`, `analysis.kind="dtw"` requires
|
||||
`inputs.reference`, `analysis.kind="stats"` refuses a reference —
|
||||
are enforced at parse time via `model_validator` so typos fail in
|
||||
milliseconds instead of after a multi-minute predictions load.
|
||||
`AnalysisReport` carries a `schema_version` field defaulting to
|
||||
`CURRENT_VERSION = 2`, with a new
|
||||
`register_analysis_report_migration` decorator and
|
||||
`migrate_analysis_report` driver in `neuropose.migrations` ready
|
||||
for future schema changes. `run_analysis(config)` loads the named
|
||||
predictions files, applies the configured segmentation, dispatches
|
||||
to the selected analysis kind (DTW, stats, or none), and emits a
|
||||
fully populated `AnalysisReport` whose `Provenance` inherits the
|
||||
inference-time envelope from the primary input with
|
||||
`analysis_config` stamped in, so the report is self-describing
|
||||
even if the source YAML is lost. For DTW runs with segmentation,
|
||||
segments are paired one-to-one by index across primary and
|
||||
reference, truncating to `min(len_primary, len_reference)`;
|
||||
bilateral segmentations emit per-side distances under
|
||||
`"left_heel_strikes[i]"` / `"right_heel_strikes[i]"` labels.
|
||||
`load_config(path)` parses YAML, `save_report(path, report)`
|
||||
writes atomically, and `load_report(path)` rehydrates via the
|
||||
migration chain. Wired to the CLI as `neuropose analyze --config
|
||||
<yaml> [--output <json>]` — replaces the placeholder stub that
|
||||
previously returned `EXIT_PENDING`. The CLI surfaces schema
|
||||
violations and YAML parse errors as `EXIT_USAGE=2` with a clear
|
||||
message pointing at the offending file, prints a one-line summary
|
||||
of the run (segmentation counts, analysis kind, per-segment
|
||||
distance count + mean for DTW), and supports `--output`/`-o` to
|
||||
override the report path declared in the config (useful for
|
||||
sweeping a single config over multiple input pairs from a shell
|
||||
loop). Ships three example configs under `examples/analysis/`:
|
||||
`minimal.yaml` (smallest working DTW pipeline), `paper_c_headline.yaml`
|
||||
(representative Paper C config with bilateral gait-cycle
|
||||
segmentation, per-sequence Procrustes, and joint-angle DTW on
|
||||
knee/hip triplets), and `per_joint_debug.yaml` (per-joint DTW
|
||||
breakdown for diagnosing which joint drives an unexpected
|
||||
distance). An integration suite exercises each example against
|
||||
synthetic predictions so schema drift between the YAMLs and the
|
||||
executor fails CI, not silently at run time. Documented in
|
||||
`docs/api/pipeline.md`.
|
||||
- **`neuropose.analyzer.segment.segment_gait_cycles`** and
|
||||
**`segment_gait_cycles_bilateral`** — clinical convenience
|
||||
wrappers over `segment_predictions` that pre-fill a `joint_axis`
|
||||
extractor with gait-appropriate defaults (`joint="rhee"`,
|
||||
`axis="y"`, `min_cycle_seconds=0.4`). The single-side entry point
|
||||
accepts any berkeley_mhad_43 joint name and any spatial axis as a
|
||||
string literal `"x" | "y" | "z"`, plus an `invert` flag for
|
||||
recordings whose vertical axis runs opposite to MeTRAbs's
|
||||
Y-down world-coordinate convention. The bilateral wrapper runs
|
||||
the detection on both `lhee` and `rhee` and returns the two
|
||||
results under `"left_heel_strikes"` / `"right_heel_strikes"`
|
||||
keys — shape-compatible with `VideoPredictions.segmentations` so
|
||||
the dict can be merged in directly. Degrades gracefully on
|
||||
pathological gaits (shuffling, walker-assisted) by returning an
|
||||
empty segments list rather than raising. Closes the gait-cycle
|
||||
segmentation item in `TECHNICAL.md` Phase 0.
|
||||
- **`neuropose.io.Provenance`** — reproducibility envelope for every
|
||||
inference run. Populated automatically by `Estimator.process_video`
|
||||
when the model was loaded via `load_model` (the production path)
|
||||
and attached to the output `VideoPredictions`; propagates from
|
||||
there into `JobResults` (per-video) and `BenchmarkResult` (via the
|
||||
benchmark loop). Captures the MeTRAbs artifact SHA-256 and
|
||||
filename, `tensorflow` / `tensorflow-metal` / `numpy` /
|
||||
`neuropose` / Python versions, and reserved slots for a `seed`,
|
||||
`deterministic` flag (Track 2), and `analysis_config` (Phase 0
|
||||
YAML pipeline). `None` on the injected-model test path where
|
||||
NeuroPose has no way to fingerprint the supplied artifact. Frozen
|
||||
pydantic model with `extra="forbid"` and
|
||||
`protected_namespaces=()` so the `model_*` field names do not
|
||||
collide with pydantic v2's internal namespace. `_model.load_metrabs_model`
|
||||
now returns a `LoadedModel` dataclass bundling the TF handle with
|
||||
the pinned SHA and filename so the estimator can build the
|
||||
`Provenance` without re-hashing the tarball.
|
||||
- **`neuropose.reset`** — pipeline-wide reset utility for the
|
||||
benchmark / iteration loop. `find_neuropose_processes()` scans the
|
||||
OS process table (via `psutil`) for running `neuropose watch` and
|
||||
`neuropose serve` instances and classifies each as `daemon` or
|
||||
`monitor`. `terminate_processes()` SIGINTs them, polls for graceful
|
||||
exit up to a configurable grace period, and optionally escalates
|
||||
to SIGKILL with `force_kill=True`. `wipe_state()` removes the
|
||||
contents of `$data_dir/in/`, `$data_dir/out/` (including
|
||||
`status.json`), `$data_dir/failed/` (unless `keep_failed=True`),
|
||||
the `.neuropose.lock` file, and any leftover `.ingest_<uuid>/`
|
||||
staging dirs from interrupted ingests; container directories
|
||||
themselves are preserved so the daemon does not need to recreate
|
||||
them on next startup. `reset_pipeline()` composes the three with
|
||||
one safety guard: if any process survives termination, the wipe
|
||||
phase is skipped and the returned `ResetReport` flags
|
||||
`wipe_skipped_due_to_survivors`, because removing `$data_dir`
|
||||
out from under an active daemon would corrupt its in-flight
|
||||
writes. Surfaced as `neuropose reset` in the CLI with
|
||||
`--yes/-y`, `--keep-failed`, `--force-kill`, `--grace-seconds`,
|
||||
and `--dry-run/-n` flags; the command always prints a preview
|
||||
before prompting (skipped under `--yes`) and returns
|
||||
`EXIT_USAGE=2` when survivors block the wipe.
|
||||
- **`neuropose.benchmark`** — multi-pass inference benchmarking for
|
||||
a single video. `run_benchmark()` runs `process_video` N times
|
||||
(default 5), always discards the first pass as warmup (graph
|
||||
|
|
@ -385,18 +221,14 @@ be split into per-release sections once tagging begins.
|
|||
imports for the heavy dependencies:
|
||||
- `analyzer.dtw` — three DTW entry points (`dtw_all`,
|
||||
`dtw_per_joint`, `dtw_relation`) over fastdtw, with a frozen
|
||||
`DTWResult` dataclass and three orthogonal preprocessing knobs
|
||||
(`align`, `representation`, `nan_policy`). See `RESEARCH.md`
|
||||
for the ongoing
|
||||
`DTWResult` dataclass. See `RESEARCH.md` for the ongoing
|
||||
methodology investigation.
|
||||
- `analyzer.features` — `predictions_to_numpy`,
|
||||
`normalize_pose_sequence` (uniform and axis-wise),
|
||||
`pad_sequences` (edge-padding), `procrustes_align` (Kabsch
|
||||
rigid alignment, per-frame or per-sequence, optional uniform
|
||||
scaling), `extract_joint_angles` (NaN on degenerate vectors),
|
||||
`extract_feature_statistics` (`FeatureStatistics` frozen
|
||||
dataclass), and a `find_peaks` thin wrapper around
|
||||
`scipy.signal.find_peaks`.
|
||||
`pad_sequences` (edge-padding), `extract_joint_angles` (NaN on
|
||||
degenerate vectors), `extract_feature_statistics`
|
||||
(`FeatureStatistics` frozen dataclass), and a `find_peaks` thin
|
||||
wrapper around `scipy.signal.find_peaks`.
|
||||
- `analyzer.segment` — repetition segmentation for trials in
|
||||
which a subject performs the same movement several times. A
|
||||
three-layer API: `segment_by_peaks` (pure 1D
|
||||
|
|
@ -406,11 +238,7 @@ be split into per-release sections once tagging begins.
|
|||
time-based parameters to frame counts via `metadata.fps`), and
|
||||
`slice_predictions` (split a `VideoPredictions` into one per
|
||||
detected repetition with re-keyed frame names and a rewritten
|
||||
`frame_count`). Gait-specific convenience wrappers
|
||||
`segment_gait_cycles` (single heel) and
|
||||
`segment_gait_cycles_bilateral` (both heels, returning a dict
|
||||
keyed by `"left_heel_strikes"` / `"right_heel_strikes"`) sit
|
||||
above `segment_predictions` with clinical defaults. Ships four extractor factories —
|
||||
`frame_count`). Ships four extractor factories —
|
||||
`joint_axis`, `joint_pair_distance`, `joint_speed`, and
|
||||
`joint_angle` — plus a `JOINT_NAMES` constant for the
|
||||
berkeley_mhad_43 skeleton with a `joint_index(name)` lookup,
|
||||
|
|
@ -420,7 +248,7 @@ be split into per-release sections once tagging begins.
|
|||
`slow`) loads the real model and asserts the constant still
|
||||
matches, so any upstream skeleton drift fails CI.
|
||||
- **`neuropose.cli`** — Typer-based command-line interface with
|
||||
eight subcommands: `watch` (run the daemon), `process <video>`
|
||||
seven subcommands: `watch` (run the daemon), `process <video>`
|
||||
(run the estimator on a single video), `ingest <archive>` (unzip
|
||||
a video archive into per-video job directories under
|
||||
`$data_dir/in/` with validation-before-write and atomic
|
||||
|
|
@ -431,13 +259,6 @@ be split into per-release sections once tagging begins.
|
|||
KeyboardInterrupt exits with the standard shell-interruption
|
||||
code and an `OSError` at bind time is translated to a clean
|
||||
usage error with the bind target in the message),
|
||||
`reset` (stop the daemon and monitor, then wipe pipeline state
|
||||
for a clean restart — wraps `neuropose.reset` with a confirmation
|
||||
prompt, `--dry-run` preview, `--keep-failed` to preserve the
|
||||
forensic quarantine, `--force-kill` to escalate to SIGKILL after
|
||||
the SIGINT grace period, and `--grace-seconds` to tune the wait;
|
||||
refuses to wipe state while any process survives termination so
|
||||
active writes cannot be corrupted),
|
||||
`segment <results>` (post-hoc repetition segmentation — loads a
|
||||
JobResults or a single VideoPredictions, runs
|
||||
`neuropose.analyzer.segment.segment_predictions` with the chosen
|
||||
|
|
@ -452,9 +273,7 @@ be split into per-release sections once tagging begins.
|
|||
the resulting `poses3d` arrays, and reports throughput speedup
|
||||
and max divergence in mm — the missing Apple Silicon numerical
|
||||
verification answer from `RESEARCH.md`), and
|
||||
`analyze --config <yaml>` (run the declarative analysis
|
||||
pipeline — see the dedicated entry above for scope). The
|
||||
`segment` subcommand accepts
|
||||
`analyze <results>` (stub). The `segment` subcommand accepts
|
||||
joint specifiers as either berkeley_mhad_43 names (`lwri`,
|
||||
`rwri`, …) or integer indices, and refuses to overwrite an
|
||||
existing segmentation of the same name without `--force`.
|
||||
|
|
|
|||
1191
TECHNICAL.md
1191
TECHNICAL.md
File diff suppressed because it is too large
Load Diff
|
|
@ -1,3 +0,0 @@
|
|||
# `neuropose.migrations`
|
||||
|
||||
::: neuropose.migrations
|
||||
|
|
@ -1,3 +0,0 @@
|
|||
# `neuropose.analyzer.pipeline`
|
||||
|
||||
::: neuropose.analyzer.pipeline
|
||||
|
|
@ -1,3 +0,0 @@
|
|||
# `neuropose.reset`
|
||||
|
||||
::: neuropose.reset
|
||||
|
|
@ -1,26 +0,0 @@
|
|||
# Minimal DTW config: full-trial comparison on raw 3D coordinates,
|
||||
# no Procrustes alignment, no segmentation. The simplest working
|
||||
# example; use this as a starting template and add stages as needed.
|
||||
#
|
||||
# Run:
|
||||
# neuropose analyze --config examples/analysis/minimal.yaml \
|
||||
# --output out/minimal_report.json
|
||||
#
|
||||
# (Substitute real paths for `inputs.primary` and `inputs.reference`
|
||||
# before running.)
|
||||
|
||||
config_version: 1
|
||||
|
||||
inputs:
|
||||
primary: data/trial_a.json
|
||||
reference: data/trial_b.json
|
||||
|
||||
analysis:
|
||||
kind: dtw
|
||||
method: dtw_all
|
||||
align: none
|
||||
representation: coords
|
||||
nan_policy: propagate
|
||||
|
||||
output:
|
||||
report: out/minimal_report.json
|
||||
|
|
@ -1,48 +0,0 @@
|
|||
# Paper C headline config: cycle-segmented joint-angle DTW with
|
||||
# per-sequence Procrustes alignment. This is the representative
|
||||
# Paper C pipeline — bilateral gait cycles drive the segmentation
|
||||
# so distances are reported per-stride per-side, and the angle-space
|
||||
# representation makes the distance clinically interpretable
|
||||
# (knee flexion angle, hip extension angle, etc.).
|
||||
#
|
||||
# Joint-triplet indices below target the berkeley_mhad_43 skeleton:
|
||||
# - (27, 31, 32): left hip → left knee → left ankle (left knee flex)
|
||||
# - (35, 39, 40): right hip → right knee → right ankle (right knee flex)
|
||||
# - (34, 27, 31): back hip ← left hip → left knee (left hip flex)
|
||||
# - (34, 35, 39): back hip ← right hip → right knee (right hip flex)
|
||||
#
|
||||
# See neuropose.analyzer.JOINT_NAMES for the full 43-joint table.
|
||||
#
|
||||
# Run:
|
||||
# neuropose analyze --config examples/analysis/paper_c_headline.yaml \
|
||||
# --output out/paper_c_report.json
|
||||
|
||||
config_version: 1
|
||||
|
||||
inputs:
|
||||
primary: data/subject_trial.json
|
||||
reference: data/mocap_reference.json
|
||||
|
||||
preprocessing:
|
||||
person_index: 0
|
||||
|
||||
segmentation:
|
||||
kind: gait_cycles_bilateral
|
||||
axis: y
|
||||
invert: false
|
||||
min_cycle_seconds: 0.4
|
||||
|
||||
analysis:
|
||||
kind: dtw
|
||||
method: dtw_all
|
||||
align: procrustes_per_sequence
|
||||
representation: angles
|
||||
angle_triplets:
|
||||
- [27, 31, 32] # left knee flexion
|
||||
- [35, 39, 40] # right knee flexion
|
||||
- [34, 27, 31] # left hip flexion
|
||||
- [34, 35, 39] # right hip flexion
|
||||
nan_policy: interpolate
|
||||
|
||||
output:
|
||||
report: out/paper_c_report.json
|
||||
|
|
@ -1,36 +0,0 @@
|
|||
# Per-joint debug config: runs dtw_per_joint on raw coordinates so
|
||||
# the resulting report carries a full (segments × joints) distance
|
||||
# breakdown. Useful when one joint is suspected of driving an
|
||||
# otherwise-unexpected DTW distance — the per-joint numbers make it
|
||||
# obvious which joint's trajectory diverges most.
|
||||
#
|
||||
# Raw coordinates (representation: coords) are used rather than
|
||||
# angles because joint-level debugging is most interpretable in the
|
||||
# native measurement space. Procrustes alignment is on so
|
||||
# translation and rotation between trials do not inflate the numbers.
|
||||
#
|
||||
# Run:
|
||||
# neuropose analyze --config examples/analysis/per_joint_debug.yaml \
|
||||
# --output out/per_joint_report.json
|
||||
|
||||
config_version: 1
|
||||
|
||||
inputs:
|
||||
primary: data/trial_a.json
|
||||
reference: data/trial_b.json
|
||||
|
||||
segmentation:
|
||||
kind: gait_cycles
|
||||
joint: rhee
|
||||
axis: y
|
||||
min_cycle_seconds: 0.4
|
||||
|
||||
analysis:
|
||||
kind: dtw
|
||||
method: dtw_per_joint
|
||||
align: procrustes_per_sequence
|
||||
representation: coords
|
||||
nan_policy: propagate
|
||||
|
||||
output:
|
||||
report: out/per_joint_report.json
|
||||
|
|
@ -95,12 +95,9 @@ nav:
|
|||
- neuropose.interfacer: api/interfacer.md
|
||||
- neuropose.ingest: api/ingest.md
|
||||
- neuropose.monitor: api/monitor.md
|
||||
- neuropose.reset: api/reset.md
|
||||
- neuropose.io: api/io.md
|
||||
- neuropose.migrations: api/migrations.md
|
||||
- neuropose.benchmark: api/benchmark.md
|
||||
- neuropose.analyzer.segment: api/segment.md
|
||||
- neuropose.analyzer.pipeline: api/pipeline.md
|
||||
- neuropose.visualize: api/visualize.md
|
||||
- Development: development.md
|
||||
- Deployment: deployment.md
|
||||
|
|
|
|||
|
|
@ -41,34 +41,11 @@ import os
|
|||
import shutil
|
||||
import tarfile
|
||||
import urllib.request
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LoadedModel:
|
||||
"""Result of :func:`load_metrabs_model`.
|
||||
|
||||
Bundles the loaded TensorFlow model with the provenance metadata
|
||||
that identifies which artifact it came from. Callers that only want
|
||||
the model reach for :attr:`model`; callers that build a
|
||||
:class:`~neuropose.io.Provenance` (primarily
|
||||
:class:`~neuropose.estimator.Estimator`) pull :attr:`sha256` and
|
||||
:attr:`filename` too.
|
||||
|
||||
Frozen — once :func:`load_metrabs_model` has produced a
|
||||
``LoadedModel``, nothing downstream should edit the identity of
|
||||
the artifact it describes.
|
||||
"""
|
||||
|
||||
model: Any
|
||||
sha256: str
|
||||
filename: str
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Model artifact: pinned URL and checksum.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -97,7 +74,7 @@ _REQUIRED_MODEL_ATTRS = (
|
|||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def load_metrabs_model(cache_dir: Path | None = None) -> LoadedModel:
|
||||
def load_metrabs_model(cache_dir: Path | None = None) -> Any:
|
||||
"""Load the MeTRAbs model, downloading and caching on first use.
|
||||
|
||||
Parameters
|
||||
|
|
@ -110,11 +87,9 @@ def load_metrabs_model(cache_dir: Path | None = None) -> LoadedModel:
|
|||
|
||||
Returns
|
||||
-------
|
||||
LoadedModel
|
||||
Bundle containing the TensorFlow SavedModel handle alongside
|
||||
the pinned artifact SHA-256 and filename that identify which
|
||||
model the handle came from. The handle exposes ``detect_poses``
|
||||
and the ``per_skeleton_joint_names`` / ``per_skeleton_joint_edges``
|
||||
object
|
||||
A TensorFlow SavedModel handle exposing ``detect_poses`` and
|
||||
the ``per_skeleton_joint_names`` / ``per_skeleton_joint_edges``
|
||||
attributes used by :class:`neuropose.estimator.Estimator`.
|
||||
|
||||
Raises
|
||||
|
|
@ -124,18 +99,6 @@ def load_metrabs_model(cache_dir: Path | None = None) -> LoadedModel:
|
|||
automatic retry), extraction fails, TensorFlow is not
|
||||
installed, or the loaded model does not expose the expected
|
||||
interface.
|
||||
|
||||
Notes
|
||||
-----
|
||||
The returned ``sha256`` is the module-pinned :data:`_MODEL_SHA256`,
|
||||
not a re-hash of the on-disk tarball. On the cold-cache path this
|
||||
is exactly the hash we verified against before loading. On the
|
||||
warm-cache path the tarball is not re-verified (that would cost a
|
||||
2 GB I/O pass on every daemon startup), so the reported SHA is an
|
||||
attestation of "this is the pinned artifact NeuroPose loads" rather
|
||||
than a direct fingerprint of the on-disk bytes. For the threat
|
||||
model this supports — reproducibility, not tamper-evidence — that
|
||||
is the correct semantics.
|
||||
"""
|
||||
resolved_cache = Path(cache_dir) if cache_dir is not None else _default_cache_dir()
|
||||
resolved_cache.mkdir(parents=True, exist_ok=True)
|
||||
|
|
@ -152,11 +115,7 @@ def load_metrabs_model(cache_dir: Path | None = None) -> LoadedModel:
|
|||
)
|
||||
shutil.rmtree(model_dir, ignore_errors=True)
|
||||
else:
|
||||
return LoadedModel(
|
||||
model=_tf_load(saved_model_dir),
|
||||
sha256=_MODEL_SHA256,
|
||||
filename=_MODEL_ARCHIVE_NAME,
|
||||
)
|
||||
return _tf_load(saved_model_dir)
|
||||
|
||||
tarball = resolved_cache / _MODEL_ARCHIVE_NAME
|
||||
|
||||
|
|
@ -176,11 +135,7 @@ def load_metrabs_model(cache_dir: Path | None = None) -> LoadedModel:
|
|||
|
||||
_extract_tarball(tarball, model_dir)
|
||||
saved_model_dir = _find_saved_model(model_dir)
|
||||
return LoadedModel(
|
||||
model=_tf_load(saved_model_dir),
|
||||
sha256=_MODEL_SHA256,
|
||||
filename=_MODEL_ARCHIVE_NAME,
|
||||
)
|
||||
return _tf_load(saved_model_dir)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -28,56 +28,23 @@ here for ergonomic access.
|
|||
from __future__ import annotations
|
||||
|
||||
from neuropose.analyzer.dtw import (
|
||||
AlignMode,
|
||||
DTWResult,
|
||||
NanPolicy,
|
||||
Representation,
|
||||
dtw_all,
|
||||
dtw_per_joint,
|
||||
dtw_relation,
|
||||
)
|
||||
from neuropose.analyzer.features import (
|
||||
AlignmentDiagnostics,
|
||||
FeatureStatistics,
|
||||
ProcrustesMode,
|
||||
extract_feature_statistics,
|
||||
extract_joint_angles,
|
||||
find_peaks,
|
||||
normalize_pose_sequence,
|
||||
pad_sequences,
|
||||
predictions_to_numpy,
|
||||
procrustes_align,
|
||||
)
|
||||
from neuropose.analyzer.pipeline import (
|
||||
AnalysisConfig,
|
||||
AnalysisReport,
|
||||
AnalysisResults,
|
||||
AnalysisStage,
|
||||
DtwAnalysis,
|
||||
DtwResults,
|
||||
ExtractorSegmentation,
|
||||
FeatureSummary,
|
||||
GaitCyclesBilateralSegmentation,
|
||||
GaitCyclesSegmentation,
|
||||
InputsConfig,
|
||||
InputSummary,
|
||||
NoAnalysis,
|
||||
NoResults,
|
||||
OutputConfig,
|
||||
PreprocessingConfig,
|
||||
SegmentationStage,
|
||||
StatsAnalysis,
|
||||
StatsResults,
|
||||
analysis_config_to_dict,
|
||||
load_config,
|
||||
load_report,
|
||||
run_analysis,
|
||||
save_report,
|
||||
)
|
||||
from neuropose.analyzer.segment import (
|
||||
JOINT_INDEX,
|
||||
JOINT_NAMES,
|
||||
AxisLetter,
|
||||
extract_signal,
|
||||
joint_angle,
|
||||
joint_axis,
|
||||
|
|
@ -85,8 +52,6 @@ from neuropose.analyzer.segment import (
|
|||
joint_pair_distance,
|
||||
joint_speed,
|
||||
segment_by_peaks,
|
||||
segment_gait_cycles,
|
||||
segment_gait_cycles_bilateral,
|
||||
segment_predictions,
|
||||
slice_predictions,
|
||||
)
|
||||
|
|
@ -94,34 +59,8 @@ from neuropose.analyzer.segment import (
|
|||
__all__ = [
|
||||
"JOINT_INDEX",
|
||||
"JOINT_NAMES",
|
||||
"AlignMode",
|
||||
"AlignmentDiagnostics",
|
||||
"AnalysisConfig",
|
||||
"AnalysisReport",
|
||||
"AnalysisResults",
|
||||
"AnalysisStage",
|
||||
"AxisLetter",
|
||||
"DTWResult",
|
||||
"DtwAnalysis",
|
||||
"DtwResults",
|
||||
"ExtractorSegmentation",
|
||||
"FeatureStatistics",
|
||||
"FeatureSummary",
|
||||
"GaitCyclesBilateralSegmentation",
|
||||
"GaitCyclesSegmentation",
|
||||
"InputSummary",
|
||||
"InputsConfig",
|
||||
"NanPolicy",
|
||||
"NoAnalysis",
|
||||
"NoResults",
|
||||
"OutputConfig",
|
||||
"PreprocessingConfig",
|
||||
"ProcrustesMode",
|
||||
"Representation",
|
||||
"SegmentationStage",
|
||||
"StatsAnalysis",
|
||||
"StatsResults",
|
||||
"analysis_config_to_dict",
|
||||
"dtw_all",
|
||||
"dtw_per_joint",
|
||||
"dtw_relation",
|
||||
|
|
@ -134,17 +73,10 @@ __all__ = [
|
|||
"joint_index",
|
||||
"joint_pair_distance",
|
||||
"joint_speed",
|
||||
"load_config",
|
||||
"load_report",
|
||||
"normalize_pose_sequence",
|
||||
"pad_sequences",
|
||||
"predictions_to_numpy",
|
||||
"procrustes_align",
|
||||
"run_analysis",
|
||||
"save_report",
|
||||
"segment_by_peaks",
|
||||
"segment_gait_cycles",
|
||||
"segment_gait_cycles_bilateral",
|
||||
"segment_predictions",
|
||||
"slice_predictions",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -2,12 +2,10 @@
|
|||
|
||||
Three entry points, ordered by increasing precision (and increasing cost):
|
||||
|
||||
- :func:`dtw_all` — DTW on the flattened per-frame feature vector. Fast
|
||||
but coarse; collapses every joint axis (or every angle triplet) into
|
||||
a single per-frame vector.
|
||||
- :func:`dtw_per_joint` — DTW on each joint (or angle triplet)
|
||||
independently. Preserves per-unit temporal alignment at the cost of
|
||||
one DTW call per unit.
|
||||
- :func:`dtw_all` — DTW on the flattened per-frame joint vector. Fast but
|
||||
coarse; collapses every joint axis into a single per-frame vector.
|
||||
- :func:`dtw_per_joint` — DTW on each joint independently. Preserves
|
||||
per-joint temporal alignment at the cost of one DTW call per joint.
|
||||
- :func:`dtw_relation` — DTW on the displacement vector between two
|
||||
specific joints. This is the right tool when the research question is
|
||||
about the *relative* motion of a specific pair of joints (e.g. the
|
||||
|
|
@ -18,24 +16,6 @@ and the warping path. Inputs are expected to be ``(frames, joints, 3)``
|
|||
numpy arrays — the shape :func:`~neuropose.analyzer.features.predictions_to_numpy`
|
||||
produces.
|
||||
|
||||
Three orthogonal preprocessing knobs are available on the entry points:
|
||||
|
||||
- **``align``** routes the inputs through
|
||||
:func:`~neuropose.analyzer.features.procrustes_align` before DTW runs,
|
||||
yielding translation- and rotation-invariant distances.
|
||||
``align="none"`` (the default) preserves the raw-coordinate behaviour
|
||||
shipped in 0.1.
|
||||
- **``representation``** (on :func:`dtw_all` and :func:`dtw_per_joint`)
|
||||
selects what each frame is reduced to before DTW. ``"coords"`` uses
|
||||
the raw joint coordinates; ``"angles"`` replaces them with joint
|
||||
angles computed at caller-supplied triplets via
|
||||
:func:`~neuropose.analyzer.features.extract_joint_angles`, giving
|
||||
DTW distances that are directly interpretable as clinical joint-range
|
||||
comparisons.
|
||||
- **``nan_policy``** decides how the DTW path handles non-finite values
|
||||
in its input — typically a concern only for the angle representation,
|
||||
where degenerate (zero-length) vectors produce NaN. See :data:`NanPolicy`.
|
||||
|
||||
Dependency note
|
||||
---------------
|
||||
This module requires :mod:`fastdtw` and :mod:`scipy`, which are part of
|
||||
|
|
@ -48,58 +28,11 @@ called.
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Sequence
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
import numpy as np
|
||||
|
||||
from neuropose.analyzer.features import extract_joint_angles, procrustes_align
|
||||
|
||||
AlignMode = Literal["none", "procrustes_per_frame", "procrustes_per_sequence"]
|
||||
"""Alignment selector for DTW entry points.
|
||||
|
||||
- ``"none"`` — feed raw coordinates directly to DTW.
|
||||
- ``"procrustes_per_frame"`` — per-frame Kabsch alignment before DTW.
|
||||
- ``"procrustes_per_sequence"`` — single sequence-wide Kabsch
|
||||
alignment before DTW.
|
||||
"""
|
||||
|
||||
Representation = Literal["coords", "angles"]
|
||||
"""Per-frame feature representation for :func:`dtw_all` and :func:`dtw_per_joint`.
|
||||
|
||||
- ``"coords"`` — use the raw joint coordinates (the input's last two
|
||||
axes). Preserves the 0.1 behaviour.
|
||||
- ``"angles"`` — replace joints with joint angles at caller-supplied
|
||||
triplets. Translation- and rotation-invariant by construction,
|
||||
scale-invariant modulo the upstream normalization, and directly
|
||||
interpretable in clinical terms ("knee flexion during swing phase").
|
||||
The ``angle_triplets`` keyword becomes mandatory in this mode.
|
||||
"""
|
||||
|
||||
NanPolicy = Literal["propagate", "interpolate", "drop"]
|
||||
"""Per-feature NaN handling for the DTW input.
|
||||
|
||||
NaN typically appears when ``representation="angles"`` encounters a
|
||||
degenerate (zero-length) vector — the angle is undefined and
|
||||
:func:`extract_joint_angles` propagates NaN rather than quietly returning
|
||||
a stand-in value.
|
||||
|
||||
- ``"propagate"`` (default) — pass NaN straight through to the DTW
|
||||
engine. fastdtw validates its input via
|
||||
:func:`numpy.asarray_chkfinite` and raises :class:`ValueError`
|
||||
the moment a NaN appears, which is the safest default because it
|
||||
makes the problem visible instead of quietly corrupting a
|
||||
distance.
|
||||
- ``"interpolate"`` — linearly interpolate NaN frames along each
|
||||
feature column using neighbouring finite values. Reasonable when a
|
||||
small number of frames are corrupted and the surrounding motion is
|
||||
smooth; inappropriate when long stretches are missing.
|
||||
- ``"drop"`` — remove any frame where *any* feature is NaN before DTW
|
||||
runs. Simple, but compresses the time axis, so warping-path indices
|
||||
refer to the *compacted* sequence rather than the original.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DTWResult:
|
||||
|
|
@ -144,44 +77,20 @@ def _require_fastdtw() -> tuple[Callable, Callable]:
|
|||
return fastdtw, euclidean
|
||||
|
||||
|
||||
def dtw_all(
|
||||
a: np.ndarray,
|
||||
b: np.ndarray,
|
||||
*,
|
||||
align: AlignMode = "none",
|
||||
representation: Representation = "coords",
|
||||
angle_triplets: Sequence[tuple[int, int, int]] | None = None,
|
||||
nan_policy: NanPolicy = "propagate",
|
||||
) -> DTWResult:
|
||||
"""DTW on the flattened per-frame feature vector.
|
||||
def dtw_all(a: np.ndarray, b: np.ndarray) -> DTWResult:
|
||||
"""DTW on the flattened per-frame joint vector.
|
||||
|
||||
Under the default ``representation="coords"`` each frame's joints
|
||||
are collapsed into a single vector before DTW is applied — fast
|
||||
(one DTW call regardless of joint count) but coarse, since a small
|
||||
Each frame's joints are collapsed into a single vector before DTW
|
||||
is applied. This is fast — one DTW call regardless of the joint
|
||||
count — but loses per-joint temporal structure, so a small
|
||||
timing mismatch on one joint can dominate the distance metric.
|
||||
Switching to ``representation="angles"`` computes joint angles at
|
||||
the supplied triplets first and flattens those instead.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
a, b
|
||||
Pose sequences as ``(frames, joints, 3)`` numpy arrays. The two
|
||||
sequences do not need to have the same number of frames, but
|
||||
they must have the same number of joints. When ``align`` is not
|
||||
``"none"``, the two sequences must additionally share a frame
|
||||
count (Procrustes requires a 1:1 correspondence).
|
||||
align
|
||||
Procrustes alignment mode applied before DTW. See
|
||||
:data:`AlignMode`.
|
||||
representation
|
||||
Per-frame feature representation. See :data:`Representation`.
|
||||
angle_triplets
|
||||
Required when ``representation="angles"``. Sequence of
|
||||
``(a, b, c)`` joint-index triplets passed through to
|
||||
:func:`~neuropose.analyzer.features.extract_joint_angles`.
|
||||
Ignored otherwise.
|
||||
nan_policy
|
||||
How to handle NaN values in the DTW input. See :data:`NanPolicy`.
|
||||
they must have the same number of joints.
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
|
@ -192,102 +101,48 @@ def dtw_all(
|
|||
Raises
|
||||
------
|
||||
ValueError
|
||||
If ``a`` and ``b`` do not have the same joint count, if
|
||||
``align`` requires a matching frame count that is not present,
|
||||
if ``representation="angles"`` is requested without
|
||||
``angle_triplets``, or if ``nan_policy="interpolate"``
|
||||
encounters an all-NaN column.
|
||||
If ``a`` and ``b`` do not have the same joint count.
|
||||
"""
|
||||
_validate_same_joint_count(a, b)
|
||||
a, b = _maybe_align(a, b, align=align)
|
||||
feat_a = _apply_representation(a, representation, angle_triplets=angle_triplets)
|
||||
feat_b = _apply_representation(b, representation, angle_triplets=angle_triplets)
|
||||
feat_a = _apply_nan_policy(feat_a, nan_policy)
|
||||
feat_b = _apply_nan_policy(feat_b, nan_policy)
|
||||
fastdtw, euclidean = _require_fastdtw()
|
||||
distance, path = fastdtw(feat_a, feat_b, dist=euclidean)
|
||||
a_flat = a.reshape(a.shape[0], -1)
|
||||
b_flat = b.reshape(b.shape[0], -1)
|
||||
distance, path = fastdtw(a_flat, b_flat, dist=euclidean)
|
||||
return DTWResult(distance=float(distance), path=[tuple(p) for p in path])
|
||||
|
||||
|
||||
def dtw_per_joint(
|
||||
a: np.ndarray,
|
||||
b: np.ndarray,
|
||||
*,
|
||||
align: AlignMode = "none",
|
||||
representation: Representation = "coords",
|
||||
angle_triplets: Sequence[tuple[int, int, int]] | None = None,
|
||||
nan_policy: NanPolicy = "propagate",
|
||||
) -> list[DTWResult]:
|
||||
"""DTW on each joint (or angle triplet) independently.
|
||||
def dtw_per_joint(a: np.ndarray, b: np.ndarray) -> list[DTWResult]:
|
||||
"""DTW on each joint independently.
|
||||
|
||||
Performs one DTW computation per unit, yielding a list of
|
||||
:class:`DTWResult` objects in input order. More precise than
|
||||
:func:`dtw_all` because each unit's temporal alignment is optimised
|
||||
separately, at the cost of J times more DTW calls for J units.
|
||||
|
||||
Under the default ``representation="coords"`` a "unit" is one of
|
||||
the input's joints (xyz treated jointly). Under
|
||||
``representation="angles"`` a "unit" is one scalar angle column
|
||||
computed from one ``angle_triplets`` entry.
|
||||
Performs one DTW computation per joint, yielding a list of
|
||||
:class:`DTWResult` objects in joint-index order. More precise than
|
||||
:func:`dtw_all` because each joint's temporal alignment is optimised
|
||||
separately, at the cost of J times more DTW calls for J joints.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
a, b
|
||||
Pose sequences as ``(frames, joints, 3)`` numpy arrays. The two
|
||||
sequences do not need to have the same number of frames but
|
||||
must have the same number of joints. When ``align`` is not
|
||||
``"none"``, they must additionally share a frame count.
|
||||
align
|
||||
Procrustes alignment mode applied before DTW. See
|
||||
:data:`AlignMode`.
|
||||
representation
|
||||
Per-frame feature representation. See :data:`Representation`.
|
||||
angle_triplets
|
||||
Required when ``representation="angles"``; see
|
||||
:func:`dtw_all` for details.
|
||||
nan_policy
|
||||
How to handle NaN values in the DTW input. See :data:`NanPolicy`.
|
||||
must have the same number of joints.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[DTWResult]
|
||||
One DTW result per joint or per angle triplet, in input order.
|
||||
One DTW result per joint, in index order.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
Same conditions as :func:`dtw_all`.
|
||||
If ``a`` and ``b`` do not have the same joint count.
|
||||
"""
|
||||
_validate_same_joint_count(a, b)
|
||||
a, b = _maybe_align(a, b, align=align)
|
||||
|
||||
if representation == "coords":
|
||||
feat_a = a
|
||||
feat_b = b
|
||||
# (frames, joints, 3) — one DTW per joint over its (frames, 3) slice.
|
||||
num_units = feat_a.shape[1]
|
||||
slicers: list[Callable[[np.ndarray], np.ndarray]] = [
|
||||
(lambda arr, idx=i: arr[:, idx, :]) for i in range(num_units)
|
||||
]
|
||||
else: # "angles"
|
||||
if angle_triplets is None:
|
||||
raise ValueError("representation='angles' requires angle_triplets")
|
||||
feat_a = extract_joint_angles(a, angle_triplets) # (frames, num_triplets)
|
||||
feat_b = extract_joint_angles(b, angle_triplets)
|
||||
num_units = feat_a.shape[1]
|
||||
slicers = [
|
||||
# Scalar columns become 2D for DTW (fastdtw expects a
|
||||
# sequence of vectors, not a sequence of scalars).
|
||||
(lambda arr, idx=i: arr[:, idx : idx + 1])
|
||||
for i in range(num_units)
|
||||
]
|
||||
|
||||
fastdtw, euclidean = _require_fastdtw()
|
||||
results: list[DTWResult] = []
|
||||
for slicer in slicers:
|
||||
unit_a = _apply_nan_policy(slicer(feat_a), nan_policy)
|
||||
unit_b = _apply_nan_policy(slicer(feat_b), nan_policy)
|
||||
distance, path = fastdtw(unit_a, unit_b, dist=euclidean)
|
||||
for joint_idx in range(a.shape[1]):
|
||||
a_joint = a[:, joint_idx, :]
|
||||
b_joint = b[:, joint_idx, :]
|
||||
distance, path = fastdtw(a_joint, b_joint, dist=euclidean)
|
||||
results.append(DTWResult(distance=float(distance), path=[tuple(p) for p in path]))
|
||||
return results
|
||||
|
||||
|
|
@ -297,9 +152,6 @@ def dtw_relation(
|
|||
b: np.ndarray,
|
||||
joint_i: int,
|
||||
joint_j: int,
|
||||
*,
|
||||
align: AlignMode = "none",
|
||||
nan_policy: NanPolicy = "propagate",
|
||||
) -> DTWResult:
|
||||
"""DTW on the displacement vector between two specific joints.
|
||||
|
||||
|
|
@ -318,14 +170,6 @@ def dtw_relation(
|
|||
Indices of the two joints whose relative position should be
|
||||
compared. Must be valid indices into ``a`` and ``b``'s joint
|
||||
axis.
|
||||
align
|
||||
Procrustes alignment mode applied to the full sequences
|
||||
before the displacement vectors are extracted. See
|
||||
:data:`AlignMode`. Note that displacement vectors are already
|
||||
translation-invariant; alignment is still useful for cancelling
|
||||
camera rotation between trials.
|
||||
nan_policy
|
||||
How to handle NaN values in the DTW input. See :data:`NanPolicy`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
|
@ -335,9 +179,8 @@ def dtw_relation(
|
|||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the sequences have different joint counts, either joint
|
||||
index is out of range, or ``align`` requires a matching frame
|
||||
count that is not present.
|
||||
If the sequences have different joint counts or if either joint
|
||||
index is out of range.
|
||||
"""
|
||||
_validate_same_joint_count(a, b)
|
||||
num_joints = a.shape[1]
|
||||
|
|
@ -345,10 +188,9 @@ def dtw_relation(
|
|||
raise ValueError(
|
||||
f"joint indices must be in [0, {num_joints}); got joint_i={joint_i}, joint_j={joint_j}"
|
||||
)
|
||||
a, b = _maybe_align(a, b, align=align)
|
||||
disp_a = _apply_nan_policy(a[:, joint_j, :] - a[:, joint_i, :], nan_policy)
|
||||
disp_b = _apply_nan_policy(b[:, joint_j, :] - b[:, joint_i, :], nan_policy)
|
||||
fastdtw, euclidean = _require_fastdtw()
|
||||
disp_a = a[:, joint_j, :] - a[:, joint_i, :]
|
||||
disp_b = b[:, joint_j, :] - b[:, joint_i, :]
|
||||
distance, path = fastdtw(disp_a, disp_b, dist=euclidean)
|
||||
return DTWResult(distance=float(distance), path=[tuple(p) for p in path])
|
||||
|
||||
|
|
@ -364,95 +206,3 @@ def _validate_same_joint_count(a: np.ndarray, b: np.ndarray) -> None:
|
|||
f"input arrays disagree on joint count: "
|
||||
f"a has {a.shape[1]} joints, b has {b.shape[1]} joints"
|
||||
)
|
||||
|
||||
|
||||
def _maybe_align(
|
||||
a: np.ndarray,
|
||||
b: np.ndarray,
|
||||
*,
|
||||
align: AlignMode,
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""Apply Procrustes alignment if ``align`` requests it.
|
||||
|
||||
Procrustes requires a frame-by-frame correspondence, so this
|
||||
helper rejects calls where the two sequences disagree on frame
|
||||
count and ``align`` is not ``"none"``. Pad upstream with
|
||||
:func:`~neuropose.analyzer.features.pad_sequences` if the lengths
|
||||
differ.
|
||||
"""
|
||||
if align == "none":
|
||||
return a, b
|
||||
if a.shape[0] != b.shape[0]:
|
||||
raise ValueError(
|
||||
f"align={align!r} requires matching frame counts; "
|
||||
f"got a with {a.shape[0]} frames and b with {b.shape[0]} frames"
|
||||
)
|
||||
mode = "per_frame" if align == "procrustes_per_frame" else "per_sequence"
|
||||
aligned_a, _target, _diag = procrustes_align(a, b, mode=mode)
|
||||
return aligned_a, b
|
||||
|
||||
|
||||
def _apply_representation(
|
||||
sequence: np.ndarray,
|
||||
representation: Representation,
|
||||
*,
|
||||
angle_triplets: Sequence[tuple[int, int, int]] | None,
|
||||
) -> np.ndarray:
|
||||
"""Reduce a ``(frames, joints, 3)`` sequence to DTW-ready 2D features.
|
||||
|
||||
``"coords"`` reshapes to ``(frames, joints * 3)``; ``"angles"``
|
||||
runs :func:`extract_joint_angles` to produce
|
||||
``(frames, len(angle_triplets))``.
|
||||
"""
|
||||
if representation == "coords":
|
||||
return sequence.reshape(sequence.shape[0], -1)
|
||||
if representation == "angles":
|
||||
if angle_triplets is None:
|
||||
raise ValueError("representation='angles' requires angle_triplets")
|
||||
return extract_joint_angles(sequence, angle_triplets)
|
||||
raise ValueError(f"unknown representation {representation!r}")
|
||||
|
||||
|
||||
def _apply_nan_policy(features: np.ndarray, policy: NanPolicy) -> np.ndarray:
|
||||
"""Handle NaN values in a ``(frames, features)`` array per ``policy``.
|
||||
|
||||
``"propagate"`` is a no-op. ``"interpolate"`` runs 1D linear
|
||||
interpolation along the frame axis within each feature column,
|
||||
leaving finite data untouched. ``"drop"`` removes any frame where
|
||||
*any* feature is NaN.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If ``"interpolate"`` encounters a column that is entirely NaN
|
||||
(no finite anchors to interpolate between), or if ``"drop"``
|
||||
leaves an empty sequence.
|
||||
"""
|
||||
if policy == "propagate":
|
||||
return features
|
||||
if features.ndim == 1:
|
||||
features = features.reshape(-1, 1)
|
||||
if policy == "drop":
|
||||
keep = np.isfinite(features).all(axis=1)
|
||||
dropped = features[keep]
|
||||
if dropped.shape[0] == 0:
|
||||
raise ValueError(
|
||||
"nan_policy='drop' removed every frame; DTW needs a non-empty sequence"
|
||||
)
|
||||
return dropped
|
||||
if policy == "interpolate":
|
||||
out = features.astype(float, copy=True)
|
||||
num_frames = out.shape[0]
|
||||
indices = np.arange(num_frames, dtype=float)
|
||||
for col in range(out.shape[1]):
|
||||
column = out[:, col]
|
||||
finite = np.isfinite(column)
|
||||
if finite.all():
|
||||
continue
|
||||
if not finite.any():
|
||||
raise ValueError(
|
||||
f"nan_policy='interpolate' cannot fill column {col}: all values are NaN"
|
||||
)
|
||||
out[:, col] = np.interp(indices, indices[finite], column[finite])
|
||||
return out
|
||||
raise ValueError(f"unknown nan_policy {policy!r}")
|
||||
|
|
|
|||
|
|
@ -14,8 +14,6 @@ The following helpers are provided:
|
|||
fit in the unit cube (either per-axis or uniform).
|
||||
- :func:`pad_sequences` — edge-pad a batch of sequences to a common
|
||||
length, suitable for downstream tensor-based analysis.
|
||||
- :func:`procrustes_align` — rigid-align one pose sequence to another
|
||||
via the Kabsch algorithm, with optional uniform scaling.
|
||||
- :func:`extract_joint_angles` — compute joint angles at specified
|
||||
triplet positions across a pose sequence.
|
||||
- :func:`extract_feature_statistics` — summary statistics
|
||||
|
|
@ -28,19 +26,12 @@ from __future__ import annotations
|
|||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from neuropose.io import VideoPredictions
|
||||
|
||||
ProcrustesMode = Literal["per_frame", "per_sequence"]
|
||||
"""Mode selector for :func:`procrustes_align`.
|
||||
|
||||
``per_sequence`` computes a single rigid transform over the whole
|
||||
sequence; ``per_frame`` aligns every frame independently.
|
||||
"""
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# VideoPredictions → numpy
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -220,247 +211,6 @@ def pad_sequences(
|
|||
return padded
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Procrustes alignment (Kabsch)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AlignmentDiagnostics:
|
||||
"""Summary of the rigid transform fitted by :func:`procrustes_align`.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
mode
|
||||
Which alignment mode produced this result; mirrors the ``mode``
|
||||
argument passed to :func:`procrustes_align`.
|
||||
rotation_deg
|
||||
Magnitude of the fitted rotation, in degrees, computed as
|
||||
``arccos((trace(R) - 1) / 2)``. For ``per_frame`` mode this is
|
||||
the mean magnitude across frames.
|
||||
rotation_deg_max
|
||||
Worst-case (maximum) rotation magnitude across frames.
|
||||
Equal to :attr:`rotation_deg` in ``per_sequence`` mode.
|
||||
translation
|
||||
Magnitude of the fitted translation vector, in the same units
|
||||
as the input (millimetres for MeTRAbs output). For ``per_frame``
|
||||
mode this is the mean magnitude across frames.
|
||||
translation_max
|
||||
Worst-case (maximum) translation magnitude across frames.
|
||||
Equal to :attr:`translation` in ``per_sequence`` mode.
|
||||
scale
|
||||
Applied uniform scale factor. Always ``1.0`` when
|
||||
``procrustes_align`` was called with ``scale=False``. In
|
||||
``per_frame`` mode this is the mean scale across frames.
|
||||
"""
|
||||
|
||||
mode: ProcrustesMode
|
||||
rotation_deg: float
|
||||
rotation_deg_max: float
|
||||
translation: float
|
||||
translation_max: float
|
||||
scale: float
|
||||
|
||||
|
||||
def _kabsch_single(
|
||||
source: np.ndarray,
|
||||
target: np.ndarray,
|
||||
*,
|
||||
scale: bool,
|
||||
) -> tuple[np.ndarray, np.ndarray, float, np.ndarray]:
|
||||
"""Fit the optimal rigid (+ optional uniform scale) transform.
|
||||
|
||||
Aligns ``source`` to ``target`` via the closed-form Kabsch
|
||||
algorithm and returns ``(aligned_source, R, s, t)`` where
|
||||
``aligned_source = s * (source - centroid_source) @ R.T + centroid_target + t_fine``
|
||||
(with ``t_fine`` absorbed for convenience — aligned points match
|
||||
the target's centroid to within floating-point error).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
source
|
||||
``(N, 3)`` point set to align.
|
||||
target
|
||||
``(N, 3)`` reference point set. Must have the same shape as
|
||||
``source``.
|
||||
scale
|
||||
If ``True``, fit a uniform scale factor; otherwise lock to
|
||||
``1.0``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
aligned_source
|
||||
``(N, 3)`` aligned copy of ``source``.
|
||||
R
|
||||
``(3, 3)`` rotation matrix.
|
||||
s
|
||||
Scalar scale factor (``1.0`` when ``scale=False``).
|
||||
t
|
||||
``(3,)`` translation vector in world coordinates such that
|
||||
``aligned_source[i] = s * R @ source[i] + t``.
|
||||
"""
|
||||
centroid_source = source.mean(axis=0)
|
||||
centroid_target = target.mean(axis=0)
|
||||
source_centered = source - centroid_source
|
||||
target_centered = target - centroid_target
|
||||
|
||||
covariance = source_centered.T @ target_centered
|
||||
u_mat, sigma, vt_mat = np.linalg.svd(covariance)
|
||||
reflection_sign = float(np.sign(np.linalg.det(vt_mat.T @ u_mat.T)))
|
||||
# Guard against the degenerate det == 0 case (coplanar points).
|
||||
if reflection_sign == 0.0:
|
||||
reflection_sign = 1.0
|
||||
diag = np.diag([1.0, 1.0, reflection_sign])
|
||||
rotation = vt_mat.T @ diag @ u_mat.T
|
||||
|
||||
if scale:
|
||||
source_var = float((source_centered**2).sum())
|
||||
if source_var <= 0.0:
|
||||
scale_factor = 1.0
|
||||
else:
|
||||
scale_factor = float((sigma * np.array([1.0, 1.0, reflection_sign])).sum() / source_var)
|
||||
else:
|
||||
scale_factor = 1.0
|
||||
|
||||
translation = centroid_target - scale_factor * rotation @ centroid_source
|
||||
aligned = scale_factor * source @ rotation.T + translation
|
||||
return aligned, rotation, scale_factor, translation
|
||||
|
||||
|
||||
def _rotation_magnitude_deg(rotation: np.ndarray) -> float:
|
||||
"""Return the rotation angle (degrees) represented by ``rotation``.
|
||||
|
||||
Uses the axis-angle relation ``cos(theta) = (trace(R) - 1) / 2``.
|
||||
"""
|
||||
cos_theta = (float(np.trace(rotation)) - 1.0) / 2.0
|
||||
cos_theta = max(-1.0, min(1.0, cos_theta))
|
||||
return float(np.degrees(np.arccos(cos_theta)))
|
||||
|
||||
|
||||
def procrustes_align(
|
||||
source: np.ndarray,
|
||||
target: np.ndarray,
|
||||
*,
|
||||
mode: ProcrustesMode = "per_sequence",
|
||||
scale: bool = False,
|
||||
) -> tuple[np.ndarray, np.ndarray, AlignmentDiagnostics]:
|
||||
"""Rigid-align ``source`` to ``target`` via the Kabsch algorithm.
|
||||
|
||||
Fits the optimal rigid transform (optionally including uniform
|
||||
scaling) that minimizes the sum of squared distances between
|
||||
corresponding joints. The transform is always applied to
|
||||
``source``; ``target`` is returned unchanged alongside it for
|
||||
symmetry with downstream DTW callers, which typically consume both
|
||||
aligned arrays as a pair.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
source
|
||||
Pose sequence to align, shape ``(frames, joints, 3)``.
|
||||
target
|
||||
Reference pose sequence, shape ``(frames, joints, 3)``. For
|
||||
``per_frame`` mode the frame counts must match; for
|
||||
``per_sequence`` mode they must also match (the correspondence
|
||||
runs frame-by-frame and joint-by-joint). Use
|
||||
:func:`pad_sequences` first if your sequences have different
|
||||
lengths.
|
||||
mode
|
||||
``"per_sequence"`` (default) fits a single rigid transform over
|
||||
the whole sequence — good when the recording geometry is
|
||||
stable across frames. ``"per_frame"`` fits an independent
|
||||
transform per frame — good for matching pose shape while
|
||||
discarding global trajectory.
|
||||
scale
|
||||
If ``True``, also fit a uniform scale factor. Useful for
|
||||
cross-subject comparisons where the reference skeleton has a
|
||||
different overall size.
|
||||
|
||||
Returns
|
||||
-------
|
||||
aligned_source
|
||||
``source`` transformed to align with ``target``, same shape as
|
||||
the input.
|
||||
target
|
||||
The ``target`` array, unchanged.
|
||||
diagnostics
|
||||
:class:`AlignmentDiagnostics` summarising the fitted transform.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If ``source`` and ``target`` have different shapes or the
|
||||
trailing axis is not of size 3.
|
||||
|
||||
Notes
|
||||
-----
|
||||
The Kabsch algorithm (Kabsch 1976, "A solution for the best
|
||||
rotation to relate two sets of vectors") is a closed-form SVD
|
||||
solution and does not iterate. Reflection is explicitly prevented
|
||||
via a sign correction on the smallest singular value; the fitted
|
||||
matrix is always a proper rotation (det = +1).
|
||||
|
||||
In ``per_frame`` mode, rotation, translation, and scale
|
||||
diagnostics are reported as means across frames, with
|
||||
:attr:`AlignmentDiagnostics.rotation_deg_max` and
|
||||
:attr:`AlignmentDiagnostics.translation_max` exposing the worst
|
||||
frame for anomaly detection.
|
||||
"""
|
||||
if source.ndim != 3 or source.shape[-1] != 3:
|
||||
raise ValueError(f"expected (frames, joints, 3); got source shape {source.shape}")
|
||||
if source.shape != target.shape:
|
||||
raise ValueError(
|
||||
f"source and target must have the same shape; got {source.shape} and {target.shape}"
|
||||
)
|
||||
|
||||
source = source.astype(float, copy=False)
|
||||
target = target.astype(float, copy=False)
|
||||
num_frames = source.shape[0]
|
||||
|
||||
if mode == "per_sequence":
|
||||
flat_source = source.reshape(-1, 3)
|
||||
flat_target = target.reshape(-1, 3)
|
||||
aligned_flat, rotation, scale_factor, translation = _kabsch_single(
|
||||
flat_source, flat_target, scale=scale
|
||||
)
|
||||
aligned = aligned_flat.reshape(source.shape)
|
||||
rotation_deg = _rotation_magnitude_deg(rotation)
|
||||
translation_mag = float(np.linalg.norm(translation))
|
||||
diagnostics = AlignmentDiagnostics(
|
||||
mode="per_sequence",
|
||||
rotation_deg=rotation_deg,
|
||||
rotation_deg_max=rotation_deg,
|
||||
translation=translation_mag,
|
||||
translation_max=translation_mag,
|
||||
scale=scale_factor,
|
||||
)
|
||||
return aligned, target, diagnostics
|
||||
|
||||
if mode == "per_frame":
|
||||
aligned = np.empty_like(source)
|
||||
rotation_degs = np.empty(num_frames, dtype=float)
|
||||
translations = np.empty(num_frames, dtype=float)
|
||||
scales = np.empty(num_frames, dtype=float)
|
||||
for frame_idx in range(num_frames):
|
||||
aligned_frame, rotation, scale_factor, translation = _kabsch_single(
|
||||
source[frame_idx], target[frame_idx], scale=scale
|
||||
)
|
||||
aligned[frame_idx] = aligned_frame
|
||||
rotation_degs[frame_idx] = _rotation_magnitude_deg(rotation)
|
||||
translations[frame_idx] = float(np.linalg.norm(translation))
|
||||
scales[frame_idx] = scale_factor
|
||||
diagnostics = AlignmentDiagnostics(
|
||||
mode="per_frame",
|
||||
rotation_deg=float(rotation_degs.mean()) if num_frames else 0.0,
|
||||
rotation_deg_max=float(rotation_degs.max()) if num_frames else 0.0,
|
||||
translation=float(translations.mean()) if num_frames else 0.0,
|
||||
translation_max=float(translations.max()) if num_frames else 0.0,
|
||||
scale=float(scales.mean()) if num_frames else 1.0,
|
||||
)
|
||||
return aligned, target, diagnostics
|
||||
|
||||
raise ValueError(f"unknown mode {mode!r}; expected 'per_frame' or 'per_sequence'")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Joint angles
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -1,932 +0,0 @@
|
|||
"""YAML-configurable analysis pipeline.
|
||||
|
||||
This module unifies the analyzer's individual primitives (Procrustes
|
||||
alignment, gait-cycle segmentation, DTW on coords or angles, feature
|
||||
statistics) behind a single declarative configuration object so an
|
||||
experiment can be reproduced from one file that lives in a git
|
||||
repository, carries through to the :class:`~neuropose.io.Provenance`
|
||||
envelope on the output artifact, and can be cited unambiguously in
|
||||
accompanying papers.
|
||||
|
||||
Two top-level schemas live here:
|
||||
|
||||
- :class:`AnalysisConfig` — what the user writes in YAML. Describes
|
||||
the full pipeline: inputs, preprocessing, optional segmentation,
|
||||
required analysis stage, and output path.
|
||||
- :class:`AnalysisReport` — what :func:`run_analysis` emits. Carries
|
||||
the config, a :class:`~neuropose.io.Provenance` envelope (with the
|
||||
config serialised into :attr:`~neuropose.io.Provenance.analysis_config`),
|
||||
per-input summaries, segmentation results, and the analysis results
|
||||
themselves.
|
||||
|
||||
Both schemas parse from (and serialise to) JSON via pydantic; the
|
||||
config additionally parses from YAML via :func:`load_config`. Cross-field
|
||||
invariants (for example, ``method="dtw_relation"`` requires ``joint_i``
|
||||
and ``joint_j``) are enforced at parse time so typo-laden configs fail
|
||||
fast rather than after an expensive multi-minute load.
|
||||
|
||||
Execution
|
||||
---------
|
||||
:func:`run_analysis` is the top-level executor: it loads the
|
||||
predictions files named in the config, applies any configured
|
||||
segmentation stage, dispatches to the configured analysis stage, and
|
||||
returns a fully populated :class:`AnalysisReport`. The executor is
|
||||
intended to be called from the ``neuropose analyze`` CLI but is
|
||||
equally valid as a Python-level entry point for notebook-driven
|
||||
exploration.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
import numpy as np
|
||||
import yaml
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
from neuropose.analyzer.dtw import (
|
||||
AlignMode,
|
||||
DTWResult,
|
||||
NanPolicy,
|
||||
Representation,
|
||||
dtw_all,
|
||||
dtw_per_joint,
|
||||
dtw_relation,
|
||||
)
|
||||
from neuropose.analyzer.features import (
|
||||
extract_feature_statistics,
|
||||
predictions_to_numpy,
|
||||
)
|
||||
from neuropose.analyzer.segment import (
|
||||
AxisLetter,
|
||||
extract_signal,
|
||||
segment_gait_cycles,
|
||||
segment_gait_cycles_bilateral,
|
||||
segment_predictions,
|
||||
)
|
||||
from neuropose.io import (
|
||||
ExtractorSpec,
|
||||
Provenance,
|
||||
Segmentation,
|
||||
VideoPredictions,
|
||||
load_video_predictions,
|
||||
)
|
||||
from neuropose.migrations import CURRENT_VERSION, migrate_analysis_report
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Inputs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class InputsConfig(BaseModel):
|
||||
"""Predictions files consumed by the pipeline.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
primary
|
||||
Path to a :class:`~neuropose.io.VideoPredictions` JSON file.
|
||||
Always required.
|
||||
reference
|
||||
Optional second predictions file. When provided,
|
||||
:class:`DtwAnalysis` runs comparative DTW between primary and
|
||||
reference; when absent, analysis stages that require a
|
||||
reference (i.e. DTW) raise a validation error at parse time.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid", frozen=True)
|
||||
|
||||
primary: Path
|
||||
reference: Path | None = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Preprocessing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class PreprocessingConfig(BaseModel):
|
||||
"""Per-input preprocessing applied before segmentation and analysis.
|
||||
|
||||
Minimal today — just picks which detected person to extract from
|
||||
each frame. Left as a named stage so future extensions (coordinate
|
||||
normalisation, smoothing) can land here without reshuffling the
|
||||
config shape.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
person_index
|
||||
Which detected person to extract per frame. Defaults to ``0``
|
||||
(the first detected person), matching the single-subject
|
||||
clinical case.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid", frozen=True)
|
||||
|
||||
person_index: int = Field(default=0, ge=0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Segmentation stage
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class GaitCyclesSegmentation(BaseModel):
|
||||
"""Single-heel gait-cycle segmentation via peak detection.
|
||||
|
||||
Produces one :class:`~neuropose.io.Segmentation` keyed under the
|
||||
joint name (e.g. ``"rhee_cycles"``). See
|
||||
:func:`~neuropose.analyzer.segment.segment_gait_cycles` for the
|
||||
underlying implementation.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid", frozen=True)
|
||||
|
||||
kind: Literal["gait_cycles"]
|
||||
joint: str = "rhee"
|
||||
axis: AxisLetter = "y"
|
||||
invert: bool = False
|
||||
min_cycle_seconds: float = Field(default=0.4, gt=0.0)
|
||||
min_prominence: float | None = None
|
||||
|
||||
|
||||
class GaitCyclesBilateralSegmentation(BaseModel):
|
||||
"""Bilateral (both heels) gait-cycle segmentation.
|
||||
|
||||
Produces two :class:`~neuropose.io.Segmentation` objects keyed as
|
||||
``"left_heel_strikes"`` and ``"right_heel_strikes"``.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid", frozen=True)
|
||||
|
||||
kind: Literal["gait_cycles_bilateral"]
|
||||
axis: AxisLetter = "y"
|
||||
invert: bool = False
|
||||
min_cycle_seconds: float = Field(default=0.4, gt=0.0)
|
||||
min_prominence: float | None = None
|
||||
|
||||
|
||||
class ExtractorSegmentation(BaseModel):
|
||||
"""Generic extractor-driven segmentation.
|
||||
|
||||
Wraps :func:`~neuropose.analyzer.segment.segment_predictions` with
|
||||
a caller-supplied :class:`~neuropose.io.ExtractorSpec`. Use this
|
||||
when the signal of interest is not the vertical heel trace — e.g.
|
||||
wrist-hip distance for a reach-and-grasp task, or elbow flexion
|
||||
angle for a range-of-motion trial.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid", frozen=True)
|
||||
|
||||
kind: Literal["extractor"]
|
||||
extractor: ExtractorSpec
|
||||
label: str = Field(
|
||||
default="segmentation",
|
||||
description="Key under which the resulting Segmentation is stored.",
|
||||
)
|
||||
person_index: int | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Overrides preprocessing.person_index for this stage. "
|
||||
"None defers to the global preprocessing setting."
|
||||
),
|
||||
)
|
||||
min_distance_seconds: float | None = Field(default=None, ge=0.0)
|
||||
min_prominence: float | None = None
|
||||
min_height: float | None = None
|
||||
pad_seconds: float = Field(default=0.0, ge=0.0)
|
||||
|
||||
|
||||
SegmentationStage = Annotated[
|
||||
GaitCyclesSegmentation | GaitCyclesBilateralSegmentation | ExtractorSegmentation,
|
||||
Field(discriminator="kind"),
|
||||
]
|
||||
"""Discriminated-union alias for the three segmentation variants.
|
||||
|
||||
Pydantic dispatches on the ``kind`` field. A config without a
|
||||
``segmentation`` key at all skips this stage entirely
|
||||
(see :class:`AnalysisConfig.segmentation`).
|
||||
"""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Analysis stage
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class DtwAnalysis(BaseModel):
|
||||
"""Dynamic Time Warping between the primary and reference inputs.
|
||||
|
||||
Dispatches to one of :func:`~neuropose.analyzer.dtw.dtw_all`,
|
||||
:func:`~neuropose.analyzer.dtw.dtw_per_joint`, or
|
||||
:func:`~neuropose.analyzer.dtw.dtw_relation` per the ``method``
|
||||
field. Cross-field invariants — ``method="dtw_relation"`` requires
|
||||
``joint_i`` and ``joint_j``, ``representation="angles"`` requires
|
||||
a non-empty ``angle_triplets`` — are enforced at parse time.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid", frozen=True)
|
||||
|
||||
kind: Literal["dtw"]
|
||||
method: Literal["dtw_all", "dtw_per_joint", "dtw_relation"] = "dtw_all"
|
||||
align: AlignMode = "none"
|
||||
representation: Representation = "coords"
|
||||
angle_triplets: list[tuple[int, int, int]] | None = None
|
||||
joint_i: int | None = None
|
||||
joint_j: int | None = None
|
||||
nan_policy: NanPolicy = "propagate"
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _check_method_fields(self) -> DtwAnalysis:
|
||||
if self.method == "dtw_relation":
|
||||
if self.joint_i is None or self.joint_j is None:
|
||||
raise ValueError("method='dtw_relation' requires joint_i and joint_j")
|
||||
if self.representation != "coords":
|
||||
raise ValueError(
|
||||
"method='dtw_relation' only supports representation='coords' "
|
||||
"(a two-joint displacement is not a joint-angle signal)"
|
||||
)
|
||||
if self.representation == "angles":
|
||||
if not self.angle_triplets:
|
||||
raise ValueError("representation='angles' requires a non-empty angle_triplets list")
|
||||
if self.method == "dtw_relation":
|
||||
# Guarded by the earlier branch, but make the invariant explicit.
|
||||
raise ValueError(
|
||||
"representation='angles' is incompatible with method='dtw_relation'"
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class StatsAnalysis(BaseModel):
|
||||
"""Summary statistics over a scalar signal extracted from the primary input.
|
||||
|
||||
Runs :func:`~neuropose.analyzer.segment.extract_signal` with the
|
||||
caller-supplied :class:`~neuropose.io.ExtractorSpec`, then
|
||||
computes :func:`~neuropose.analyzer.features.extract_feature_statistics`
|
||||
on each segment (or on the full trial if no segmentation stage
|
||||
runs).
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid", frozen=True)
|
||||
|
||||
kind: Literal["stats"]
|
||||
extractor: ExtractorSpec
|
||||
|
||||
|
||||
class NoAnalysis(BaseModel):
|
||||
"""Terminal stage placeholder; produces no per-segment results.
|
||||
|
||||
Useful when the pipeline's goal is just to segment the input and
|
||||
persist the :class:`~neuropose.io.Segmentation` plus an
|
||||
:class:`AnalysisReport` with provenance — the ``none`` analysis
|
||||
kind makes that explicit rather than requiring the absence of the
|
||||
stage.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid", frozen=True)
|
||||
|
||||
kind: Literal["none"]
|
||||
|
||||
|
||||
AnalysisStage = Annotated[
|
||||
DtwAnalysis | StatsAnalysis | NoAnalysis,
|
||||
Field(discriminator="kind"),
|
||||
]
|
||||
"""Discriminated-union alias for the three analysis variants.
|
||||
|
||||
Pydantic dispatches on ``kind``. One of the three must always be
|
||||
present in a valid config.
|
||||
"""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Output
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class OutputConfig(BaseModel):
|
||||
"""Where :func:`run_analysis` should write its :class:`AnalysisReport`.
|
||||
|
||||
Kept as a sub-object rather than a bare path so downstream
|
||||
extensions (figure paths, supplementary distance-matrix files)
|
||||
can land here without changing the config's top-level shape.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid", frozen=True)
|
||||
|
||||
report: Path
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Top-level config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class AnalysisConfig(BaseModel):
|
||||
"""Declarative description of a full analysis run.
|
||||
|
||||
Parsed from YAML (via :func:`load_config`) or JSON (via
|
||||
:meth:`pydantic.BaseModel.model_validate`). Every field is
|
||||
cross-validated at parse time so a typo in a nested sub-field
|
||||
fails in milliseconds rather than after a multi-minute
|
||||
predictions load.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
config_version
|
||||
Schema version for the config itself. Only ``1`` is valid at
|
||||
this release. Future config-format breaks bump this and a
|
||||
sibling migration registry handles legacy YAML in place.
|
||||
inputs
|
||||
Predictions-file paths.
|
||||
preprocessing
|
||||
Per-input preprocessing (person-index selection today).
|
||||
segmentation
|
||||
Optional segmentation stage. ``None`` skips segmentation
|
||||
entirely and analysis runs over each full trial as a single
|
||||
"segment".
|
||||
analysis
|
||||
Required analysis stage. Exactly one of
|
||||
:class:`DtwAnalysis` / :class:`StatsAnalysis` / :class:`NoAnalysis`.
|
||||
output
|
||||
Output paths.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid", frozen=True)
|
||||
|
||||
config_version: Literal[1] = 1
|
||||
inputs: InputsConfig
|
||||
preprocessing: PreprocessingConfig = Field(default_factory=PreprocessingConfig)
|
||||
segmentation: SegmentationStage | None = None
|
||||
analysis: AnalysisStage
|
||||
output: OutputConfig
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _check_cross_stage_invariants(self) -> AnalysisConfig:
|
||||
# DTW is comparative — it needs a reference input.
|
||||
if isinstance(self.analysis, DtwAnalysis) and self.inputs.reference is None:
|
||||
raise ValueError("analysis.kind='dtw' requires inputs.reference to be set")
|
||||
# Stats is non-comparative — a reference without a use is
|
||||
# almost certainly an operator error.
|
||||
if isinstance(self.analysis, StatsAnalysis) and self.inputs.reference is not None:
|
||||
raise ValueError(
|
||||
"analysis.kind='stats' operates on inputs.primary only; "
|
||||
"remove inputs.reference or switch analysis.kind to 'dtw'"
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Report pieces
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class InputSummary(BaseModel):
|
||||
"""Capsule of an input predictions file's headline metadata.
|
||||
|
||||
Stored in the :class:`AnalysisReport` so a reader of the report
|
||||
can tell at a glance what was analysed without having to load the
|
||||
underlying predictions JSONs.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid", frozen=True)
|
||||
|
||||
path: Path
|
||||
frame_count: int = Field(ge=0)
|
||||
fps: float = Field(ge=0.0)
|
||||
provenance: Provenance | None = None
|
||||
|
||||
|
||||
class FeatureSummary(BaseModel):
|
||||
"""Pydantic twin of :class:`~neuropose.analyzer.features.FeatureStatistics`.
|
||||
|
||||
The dataclass is used throughout the analyzer for ad-hoc Python
|
||||
consumption; the report path needs a pydantic model for
|
||||
round-tripping through JSON.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid", frozen=True)
|
||||
|
||||
mean: float
|
||||
std: float
|
||||
min: float
|
||||
max: float
|
||||
range: float
|
||||
|
||||
|
||||
class DtwResults(BaseModel):
|
||||
"""DTW results attached to an :class:`AnalysisReport`.
|
||||
|
||||
``distances`` is parallel to ``segment_labels``. For an
|
||||
unsegmented run the lists have length 1 and the label is
|
||||
``"full_trial"``. For a segmented run each label takes the form
|
||||
``"<segmentation_key>[<index>]"`` (e.g. ``"left_heel_strikes[3]"``).
|
||||
``per_joint_distances`` carries a per-unit breakdown for
|
||||
``method="dtw_per_joint"`` only; its outer length matches
|
||||
``distances``, inner length matches either ``num_joints`` (coords)
|
||||
or ``len(angle_triplets)`` (angles).
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid", frozen=True)
|
||||
|
||||
kind: Literal["dtw"]
|
||||
method: Literal["dtw_all", "dtw_per_joint", "dtw_relation"]
|
||||
distances: list[float]
|
||||
paths: list[list[tuple[int, int]]]
|
||||
per_joint_distances: list[list[float]] | None = None
|
||||
segment_labels: list[str]
|
||||
summary: dict[str, float]
|
||||
|
||||
|
||||
class StatsResults(BaseModel):
|
||||
"""Feature-statistics results attached to an :class:`AnalysisReport`.
|
||||
|
||||
``statistics`` is parallel to ``segment_labels``; see
|
||||
:class:`DtwResults` for the labelling convention.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid", frozen=True)
|
||||
|
||||
kind: Literal["stats"]
|
||||
statistics: list[FeatureSummary]
|
||||
segment_labels: list[str]
|
||||
|
||||
|
||||
class NoResults(BaseModel):
|
||||
"""Empty results payload for ``analysis.kind='none'`` runs."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid", frozen=True)
|
||||
|
||||
kind: Literal["none"]
|
||||
|
||||
|
||||
AnalysisResults = Annotated[
|
||||
DtwResults | StatsResults | NoResults,
|
||||
Field(discriminator="kind"),
|
||||
]
|
||||
"""Discriminated-union alias for the three analysis-result shapes.
|
||||
|
||||
Mirrors :data:`AnalysisStage` one-for-one: ``DtwAnalysis`` produces
|
||||
:class:`DtwResults`, etc.
|
||||
"""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Top-level report
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class AnalysisReport(BaseModel):
|
||||
"""Self-describing output artifact of :func:`run_analysis`.
|
||||
|
||||
Serialised to JSON on disk. Carries the originating config, the
|
||||
:class:`~neuropose.io.Provenance` envelope (with the config
|
||||
serialised into :attr:`~neuropose.io.Provenance.analysis_config`
|
||||
so the report is self-describing even if the YAML is lost), each
|
||||
input's headline metadata plus its own provenance if available,
|
||||
any segmentations produced, and the analysis results themselves.
|
||||
|
||||
Lives in the schema-migration registry under ``"AnalysisReport"``
|
||||
at ``CURRENT_VERSION``; see :mod:`neuropose.migrations`.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid", frozen=True)
|
||||
|
||||
schema_version: int = Field(default=CURRENT_VERSION, ge=1)
|
||||
config: AnalysisConfig
|
||||
provenance: Provenance | None = None
|
||||
primary: InputSummary
|
||||
reference: InputSummary | None = None
|
||||
segmentations: dict[str, Segmentation] = Field(default_factory=dict)
|
||||
results: AnalysisResults
|
||||
|
||||
|
||||
def analysis_config_to_dict(config: AnalysisConfig) -> dict[str, Any]:
|
||||
"""Serialise an :class:`AnalysisConfig` to a JSON-safe dict.
|
||||
|
||||
Returned shape is identical to what pydantic would produce via
|
||||
:meth:`~pydantic.BaseModel.model_dump` in ``mode="json"`` — paths
|
||||
become strings, tuples become lists, enums become their values.
|
||||
Useful for stamping
|
||||
:attr:`~neuropose.io.Provenance.analysis_config` on the
|
||||
:class:`AnalysisReport`'s provenance envelope.
|
||||
"""
|
||||
return config.model_dump(mode="json")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Load / save
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def load_config(path: Path) -> AnalysisConfig:
|
||||
"""Load and validate an :class:`AnalysisConfig` from a YAML file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path
|
||||
Filesystem path to a YAML file conforming to the
|
||||
:class:`AnalysisConfig` schema.
|
||||
|
||||
Returns
|
||||
-------
|
||||
AnalysisConfig
|
||||
The fully validated config. Cross-field invariants have
|
||||
already been checked.
|
||||
|
||||
Raises
|
||||
------
|
||||
pydantic.ValidationError
|
||||
On any schema violation — unknown keys, wrong types, or
|
||||
failed cross-field invariants.
|
||||
yaml.YAMLError
|
||||
On malformed YAML.
|
||||
"""
|
||||
with path.open("r", encoding="utf-8") as f:
|
||||
raw = yaml.safe_load(f)
|
||||
if raw is None:
|
||||
raw = {}
|
||||
return AnalysisConfig.model_validate(raw)
|
||||
|
||||
|
||||
def save_report(path: Path, report: AnalysisReport) -> None:
|
||||
"""Serialise an :class:`AnalysisReport` to ``path`` atomically.
|
||||
|
||||
Writes to a sibling ``<path>.tmp`` first, then renames over
|
||||
``path`` so a crash mid-write cannot leave behind a truncated
|
||||
file. The parent directory is created if it does not exist.
|
||||
"""
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp = path.with_suffix(path.suffix + ".tmp")
|
||||
payload = report.model_dump(mode="json")
|
||||
with tmp.open("w", encoding="utf-8") as f:
|
||||
json.dump(payload, f, indent=2)
|
||||
tmp.replace(path)
|
||||
|
||||
|
||||
def load_report(path: Path) -> AnalysisReport:
|
||||
"""Load and validate an :class:`AnalysisReport` JSON file.
|
||||
|
||||
Runs the payload through :func:`~neuropose.migrations.migrate_analysis_report`
|
||||
before pydantic validation so future schema bumps can upgrade
|
||||
legacy reports transparently.
|
||||
"""
|
||||
with path.open("r", encoding="utf-8") as f:
|
||||
data: Any = json.load(f)
|
||||
if isinstance(data, dict):
|
||||
data = migrate_analysis_report(data)
|
||||
return AnalysisReport.model_validate(data)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Executor
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def run_analysis(config: AnalysisConfig) -> AnalysisReport:
|
||||
"""Execute the pipeline described by ``config`` end-to-end.
|
||||
|
||||
Loads the predictions files named in ``config.inputs``, applies
|
||||
the configured preprocessing + segmentation + analysis stages,
|
||||
and returns an :class:`AnalysisReport` whose
|
||||
:attr:`~AnalysisReport.provenance` inherits the inference-time
|
||||
provenance of the primary input with
|
||||
:attr:`~neuropose.io.Provenance.analysis_config` populated so the
|
||||
report is self-describing even if the YAML config is later lost.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config
|
||||
The pre-validated pipeline configuration.
|
||||
|
||||
Returns
|
||||
-------
|
||||
AnalysisReport
|
||||
Fully populated report. Not yet written to disk — the caller
|
||||
passes it to :func:`save_report` (or inspects it directly).
|
||||
|
||||
Notes
|
||||
-----
|
||||
For DTW runs with a segmentation stage, segments are paired
|
||||
one-to-one by index across primary and reference, truncating to
|
||||
``min(len_primary, len_reference)``. Bilateral segmentations
|
||||
produce distances for each side independently, labelled under
|
||||
their segmentation key (e.g. ``"left_heel_strikes[3]"``).
|
||||
"""
|
||||
primary_preds = load_video_predictions(config.inputs.primary)
|
||||
reference_preds: VideoPredictions | None = None
|
||||
if config.inputs.reference is not None:
|
||||
reference_preds = load_video_predictions(config.inputs.reference)
|
||||
|
||||
person_index = config.preprocessing.person_index
|
||||
|
||||
primary_seq = predictions_to_numpy(primary_preds, person_index=person_index)
|
||||
reference_seq: np.ndarray | None = None
|
||||
if reference_preds is not None:
|
||||
reference_seq = predictions_to_numpy(reference_preds, person_index=person_index)
|
||||
|
||||
primary_segmentations: dict[str, Segmentation] = {}
|
||||
reference_segmentations: dict[str, Segmentation] = {}
|
||||
if config.segmentation is not None:
|
||||
primary_segmentations = _run_segmentation(primary_preds, config.segmentation, person_index)
|
||||
if reference_preds is not None:
|
||||
reference_segmentations = _run_segmentation(
|
||||
reference_preds, config.segmentation, person_index
|
||||
)
|
||||
|
||||
results = _run_analysis_stage(
|
||||
config.analysis,
|
||||
primary_seq=primary_seq,
|
||||
reference_seq=reference_seq,
|
||||
primary_segmentations=primary_segmentations,
|
||||
reference_segmentations=reference_segmentations,
|
||||
)
|
||||
|
||||
analysis_config_dump = analysis_config_to_dict(config)
|
||||
report_provenance: Provenance | None = None
|
||||
if primary_preds.provenance is not None:
|
||||
report_provenance = primary_preds.provenance.model_copy(
|
||||
update={"analysis_config": analysis_config_dump}
|
||||
)
|
||||
|
||||
primary_summary = InputSummary(
|
||||
path=config.inputs.primary,
|
||||
frame_count=primary_preds.metadata.frame_count,
|
||||
fps=primary_preds.metadata.fps,
|
||||
provenance=primary_preds.provenance,
|
||||
)
|
||||
reference_summary: InputSummary | None = None
|
||||
if reference_preds is not None and config.inputs.reference is not None:
|
||||
reference_summary = InputSummary(
|
||||
path=config.inputs.reference,
|
||||
frame_count=reference_preds.metadata.frame_count,
|
||||
fps=reference_preds.metadata.fps,
|
||||
provenance=reference_preds.provenance,
|
||||
)
|
||||
|
||||
return AnalysisReport(
|
||||
config=config,
|
||||
provenance=report_provenance,
|
||||
primary=primary_summary,
|
||||
reference=reference_summary,
|
||||
segmentations=primary_segmentations,
|
||||
results=results,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal dispatch helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _run_segmentation(
|
||||
predictions: VideoPredictions,
|
||||
stage: SegmentationStage, # type: ignore[valid-type]
|
||||
person_index: int,
|
||||
) -> dict[str, Segmentation]:
|
||||
"""Apply a segmentation stage to a :class:`VideoPredictions`.
|
||||
|
||||
Returns a dict keyed by a stage-appropriate label: single-side
|
||||
gait cycles use ``"<joint>_cycles"``, bilateral gait cycles use
|
||||
``"left_heel_strikes"`` / ``"right_heel_strikes"``, and extractor
|
||||
segmentation uses the caller-supplied
|
||||
:attr:`~ExtractorSegmentation.label`.
|
||||
"""
|
||||
if isinstance(stage, GaitCyclesSegmentation):
|
||||
seg = segment_gait_cycles(
|
||||
predictions,
|
||||
joint=stage.joint,
|
||||
axis=stage.axis,
|
||||
invert=stage.invert,
|
||||
min_cycle_seconds=stage.min_cycle_seconds,
|
||||
min_prominence=stage.min_prominence,
|
||||
)
|
||||
return {f"{stage.joint}_cycles": seg}
|
||||
if isinstance(stage, GaitCyclesBilateralSegmentation):
|
||||
return segment_gait_cycles_bilateral(
|
||||
predictions,
|
||||
axis=stage.axis,
|
||||
invert=stage.invert,
|
||||
min_cycle_seconds=stage.min_cycle_seconds,
|
||||
min_prominence=stage.min_prominence,
|
||||
)
|
||||
if isinstance(stage, ExtractorSegmentation):
|
||||
effective_person_index = (
|
||||
stage.person_index if stage.person_index is not None else person_index
|
||||
)
|
||||
seg = segment_predictions(
|
||||
predictions,
|
||||
stage.extractor,
|
||||
person_index=effective_person_index,
|
||||
min_distance_seconds=stage.min_distance_seconds,
|
||||
min_prominence=stage.min_prominence,
|
||||
min_height=stage.min_height,
|
||||
pad_seconds=stage.pad_seconds,
|
||||
)
|
||||
return {stage.label: seg}
|
||||
raise TypeError(f"unknown segmentation stage: {type(stage).__name__}")
|
||||
|
||||
|
||||
def _run_analysis_stage(
|
||||
stage: AnalysisStage, # type: ignore[valid-type]
|
||||
*,
|
||||
primary_seq: np.ndarray,
|
||||
reference_seq: np.ndarray | None,
|
||||
primary_segmentations: dict[str, Segmentation],
|
||||
reference_segmentations: dict[str, Segmentation],
|
||||
) -> AnalysisResults: # type: ignore[valid-type]
|
||||
"""Dispatch to the appropriate analysis executor per ``stage.kind``."""
|
||||
if isinstance(stage, DtwAnalysis):
|
||||
if reference_seq is None:
|
||||
# AnalysisConfig's cross-stage validator should prevent
|
||||
# this; duplicate the check here so a direct programmatic
|
||||
# call can't slip through.
|
||||
raise ValueError("DtwAnalysis requires a reference sequence")
|
||||
return _run_dtw(
|
||||
stage,
|
||||
primary_seq=primary_seq,
|
||||
reference_seq=reference_seq,
|
||||
primary_segmentations=primary_segmentations,
|
||||
reference_segmentations=reference_segmentations,
|
||||
)
|
||||
if isinstance(stage, StatsAnalysis):
|
||||
return _run_stats(
|
||||
stage,
|
||||
primary_seq=primary_seq,
|
||||
primary_segmentations=primary_segmentations,
|
||||
)
|
||||
if isinstance(stage, NoAnalysis):
|
||||
return NoResults(kind="none")
|
||||
raise TypeError(f"unknown analysis stage: {type(stage).__name__}")
|
||||
|
||||
|
||||
def _run_dtw(
|
||||
stage: DtwAnalysis,
|
||||
*,
|
||||
primary_seq: np.ndarray,
|
||||
reference_seq: np.ndarray,
|
||||
primary_segmentations: dict[str, Segmentation],
|
||||
reference_segmentations: dict[str, Segmentation],
|
||||
) -> DtwResults:
|
||||
"""Execute a DTW analysis stage, returning :class:`DtwResults`."""
|
||||
labels: list[str] = []
|
||||
distances: list[float] = []
|
||||
paths: list[list[tuple[int, int]]] = []
|
||||
per_joint_distances: list[list[float]] | None = [] if stage.method == "dtw_per_joint" else None
|
||||
|
||||
pairs: list[tuple[str, np.ndarray, np.ndarray]] = []
|
||||
if primary_segmentations:
|
||||
for key, primary_seg in primary_segmentations.items():
|
||||
reference_seg = reference_segmentations.get(key)
|
||||
if reference_seg is None:
|
||||
# Same config was applied to both, so this should not
|
||||
# happen unless the segmentation depends on the input
|
||||
# length in some unexpected way. Skip with a warning
|
||||
# rather than crash the whole run.
|
||||
continue
|
||||
pair_count = min(len(primary_seg.segments), len(reference_seg.segments))
|
||||
for i in range(pair_count):
|
||||
p_seg = primary_seg.segments[i]
|
||||
r_seg = reference_seg.segments[i]
|
||||
pairs.append(
|
||||
(
|
||||
f"{key}[{i}]",
|
||||
primary_seq[p_seg.start : p_seg.end],
|
||||
reference_seq[r_seg.start : r_seg.end],
|
||||
)
|
||||
)
|
||||
else:
|
||||
pairs.append(("full_trial", primary_seq, reference_seq))
|
||||
|
||||
for label, primary_slice, reference_slice in pairs:
|
||||
labels.append(label)
|
||||
if stage.method == "dtw_all":
|
||||
result = dtw_all(
|
||||
primary_slice,
|
||||
reference_slice,
|
||||
align=stage.align,
|
||||
representation=stage.representation,
|
||||
angle_triplets=stage.angle_triplets,
|
||||
nan_policy=stage.nan_policy,
|
||||
)
|
||||
distances.append(result.distance)
|
||||
paths.append(result.path)
|
||||
elif stage.method == "dtw_per_joint":
|
||||
assert per_joint_distances is not None
|
||||
per_joint_results = dtw_per_joint(
|
||||
primary_slice,
|
||||
reference_slice,
|
||||
align=stage.align,
|
||||
representation=stage.representation,
|
||||
angle_triplets=stage.angle_triplets,
|
||||
nan_policy=stage.nan_policy,
|
||||
)
|
||||
# "distance" for a per-joint run is the sum across units;
|
||||
# "per_joint_distances" carries the full breakdown.
|
||||
per_unit = [r.distance for r in per_joint_results]
|
||||
distances.append(float(sum(per_unit)))
|
||||
per_joint_distances.append(per_unit)
|
||||
# Store just the first joint's path as a representative —
|
||||
# per-joint paths are a list of equal length, but
|
||||
# reporting all of them on disk is almost always overkill.
|
||||
paths.append(per_joint_results[0].path if per_joint_results else [])
|
||||
else: # "dtw_relation"
|
||||
assert stage.joint_i is not None
|
||||
assert stage.joint_j is not None
|
||||
result = _invoke_dtw_relation(
|
||||
primary_slice,
|
||||
reference_slice,
|
||||
joint_i=stage.joint_i,
|
||||
joint_j=stage.joint_j,
|
||||
align=stage.align,
|
||||
nan_policy=stage.nan_policy,
|
||||
)
|
||||
distances.append(result.distance)
|
||||
paths.append(result.path)
|
||||
|
||||
return DtwResults(
|
||||
kind="dtw",
|
||||
method=stage.method,
|
||||
distances=distances,
|
||||
paths=paths,
|
||||
per_joint_distances=per_joint_distances,
|
||||
segment_labels=labels,
|
||||
summary=_summarize_distances(distances),
|
||||
)
|
||||
|
||||
|
||||
def _invoke_dtw_relation(
|
||||
primary_slice: np.ndarray,
|
||||
reference_slice: np.ndarray,
|
||||
*,
|
||||
joint_i: int,
|
||||
joint_j: int,
|
||||
align: AlignMode,
|
||||
nan_policy: NanPolicy,
|
||||
) -> DTWResult:
|
||||
"""Isolating thin wrapper so test fakes can replace the call site cleanly."""
|
||||
return dtw_relation(
|
||||
primary_slice,
|
||||
reference_slice,
|
||||
joint_i,
|
||||
joint_j,
|
||||
align=align,
|
||||
nan_policy=nan_policy,
|
||||
)
|
||||
|
||||
|
||||
def _run_stats(
|
||||
stage: StatsAnalysis,
|
||||
*,
|
||||
primary_seq: np.ndarray,
|
||||
primary_segmentations: dict[str, Segmentation],
|
||||
) -> StatsResults:
|
||||
"""Execute a stats analysis stage, returning :class:`StatsResults`."""
|
||||
labels: list[str] = []
|
||||
stats: list[FeatureSummary] = []
|
||||
|
||||
if primary_segmentations:
|
||||
for key, seg in primary_segmentations.items():
|
||||
for i, segment in enumerate(seg.segments):
|
||||
labels.append(f"{key}[{i}]")
|
||||
signal = extract_signal(
|
||||
primary_seq[segment.start : segment.end],
|
||||
stage.extractor,
|
||||
)
|
||||
stats.append(_feature_summary(signal))
|
||||
else:
|
||||
labels.append("full_trial")
|
||||
signal = extract_signal(primary_seq, stage.extractor)
|
||||
stats.append(_feature_summary(signal))
|
||||
|
||||
return StatsResults(kind="stats", statistics=stats, segment_labels=labels)
|
||||
|
||||
|
||||
def _feature_summary(signal: np.ndarray) -> FeatureSummary:
|
||||
"""Wrap :func:`extract_feature_statistics` output in a pydantic model."""
|
||||
raw = extract_feature_statistics(signal)
|
||||
return FeatureSummary(
|
||||
mean=raw.mean,
|
||||
std=raw.std,
|
||||
min=raw.min,
|
||||
max=raw.max,
|
||||
range=raw.range,
|
||||
)
|
||||
|
||||
|
||||
def _summarize_distances(distances: list[float]) -> dict[str, float]:
|
||||
"""Compute mean / p50 / p95 / p99 of a distance list.
|
||||
|
||||
Empty inputs return an empty dict so the report's ``summary``
|
||||
field still round-trips through JSON without special cases.
|
||||
"""
|
||||
if not distances:
|
||||
return {}
|
||||
arr = np.asarray(distances, dtype=float)
|
||||
return {
|
||||
"mean": float(arr.mean()),
|
||||
"p50": float(np.percentile(arr, 50)),
|
||||
"p95": float(np.percentile(arr, 95)),
|
||||
"p99": float(np.percentile(arr, 99)),
|
||||
}
|
||||
|
|
@ -30,12 +30,6 @@ Three layers of API are provided, in increasing order of convenience:
|
|||
:class:`~neuropose.io.ExtractorSpec`, converts time-based parameters
|
||||
to frame counts using ``metadata.fps``, and returns a full
|
||||
:class:`~neuropose.io.Segmentation` ready to attach to the predictions.
|
||||
- :func:`segment_gait_cycles` and :func:`segment_gait_cycles_bilateral`
|
||||
— clinical convenience wrappers over :func:`segment_predictions`
|
||||
that pre-fill a :func:`joint_axis` extractor with gait-appropriate
|
||||
defaults (heel joint, Y axis, 0.4 s minimum cycle). The bilateral
|
||||
variant returns both sides under ``"left_heel_strikes"`` and
|
||||
``"right_heel_strikes"`` keys.
|
||||
- :func:`slice_predictions` — split a :class:`~neuropose.io.VideoPredictions`
|
||||
into one per-repetition :class:`~neuropose.io.VideoPredictions`,
|
||||
useful when downstream code wants per-rep objects rather than windows
|
||||
|
|
@ -72,7 +66,6 @@ installed; a clear :class:`ImportError` surfaces at the first call to
|
|||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Literal
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
|
@ -90,11 +83,6 @@ from neuropose.io import (
|
|||
VideoPredictions,
|
||||
)
|
||||
|
||||
AxisLetter = Literal["x", "y", "z"]
|
||||
"""Axis selector used by gait-cycle segmentation helpers."""
|
||||
|
||||
_AXIS_INDICES: dict[AxisLetter, int] = {"x": 0, "y": 1, "z": 2}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# berkeley_mhad_43 joint names
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -534,156 +522,6 @@ def segment_predictions(
|
|||
return Segmentation(config=config, segments=segments)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Gait-cycle segmentation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def segment_gait_cycles(
|
||||
predictions: VideoPredictions,
|
||||
*,
|
||||
joint: str = "rhee",
|
||||
axis: AxisLetter = "y",
|
||||
invert: bool = False,
|
||||
min_cycle_seconds: float = 0.4,
|
||||
min_prominence: float | None = None,
|
||||
) -> Segmentation:
|
||||
"""Segment gait cycles from a single heel's vertical trace.
|
||||
|
||||
Runs valley-to-valley peak detection (the same engine used by
|
||||
:func:`segment_predictions`) on the chosen joint's coordinate along
|
||||
the chosen spatial axis. By default, each detected peak corresponds
|
||||
to one heel-strike — the frame where the heel reaches its lowest
|
||||
point on the Y-down MeTRAbs world-coordinate convention — and the
|
||||
returned :class:`~neuropose.io.Segment` windows span one full gait
|
||||
cycle from the preceding toe-off valley to the following toe-off
|
||||
valley.
|
||||
|
||||
The function is a **thin wrapper** over :func:`segment_predictions`
|
||||
with a :func:`joint_axis` extractor; it exists to give clinical
|
||||
callers a gait-specific entry point with meaningful defaults
|
||||
(``joint="rhee"``, ``axis="y"``, ``min_cycle_seconds=0.4``)
|
||||
rather than forcing them to construct the extractor by hand.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
predictions
|
||||
Per-video predictions to segment. ``metadata.fps`` is used to
|
||||
translate ``min_cycle_seconds`` into a sample-count distance
|
||||
threshold.
|
||||
joint
|
||||
Joint name in the berkeley_mhad_43 skeleton — typically
|
||||
``"rhee"`` (right heel) or ``"lhee"`` (left heel). Resolved
|
||||
via :func:`joint_index`.
|
||||
axis
|
||||
Spatial axis to track, as ``"x"``, ``"y"``, or ``"z"``. The
|
||||
default ``"y"`` matches the vertical axis in MeTRAbs's output
|
||||
(Y-down world coordinates).
|
||||
invert
|
||||
If ``True``, negate the extracted signal so that minima
|
||||
become peaks. Needed when the recording convention makes a
|
||||
heel-strike appear as a *decrease* in the chosen coordinate
|
||||
— for example, a camera orientation where the vertical axis
|
||||
runs bottom-to-top instead of MeTRAbs's default top-to-bottom.
|
||||
min_cycle_seconds
|
||||
Minimum gait-cycle duration. Used as scipy's
|
||||
``find_peaks(distance=...)`` parameter after conversion to
|
||||
frame count via ``metadata.fps``. Defaults to ``0.4`` seconds,
|
||||
which rejects noise peaks on even the fastest human gaits
|
||||
(~120 strides/min) while retaining every real cadence.
|
||||
min_prominence
|
||||
Forwarded to :func:`segment_by_peaks` to filter out shallow
|
||||
local maxima that aren't real heel-strikes. In MeTRAbs units
|
||||
(millimetres) a threshold of 20 to 50 mm is typical for
|
||||
able-bodied gait; leave ``None`` to accept every peak scipy
|
||||
identifies.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Segmentation
|
||||
A :class:`~neuropose.io.Segmentation` paired with the full
|
||||
:class:`~neuropose.io.SegmentationConfig` that produced it, so
|
||||
the output is self-describing when persisted. The segments
|
||||
list is **empty** rather than an exception when no peaks are
|
||||
detected — a common outcome for shuffling gaits or
|
||||
walker-assisted trials.
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError
|
||||
If ``joint`` is not a known berkeley_mhad_43 joint name.
|
||||
ValueError
|
||||
If ``axis`` is not one of ``"x"``, ``"y"``, ``"z"``, or if
|
||||
``predictions`` has zero frames, or if ``metadata.fps`` is
|
||||
non-positive.
|
||||
ImportError
|
||||
If :mod:`scipy` is not installed.
|
||||
"""
|
||||
if axis not in _AXIS_INDICES:
|
||||
raise ValueError(f"axis must be one of 'x', 'y', 'z'; got {axis!r}")
|
||||
joint_idx = joint_index(joint)
|
||||
axis_idx = _AXIS_INDICES[axis]
|
||||
extractor = joint_axis(joint_idx, axis_idx, invert=invert)
|
||||
return segment_predictions(
|
||||
predictions,
|
||||
extractor,
|
||||
min_distance_seconds=min_cycle_seconds,
|
||||
min_prominence=min_prominence,
|
||||
)
|
||||
|
||||
|
||||
def segment_gait_cycles_bilateral(
|
||||
predictions: VideoPredictions,
|
||||
*,
|
||||
axis: AxisLetter = "y",
|
||||
invert: bool = False,
|
||||
min_cycle_seconds: float = 0.4,
|
||||
min_prominence: float | None = None,
|
||||
) -> dict[str, Segmentation]:
|
||||
"""Segment gait cycles for both heels.
|
||||
|
||||
Runs :func:`segment_gait_cycles` twice — once with ``joint="lhee"``
|
||||
and once with ``joint="rhee"`` — and returns the two results under
|
||||
the keys ``"left_heel_strikes"`` and ``"right_heel_strikes"``. The
|
||||
returned dict is shape-compatible with
|
||||
:class:`~neuropose.io.VideoPredictions.segmentations` so it can be
|
||||
merged directly into a predictions object and persisted to
|
||||
``results.json`` via the usual save path.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
predictions, axis, invert, min_cycle_seconds, min_prominence
|
||||
Forwarded to :func:`segment_gait_cycles`; see that function's
|
||||
docstring for details.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict[str, Segmentation]
|
||||
Two-keyed mapping with the left and right heel segmentations
|
||||
under ``"left_heel_strikes"`` and ``"right_heel_strikes"``.
|
||||
Either side may carry an empty segments list if its heel's
|
||||
trace contained no detectable strikes.
|
||||
"""
|
||||
return {
|
||||
"left_heel_strikes": segment_gait_cycles(
|
||||
predictions,
|
||||
joint="lhee",
|
||||
axis=axis,
|
||||
invert=invert,
|
||||
min_cycle_seconds=min_cycle_seconds,
|
||||
min_prominence=min_prominence,
|
||||
),
|
||||
"right_heel_strikes": segment_gait_cycles(
|
||||
predictions,
|
||||
joint="rhee",
|
||||
axis=axis,
|
||||
invert=invert,
|
||||
min_cycle_seconds=min_cycle_seconds,
|
||||
min_prominence=min_prominence,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Slicing: one VideoPredictions per segment
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -105,17 +105,9 @@ def run_benchmark(
|
|||
|
||||
passes: list[PerformanceMetrics] = []
|
||||
reference_predictions: VideoPredictions | None = None
|
||||
# Provenance is identical across every pass of a single run (same
|
||||
# estimator, same model, same environment), so we keep just the
|
||||
# latest one we see. Doing this on every iteration is cheap — it's
|
||||
# one attribute read — and means the benchmark result carries
|
||||
# provenance even when ``capture_reference`` is off.
|
||||
latest_provenance = None
|
||||
for i in range(repeats):
|
||||
result = estimator.process_video(video_path)
|
||||
passes.append(result.metrics)
|
||||
if result.predictions.provenance is not None:
|
||||
latest_provenance = result.predictions.provenance
|
||||
# Only the *last* measured pass needs to be captured for
|
||||
# divergence comparison. Earlier passes would just be
|
||||
# overwritten, so we avoid holding their frame dicts in memory.
|
||||
|
|
@ -130,7 +122,6 @@ def run_benchmark(
|
|||
warmup_pass=passes[0],
|
||||
measured_passes=passes[1:],
|
||||
aggregate=aggregate,
|
||||
provenance=latest_provenance,
|
||||
)
|
||||
return BenchmarkRunOutcome(
|
||||
result=benchmark_result,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""NeuroPose command-line interface.
|
||||
|
||||
Eight subcommands:
|
||||
Seven subcommands:
|
||||
|
||||
- ``neuropose watch`` — run the :class:`~neuropose.interfacer.Interfacer`
|
||||
daemon against the configured input directory.
|
||||
|
|
@ -12,10 +12,6 @@ Eight subcommands:
|
|||
- ``neuropose serve`` — start the :mod:`~neuropose.monitor` localhost
|
||||
HTTP dashboard so collaborators can watch a run's progress in a
|
||||
browser or via ``curl``.
|
||||
- ``neuropose reset`` — stop the daemon and monitor, then wipe pipeline
|
||||
state (input queue, results, status file, lock file, ingest staging
|
||||
dirs) for a clean restart. See :mod:`neuropose.reset` for the layered
|
||||
implementation.
|
||||
- ``neuropose segment <results>`` — post-hoc repetition segmentation of
|
||||
an existing predictions file. Attaches a named
|
||||
:class:`~neuropose.io.Segmentation` to every video it contains and
|
||||
|
|
@ -25,11 +21,8 @@ Eight subcommands:
|
|||
vs CPU numerical-divergence checks. Prints a human report to stdout
|
||||
and (optionally) writes a structured :class:`~neuropose.io.BenchmarkResult`
|
||||
JSON to ``--output``.
|
||||
- ``neuropose analyze --config <yaml>`` — run the declarative analysis
|
||||
pipeline described in a YAML config. Loads the named predictions
|
||||
files, applies segmentation + analysis, writes an
|
||||
:class:`~neuropose.analyzer.pipeline.AnalysisReport` JSON. See
|
||||
``examples/analysis/*.yaml`` for runnable references.
|
||||
- ``neuropose analyze <results>`` — stubbed placeholder pending the
|
||||
analyzer rewrite in commit 10.
|
||||
|
||||
User-facing error handling
|
||||
--------------------------
|
||||
|
|
@ -60,7 +53,6 @@ from pathlib import Path
|
|||
from typing import Annotated
|
||||
|
||||
import typer
|
||||
import yaml
|
||||
from pydantic import ValidationError
|
||||
|
||||
from neuropose import __version__
|
||||
|
|
@ -434,151 +426,6 @@ def serve(
|
|||
raise typer.Exit(code=EXIT_USAGE) from exc
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# reset
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@app.command()
|
||||
def reset(
|
||||
ctx: typer.Context,
|
||||
yes: Annotated[
|
||||
bool,
|
||||
typer.Option(
|
||||
"--yes",
|
||||
"-y",
|
||||
help="Skip the interactive confirmation prompt.",
|
||||
),
|
||||
] = False,
|
||||
keep_failed: Annotated[
|
||||
bool,
|
||||
typer.Option(
|
||||
"--keep-failed",
|
||||
help=(
|
||||
"Preserve $data_dir/failed/ for forensic review. By "
|
||||
"default the failed-job quarantine is wiped along with "
|
||||
"in/ and out/."
|
||||
),
|
||||
),
|
||||
] = False,
|
||||
force_kill: Annotated[
|
||||
bool,
|
||||
typer.Option(
|
||||
"--force-kill",
|
||||
help=(
|
||||
"Escalate to SIGKILL on any daemon or monitor still "
|
||||
"alive after the SIGINT grace period. Necessary if the "
|
||||
"daemon is mid-inference on a long video and you do "
|
||||
"not want to wait for the current video to finish."
|
||||
),
|
||||
),
|
||||
] = False,
|
||||
grace_seconds: Annotated[
|
||||
float,
|
||||
typer.Option(
|
||||
"--grace-seconds",
|
||||
min=0.0,
|
||||
help=(
|
||||
"Seconds to wait after SIGINT before declaring a "
|
||||
"process a survivor (or escalating to SIGKILL when "
|
||||
"--force-kill is set)."
|
||||
),
|
||||
),
|
||||
] = 10.0,
|
||||
dry_run: Annotated[
|
||||
bool,
|
||||
typer.Option(
|
||||
"--dry-run",
|
||||
"-n",
|
||||
help="Show what would be killed and removed without doing it.",
|
||||
),
|
||||
] = False,
|
||||
) -> None:
|
||||
"""Stop the daemon and monitor, then wipe pipeline state.
|
||||
|
||||
Discovers running ``neuropose watch`` and ``neuropose serve``
|
||||
processes, sends SIGINT, waits ``--grace-seconds`` for graceful
|
||||
shutdown (optionally escalating to SIGKILL with ``--force-kill``),
|
||||
then removes the contents of ``$data_dir/in/``, ``$data_dir/out/``
|
||||
(including ``status.json``), ``$data_dir/failed/`` (unless
|
||||
``--keep-failed``), the daemon lock file, and any leftover
|
||||
``.ingest_<uuid>/`` staging directories from interrupted ingests.
|
||||
|
||||
Refuses to wipe state if any process survives the termination
|
||||
phase — wiping the data directory out from under an active daemon
|
||||
would leave it writing into deleted directory entries. Re-run
|
||||
with ``--force-kill`` or stop the survivor manually.
|
||||
"""
|
||||
# Deferred import so reset's psutil scan stays off the watch/process
|
||||
# hot path. psutil is already a runtime dependency for benchmark
|
||||
# metrics, so this import is free at install time.
|
||||
from neuropose.reset import find_neuropose_processes, reset_pipeline, wipe_state
|
||||
|
||||
settings: Settings = ctx.obj
|
||||
|
||||
discovered = find_neuropose_processes()
|
||||
preview = wipe_state(settings, keep_failed=keep_failed, dry_run=True)
|
||||
|
||||
typer.echo(f"data dir: {settings.data_dir}")
|
||||
if discovered:
|
||||
typer.echo(f"would stop: {len(discovered)} process(es)")
|
||||
for rp in discovered:
|
||||
typer.echo(f" pid {rp.pid:>7} {rp.role:<7} {rp.cmdline}")
|
||||
else:
|
||||
typer.echo("would stop: no daemon or monitor running")
|
||||
if preview.removed_paths:
|
||||
size_mb = preview.bytes_freed / (1024 * 1024)
|
||||
typer.echo(f"would remove: {len(preview.removed_paths)} path(s) ({size_mb:.1f} MB)")
|
||||
for path in preview.removed_paths:
|
||||
typer.echo(f" {path}")
|
||||
else:
|
||||
typer.echo("would remove: nothing — data dir is already clean")
|
||||
|
||||
if dry_run:
|
||||
typer.echo("(dry-run; no changes made)")
|
||||
return
|
||||
|
||||
if not discovered and not preview.removed_paths:
|
||||
typer.echo("nothing to do.")
|
||||
return
|
||||
|
||||
if not yes and not typer.confirm("\nproceed?"):
|
||||
typer.echo("aborted.")
|
||||
raise typer.Exit(code=EXIT_USAGE)
|
||||
|
||||
report = reset_pipeline(
|
||||
settings,
|
||||
grace_seconds=grace_seconds,
|
||||
force_kill=force_kill,
|
||||
keep_failed=keep_failed,
|
||||
)
|
||||
|
||||
if report.termination.stopped:
|
||||
typer.echo(f"stopped {len(report.termination.stopped)} process(es) via SIGINT")
|
||||
if report.termination.force_killed:
|
||||
typer.echo(
|
||||
f"force-killed {len(report.termination.force_killed)} process(es) "
|
||||
f"after {grace_seconds:.0f}s grace period"
|
||||
)
|
||||
if report.termination.survivors:
|
||||
typer.echo(
|
||||
f"error: {len(report.termination.survivors)} process(es) did not exit:",
|
||||
err=True,
|
||||
)
|
||||
for rp in report.termination.survivors:
|
||||
typer.echo(f" pid {rp.pid} ({rp.role})", err=True)
|
||||
if report.wipe_skipped_due_to_survivors:
|
||||
typer.echo(
|
||||
" state on disk was NOT wiped — re-run with --force-kill, "
|
||||
"or stop these processes manually first.",
|
||||
err=True,
|
||||
)
|
||||
raise typer.Exit(code=EXIT_USAGE)
|
||||
|
||||
size_mb = report.wipe.bytes_freed / (1024 * 1024)
|
||||
typer.echo(f"removed {len(report.wipe.removed_paths)} path(s) ({size_mb:.1f} MB freed)")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# segment
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -1200,94 +1047,26 @@ def benchmark(
|
|||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# analyze
|
||||
# analyze (stub)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@app.command()
|
||||
def analyze(
|
||||
ctx: typer.Context,
|
||||
config: Annotated[
|
||||
results: Annotated[
|
||||
Path,
|
||||
typer.Option(
|
||||
"--config",
|
||||
"-c",
|
||||
help=(
|
||||
"Path to a YAML AnalysisConfig file. See examples/analysis/ "
|
||||
"for runnable references."
|
||||
),
|
||||
),
|
||||
typer.Argument(help="Path to a results.json produced by watch or process."),
|
||||
],
|
||||
output: Annotated[
|
||||
Path | None,
|
||||
typer.Option(
|
||||
"--output",
|
||||
"-o",
|
||||
help=(
|
||||
"Override the report path declared in the config's "
|
||||
"output.report field. Useful when running the same config "
|
||||
"against multiple input pairs from a shell loop."
|
||||
),
|
||||
),
|
||||
] = None,
|
||||
) -> None:
|
||||
"""Run the declarative analysis pipeline described by a YAML config.
|
||||
|
||||
Loads the config, parses it through
|
||||
:class:`~neuropose.analyzer.pipeline.AnalysisConfig` (so typos fail
|
||||
immediately with a clear error), executes the pipeline via
|
||||
:func:`~neuropose.analyzer.pipeline.run_analysis`, and writes the
|
||||
resulting :class:`~neuropose.analyzer.pipeline.AnalysisReport` to
|
||||
``--output`` (or to ``output.report`` declared in the config).
|
||||
|
||||
Cross-field invariants (for example,
|
||||
``method='dtw_relation'`` requires ``joint_i`` / ``joint_j``) are
|
||||
enforced at parse time, so a typo fails before any predictions
|
||||
are loaded.
|
||||
"""
|
||||
del ctx
|
||||
# Deferred import keeps the CLI module's top-level imports free of
|
||||
# pipeline dependencies so ``watch`` / ``process`` startup stays
|
||||
# cheap.
|
||||
from neuropose.analyzer.pipeline import load_config, run_analysis, save_report
|
||||
|
||||
if not config.exists():
|
||||
typer.echo(f"error: config file not found: {config}", err=True)
|
||||
raise typer.Exit(code=EXIT_USAGE)
|
||||
|
||||
try:
|
||||
analysis_config = load_config(config)
|
||||
except ValidationError as exc:
|
||||
typer.echo(f"error: invalid config {config}:\n{exc}", err=True)
|
||||
raise typer.Exit(code=EXIT_USAGE) from exc
|
||||
except yaml.YAMLError as exc:
|
||||
typer.echo(f"error: could not parse YAML {config}: {exc}", err=True)
|
||||
raise typer.Exit(code=EXIT_USAGE) from exc
|
||||
|
||||
report_path = output if output is not None else analysis_config.output.report
|
||||
|
||||
try:
|
||||
report = run_analysis(analysis_config)
|
||||
except (FileNotFoundError, ValueError) as exc:
|
||||
typer.echo(f"error: analysis failed: {exc}", err=True)
|
||||
raise typer.Exit(code=EXIT_USAGE) from exc
|
||||
|
||||
save_report(report_path, report)
|
||||
|
||||
typer.echo(f"wrote analysis report to {report_path}")
|
||||
if report.segmentations:
|
||||
seg_summary = ", ".join(
|
||||
f"{name}={len(seg.segments)}" for name, seg in report.segmentations.items()
|
||||
)
|
||||
typer.echo(f"segmentations: {seg_summary}")
|
||||
# Emit a one-line summary of the results regardless of kind.
|
||||
typer.echo(f"analysis kind: {report.results.kind}")
|
||||
if report.results.kind == "dtw":
|
||||
n = len(report.results.distances)
|
||||
mean = report.results.summary.get("mean", float("nan"))
|
||||
typer.echo(f"distances computed: {n} (mean={mean:.4f})")
|
||||
elif report.results.kind == "stats":
|
||||
typer.echo(f"statistic blocks computed: {len(report.results.statistics)}")
|
||||
"""Run the analyzer subpackage against a results.json (pending commit 10)."""
|
||||
del ctx, results
|
||||
typer.echo(
|
||||
"error: the analyzer subpackage is pending commit 10. "
|
||||
"Until it lands, use neuropose.io to load results.json from Python.",
|
||||
err=True,
|
||||
)
|
||||
raise typer.Exit(code=EXIT_PENDING)
|
||||
|
||||
|
||||
def run() -> None:
|
||||
|
|
|
|||
|
|
@ -34,25 +34,19 @@ model is present raises :class:`ModelNotLoadedError`.
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from importlib.metadata import PackageNotFoundError
|
||||
from importlib.metadata import version as _pkg_version
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import psutil
|
||||
|
||||
from neuropose import __version__ as _neuropose_version
|
||||
from neuropose._model import load_metrabs_model
|
||||
from neuropose.io import (
|
||||
FramePrediction,
|
||||
PerformanceMetrics,
|
||||
Provenance,
|
||||
VideoMetadata,
|
||||
VideoPredictions,
|
||||
)
|
||||
|
|
@ -164,12 +158,6 @@ class Estimator:
|
|||
# successful ``load_model`` below so the next ``process_video`` can
|
||||
# pass the real number through into ``PerformanceMetrics``.
|
||||
self._model_load_seconds: float | None = None
|
||||
# MeTRAbs artifact identity, set only by ``load_model``. When the
|
||||
# model was injected via the constructor we have no way to
|
||||
# fingerprint it, so these remain ``None`` and ``process_video``
|
||||
# leaves the output's ``provenance`` as ``None`` too.
|
||||
self._model_sha256: str | None = None
|
||||
self._model_filename: str | None = None
|
||||
|
||||
# -- model lifecycle ----------------------------------------------------
|
||||
|
||||
|
|
@ -188,21 +176,6 @@ class Estimator:
|
|||
"""Return ``True`` if a model has been supplied or loaded."""
|
||||
return self._model is not None
|
||||
|
||||
@property
|
||||
def model_sha256(self) -> str | None:
|
||||
"""Return the SHA-256 of the loaded MeTRAbs artifact, or ``None``.
|
||||
|
||||
``None`` when the model was injected via ``Estimator(model=...)``
|
||||
rather than loaded via :meth:`load_model`. The value, when
|
||||
present, is the module-pinned SHA from :mod:`neuropose._model`.
|
||||
"""
|
||||
return self._model_sha256
|
||||
|
||||
@property
|
||||
def model_filename(self) -> str | None:
|
||||
"""Return the basename of the MeTRAbs artifact, or ``None`` if injected."""
|
||||
return self._model_filename
|
||||
|
||||
def load_model(self, cache_dir: Path | None = None) -> None:
|
||||
"""Load the MeTRAbs model via :func:`neuropose._model.load_metrabs_model`.
|
||||
|
||||
|
|
@ -223,16 +196,9 @@ class Estimator:
|
|||
return
|
||||
logger.info("Loading MeTRAbs model (cache_dir=%s)", cache_dir)
|
||||
start = time.perf_counter()
|
||||
loaded = load_metrabs_model(cache_dir=cache_dir)
|
||||
self._model = load_metrabs_model(cache_dir=cache_dir)
|
||||
self._model_load_seconds = time.perf_counter() - start
|
||||
self._model = loaded.model
|
||||
self._model_sha256 = loaded.sha256
|
||||
self._model_filename = loaded.filename
|
||||
logger.info(
|
||||
"MeTRAbs model loaded in %.2f s (sha256=%s)",
|
||||
self._model_load_seconds,
|
||||
loaded.sha256[:12],
|
||||
)
|
||||
logger.info("MeTRAbs model loaded in %.2f s", self._model_load_seconds)
|
||||
|
||||
# -- inference ----------------------------------------------------------
|
||||
|
||||
|
|
@ -364,53 +330,11 @@ class Estimator:
|
|||
metrics.active_device,
|
||||
)
|
||||
|
||||
provenance = self._build_provenance(device_info=device_info)
|
||||
predictions = VideoPredictions(
|
||||
metadata=metadata,
|
||||
frames=frames,
|
||||
provenance=provenance,
|
||||
)
|
||||
predictions = VideoPredictions(metadata=metadata, frames=frames)
|
||||
return ProcessVideoResult(predictions=predictions, metrics=metrics)
|
||||
|
||||
# -- internals ----------------------------------------------------------
|
||||
|
||||
def _build_provenance(self, *, device_info: _ActiveDeviceInfo) -> Provenance | None:
|
||||
"""Construct a :class:`~neuropose.io.Provenance` for the current run.
|
||||
|
||||
Returns ``None`` when the model was injected via the constructor
|
||||
rather than loaded via :meth:`load_model` — in that case we
|
||||
cannot fingerprint the artifact, and a partial provenance would
|
||||
mislead readers into thinking we could.
|
||||
|
||||
The device-info bundle is shared with the :class:`PerformanceMetrics`
|
||||
construction (one call to :func:`_detect_active_device` per
|
||||
``process_video`` invocation) so that both artifacts see
|
||||
identical TF and Metal state.
|
||||
"""
|
||||
if self._model_sha256 is None or self._model_filename is None:
|
||||
return None
|
||||
|
||||
metal_version: str | None = None
|
||||
if device_info.metal_active:
|
||||
try:
|
||||
metal_version = _pkg_version("tensorflow-metal")
|
||||
except PackageNotFoundError:
|
||||
metal_version = None
|
||||
|
||||
python_version = (
|
||||
f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
|
||||
)
|
||||
|
||||
return Provenance(
|
||||
model_sha256=self._model_sha256,
|
||||
model_filename=self._model_filename,
|
||||
tensorflow_version=device_info.tf_version,
|
||||
tensorflow_metal_version=metal_version,
|
||||
numpy_version=np.__version__,
|
||||
neuropose_version=_neuropose_version,
|
||||
python_version=python_version,
|
||||
)
|
||||
|
||||
def _infer_frame(
|
||||
self,
|
||||
model: Any,
|
||||
|
|
|
|||
|
|
@ -10,14 +10,6 @@ Atomicity: :func:`save_status`, :func:`save_job_results`, and
|
|||
atomically rename, so a crash mid-write will not leave a partially-written
|
||||
file behind. This matches the crash-resilience guarantee the interfacer
|
||||
daemon makes to callers.
|
||||
|
||||
Schema versioning: :class:`VideoPredictions` and :class:`BenchmarkResult`
|
||||
each carry a ``schema_version`` integer. On load, the raw JSON dict is
|
||||
passed through :mod:`neuropose.migrations` before pydantic validation so
|
||||
that files written by earlier versions upgrade transparently. :class:`JobResults`
|
||||
is a ``RootModel`` with no envelope of its own, so its loader runs the
|
||||
per-video migration on each entry of its mapping. See
|
||||
:mod:`neuropose.migrations` for the migration-registration pattern.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -31,13 +23,6 @@ from typing import Annotated, Any, Literal
|
|||
|
||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, model_validator
|
||||
|
||||
from neuropose.migrations import (
|
||||
CURRENT_VERSION,
|
||||
migrate_benchmark_result,
|
||||
migrate_job_results,
|
||||
migrate_video_predictions,
|
||||
)
|
||||
|
||||
|
||||
class JobStatus(StrEnum):
|
||||
"""Lifecycle state of a single processing job."""
|
||||
|
|
@ -172,104 +157,6 @@ class PerformanceMetrics(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class Provenance(BaseModel):
|
||||
"""Reproducibility-grade record of the environment that produced a payload.
|
||||
|
||||
Populated by the estimator on every inference run when the MeTRAbs
|
||||
model was loaded through
|
||||
:meth:`neuropose.estimator.Estimator.load_model` (the production
|
||||
path). ``None`` when the model was injected directly via the
|
||||
``Estimator(model=...)`` constructor (the test-fixture path), since
|
||||
NeuroPose has no way to fingerprint a model it did not load itself.
|
||||
|
||||
Paper C's reproducibility story rests on this envelope: two runs
|
||||
that produced equal ``Provenance`` objects against the same input
|
||||
are expected to produce equal output (modulo non-determinism
|
||||
controlled by ``deterministic``). Reviewers who want to re-derive a
|
||||
figure from raw video need exactly these fields.
|
||||
|
||||
Frozen so a captured ``Provenance`` cannot be mutated after it has
|
||||
been attached to a result; this matches the invariant that
|
||||
provenance is a property of the run, not of the reader.
|
||||
|
||||
``protected_namespaces=()`` silences pydantic's ``model_*`` field
|
||||
warning — the ``model_sha256`` / ``model_filename`` names refer to
|
||||
the MeTRAbs model artifact, not to pydantic's internal
|
||||
``model_validate`` / ``model_dump`` namespace, so the collision is
|
||||
cosmetic.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid", frozen=True, protected_namespaces=())
|
||||
|
||||
model_sha256: str = Field(
|
||||
description=(
|
||||
"SHA-256 of the MeTRAbs model tarball (hex-encoded, lowercase). "
|
||||
"Pinned at build time in :mod:`neuropose._model` and verified on "
|
||||
"first download. Identifies the exact model weights used."
|
||||
),
|
||||
)
|
||||
model_filename: str = Field(
|
||||
description=(
|
||||
"Canonical basename of the MeTRAbs tarball, e.g. "
|
||||
"``metrabs_eff2l_y4_384px_800k_28ds.tar.gz``. Human-readable "
|
||||
"companion to ``model_sha256``."
|
||||
),
|
||||
)
|
||||
tensorflow_version: str = Field(
|
||||
description="Value of ``tensorflow.__version__`` at the time of the run.",
|
||||
)
|
||||
tensorflow_metal_version: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Version of the ``tensorflow-metal`` PyPI package when installed; "
|
||||
"``None`` on platforms without Metal GPU acceleration."
|
||||
),
|
||||
)
|
||||
numpy_version: str = Field(
|
||||
description="Value of ``numpy.__version__`` at the time of the run.",
|
||||
)
|
||||
neuropose_version: str = Field(
|
||||
description="Value of ``neuropose.__version__`` at the time of the run.",
|
||||
)
|
||||
python_version: str = Field(
|
||||
description=(
|
||||
"Python version as ``MAJOR.MINOR.MICRO``, e.g. ``3.11.14``. The "
|
||||
"full ``sys.version`` string is intentionally not captured; the "
|
||||
"three-component form is stable across patch builds and avoids "
|
||||
"embedding compiler and build-date metadata."
|
||||
),
|
||||
)
|
||||
seed: int | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Random seed used for the run if one was set, else ``None``. "
|
||||
"MeTRAbs inference is deterministic on a given device up to "
|
||||
"floating-point associativity, so seeding mostly matters for "
|
||||
"downstream analysis that introduces randomness (bootstraps, "
|
||||
"learned metrics)."
|
||||
),
|
||||
)
|
||||
deterministic: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"``True`` if ``tf.config.experimental.enable_op_determinism()`` "
|
||||
"was active during the run. Track 2 deterministic-inference "
|
||||
"mode; the field exists in Phase 0 so payloads can record "
|
||||
"whether the run *was* deterministic without requiring a "
|
||||
"schema change when the toggle lands."
|
||||
),
|
||||
)
|
||||
analysis_config: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Parsed YAML dict if this payload was produced by ``neuropose "
|
||||
"analyze --config <file>``. ``None`` for direct-library or "
|
||||
"``neuropose watch`` invocations. Reserved for the Phase 0 "
|
||||
"YAML-configurable analysis pipeline."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class BenchmarkAggregate(BaseModel):
|
||||
"""Distributional statistics aggregated across benchmark passes.
|
||||
|
||||
|
|
@ -368,16 +255,6 @@ class BenchmarkResult(BaseModel):
|
|||
|
||||
model_config = ConfigDict(extra="forbid", frozen=True)
|
||||
|
||||
schema_version: int = Field(
|
||||
default=CURRENT_VERSION,
|
||||
ge=1,
|
||||
description=(
|
||||
"Schema version of this BenchmarkResult payload. Fresh writes "
|
||||
"stamp :data:`neuropose.migrations.CURRENT_VERSION`; older files "
|
||||
"are migrated on load via :mod:`neuropose.migrations` before "
|
||||
"pydantic validation."
|
||||
),
|
||||
)
|
||||
video_name: str = Field(
|
||||
description="Basename of the benchmarked video (no directory components).",
|
||||
)
|
||||
|
|
@ -403,14 +280,6 @@ class BenchmarkResult(BaseModel):
|
|||
)
|
||||
aggregate: BenchmarkAggregate
|
||||
cpu_comparison: CpuComparisonResult | None = None
|
||||
provenance: Provenance | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Reproducibility envelope from the benchmark run. ``None`` on "
|
||||
"tests where the model was injected directly via "
|
||||
"``Estimator(model=...)``."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class JointAxisExtractor(BaseModel):
|
||||
|
|
@ -600,30 +469,9 @@ class VideoPredictions(BaseModel):
|
|||
|
||||
model_config = ConfigDict(extra="forbid", frozen=True)
|
||||
|
||||
schema_version: int = Field(
|
||||
default=CURRENT_VERSION,
|
||||
ge=1,
|
||||
description=(
|
||||
"Schema version of this VideoPredictions payload. Fresh writes "
|
||||
"stamp :data:`neuropose.migrations.CURRENT_VERSION`; files written "
|
||||
"by older NeuroPose versions are migrated to the current version "
|
||||
"by :mod:`neuropose.migrations` before pydantic validation."
|
||||
),
|
||||
)
|
||||
metadata: VideoMetadata
|
||||
frames: dict[str, FramePrediction]
|
||||
segmentations: dict[str, Segmentation] = Field(default_factory=dict)
|
||||
provenance: Provenance | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Reproducibility envelope populated by the estimator on runs "
|
||||
"where the MeTRAbs model was loaded via "
|
||||
":meth:`neuropose.estimator.Estimator.load_model`. ``None`` on "
|
||||
"test paths where the model was injected via "
|
||||
"``Estimator(model=...)``, because no model SHA is known in "
|
||||
"that case."
|
||||
),
|
||||
)
|
||||
|
||||
def frame_names(self) -> list[str]:
|
||||
"""Return frame identifiers in insertion order."""
|
||||
|
|
@ -775,16 +623,9 @@ class StatusFile(RootModel[dict[str, JobStatusEntry]]):
|
|||
|
||||
|
||||
def load_video_predictions(path: Path) -> VideoPredictions:
|
||||
"""Load and validate a per-video predictions JSON file.
|
||||
|
||||
Runs the payload through :func:`neuropose.migrations.migrate_video_predictions`
|
||||
before pydantic validation so files written by older NeuroPose versions
|
||||
upgrade to the current schema transparently.
|
||||
"""
|
||||
"""Load and validate a per-video predictions JSON file."""
|
||||
with path.open("r", encoding="utf-8") as f:
|
||||
data: Any = json.load(f)
|
||||
if isinstance(data, dict):
|
||||
data = migrate_video_predictions(data)
|
||||
return VideoPredictions.model_validate(data)
|
||||
|
||||
|
||||
|
|
@ -795,17 +636,9 @@ def save_video_predictions(path: Path, predictions: VideoPredictions) -> None:
|
|||
|
||||
|
||||
def load_job_results(path: Path) -> JobResults:
|
||||
"""Load and validate an aggregated per-job results JSON file.
|
||||
|
||||
Runs each video's payload through
|
||||
:func:`neuropose.migrations.migrate_video_predictions` before pydantic
|
||||
validation. :class:`JobResults` is a ``RootModel`` with no envelope of
|
||||
its own, so migration happens per-entry rather than at the top level.
|
||||
"""
|
||||
"""Load and validate an aggregated per-job results JSON file."""
|
||||
with path.open("r", encoding="utf-8") as f:
|
||||
data: Any = json.load(f)
|
||||
if isinstance(data, dict):
|
||||
data = migrate_job_results(data)
|
||||
return JobResults.model_validate(data)
|
||||
|
||||
|
||||
|
|
@ -816,16 +649,9 @@ def save_job_results(path: Path, results: JobResults) -> None:
|
|||
|
||||
|
||||
def load_benchmark_result(path: Path) -> BenchmarkResult:
|
||||
"""Load and validate a benchmark-result JSON file.
|
||||
|
||||
Runs the payload through :func:`neuropose.migrations.migrate_benchmark_result`
|
||||
before pydantic validation so files written by older NeuroPose versions
|
||||
upgrade transparently.
|
||||
"""
|
||||
"""Load and validate a benchmark-result JSON file."""
|
||||
with path.open("r", encoding="utf-8") as f:
|
||||
data: Any = json.load(f)
|
||||
if isinstance(data, dict):
|
||||
data = migrate_benchmark_result(data)
|
||||
return BenchmarkResult.model_validate(data)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,318 +0,0 @@
|
|||
"""Schema migration infrastructure for serialised NeuroPose payloads.
|
||||
|
||||
Every top-level JSON schema that NeuroPose persists to disk
|
||||
(:class:`~neuropose.io.VideoPredictions`,
|
||||
:class:`~neuropose.io.JobResults`, and
|
||||
:class:`~neuropose.io.BenchmarkResult`) carries a ``schema_version``
|
||||
integer. When those files are read back, the raw dict is passed
|
||||
through :func:`migrate_video_predictions` /
|
||||
:func:`migrate_job_results` / :func:`migrate_benchmark_result`
|
||||
*before* pydantic validation runs, so each schema version can be
|
||||
brought up to the current one transparently.
|
||||
|
||||
The pattern is deliberately small: one integer version counter shared
|
||||
across all top-level schemas, plus a per-schema registry of
|
||||
``{from_version: migration_fn}``. Each migration is a pure function
|
||||
``dict -> dict`` responsible for stamping the new ``schema_version``
|
||||
on its output. The framework chains them.
|
||||
|
||||
This module is intentionally separate from :mod:`neuropose.io` so
|
||||
that migration registrations cannot accidentally import the pydantic
|
||||
models they migrate — migrations must operate on raw dicts to be
|
||||
robust to schema drift (a field a migration references may not exist
|
||||
on the pydantic model by the time CURRENT_VERSION has moved past
|
||||
it).
|
||||
|
||||
Adding a new migration
|
||||
----------------------
|
||||
When a schema change lands:
|
||||
|
||||
1. Bump :data:`CURRENT_VERSION`.
|
||||
2. Register a migration from the *previous* version to the new one
|
||||
via :func:`register_video_predictions_migration` (or the sibling
|
||||
for benchmark results). The function receives the raw dict at the
|
||||
old version and must return a dict at the new version *including*
|
||||
the updated ``schema_version`` stamp.
|
||||
3. Update the pydantic model in :mod:`neuropose.io` to reflect the
|
||||
new field set.
|
||||
4. Add a unit test verifying that a fixture at the old version
|
||||
round-trips through ``load_*`` to the expected new-version shape.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
CURRENT_VERSION = 2
|
||||
"""The current schema version for all NeuroPose-persisted JSON payloads.
|
||||
|
||||
Shared across :class:`~neuropose.io.VideoPredictions`,
|
||||
:class:`~neuropose.io.JobResults`, and
|
||||
:class:`~neuropose.io.BenchmarkResult` so that coordinated schema
|
||||
changes (for example, adding a ``provenance`` field to all three at
|
||||
once) bump a single counter rather than three parallel ones.
|
||||
|
||||
Version history
|
||||
---------------
|
||||
- **v1:** initial schema, pre-Phase-0.
|
||||
- **v2:** added optional ``provenance`` field to :class:`~neuropose.io.VideoPredictions`
|
||||
and :class:`~neuropose.io.BenchmarkResult` (Phase 0, Paper C reproducibility envelope).
|
||||
:class:`~neuropose.analyzer.pipeline.AnalysisReport` also enters the registry at v2
|
||||
(no legacy v1 payloads ever existed for it, so no migration is registered)."""
|
||||
|
||||
|
||||
class MigrationError(Exception):
|
||||
"""Base class for schema-migration failures."""
|
||||
|
||||
|
||||
class FutureSchemaError(MigrationError):
|
||||
"""Raised when a payload's ``schema_version`` exceeds :data:`CURRENT_VERSION`.
|
||||
|
||||
Produced when a newer NeuroPose version writes a file and an older
|
||||
version tries to read it. The fix is upgrading NeuroPose; silently
|
||||
stripping fields would corrupt the payload.
|
||||
"""
|
||||
|
||||
|
||||
class MigrationNotFoundError(MigrationError):
|
||||
"""Raised when no migration is registered for an intermediate version.
|
||||
|
||||
Indicates a bug — :data:`CURRENT_VERSION` was bumped past a version
|
||||
for which no migration function was registered. Should only surface
|
||||
in tests or on a corrupted install.
|
||||
"""
|
||||
|
||||
|
||||
# Per-schema registries. Keys are the *source* version of the migration;
|
||||
# the value is a callable that takes a dict at that version and returns a
|
||||
# dict at ``source + 1``.
|
||||
_VIDEO_PREDICTIONS_MIGRATIONS: dict[int, Callable[[dict], dict]] = {}
|
||||
_BENCHMARK_RESULT_MIGRATIONS: dict[int, Callable[[dict], dict]] = {}
|
||||
_ANALYSIS_REPORT_MIGRATIONS: dict[int, Callable[[dict], dict]] = {}
|
||||
|
||||
|
||||
def register_video_predictions_migration(
|
||||
from_version: int,
|
||||
) -> Callable[[Callable[[dict], dict]], Callable[[dict], dict]]:
|
||||
"""Register a :class:`~neuropose.io.VideoPredictions` migration.
|
||||
|
||||
Usage::
|
||||
|
||||
@register_video_predictions_migration(from_version=1)
|
||||
def _v1_to_v2(payload: dict) -> dict:
|
||||
payload = dict(payload)
|
||||
payload["provenance"] = None
|
||||
payload["schema_version"] = 2
|
||||
return payload
|
||||
|
||||
The decorator registers the function into the per-schema migration
|
||||
registry and returns it unchanged, so it can still be called
|
||||
directly from tests.
|
||||
"""
|
||||
|
||||
def wrap(fn: Callable[[dict], dict]) -> Callable[[dict], dict]:
|
||||
if from_version in _VIDEO_PREDICTIONS_MIGRATIONS:
|
||||
raise RuntimeError(
|
||||
f"video-predictions migration already registered from version {from_version}"
|
||||
)
|
||||
_VIDEO_PREDICTIONS_MIGRATIONS[from_version] = fn
|
||||
return fn
|
||||
|
||||
return wrap
|
||||
|
||||
|
||||
def register_benchmark_result_migration(
|
||||
from_version: int,
|
||||
) -> Callable[[Callable[[dict], dict]], Callable[[dict], dict]]:
|
||||
"""Register a :class:`~neuropose.io.BenchmarkResult` migration.
|
||||
|
||||
See :func:`register_video_predictions_migration` for usage — this
|
||||
is the same pattern for the benchmark-result registry.
|
||||
"""
|
||||
|
||||
def wrap(fn: Callable[[dict], dict]) -> Callable[[dict], dict]:
|
||||
if from_version in _BENCHMARK_RESULT_MIGRATIONS:
|
||||
raise RuntimeError(
|
||||
f"benchmark-result migration already registered from version {from_version}"
|
||||
)
|
||||
_BENCHMARK_RESULT_MIGRATIONS[from_version] = fn
|
||||
return fn
|
||||
|
||||
return wrap
|
||||
|
||||
|
||||
def register_analysis_report_migration(
|
||||
from_version: int,
|
||||
) -> Callable[[Callable[[dict], dict]], Callable[[dict], dict]]:
|
||||
"""Register a :class:`~neuropose.analyzer.pipeline.AnalysisReport` migration.
|
||||
|
||||
See :func:`register_video_predictions_migration` for usage — this
|
||||
is the same pattern for the analysis-report registry. Unlike the
|
||||
other two schemas, :class:`AnalysisReport` first appeared at
|
||||
:data:`CURRENT_VERSION = 2`, so no ``from_version=1`` migration
|
||||
exists (and none is expected).
|
||||
"""
|
||||
|
||||
def wrap(fn: Callable[[dict], dict]) -> Callable[[dict], dict]:
|
||||
if from_version in _ANALYSIS_REPORT_MIGRATIONS:
|
||||
raise RuntimeError(
|
||||
f"analysis-report migration already registered from version {from_version}"
|
||||
)
|
||||
_ANALYSIS_REPORT_MIGRATIONS[from_version] = fn
|
||||
return fn
|
||||
|
||||
return wrap
|
||||
|
||||
|
||||
def migrate_video_predictions(payload: dict) -> dict:
|
||||
"""Migrate a raw :class:`~neuropose.io.VideoPredictions` dict to current.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
payload
|
||||
Raw JSON-loaded dict. Must not yet have been through pydantic
|
||||
validation. A missing ``schema_version`` key is interpreted as
|
||||
version ``1`` (the earliest tracked version, shipped before
|
||||
the migration infrastructure existed).
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
The payload at :data:`CURRENT_VERSION`. Ready to be passed to
|
||||
``VideoPredictions.model_validate``.
|
||||
|
||||
Raises
|
||||
------
|
||||
FutureSchemaError
|
||||
If the payload declares a ``schema_version`` higher than
|
||||
:data:`CURRENT_VERSION`.
|
||||
MigrationNotFoundError
|
||||
If an intermediate migration is missing from the registry.
|
||||
"""
|
||||
return _migrate(payload, _VIDEO_PREDICTIONS_MIGRATIONS, schema_name="VideoPredictions")
|
||||
|
||||
|
||||
def migrate_benchmark_result(payload: dict) -> dict:
|
||||
"""Migrate a raw :class:`~neuropose.io.BenchmarkResult` dict to current.
|
||||
|
||||
See :func:`migrate_video_predictions` for semantics. This is the
|
||||
sibling function for benchmark-result payloads.
|
||||
"""
|
||||
return _migrate(payload, _BENCHMARK_RESULT_MIGRATIONS, schema_name="BenchmarkResult")
|
||||
|
||||
|
||||
def migrate_analysis_report(payload: dict) -> dict:
|
||||
"""Migrate a raw :class:`~neuropose.analyzer.pipeline.AnalysisReport` dict.
|
||||
|
||||
See :func:`migrate_video_predictions` for semantics. Because
|
||||
:class:`AnalysisReport` first shipped at schema_version 2, a
|
||||
payload missing the key still defaults to 1 (and would require a
|
||||
not-yet-registered v1 → v2 migration); this is only reachable for
|
||||
deliberately malformed inputs.
|
||||
"""
|
||||
return _migrate(payload, _ANALYSIS_REPORT_MIGRATIONS, schema_name="AnalysisReport")
|
||||
|
||||
|
||||
def migrate_job_results(payload: dict) -> dict:
|
||||
"""Migrate a :class:`~neuropose.io.JobResults` root dict to current.
|
||||
|
||||
``JobResults`` is a ``RootModel`` whose root is a mapping of video
|
||||
name to :class:`~neuropose.io.VideoPredictions` payload. It has no
|
||||
envelope of its own, so the migration is "run
|
||||
:func:`migrate_video_predictions` on every value in the mapping."
|
||||
|
||||
Parameters
|
||||
----------
|
||||
payload
|
||||
Raw JSON-loaded dict of ``{video_name: VideoPredictions-shaped dict}``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
The same mapping with each video payload migrated to the
|
||||
current schema version.
|
||||
"""
|
||||
return {name: migrate_video_predictions(video) for name, video in payload.items()}
|
||||
|
||||
|
||||
def _migrate(
|
||||
payload: dict,
|
||||
migrations: dict[int, Callable[[dict], dict]],
|
||||
*,
|
||||
schema_name: str,
|
||||
) -> dict:
|
||||
"""Walk the migration chain to :data:`CURRENT_VERSION` and return the migrated payload.
|
||||
|
||||
Shared driver for :func:`migrate_video_predictions` and
|
||||
:func:`migrate_benchmark_result`. Looks up the incoming
|
||||
``schema_version`` (defaulting to 1 when absent), walks the migration
|
||||
chain until reaching :data:`CURRENT_VERSION`, and returns the
|
||||
migrated payload. Logs at INFO each time it actually advances a
|
||||
version so operators see the upgrade happen.
|
||||
"""
|
||||
version = payload.get("schema_version", 1)
|
||||
if not isinstance(version, int) or version < 1:
|
||||
raise MigrationError(
|
||||
f"{schema_name} payload has invalid schema_version {version!r}; must be an integer >= 1"
|
||||
)
|
||||
if version > CURRENT_VERSION:
|
||||
raise FutureSchemaError(
|
||||
f"{schema_name} payload declares schema_version {version}, which is newer "
|
||||
f"than this build's CURRENT_VERSION ({CURRENT_VERSION}). Upgrade NeuroPose."
|
||||
)
|
||||
while version < CURRENT_VERSION:
|
||||
if version not in migrations:
|
||||
raise MigrationNotFoundError(
|
||||
f"no {schema_name} migration registered from schema_version {version}"
|
||||
)
|
||||
logger.info(
|
||||
"Migrating %s payload from schema_version %d to %d",
|
||||
schema_name,
|
||||
version,
|
||||
version + 1,
|
||||
)
|
||||
payload = migrations[version](payload)
|
||||
version += 1
|
||||
return payload
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registered migrations
|
||||
# ---------------------------------------------------------------------------
|
||||
#
|
||||
# Keep registrations *below* the driver so the module's public API surfaces
|
||||
# at the top and the version-specific diffs live together at the bottom where
|
||||
# they are easiest to audit chronologically.
|
||||
|
||||
|
||||
@register_video_predictions_migration(from_version=1)
|
||||
def _video_predictions_v1_to_v2(payload: dict) -> dict:
|
||||
"""v1 → v2: add the optional ``provenance`` field (Phase 0).
|
||||
|
||||
Phase 0 introduces the :class:`~neuropose.io.Provenance` envelope
|
||||
for Paper C reproducibility. v1 files predate it, so we stamp
|
||||
``provenance = None`` on load — the field is optional on the
|
||||
pydantic model and ``None`` correctly indicates "we don't have
|
||||
provenance metadata for this payload."
|
||||
"""
|
||||
payload = dict(payload)
|
||||
payload.setdefault("provenance", None)
|
||||
payload["schema_version"] = 2
|
||||
return payload
|
||||
|
||||
|
||||
@register_benchmark_result_migration(from_version=1)
|
||||
def _benchmark_result_v1_to_v2(payload: dict) -> dict:
|
||||
"""v1 → v2: add the optional ``provenance`` field (Phase 0).
|
||||
|
||||
Sibling of :func:`_video_predictions_v1_to_v2` for benchmark
|
||||
payloads; same rationale.
|
||||
"""
|
||||
payload = dict(payload)
|
||||
payload.setdefault("provenance", None)
|
||||
payload["schema_version"] = 2
|
||||
return payload
|
||||
|
|
@ -1,388 +0,0 @@
|
|||
"""Pipeline-wide reset utility.
|
||||
|
||||
Tear down a running NeuroPose deployment back to a clean state: stop
|
||||
the watch daemon and the monitor server, then wipe the job queue,
|
||||
results, status file, and ingest staging directories. Intended for
|
||||
the rapid iteration loop that comes up during benchmark and validation
|
||||
work, where you want an empty ``$data_dir/in`` without manually
|
||||
``rm -rf``-ing five separate paths and pkill'ing two processes.
|
||||
|
||||
The module is split into three independently-callable layers so each
|
||||
piece is testable in isolation and reusable from non-CLI contexts:
|
||||
|
||||
- :func:`find_neuropose_processes` enumerates running ``neuropose
|
||||
watch`` and ``neuropose serve`` processes by scanning the OS process
|
||||
table. Pure read; no side effects.
|
||||
- :func:`terminate_processes` signals the discovered processes (SIGINT
|
||||
first, optionally SIGKILL after a grace period) and reports
|
||||
survivors.
|
||||
- :func:`wipe_state` removes the data-directory paths that the daemon
|
||||
and monitor produce. Idempotent; safe against a fresh install.
|
||||
|
||||
The top-level :func:`reset_pipeline` orchestrates all three and
|
||||
returns a :class:`ResetReport` summarizing what happened.
|
||||
|
||||
Safety
|
||||
------
|
||||
:func:`reset_pipeline` refuses to wipe state while *any* discovered
|
||||
process is still alive after the termination phase. Wiping
|
||||
``$data_dir`` out from under an active daemon would leave the daemon
|
||||
writing into deleted directory entries — a guaranteed mess. Callers
|
||||
that hit a survivor must either raise the grace period, opt into
|
||||
``force_kill=True`` (SIGKILL), or kill the survivor manually before
|
||||
re-running.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import signal
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import psutil
|
||||
|
||||
from neuropose.config import Settings
|
||||
from neuropose.interfacer import LOCK_FILENAME
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Cmdline substrings that identify the daemon and monitor.
|
||||
# Matched against the joined argv so ``uv run neuropose watch -v`` and
|
||||
# ``python -m neuropose watch`` both hit. The substring choice
|
||||
# deliberately includes the subcommand name so a generic
|
||||
# ``neuropose --help`` shell doesn't get caught.
|
||||
_DAEMON_MARKER = "neuropose watch"
|
||||
_MONITOR_MARKER = "neuropose serve"
|
||||
|
||||
DEFAULT_GRACE_SECONDS = 10.0
|
||||
"""How long :func:`terminate_processes` waits after SIGINT before
|
||||
declaring a process a survivor (or escalating to SIGKILL when
|
||||
``force_kill`` is set). Long enough for an idle daemon to finish its
|
||||
current poll iteration; short enough that an interactive ``reset``
|
||||
invocation doesn't feel hung. Override per-call when waiting on a
|
||||
multi-minute inference."""
|
||||
|
||||
_POLL_INTERVAL_SECONDS = 0.2
|
||||
|
||||
ProcessRole = Literal["daemon", "monitor"]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RunningProcess:
|
||||
"""A neuropose process discovered in the OS process table."""
|
||||
|
||||
pid: int
|
||||
role: ProcessRole
|
||||
cmdline: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class TerminationReport:
|
||||
"""Outcome of trying to stop a set of running processes."""
|
||||
|
||||
stopped: list[RunningProcess] = field(default_factory=list)
|
||||
survivors: list[RunningProcess] = field(default_factory=list)
|
||||
force_killed: list[RunningProcess] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class WipeReport:
|
||||
"""Outcome of wiping data-directory state."""
|
||||
|
||||
removed_paths: list[Path] = field(default_factory=list)
|
||||
bytes_freed: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResetReport:
|
||||
"""Aggregate report from a full pipeline reset."""
|
||||
|
||||
discovered: list[RunningProcess]
|
||||
termination: TerminationReport
|
||||
wipe: WipeReport
|
||||
dry_run: bool
|
||||
wipe_skipped_due_to_survivors: bool = False
|
||||
|
||||
|
||||
def find_neuropose_processes(*, exclude_self: bool = True) -> list[RunningProcess]:
|
||||
"""Scan the process table for ``neuropose watch`` / ``neuropose serve``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
exclude_self
|
||||
Skip the current process. The ``neuropose reset`` command
|
||||
itself has ``"neuropose"`` in its argv and would otherwise see
|
||||
itself in the result. Set to ``False`` only in tests where the
|
||||
caller has constructed a process table that should match
|
||||
verbatim.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[RunningProcess]
|
||||
Processes whose joined argv contains either marker substring.
|
||||
Empty list when nothing matches.
|
||||
"""
|
||||
self_pid = os.getpid()
|
||||
found: list[RunningProcess] = []
|
||||
for proc in psutil.process_iter(["pid", "cmdline"]):
|
||||
try:
|
||||
pid = proc.info["pid"]
|
||||
cmdline = proc.info["cmdline"] or []
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied):
|
||||
continue
|
||||
if exclude_self and pid == self_pid:
|
||||
continue
|
||||
joined = " ".join(cmdline)
|
||||
# Daemon check first because "neuropose serve" and
|
||||
# "neuropose watch" cannot both appear in a single process.
|
||||
if _DAEMON_MARKER in joined:
|
||||
found.append(RunningProcess(pid=pid, role="daemon", cmdline=joined))
|
||||
elif _MONITOR_MARKER in joined:
|
||||
found.append(RunningProcess(pid=pid, role="monitor", cmdline=joined))
|
||||
return found
|
||||
|
||||
|
||||
def terminate_processes(
|
||||
processes: list[RunningProcess],
|
||||
*,
|
||||
grace_seconds: float = DEFAULT_GRACE_SECONDS,
|
||||
force_kill: bool = False,
|
||||
) -> TerminationReport:
|
||||
"""Send SIGINT to each process; optionally escalate to SIGKILL.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
processes
|
||||
The processes to stop. Pass an empty list for a no-op.
|
||||
grace_seconds
|
||||
Maximum time to wait for processes to exit after SIGINT.
|
||||
force_kill
|
||||
When ``True``, any process still alive after ``grace_seconds``
|
||||
is sent SIGKILL. When ``False``, survivors are reported back
|
||||
to the caller untouched.
|
||||
|
||||
Notes
|
||||
-----
|
||||
SIGTERM is *not* used as an intermediate escalation step. The
|
||||
interfacer's signal handler treats SIGINT and SIGTERM identically
|
||||
(both call :meth:`Interfacer.stop`), so SIGTERM accomplishes
|
||||
nothing that SIGINT did not already attempt. The only escalation
|
||||
that actually forces a stuck daemon to exit is SIGKILL, which
|
||||
bypasses the handler entirely.
|
||||
"""
|
||||
report = TerminationReport()
|
||||
if not processes:
|
||||
return report
|
||||
|
||||
for rp in processes:
|
||||
with contextlib.suppress(ProcessLookupError, PermissionError):
|
||||
os.kill(rp.pid, signal.SIGINT)
|
||||
logger.info("sent SIGINT to pid %d (%s)", rp.pid, rp.role)
|
||||
|
||||
survivors = _wait_for_exit(processes, grace_seconds)
|
||||
stopped = [p for p in processes if p not in survivors]
|
||||
report.stopped.extend(stopped)
|
||||
|
||||
if not survivors:
|
||||
return report
|
||||
|
||||
if not force_kill:
|
||||
report.survivors.extend(survivors)
|
||||
return report
|
||||
|
||||
for rp in survivors:
|
||||
with contextlib.suppress(ProcessLookupError, PermissionError):
|
||||
os.kill(rp.pid, signal.SIGKILL)
|
||||
logger.warning("escalated to SIGKILL for pid %d (%s)", rp.pid, rp.role)
|
||||
|
||||
# SIGKILL is delivered synchronously enough that a short final
|
||||
# poll is sufficient — any remaining "survivor" at this point is
|
||||
# a permission error or a kernel-side hang, not graceful shutdown.
|
||||
final_survivors = _wait_for_exit(survivors, grace_seconds=2.0)
|
||||
killed = [p for p in survivors if p not in final_survivors]
|
||||
report.force_killed.extend(killed)
|
||||
report.survivors.extend(final_survivors)
|
||||
return report
|
||||
|
||||
|
||||
def _wait_for_exit(
|
||||
processes: list[RunningProcess],
|
||||
grace_seconds: float,
|
||||
) -> list[RunningProcess]:
|
||||
"""Poll until every process exits or the deadline passes."""
|
||||
deadline = time.monotonic() + grace_seconds
|
||||
while time.monotonic() < deadline:
|
||||
survivors = [p for p in processes if _is_alive(p.pid)]
|
||||
if not survivors:
|
||||
return []
|
||||
time.sleep(_POLL_INTERVAL_SECONDS)
|
||||
return [p for p in processes if _is_alive(p.pid)]
|
||||
|
||||
|
||||
def _is_alive(pid: int) -> bool:
|
||||
"""Return ``True`` if ``pid`` is still running and not a zombie."""
|
||||
try:
|
||||
proc = psutil.Process(pid)
|
||||
except psutil.NoSuchProcess:
|
||||
return False
|
||||
try:
|
||||
return proc.status() != psutil.STATUS_ZOMBIE
|
||||
except psutil.NoSuchProcess:
|
||||
return False
|
||||
|
||||
|
||||
def wipe_state(
|
||||
settings: Settings,
|
||||
*,
|
||||
keep_failed: bool = False,
|
||||
dry_run: bool = False,
|
||||
) -> WipeReport:
|
||||
"""Remove data-directory paths produced by the daemon and monitor.
|
||||
|
||||
Removes the *contents* of ``in/``, ``out/``, and (unless
|
||||
``keep_failed`` is set) ``failed/``, plus the daemon lock file and
|
||||
any leftover ``.ingest_<uuid>/`` staging directories. The
|
||||
container directories themselves are preserved so the daemon does
|
||||
not need to recreate them on next startup.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
settings
|
||||
Resolved :class:`~neuropose.config.Settings`. Determines all
|
||||
target paths via the ``input_dir`` / ``output_dir`` /
|
||||
``failed_dir`` properties.
|
||||
keep_failed
|
||||
Preserve ``$data_dir/failed/`` for forensic review of past
|
||||
crashes. Default removes it along with the rest of the
|
||||
pipeline state.
|
||||
dry_run
|
||||
Compute the report without actually deleting anything. Useful
|
||||
for previewing the blast radius before confirming a reset.
|
||||
"""
|
||||
report = WipeReport()
|
||||
|
||||
targets: list[Path] = []
|
||||
if settings.input_dir.exists():
|
||||
targets.extend(settings.input_dir.iterdir())
|
||||
if settings.output_dir.exists():
|
||||
targets.extend(settings.output_dir.iterdir())
|
||||
if not keep_failed and settings.failed_dir.exists():
|
||||
targets.extend(settings.failed_dir.iterdir())
|
||||
|
||||
lock_path = settings.data_dir / LOCK_FILENAME
|
||||
if lock_path.exists():
|
||||
targets.append(lock_path)
|
||||
|
||||
if settings.data_dir.exists():
|
||||
targets.extend(settings.data_dir.glob(".ingest_*"))
|
||||
|
||||
for target in targets:
|
||||
size = _path_size(target)
|
||||
if not dry_run:
|
||||
_remove(target)
|
||||
report.removed_paths.append(target)
|
||||
report.bytes_freed += size
|
||||
|
||||
return report
|
||||
|
||||
|
||||
def _path_size(path: Path) -> int:
|
||||
"""Return the cumulative size of ``path``, recursing into directories."""
|
||||
if path.is_symlink() or path.is_file():
|
||||
try:
|
||||
return path.stat().st_size
|
||||
except OSError:
|
||||
return 0
|
||||
total = 0
|
||||
for sub in path.rglob("*"):
|
||||
try:
|
||||
if sub.is_file() and not sub.is_symlink():
|
||||
total += sub.stat().st_size
|
||||
except OSError:
|
||||
continue
|
||||
return total
|
||||
|
||||
|
||||
def _remove(path: Path) -> None:
|
||||
"""Remove ``path`` whether file, symlink, or directory."""
|
||||
if path.is_dir() and not path.is_symlink():
|
||||
shutil.rmtree(path)
|
||||
else:
|
||||
path.unlink()
|
||||
|
||||
|
||||
def reset_pipeline(
|
||||
settings: Settings,
|
||||
*,
|
||||
grace_seconds: float = DEFAULT_GRACE_SECONDS,
|
||||
force_kill: bool = False,
|
||||
keep_failed: bool = False,
|
||||
dry_run: bool = False,
|
||||
) -> ResetReport:
|
||||
"""Stop daemon + monitor, then wipe pipeline state.
|
||||
|
||||
Composes :func:`find_neuropose_processes`,
|
||||
:func:`terminate_processes`, and :func:`wipe_state` into a single
|
||||
operation. The wipe phase is *skipped* if any process survives
|
||||
termination — wiping ``$data_dir`` out from under an active
|
||||
daemon would corrupt its in-flight writes. The returned
|
||||
:class:`ResetReport` flags this case via
|
||||
``wipe_skipped_due_to_survivors``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
settings
|
||||
Resolved :class:`~neuropose.config.Settings`.
|
||||
grace_seconds
|
||||
Maximum time to wait for SIGINT to take effect.
|
||||
force_kill
|
||||
Escalate to SIGKILL on any process still alive after
|
||||
``grace_seconds``. Necessary when the daemon is mid-inference
|
||||
on a long video and you do not want to wait for the current
|
||||
video to finish.
|
||||
keep_failed
|
||||
Preserve ``$data_dir/failed/`` during the wipe.
|
||||
dry_run
|
||||
Discover and report without killing anything or removing any
|
||||
paths. Termination phase is skipped entirely.
|
||||
"""
|
||||
discovered = find_neuropose_processes()
|
||||
|
||||
if dry_run:
|
||||
wipe = wipe_state(settings, keep_failed=keep_failed, dry_run=True)
|
||||
return ResetReport(
|
||||
discovered=discovered,
|
||||
termination=TerminationReport(),
|
||||
wipe=wipe,
|
||||
dry_run=True,
|
||||
)
|
||||
|
||||
termination = terminate_processes(
|
||||
discovered,
|
||||
grace_seconds=grace_seconds,
|
||||
force_kill=force_kill,
|
||||
)
|
||||
|
||||
if termination.survivors:
|
||||
return ResetReport(
|
||||
discovered=discovered,
|
||||
termination=termination,
|
||||
wipe=WipeReport(),
|
||||
dry_run=False,
|
||||
wipe_skipped_due_to_survivors=True,
|
||||
)
|
||||
|
||||
wipe = wipe_state(settings, keep_failed=keep_failed, dry_run=False)
|
||||
return ResetReport(
|
||||
discovered=discovered,
|
||||
termination=termination,
|
||||
wipe=wipe,
|
||||
dry_run=False,
|
||||
)
|
||||
|
|
@ -1,205 +0,0 @@
|
|||
"""Example-config sanity integration tests.
|
||||
|
||||
Every YAML config under ``examples/analysis/`` is exercised here in
|
||||
two ways:
|
||||
|
||||
1. ``test_<name>_parses`` — :func:`load_config` accepts the YAML,
|
||||
i.e. the file matches the current :class:`AnalysisConfig` schema
|
||||
including cross-field invariants. Catches silent drift between the
|
||||
example configs and the schema they claim to exercise.
|
||||
2. ``test_<name>_runs`` — the example's pipeline runs end-to-end
|
||||
against synthetic predictions. Paths in the YAML are overwritten
|
||||
with test fixtures before :func:`run_analysis` is invoked;
|
||||
everything else (stages, thresholds, triplets) is used verbatim.
|
||||
|
||||
These tests are deliberately not marked ``slow`` — they use synthetic
|
||||
fixtures and do not touch the MeTRAbs SavedModel, so they run in the
|
||||
default unit-test suite.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from neuropose.analyzer.pipeline import (
|
||||
AnalysisConfig,
|
||||
AnalysisReport,
|
||||
DtwResults,
|
||||
load_config,
|
||||
run_analysis,
|
||||
)
|
||||
from neuropose.analyzer.segment import JOINT_INDEX
|
||||
from neuropose.io import VideoPredictions, save_video_predictions
|
||||
|
||||
EXAMPLES_DIR = Path(__file__).resolve().parents[2] / "examples" / "analysis"
|
||||
|
||||
NUM_JOINTS = 43
|
||||
|
||||
|
||||
def _sinusoid(num_cycles: int, frames_per_cycle: int, amplitude: float = 100.0) -> np.ndarray:
|
||||
total = num_cycles * frames_per_cycle
|
||||
t = np.linspace(0.0, num_cycles * 2.0 * math.pi, total, endpoint=False)
|
||||
return (np.sin(t) * amplitude + amplitude).astype(float)
|
||||
|
||||
|
||||
def _write_trial(
|
||||
path: Path,
|
||||
*,
|
||||
num_cycles: int = 4,
|
||||
frames_per_cycle: int = 30,
|
||||
seed: int = 0,
|
||||
) -> Path:
|
||||
"""Write a synthetic VideoPredictions with every joint oscillating.
|
||||
|
||||
Joint ``0`` gets a reproducible RNG-driven trace so Procrustes has
|
||||
something non-degenerate to align. All other joints get their own
|
||||
phase-shifted sinusoid so joint-angle triplets and per-joint DTW
|
||||
have signal to act on.
|
||||
"""
|
||||
rng = np.random.default_rng(seed)
|
||||
base = _sinusoid(num_cycles, frames_per_cycle)
|
||||
total = base.shape[0]
|
||||
frames: dict[str, dict] = {}
|
||||
for frame_idx in range(total):
|
||||
poses = [[[0.0, 0.0, 0.0] for _ in range(NUM_JOINTS)]]
|
||||
for j in range(NUM_JOINTS):
|
||||
# Unique per-joint position so no triplet is degenerate.
|
||||
phase = rng.uniform(0.0, 2.0 * math.pi)
|
||||
amplitude = 30.0 + 10.0 * (j % 5)
|
||||
offset = float(j) * 15.0
|
||||
poses[0][j][0] = offset + amplitude * math.cos(
|
||||
2.0 * math.pi * frame_idx / frames_per_cycle + phase
|
||||
)
|
||||
poses[0][j][1] = offset * 0.5 + base[frame_idx] + 5.0 * j
|
||||
poses[0][j][2] = 3.0 * j
|
||||
frames[f"frame_{frame_idx:06d}"] = {
|
||||
"boxes": [[0.0, 0.0, 1.0, 1.0, 0.9]],
|
||||
"poses3d": poses,
|
||||
"poses2d": [[[0.0, 0.0]] * NUM_JOINTS],
|
||||
}
|
||||
preds = VideoPredictions.model_validate(
|
||||
{
|
||||
"metadata": {
|
||||
"frame_count": total,
|
||||
"fps": float(frames_per_cycle),
|
||||
"width": 640,
|
||||
"height": 480,
|
||||
},
|
||||
"frames": frames,
|
||||
}
|
||||
)
|
||||
save_video_predictions(path, preds)
|
||||
return path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def example_fixtures(tmp_path: Path) -> tuple[Path, Path, Path]:
|
||||
"""Return (primary_path, reference_path, report_path) under ``tmp_path``."""
|
||||
primary = _write_trial(tmp_path / "primary.json", seed=1)
|
||||
reference = _write_trial(tmp_path / "reference.json", seed=2)
|
||||
report = tmp_path / "report.json"
|
||||
return primary, reference, report
|
||||
|
||||
|
||||
def _rewrite_paths(
|
||||
example_path: Path, primary: Path, reference: Path, report: Path
|
||||
) -> AnalysisConfig:
|
||||
"""Load an example YAML and rewrite inputs/output paths to fixtures.
|
||||
|
||||
Tests run against synthetic predictions in ``tmp_path``; the
|
||||
example YAML's hardcoded ``data/*.json`` paths would never resolve
|
||||
otherwise.
|
||||
"""
|
||||
config = load_config(example_path)
|
||||
update: dict = {
|
||||
"inputs": config.inputs.model_copy(update={"primary": primary, "reference": reference}),
|
||||
"output": config.output.model_copy(update={"report": report}),
|
||||
}
|
||||
return config.model_copy(update=update)
|
||||
|
||||
|
||||
class TestMinimalExample:
|
||||
def test_minimal_parses(self) -> None:
|
||||
config = load_config(EXAMPLES_DIR / "minimal.yaml")
|
||||
assert isinstance(config, AnalysisConfig)
|
||||
assert config.analysis.kind == "dtw"
|
||||
assert config.segmentation is None
|
||||
|
||||
def test_minimal_runs(self, example_fixtures: tuple[Path, Path, Path]) -> None:
|
||||
primary, reference, report = example_fixtures
|
||||
config = _rewrite_paths(EXAMPLES_DIR / "minimal.yaml", primary, reference, report)
|
||||
result = run_analysis(config)
|
||||
assert isinstance(result, AnalysisReport)
|
||||
assert isinstance(result.results, DtwResults)
|
||||
# Unsegmented → one distance.
|
||||
assert result.results.segment_labels == ["full_trial"]
|
||||
assert len(result.results.distances) == 1
|
||||
|
||||
|
||||
class TestPaperCExample:
|
||||
def test_paper_c_parses(self) -> None:
|
||||
config = load_config(EXAMPLES_DIR / "paper_c_headline.yaml")
|
||||
assert config.analysis.kind == "dtw"
|
||||
assert config.segmentation is not None
|
||||
assert config.segmentation.kind == "gait_cycles_bilateral"
|
||||
# Joint triplets must be in range for berkeley_mhad_43.
|
||||
assert config.analysis.kind == "dtw"
|
||||
angle_triplets = config.analysis.angle_triplets # type: ignore[union-attr]
|
||||
assert angle_triplets is not None
|
||||
for a, b, c in angle_triplets:
|
||||
for idx in (a, b, c):
|
||||
assert 0 <= idx < NUM_JOINTS
|
||||
|
||||
def test_paper_c_runs(self, example_fixtures: tuple[Path, Path, Path]) -> None:
|
||||
primary, reference, report = example_fixtures
|
||||
config = _rewrite_paths(EXAMPLES_DIR / "paper_c_headline.yaml", primary, reference, report)
|
||||
result = run_analysis(config)
|
||||
assert isinstance(result.results, DtwResults)
|
||||
# Bilateral segmentation → distances labelled per side.
|
||||
assert any(lbl.startswith("left_heel_strikes") for lbl in result.results.segment_labels)
|
||||
assert any(lbl.startswith("right_heel_strikes") for lbl in result.results.segment_labels)
|
||||
|
||||
def test_paper_c_uses_documented_knee_triplets(self) -> None:
|
||||
"""The Paper C config must target knee-flexion joint triplets.
|
||||
|
||||
Safety net: if someone edits the YAML and breaks the joint
|
||||
references, this test catches it before the example silently
|
||||
starts computing the wrong angles.
|
||||
"""
|
||||
config = load_config(EXAMPLES_DIR / "paper_c_headline.yaml")
|
||||
assert config.analysis.kind == "dtw"
|
||||
triplets = config.analysis.angle_triplets # type: ignore[union-attr]
|
||||
assert triplets is not None
|
||||
# Left knee flex = hip → knee → ankle.
|
||||
assert (
|
||||
JOINT_INDEX["lhipb"],
|
||||
JOINT_INDEX["lkne"],
|
||||
JOINT_INDEX["lank"],
|
||||
) in triplets or (
|
||||
JOINT_INDEX["lhipf"],
|
||||
JOINT_INDEX["lkne"],
|
||||
JOINT_INDEX["lank"],
|
||||
) in triplets
|
||||
|
||||
|
||||
class TestPerJointDebugExample:
|
||||
def test_per_joint_debug_parses(self) -> None:
|
||||
config = load_config(EXAMPLES_DIR / "per_joint_debug.yaml")
|
||||
assert config.analysis.kind == "dtw"
|
||||
assert config.analysis.method == "dtw_per_joint" # type: ignore[union-attr]
|
||||
|
||||
def test_per_joint_debug_runs(self, example_fixtures: tuple[Path, Path, Path]) -> None:
|
||||
primary, reference, report = example_fixtures
|
||||
config = _rewrite_paths(EXAMPLES_DIR / "per_joint_debug.yaml", primary, reference, report)
|
||||
result = run_analysis(config)
|
||||
assert isinstance(result.results, DtwResults)
|
||||
# dtw_per_joint → per_joint_distances populated.
|
||||
assert result.results.per_joint_distances is not None
|
||||
# Inner length must match num_joints for the coords
|
||||
# representation.
|
||||
for per_seg in result.results.per_joint_distances:
|
||||
assert len(per_seg) == NUM_JOINTS
|
||||
|
|
@ -81,29 +81,26 @@ class TestMetrabsLoader:
|
|||
"""Exercises the loader's download → verify → extract → load path."""
|
||||
|
||||
def test_download_and_load(self, shared_model_cache_dir: Path) -> None:
|
||||
loaded = load_metrabs_model(cache_dir=shared_model_cache_dir)
|
||||
assert loaded.model is not None
|
||||
assert loaded.sha256
|
||||
assert loaded.filename
|
||||
model = load_metrabs_model(cache_dir=shared_model_cache_dir)
|
||||
assert model is not None
|
||||
for attr in ("detect_poses", "per_skeleton_joint_names", "per_skeleton_joint_edges"):
|
||||
assert hasattr(loaded.model, attr), f"loaded model is missing {attr}"
|
||||
assert hasattr(model, attr), f"loaded model is missing {attr}"
|
||||
|
||||
def test_second_call_uses_cache(self, shared_model_cache_dir: Path) -> None:
|
||||
"""Idempotent: second call should return the cached model cheaply."""
|
||||
loaded_a = load_metrabs_model(cache_dir=shared_model_cache_dir)
|
||||
loaded_b = load_metrabs_model(cache_dir=shared_model_cache_dir)
|
||||
model_a = load_metrabs_model(cache_dir=shared_model_cache_dir)
|
||||
model_b = load_metrabs_model(cache_dir=shared_model_cache_dir)
|
||||
# tf.saved_model.load returns a new Python object each call, so
|
||||
# identity comparison doesn't work — but both should still
|
||||
# expose the MeTRAbs interface, and the SHA should match.
|
||||
assert hasattr(loaded_a.model, "detect_poses")
|
||||
assert hasattr(loaded_b.model, "detect_poses")
|
||||
assert loaded_a.sha256 == loaded_b.sha256
|
||||
# expose the MeTRAbs interface.
|
||||
assert hasattr(model_a, "detect_poses")
|
||||
assert hasattr(model_b, "detect_poses")
|
||||
|
||||
def test_berkeley_mhad_skeleton_is_present(self, shared_model_cache_dir: Path) -> None:
|
||||
"""The estimator pins skeleton='berkeley_mhad_43'; verify it exists."""
|
||||
loaded = load_metrabs_model(cache_dir=shared_model_cache_dir)
|
||||
joint_names = loaded.model.per_skeleton_joint_names["berkeley_mhad_43"]
|
||||
joint_edges = loaded.model.per_skeleton_joint_edges["berkeley_mhad_43"]
|
||||
model = load_metrabs_model(cache_dir=shared_model_cache_dir)
|
||||
joint_names = model.per_skeleton_joint_names["berkeley_mhad_43"]
|
||||
joint_edges = model.per_skeleton_joint_edges["berkeley_mhad_43"]
|
||||
# MeTRAbs exposes these as tf.Tensor objects; just verify we
|
||||
# can pull a shape out.
|
||||
assert joint_names.shape[0] == 43
|
||||
|
|
|
|||
|
|
@ -50,8 +50,8 @@ def test_joint_names_match_pinned_model(metrabs_model_cache_dir: Path) -> None:
|
|||
commit that bumps the model pin in :mod:`neuropose._model`.
|
||||
2. Cross-check any CLI or docs that embed hardcoded joint names.
|
||||
"""
|
||||
loaded = load_metrabs_model(cache_dir=metrabs_model_cache_dir)
|
||||
tensor = loaded.model.per_skeleton_joint_names["berkeley_mhad_43"]
|
||||
model = load_metrabs_model(cache_dir=metrabs_model_cache_dir)
|
||||
tensor = model.per_skeleton_joint_names["berkeley_mhad_43"]
|
||||
model_names = tuple(tensor.numpy().astype(str).tolist())
|
||||
assert model_names == JOINT_NAMES, (
|
||||
"JOINT_NAMES drift detected — the hardcoded tuple in "
|
||||
|
|
|
|||
|
|
@ -131,250 +131,3 @@ class TestDtwRelation:
|
|||
b = np.zeros((3, 2, 3))
|
||||
with pytest.raises(ValueError, match="joint count"):
|
||||
dtw_relation(a, b, joint_i=0, joint_j=1)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# representation="angles"
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _rotation_matrix_z(angle_rad: float) -> np.ndarray:
|
||||
c, s = np.cos(angle_rad), np.sin(angle_rad)
|
||||
return np.array(
|
||||
[
|
||||
[c, -s, 0.0],
|
||||
[s, c, 0.0],
|
||||
[0.0, 0.0, 1.0],
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _three_joint_arm(num_frames: int = 6) -> np.ndarray:
|
||||
"""A three-joint arm opening from a right angle to straight.
|
||||
|
||||
Joints laid out as [shoulder, elbow, wrist], forming an angle at
|
||||
the elbow that linearly opens from pi/2 to pi across ``num_frames``.
|
||||
"""
|
||||
sequence = np.zeros((num_frames, 3, 3))
|
||||
angles = np.linspace(np.pi / 2, np.pi, num_frames)
|
||||
for i, theta in enumerate(angles):
|
||||
sequence[i, 0] = [-1.0, 0.0, 0.0] # shoulder
|
||||
sequence[i, 1] = [0.0, 0.0, 0.0] # elbow
|
||||
sequence[i, 2] = [np.cos(theta - np.pi), np.sin(theta - np.pi), 0.0] # wrist
|
||||
return sequence
|
||||
|
||||
|
||||
class TestDtwAllAngles:
|
||||
def test_angles_identical_sequences_distance_zero(self) -> None:
|
||||
seq = _three_joint_arm()
|
||||
result = dtw_all(
|
||||
seq,
|
||||
seq,
|
||||
representation="angles",
|
||||
angle_triplets=[(0, 1, 2)],
|
||||
)
|
||||
assert result.distance == pytest.approx(0.0, abs=1e-9)
|
||||
|
||||
def test_angles_invariant_to_global_rotation(self) -> None:
|
||||
"""Angle-space DTW must not change under a global rotation."""
|
||||
seq = _three_joint_arm()
|
||||
rotated = seq @ _rotation_matrix_z(np.deg2rad(40.0)).T
|
||||
baseline = dtw_all(seq, seq, representation="angles", angle_triplets=[(0, 1, 2)])
|
||||
under_rotation = dtw_all(
|
||||
seq,
|
||||
rotated,
|
||||
representation="angles",
|
||||
angle_triplets=[(0, 1, 2)],
|
||||
)
|
||||
assert baseline.distance == pytest.approx(under_rotation.distance, abs=1e-6)
|
||||
|
||||
def test_angles_translation_invariant(self) -> None:
|
||||
seq = _three_joint_arm()
|
||||
translated = seq + np.array([10.0, -5.0, 2.0])
|
||||
result = dtw_all(
|
||||
seq,
|
||||
translated,
|
||||
representation="angles",
|
||||
angle_triplets=[(0, 1, 2)],
|
||||
)
|
||||
assert result.distance == pytest.approx(0.0, abs=1e-9)
|
||||
|
||||
def test_angles_detects_different_motion(self) -> None:
|
||||
# A sequence whose angle is constant vs. one that opens.
|
||||
constant = np.zeros((6, 3, 3))
|
||||
constant[:, 0] = [-1.0, 0.0, 0.0]
|
||||
constant[:, 1] = [0.0, 0.0, 0.0]
|
||||
constant[:, 2] = [0.0, 1.0, 0.0] # right angle throughout
|
||||
opening = _three_joint_arm()
|
||||
result = dtw_all(
|
||||
constant,
|
||||
opening,
|
||||
representation="angles",
|
||||
angle_triplets=[(0, 1, 2)],
|
||||
)
|
||||
assert result.distance > 0.0
|
||||
|
||||
def test_angles_without_triplets_rejected(self) -> None:
|
||||
seq = _three_joint_arm()
|
||||
with pytest.raises(ValueError, match="angle_triplets"):
|
||||
dtw_all(seq, seq, representation="angles")
|
||||
|
||||
|
||||
class TestDtwPerJointAngles:
|
||||
def test_returns_one_result_per_triplet(self) -> None:
|
||||
seq = _three_joint_arm()
|
||||
triplets = [(0, 1, 2), (0, 1, 2)] # duplicate triplet on purpose
|
||||
results = dtw_per_joint(
|
||||
seq,
|
||||
seq,
|
||||
representation="angles",
|
||||
angle_triplets=triplets,
|
||||
)
|
||||
assert len(results) == 2
|
||||
for result in results:
|
||||
assert result.distance == pytest.approx(0.0, abs=1e-9)
|
||||
|
||||
def test_per_triplet_distinct_paths(self) -> None:
|
||||
# Two triplets covering different angles; with different motion
|
||||
# per triplet, the per-unit results should differ.
|
||||
seq_a = np.zeros((5, 4, 3))
|
||||
seq_b = np.zeros((5, 4, 3))
|
||||
# joint 0: pivot, joint 1/2/3: arm endpoints
|
||||
for i in range(5):
|
||||
seq_a[i, 0] = [0.0, 0.0, 0.0]
|
||||
seq_a[i, 1] = [1.0, 0.0, 0.0]
|
||||
seq_a[i, 2] = [0.0, 1.0, 0.0]
|
||||
seq_a[i, 3] = [0.0, 0.0, 1.0]
|
||||
seq_b[i, 0] = [0.0, 0.0, 0.0]
|
||||
seq_b[i, 1] = [1.0, 0.0, 0.0]
|
||||
seq_b[i, 2] = [np.cos(i * 0.3), np.sin(i * 0.3), 0.0] # rotating
|
||||
seq_b[i, 3] = [0.0, 0.0, 1.0]
|
||||
results = dtw_per_joint(
|
||||
seq_a,
|
||||
seq_b,
|
||||
representation="angles",
|
||||
angle_triplets=[(1, 0, 2), (1, 0, 3)],
|
||||
)
|
||||
assert len(results) == 2
|
||||
# First triplet tracks the rotation, second is stationary.
|
||||
assert results[0].distance > 0.0
|
||||
assert results[1].distance == pytest.approx(0.0, abs=1e-9)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# nan_policy
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _collinear_sequence(num_frames: int = 4) -> np.ndarray:
|
||||
"""Three collinear joints — the angle at the middle joint is degenerate."""
|
||||
seq = np.zeros((num_frames, 3, 3))
|
||||
seq[:, 0] = [-1.0, 0.0, 0.0]
|
||||
# Middle joint at (0,0,0); but because the outer joints are collinear
|
||||
# through the origin, we need one joint overlapping with the middle
|
||||
# to force a zero-length vector. Place joint 2 AT joint 1 to trigger
|
||||
# the degenerate case in extract_joint_angles.
|
||||
seq[:, 1] = [0.0, 0.0, 0.0]
|
||||
seq[:, 2] = [0.0, 0.0, 0.0]
|
||||
return seq
|
||||
|
||||
|
||||
class TestNanPolicy:
|
||||
def test_propagate_surfaces_error(self) -> None:
|
||||
# Degenerate triplet produces NaN angles for every frame.
|
||||
# With nan_policy="propagate" the NaN reaches fastdtw, which
|
||||
# validates via numpy.asarray_chkfinite and raises ValueError —
|
||||
# the intended behaviour ("make the problem visible").
|
||||
seq = _collinear_sequence(num_frames=4)
|
||||
other = _three_joint_arm(num_frames=4)
|
||||
with pytest.raises(ValueError, match="infs or NaNs"):
|
||||
dtw_all(
|
||||
seq,
|
||||
other,
|
||||
representation="angles",
|
||||
angle_triplets=[(0, 1, 2)],
|
||||
nan_policy="propagate",
|
||||
)
|
||||
|
||||
def test_interpolate_fills_isolated_nan(self) -> None:
|
||||
# One bad frame in a 5-frame sequence — the other four are
|
||||
# finite anchors to interpolate between.
|
||||
good = _three_joint_arm(num_frames=5)
|
||||
# Inject a degenerate middle frame.
|
||||
good[2, 2] = good[2, 1] # force zero-length vector → NaN angle
|
||||
# Reference is the same arm without injection.
|
||||
reference = _three_joint_arm(num_frames=5)
|
||||
result = dtw_all(
|
||||
good,
|
||||
reference,
|
||||
representation="angles",
|
||||
angle_triplets=[(0, 1, 2)],
|
||||
nan_policy="interpolate",
|
||||
)
|
||||
assert not np.isnan(result.distance)
|
||||
|
||||
def test_interpolate_all_nan_column_rejected(self) -> None:
|
||||
seq = _collinear_sequence(num_frames=5)
|
||||
other = _three_joint_arm(num_frames=5)
|
||||
with pytest.raises(ValueError, match="all values are NaN"):
|
||||
dtw_all(
|
||||
seq,
|
||||
other,
|
||||
representation="angles",
|
||||
angle_triplets=[(0, 1, 2)],
|
||||
nan_policy="interpolate",
|
||||
)
|
||||
|
||||
def test_drop_removes_nan_frames(self) -> None:
|
||||
good = _three_joint_arm(num_frames=6)
|
||||
good[2, 2] = good[2, 1] # inject NaN at frame 2
|
||||
good[4, 2] = good[4, 1] # inject NaN at frame 4
|
||||
reference = _three_joint_arm(num_frames=6)
|
||||
result = dtw_all(
|
||||
good,
|
||||
reference,
|
||||
representation="angles",
|
||||
angle_triplets=[(0, 1, 2)],
|
||||
nan_policy="drop",
|
||||
)
|
||||
# The 4 remaining finite frames should align cleanly with
|
||||
# their counterparts in the reference.
|
||||
assert not np.isnan(result.distance)
|
||||
|
||||
def test_drop_empties_sequence_rejected(self) -> None:
|
||||
seq = _collinear_sequence(num_frames=5)
|
||||
other = _three_joint_arm(num_frames=5)
|
||||
with pytest.raises(ValueError, match="every frame"):
|
||||
dtw_all(
|
||||
seq,
|
||||
other,
|
||||
representation="angles",
|
||||
angle_triplets=[(0, 1, 2)],
|
||||
nan_policy="drop",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# align + representation composition
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAlignWithAngles:
|
||||
def test_procrustes_before_angles_is_no_op_on_invariant_representation(self) -> None:
|
||||
"""Procrustes on angle-space DTW should be redundant but safe."""
|
||||
seq = _three_joint_arm()
|
||||
rotated = seq @ _rotation_matrix_z(np.deg2rad(20.0)).T
|
||||
with_align = dtw_all(
|
||||
seq,
|
||||
rotated,
|
||||
align="procrustes_per_sequence",
|
||||
representation="angles",
|
||||
angle_triplets=[(0, 1, 2)],
|
||||
)
|
||||
without_align = dtw_all(
|
||||
seq,
|
||||
rotated,
|
||||
representation="angles",
|
||||
angle_triplets=[(0, 1, 2)],
|
||||
)
|
||||
assert with_align.distance == pytest.approx(without_align.distance, abs=1e-6)
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ import numpy as np
|
|||
import pytest
|
||||
|
||||
from neuropose.analyzer.features import (
|
||||
AlignmentDiagnostics,
|
||||
FeatureStatistics,
|
||||
extract_feature_statistics,
|
||||
extract_joint_angles,
|
||||
|
|
@ -16,7 +15,6 @@ from neuropose.analyzer.features import (
|
|||
normalize_pose_sequence,
|
||||
pad_sequences,
|
||||
predictions_to_numpy,
|
||||
procrustes_align,
|
||||
)
|
||||
from neuropose.io import VideoPredictions
|
||||
|
||||
|
|
@ -299,177 +297,3 @@ class TestFindPeaks:
|
|||
def test_rejects_2d_input(self) -> None:
|
||||
with pytest.raises(ValueError, match="1D"):
|
||||
find_peaks(np.zeros((5, 5)))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# procrustes_align
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _rotation_matrix_z(angle_rad: float) -> np.ndarray:
|
||||
"""Rotation matrix about the Z axis."""
|
||||
c, s = np.cos(angle_rad), np.sin(angle_rad)
|
||||
return np.array(
|
||||
[
|
||||
[c, -s, 0.0],
|
||||
[s, c, 0.0],
|
||||
[0.0, 0.0, 1.0],
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _skeleton(num_joints: int = 8, seed: int = 0) -> np.ndarray:
|
||||
"""A deterministic, non-degenerate single-frame skeleton."""
|
||||
rng = np.random.default_rng(seed)
|
||||
return rng.standard_normal((num_joints, 3))
|
||||
|
||||
|
||||
class TestProcrustesAlignPerSequence:
|
||||
def test_identical_sequences_yield_identity_transform(self) -> None:
|
||||
sequence = _skeleton()[np.newaxis, :, :].repeat(3, axis=0) # (3, 8, 3)
|
||||
aligned, target, diag = procrustes_align(sequence, sequence, mode="per_sequence")
|
||||
np.testing.assert_allclose(aligned, sequence, atol=1e-10)
|
||||
np.testing.assert_array_equal(target, sequence)
|
||||
assert diag.mode == "per_sequence"
|
||||
assert diag.rotation_deg == pytest.approx(0.0, abs=1e-6)
|
||||
assert diag.translation == pytest.approx(0.0, abs=1e-9)
|
||||
assert diag.scale == pytest.approx(1.0)
|
||||
|
||||
def test_recovers_known_rotation(self) -> None:
|
||||
# Build a reference sequence; construct the source by rotating it
|
||||
# about Z, then verify alignment returns the reference up to
|
||||
# floating-point error.
|
||||
rotation = _rotation_matrix_z(np.deg2rad(37.0))
|
||||
reference = _skeleton(num_joints=10)[np.newaxis, :, :].repeat(4, axis=0)
|
||||
source = reference @ rotation.T
|
||||
aligned, _, diag = procrustes_align(source, reference, mode="per_sequence")
|
||||
np.testing.assert_allclose(aligned, reference, atol=1e-8)
|
||||
# The recovered rotation's magnitude should be the original 37°.
|
||||
assert diag.rotation_deg == pytest.approx(37.0, abs=1e-4)
|
||||
|
||||
def test_recovers_known_translation(self) -> None:
|
||||
reference = _skeleton()[np.newaxis, :, :].repeat(5, axis=0)
|
||||
translation = np.array([10.0, -4.5, 2.25])
|
||||
source = reference + translation
|
||||
aligned, _, diag = procrustes_align(source, reference, mode="per_sequence")
|
||||
np.testing.assert_allclose(aligned, reference, atol=1e-9)
|
||||
# rotation_deg may be numerically tiny but not exactly 0.
|
||||
assert diag.rotation_deg == pytest.approx(0.0, abs=1e-4)
|
||||
assert diag.translation == pytest.approx(np.linalg.norm(translation), rel=1e-6)
|
||||
|
||||
def test_recovers_combined_rotation_and_translation(self) -> None:
|
||||
rotation = _rotation_matrix_z(np.deg2rad(-12.0))
|
||||
translation = np.array([1.0, 2.0, 3.0])
|
||||
reference = _skeleton(num_joints=6)[np.newaxis, :, :].repeat(3, axis=0)
|
||||
source = reference @ rotation.T + translation
|
||||
aligned, _, diag = procrustes_align(source, reference, mode="per_sequence")
|
||||
np.testing.assert_allclose(aligned, reference, atol=1e-8)
|
||||
assert diag.rotation_deg == pytest.approx(12.0, abs=1e-4)
|
||||
assert diag.translation == pytest.approx(np.linalg.norm(translation), rel=1e-4)
|
||||
|
||||
def test_scale_flag_recovers_known_scale(self) -> None:
|
||||
reference = _skeleton()[np.newaxis, :, :].repeat(2, axis=0)
|
||||
source = reference * 0.5
|
||||
aligned, _, diag = procrustes_align(source, reference, mode="per_sequence", scale=True)
|
||||
np.testing.assert_allclose(aligned, reference, atol=1e-8)
|
||||
assert diag.scale == pytest.approx(2.0, rel=1e-6)
|
||||
|
||||
def test_scale_flag_off_leaves_scale_at_one(self) -> None:
|
||||
reference = _skeleton()[np.newaxis, :, :].repeat(2, axis=0)
|
||||
source = reference * 0.5
|
||||
_, _, diag = procrustes_align(source, reference, mode="per_sequence", scale=False)
|
||||
assert diag.scale == pytest.approx(1.0)
|
||||
|
||||
def test_rejects_mismatched_shapes(self) -> None:
|
||||
a = np.zeros((4, 8, 3))
|
||||
b = np.zeros((4, 7, 3))
|
||||
with pytest.raises(ValueError, match="same shape"):
|
||||
procrustes_align(a, b)
|
||||
|
||||
def test_rejects_wrong_trailing_axis(self) -> None:
|
||||
a = np.zeros((4, 8, 2))
|
||||
b = np.zeros((4, 8, 2))
|
||||
with pytest.raises(ValueError, match="joints, 3"):
|
||||
procrustes_align(a, b)
|
||||
|
||||
def test_rejects_unknown_mode(self) -> None:
|
||||
a = np.zeros((2, 4, 3))
|
||||
with pytest.raises(ValueError, match="unknown mode"):
|
||||
procrustes_align(a, a, mode="nope") # type: ignore[arg-type]
|
||||
|
||||
def test_does_not_mutate_inputs(self) -> None:
|
||||
source = _skeleton()[np.newaxis, :, :].repeat(3, axis=0).copy()
|
||||
target = (source @ _rotation_matrix_z(np.deg2rad(10.0)).T).copy()
|
||||
source_before = source.copy()
|
||||
target_before = target.copy()
|
||||
procrustes_align(source, target, mode="per_sequence")
|
||||
np.testing.assert_array_equal(source, source_before)
|
||||
np.testing.assert_array_equal(target, target_before)
|
||||
|
||||
def test_returns_alignment_diagnostics_dataclass(self) -> None:
|
||||
a = _skeleton()[np.newaxis, :, :].repeat(2, axis=0)
|
||||
_, _, diag = procrustes_align(a, a)
|
||||
assert isinstance(diag, AlignmentDiagnostics)
|
||||
|
||||
|
||||
class TestProcrustesAlignPerFrame:
|
||||
def test_per_frame_recovers_varying_rotations(self) -> None:
|
||||
# Each frame is rotated by a different angle; per_frame alignment
|
||||
# should recover each frame independently.
|
||||
num_frames = 4
|
||||
reference_frame = _skeleton(num_joints=6)
|
||||
angles = np.deg2rad([5.0, -10.0, 20.0, 45.0])
|
||||
reference = np.stack([reference_frame for _ in range(num_frames)], axis=0)
|
||||
source = np.stack([reference_frame @ _rotation_matrix_z(a).T for a in angles], axis=0)
|
||||
aligned, _, diag = procrustes_align(source, reference, mode="per_frame")
|
||||
np.testing.assert_allclose(aligned, reference, atol=1e-8)
|
||||
assert diag.mode == "per_frame"
|
||||
# The max rotation across frames should be 45°.
|
||||
assert diag.rotation_deg_max == pytest.approx(45.0, abs=1e-4)
|
||||
# The mean rotation across frames should be 20°.
|
||||
assert diag.rotation_deg == pytest.approx(20.0, abs=1e-4)
|
||||
|
||||
def test_per_frame_with_identical_sequences_yields_zero(self) -> None:
|
||||
sequence = _skeleton(num_joints=5)[np.newaxis, :, :].repeat(3, axis=0)
|
||||
aligned, _, diag = procrustes_align(sequence, sequence, mode="per_frame")
|
||||
np.testing.assert_allclose(aligned, sequence, atol=1e-10)
|
||||
# Per-frame SVD on a symmetric covariance is numerically ambiguous
|
||||
# in axis selection, so the fitted rotation can be a few micro-
|
||||
# degrees off zero; the residual positions are still exact.
|
||||
assert diag.rotation_deg == pytest.approx(0.0, abs=1e-3)
|
||||
assert diag.rotation_deg_max == pytest.approx(0.0, abs=1e-3)
|
||||
assert diag.translation == pytest.approx(0.0, abs=1e-9)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DTW with align= (integration)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDtwAlignIntegration:
|
||||
"""Smoke tests: align= routes through procrustes_align correctly.
|
||||
|
||||
Depth tests of the DTW path itself live in test_analyzer_dtw.
|
||||
"""
|
||||
|
||||
def test_dtw_all_with_alignment_cancels_rigid_offset(self) -> None:
|
||||
pytest.importorskip("fastdtw")
|
||||
from neuropose.analyzer.dtw import dtw_all
|
||||
|
||||
rotation = _rotation_matrix_z(np.deg2rad(30.0))
|
||||
translation = np.array([5.0, -2.0, 1.0])
|
||||
reference = _skeleton(num_joints=6)[np.newaxis, :, :].repeat(4, axis=0)
|
||||
source = reference @ rotation.T + translation
|
||||
baseline = dtw_all(source, reference, align="none")
|
||||
aligned_result = dtw_all(source, reference, align="procrustes_per_sequence")
|
||||
assert baseline.distance > 0.0
|
||||
assert aligned_result.distance == pytest.approx(0.0, abs=1e-6)
|
||||
|
||||
def test_dtw_align_rejects_mismatched_frame_counts(self) -> None:
|
||||
pytest.importorskip("fastdtw")
|
||||
from neuropose.analyzer.dtw import dtw_all
|
||||
|
||||
a = np.zeros((5, 3, 3))
|
||||
b = np.zeros((6, 3, 3))
|
||||
with pytest.raises(ValueError, match="matching frame counts"):
|
||||
dtw_all(a, b, align="procrustes_per_sequence")
|
||||
|
|
|
|||
|
|
@ -1,881 +0,0 @@
|
|||
"""Tests for :mod:`neuropose.analyzer.pipeline`.
|
||||
|
||||
Covers both halves of the pipeline:
|
||||
|
||||
- **Schemas** — :class:`AnalysisConfig` parsing (discriminated unions
|
||||
for segmentation and analysis stages, cross-field invariants),
|
||||
:class:`AnalysisReport` construction + JSON round-trip (including
|
||||
the migration hook on ``schema_version``), and
|
||||
:func:`analysis_config_to_dict` JSON-safety.
|
||||
- **Executor** — :func:`run_analysis` dispatches to each analysis kind
|
||||
(dtw / stats / none) with and without segmentation; provenance is
|
||||
inherited from the primary input with ``analysis_config``
|
||||
populated; :func:`load_config`, :func:`save_report`, and
|
||||
:func:`load_report` round-trip.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import yaml
|
||||
from pydantic import ValidationError
|
||||
|
||||
from neuropose.analyzer.pipeline import (
|
||||
AnalysisConfig,
|
||||
AnalysisReport,
|
||||
DtwAnalysis,
|
||||
DtwResults,
|
||||
ExtractorSegmentation,
|
||||
FeatureSummary,
|
||||
GaitCyclesBilateralSegmentation,
|
||||
GaitCyclesSegmentation,
|
||||
InputsConfig,
|
||||
InputSummary,
|
||||
NoAnalysis,
|
||||
NoResults,
|
||||
OutputConfig,
|
||||
PreprocessingConfig,
|
||||
StatsAnalysis,
|
||||
StatsResults,
|
||||
analysis_config_to_dict,
|
||||
load_config,
|
||||
load_report,
|
||||
run_analysis,
|
||||
save_report,
|
||||
)
|
||||
from neuropose.io import Provenance, VideoPredictions, save_video_predictions
|
||||
from neuropose.migrations import CURRENT_VERSION
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _minimal_dtw_config(tmp_path: Path) -> dict[str, Any]:
|
||||
"""A minimal AnalysisConfig dict with dtw_all + reference."""
|
||||
return {
|
||||
"inputs": {
|
||||
"primary": str(tmp_path / "primary.json"),
|
||||
"reference": str(tmp_path / "reference.json"),
|
||||
},
|
||||
"analysis": {"kind": "dtw", "method": "dtw_all"},
|
||||
"output": {"report": str(tmp_path / "report.json")},
|
||||
}
|
||||
|
||||
|
||||
def _minimal_stats_config(tmp_path: Path) -> dict[str, Any]:
|
||||
return {
|
||||
"inputs": {"primary": str(tmp_path / "primary.json")},
|
||||
"analysis": {
|
||||
"kind": "stats",
|
||||
"extractor": {"kind": "joint_axis", "joint": 32, "axis": 1, "invert": False},
|
||||
},
|
||||
"output": {"report": str(tmp_path / "report.json")},
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# InputsConfig / PreprocessingConfig / OutputConfig
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestInputsConfig:
|
||||
def test_primary_only(self, tmp_path: Path) -> None:
|
||||
cfg = InputsConfig(primary=tmp_path / "a.json")
|
||||
assert cfg.reference is None
|
||||
|
||||
def test_primary_and_reference(self, tmp_path: Path) -> None:
|
||||
cfg = InputsConfig(
|
||||
primary=tmp_path / "a.json",
|
||||
reference=tmp_path / "b.json",
|
||||
)
|
||||
assert cfg.reference == tmp_path / "b.json"
|
||||
|
||||
def test_extra_field_rejected(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(ValidationError, match="Extra inputs"):
|
||||
InputsConfig.model_validate(
|
||||
{
|
||||
"primary": str(tmp_path / "a.json"),
|
||||
"extra": "field",
|
||||
}
|
||||
)
|
||||
|
||||
def test_frozen(self, tmp_path: Path) -> None:
|
||||
cfg = InputsConfig(primary=tmp_path / "a.json")
|
||||
with pytest.raises(ValidationError, match="frozen"):
|
||||
cfg.primary = tmp_path / "b.json" # type: ignore[misc]
|
||||
|
||||
|
||||
class TestPreprocessingConfig:
|
||||
def test_default_person_index_zero(self) -> None:
|
||||
cfg = PreprocessingConfig()
|
||||
assert cfg.person_index == 0
|
||||
|
||||
def test_negative_person_index_rejected(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
PreprocessingConfig(person_index=-1)
|
||||
|
||||
|
||||
class TestOutputConfig:
|
||||
def test_report_path(self, tmp_path: Path) -> None:
|
||||
cfg = OutputConfig(report=tmp_path / "out.json")
|
||||
assert cfg.report == tmp_path / "out.json"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Segmentation stage discriminated union
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSegmentationStage:
|
||||
def test_gait_cycles_parses_from_dict(self, tmp_path: Path) -> None:
|
||||
config_dict = _minimal_dtw_config(tmp_path)
|
||||
config_dict["segmentation"] = {
|
||||
"kind": "gait_cycles",
|
||||
"joint": "lhee",
|
||||
"axis": "y",
|
||||
"min_cycle_seconds": 0.5,
|
||||
}
|
||||
cfg = AnalysisConfig.model_validate(config_dict)
|
||||
assert isinstance(cfg.segmentation, GaitCyclesSegmentation)
|
||||
assert cfg.segmentation.joint == "lhee"
|
||||
assert cfg.segmentation.min_cycle_seconds == 0.5
|
||||
|
||||
def test_bilateral_parses_from_dict(self, tmp_path: Path) -> None:
|
||||
config_dict = _minimal_dtw_config(tmp_path)
|
||||
config_dict["segmentation"] = {"kind": "gait_cycles_bilateral"}
|
||||
cfg = AnalysisConfig.model_validate(config_dict)
|
||||
assert isinstance(cfg.segmentation, GaitCyclesBilateralSegmentation)
|
||||
|
||||
def test_extractor_parses_from_dict(self, tmp_path: Path) -> None:
|
||||
config_dict = _minimal_dtw_config(tmp_path)
|
||||
config_dict["segmentation"] = {
|
||||
"kind": "extractor",
|
||||
"extractor": {
|
||||
"kind": "joint_axis",
|
||||
"joint": 15,
|
||||
"axis": 1,
|
||||
"invert": False,
|
||||
},
|
||||
"label": "wrist_cycles",
|
||||
"min_distance_seconds": 0.5,
|
||||
}
|
||||
cfg = AnalysisConfig.model_validate(config_dict)
|
||||
assert isinstance(cfg.segmentation, ExtractorSegmentation)
|
||||
assert cfg.segmentation.label == "wrist_cycles"
|
||||
|
||||
def test_unknown_kind_rejected(self, tmp_path: Path) -> None:
|
||||
config_dict = _minimal_dtw_config(tmp_path)
|
||||
config_dict["segmentation"] = {"kind": "unknown_method"}
|
||||
with pytest.raises(ValidationError):
|
||||
AnalysisConfig.model_validate(config_dict)
|
||||
|
||||
def test_segmentation_omitted_is_none(self, tmp_path: Path) -> None:
|
||||
cfg = AnalysisConfig.model_validate(_minimal_dtw_config(tmp_path))
|
||||
assert cfg.segmentation is None
|
||||
|
||||
def test_invalid_min_cycle_seconds_rejected(self, tmp_path: Path) -> None:
|
||||
config_dict = _minimal_dtw_config(tmp_path)
|
||||
config_dict["segmentation"] = {
|
||||
"kind": "gait_cycles",
|
||||
"min_cycle_seconds": 0.0, # must be > 0
|
||||
}
|
||||
with pytest.raises(ValidationError):
|
||||
AnalysisConfig.model_validate(config_dict)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Analysis stage discriminated union + cross-field invariants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDtwAnalysisValidation:
|
||||
def test_dtw_relation_requires_joints(self) -> None:
|
||||
with pytest.raises(ValidationError, match="joint_i and joint_j"):
|
||||
DtwAnalysis(kind="dtw", method="dtw_relation")
|
||||
|
||||
def test_dtw_relation_rejects_angles_representation(self) -> None:
|
||||
with pytest.raises(ValidationError, match="only supports representation='coords'"):
|
||||
DtwAnalysis(
|
||||
kind="dtw",
|
||||
method="dtw_relation",
|
||||
joint_i=0,
|
||||
joint_j=1,
|
||||
representation="angles",
|
||||
angle_triplets=[(0, 1, 2)],
|
||||
)
|
||||
|
||||
def test_angles_requires_triplets(self) -> None:
|
||||
with pytest.raises(ValidationError, match="angle_triplets"):
|
||||
DtwAnalysis(
|
||||
kind="dtw",
|
||||
method="dtw_all",
|
||||
representation="angles",
|
||||
)
|
||||
|
||||
def test_angles_with_empty_triplets_rejected(self) -> None:
|
||||
with pytest.raises(ValidationError, match="angle_triplets"):
|
||||
DtwAnalysis(
|
||||
kind="dtw",
|
||||
method="dtw_all",
|
||||
representation="angles",
|
||||
angle_triplets=[],
|
||||
)
|
||||
|
||||
def test_happy_path_dtw_all_coords(self) -> None:
|
||||
analysis = DtwAnalysis(
|
||||
kind="dtw",
|
||||
method="dtw_all",
|
||||
align="procrustes_per_sequence",
|
||||
nan_policy="interpolate",
|
||||
)
|
||||
assert analysis.align == "procrustes_per_sequence"
|
||||
|
||||
def test_happy_path_dtw_all_angles(self) -> None:
|
||||
analysis = DtwAnalysis(
|
||||
kind="dtw",
|
||||
method="dtw_all",
|
||||
representation="angles",
|
||||
angle_triplets=[(0, 1, 2), (3, 4, 5)],
|
||||
)
|
||||
assert len(analysis.angle_triplets or []) == 2
|
||||
|
||||
def test_happy_path_dtw_relation(self) -> None:
|
||||
analysis = DtwAnalysis(
|
||||
kind="dtw",
|
||||
method="dtw_relation",
|
||||
joint_i=15,
|
||||
joint_j=23,
|
||||
)
|
||||
assert analysis.joint_i == 15
|
||||
|
||||
|
||||
class TestAnalysisCrossStage:
|
||||
def test_dtw_without_reference_rejected(self, tmp_path: Path) -> None:
|
||||
config_dict = {
|
||||
"inputs": {"primary": str(tmp_path / "a.json")},
|
||||
"analysis": {"kind": "dtw", "method": "dtw_all"},
|
||||
"output": {"report": str(tmp_path / "out.json")},
|
||||
}
|
||||
with pytest.raises(ValidationError, match=r"inputs\.reference"):
|
||||
AnalysisConfig.model_validate(config_dict)
|
||||
|
||||
def test_stats_with_reference_rejected(self, tmp_path: Path) -> None:
|
||||
config_dict = _minimal_stats_config(tmp_path)
|
||||
config_dict["inputs"]["reference"] = str(tmp_path / "b.json")
|
||||
with pytest.raises(ValidationError, match="primary only"):
|
||||
AnalysisConfig.model_validate(config_dict)
|
||||
|
||||
def test_none_analysis_requires_no_reference(self, tmp_path: Path) -> None:
|
||||
# NoAnalysis is fine with either reference present or absent.
|
||||
config_dict = {
|
||||
"inputs": {"primary": str(tmp_path / "a.json")},
|
||||
"analysis": {"kind": "none"},
|
||||
"output": {"report": str(tmp_path / "out.json")},
|
||||
}
|
||||
cfg = AnalysisConfig.model_validate(config_dict)
|
||||
assert isinstance(cfg.analysis, NoAnalysis)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Top-level AnalysisConfig
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAnalysisConfig:
|
||||
def test_minimal_dtw_config_parses(self, tmp_path: Path) -> None:
|
||||
cfg = AnalysisConfig.model_validate(_minimal_dtw_config(tmp_path))
|
||||
assert cfg.config_version == 1
|
||||
assert isinstance(cfg.analysis, DtwAnalysis)
|
||||
assert cfg.preprocessing.person_index == 0 # default
|
||||
|
||||
def test_minimal_stats_config_parses(self, tmp_path: Path) -> None:
|
||||
cfg = AnalysisConfig.model_validate(_minimal_stats_config(tmp_path))
|
||||
assert isinstance(cfg.analysis, StatsAnalysis)
|
||||
|
||||
def test_config_version_must_be_1(self, tmp_path: Path) -> None:
|
||||
config_dict = _minimal_dtw_config(tmp_path)
|
||||
config_dict["config_version"] = 99
|
||||
with pytest.raises(ValidationError):
|
||||
AnalysisConfig.model_validate(config_dict)
|
||||
|
||||
def test_round_trip_json(self, tmp_path: Path) -> None:
|
||||
original = AnalysisConfig.model_validate(_minimal_dtw_config(tmp_path))
|
||||
serialised = original.model_dump_json()
|
||||
restored = AnalysisConfig.model_validate_json(serialised)
|
||||
assert restored == original
|
||||
|
||||
def test_extra_top_level_field_rejected(self, tmp_path: Path) -> None:
|
||||
config_dict = _minimal_dtw_config(tmp_path)
|
||||
config_dict["unknown_key"] = "typo"
|
||||
with pytest.raises(ValidationError):
|
||||
AnalysisConfig.model_validate(config_dict)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# analysis_config_to_dict
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAnalysisConfigToDict:
|
||||
def test_returns_json_safe_dict(self, tmp_path: Path) -> None:
|
||||
cfg = AnalysisConfig.model_validate(_minimal_dtw_config(tmp_path))
|
||||
dumped = analysis_config_to_dict(cfg)
|
||||
# Paths must have become strings.
|
||||
assert isinstance(dumped["inputs"]["primary"], str)
|
||||
assert isinstance(dumped["output"]["report"], str)
|
||||
|
||||
def test_round_trips_through_dict(self, tmp_path: Path) -> None:
|
||||
original = AnalysisConfig.model_validate(_minimal_dtw_config(tmp_path))
|
||||
dumped = analysis_config_to_dict(original)
|
||||
restored = AnalysisConfig.model_validate(dumped)
|
||||
assert restored == original
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Result sub-schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDtwResults:
|
||||
def test_minimal_construction(self) -> None:
|
||||
res = DtwResults(
|
||||
kind="dtw",
|
||||
method="dtw_all",
|
||||
distances=[0.5],
|
||||
paths=[[(0, 0), (1, 1)]],
|
||||
segment_labels=["full_trial"],
|
||||
summary={"mean": 0.5},
|
||||
)
|
||||
assert res.kind == "dtw"
|
||||
|
||||
def test_per_joint_distances_shape_is_free(self) -> None:
|
||||
# No validator enforces that per_joint_distances outer length
|
||||
# matches distances — that's run-time semantics of the
|
||||
# executor. Still, verify the field round-trips.
|
||||
res = DtwResults(
|
||||
kind="dtw",
|
||||
method="dtw_per_joint",
|
||||
distances=[0.1, 0.2],
|
||||
paths=[[(0, 0)], [(0, 0)]],
|
||||
per_joint_distances=[[0.05, 0.05], [0.1, 0.1]],
|
||||
segment_labels=["rhee_cycles[0]", "rhee_cycles[1]"],
|
||||
summary={"mean": 0.15},
|
||||
)
|
||||
assert res.per_joint_distances is not None
|
||||
|
||||
|
||||
class TestStatsResults:
|
||||
def test_round_trip(self) -> None:
|
||||
res = StatsResults(
|
||||
kind="stats",
|
||||
statistics=[
|
||||
FeatureSummary(mean=1.0, std=0.1, min=0.8, max=1.2, range=0.4),
|
||||
FeatureSummary(mean=1.1, std=0.2, min=0.7, max=1.5, range=0.8),
|
||||
],
|
||||
segment_labels=["rhee_cycles[0]", "rhee_cycles[1]"],
|
||||
)
|
||||
dumped = res.model_dump_json()
|
||||
restored = StatsResults.model_validate_json(dumped)
|
||||
assert restored == res
|
||||
|
||||
|
||||
class TestNoResults:
|
||||
def test_construction(self) -> None:
|
||||
res = NoResults(kind="none")
|
||||
assert res.kind == "none"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AnalysisReport
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_report(tmp_path: Path) -> AnalysisReport:
|
||||
config = AnalysisConfig.model_validate(_minimal_dtw_config(tmp_path))
|
||||
return AnalysisReport(
|
||||
config=config,
|
||||
primary=InputSummary(
|
||||
path=tmp_path / "primary.json",
|
||||
frame_count=300,
|
||||
fps=30.0,
|
||||
),
|
||||
reference=InputSummary(
|
||||
path=tmp_path / "reference.json",
|
||||
frame_count=300,
|
||||
fps=30.0,
|
||||
),
|
||||
results=DtwResults(
|
||||
kind="dtw",
|
||||
method="dtw_all",
|
||||
distances=[0.42],
|
||||
paths=[[(0, 0), (1, 1)]],
|
||||
segment_labels=["full_trial"],
|
||||
summary={"mean": 0.42, "p50": 0.42},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class TestAnalysisReport:
|
||||
def test_schema_version_defaults_to_current(self, tmp_path: Path) -> None:
|
||||
report = _make_report(tmp_path)
|
||||
assert report.schema_version == CURRENT_VERSION
|
||||
|
||||
def test_round_trip_json(self, tmp_path: Path) -> None:
|
||||
report = _make_report(tmp_path)
|
||||
serialised = report.model_dump_json()
|
||||
restored = AnalysisReport.model_validate_json(serialised)
|
||||
assert restored == report
|
||||
|
||||
def test_empty_segmentations_default(self, tmp_path: Path) -> None:
|
||||
report = _make_report(tmp_path)
|
||||
assert report.segmentations == {}
|
||||
|
||||
def test_extra_field_rejected(self, tmp_path: Path) -> None:
|
||||
report = _make_report(tmp_path)
|
||||
dumped = report.model_dump(mode="json")
|
||||
dumped["mystery_field"] = 1
|
||||
with pytest.raises(ValidationError):
|
||||
AnalysisReport.model_validate(dumped)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Executor helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
NUM_JOINTS = 43
|
||||
|
||||
|
||||
def _heel_signal(num_cycles: int, frames_per_cycle: int, amplitude: float = 100.0) -> np.ndarray:
|
||||
"""Clean sinusoid stand-in for a heel's vertical trace."""
|
||||
import math
|
||||
|
||||
total = num_cycles * frames_per_cycle
|
||||
t = np.linspace(0.0, num_cycles * 2.0 * math.pi, total, endpoint=False)
|
||||
return (np.sin(t) * amplitude + amplitude).astype(float)
|
||||
|
||||
|
||||
def _build_predictions(
|
||||
signal: np.ndarray,
|
||||
joint: int,
|
||||
*,
|
||||
axis: int = 1,
|
||||
fps: float = 30.0,
|
||||
provenance: Provenance | None = None,
|
||||
) -> VideoPredictions:
|
||||
"""Build a VideoPredictions whose ``joint``'s ``axis`` follows ``signal``."""
|
||||
frames = {}
|
||||
for i, value in enumerate(signal):
|
||||
poses = [[[0.0, 0.0, 0.0] for _ in range(NUM_JOINTS)]]
|
||||
poses[0][joint][axis] = float(value)
|
||||
frames[f"frame_{i:06d}"] = {
|
||||
"boxes": [[0.0, 0.0, 1.0, 1.0, 0.9]],
|
||||
"poses3d": poses,
|
||||
"poses2d": [[[0.0, 0.0]] * NUM_JOINTS],
|
||||
}
|
||||
return VideoPredictions.model_validate(
|
||||
{
|
||||
"metadata": {
|
||||
"frame_count": len(signal),
|
||||
"fps": fps,
|
||||
"width": 640,
|
||||
"height": 480,
|
||||
},
|
||||
"frames": frames,
|
||||
"provenance": provenance.model_dump() if provenance is not None else None,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _write_heel_trial(
|
||||
tmp_path: Path,
|
||||
filename: str,
|
||||
*,
|
||||
joint: int,
|
||||
num_cycles: int = 4,
|
||||
frames_per_cycle: int = 30,
|
||||
amplitude: float = 100.0,
|
||||
provenance: Provenance | None = None,
|
||||
) -> Path:
|
||||
"""Write a heel-trace VideoPredictions JSON and return its path."""
|
||||
signal = _heel_signal(num_cycles, frames_per_cycle, amplitude=amplitude)
|
||||
preds = _build_predictions(signal, joint=joint, provenance=provenance)
|
||||
path = tmp_path / filename
|
||||
save_video_predictions(path, preds)
|
||||
return path
|
||||
|
||||
|
||||
def _fake_provenance(sha: str = "a" * 64) -> Provenance:
|
||||
return Provenance(
|
||||
model_sha256=sha,
|
||||
model_filename="fake_model.tar.gz",
|
||||
tensorflow_version="2.18.0",
|
||||
numpy_version="1.26.0",
|
||||
neuropose_version="0.0.0",
|
||||
python_version="3.11.0",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Executor: run_analysis
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunAnalysisDtwFullTrial:
|
||||
def test_dtw_all_unsegmented_yields_one_distance(self, tmp_path: Path) -> None:
|
||||
from neuropose.analyzer.segment import JOINT_INDEX
|
||||
|
||||
primary = _write_heel_trial(tmp_path, "a.json", joint=JOINT_INDEX["rhee"])
|
||||
reference = _write_heel_trial(tmp_path, "b.json", joint=JOINT_INDEX["rhee"])
|
||||
report_path = tmp_path / "report.json"
|
||||
config = AnalysisConfig.model_validate(
|
||||
{
|
||||
"inputs": {"primary": str(primary), "reference": str(reference)},
|
||||
"analysis": {"kind": "dtw", "method": "dtw_all"},
|
||||
"output": {"report": str(report_path)},
|
||||
}
|
||||
)
|
||||
report = run_analysis(config)
|
||||
assert isinstance(report.results, DtwResults)
|
||||
assert report.results.segment_labels == ["full_trial"]
|
||||
assert len(report.results.distances) == 1
|
||||
# Identical inputs → distance 0.
|
||||
assert report.results.distances[0] == pytest.approx(0.0, abs=1e-9)
|
||||
|
||||
def test_dtw_all_different_trials_positive_distance(self, tmp_path: Path) -> None:
|
||||
from neuropose.analyzer.segment import JOINT_INDEX
|
||||
|
||||
primary = _write_heel_trial(tmp_path, "a.json", joint=JOINT_INDEX["rhee"], amplitude=100.0)
|
||||
reference = _write_heel_trial(
|
||||
tmp_path, "b.json", joint=JOINT_INDEX["rhee"], amplitude=200.0
|
||||
)
|
||||
config = AnalysisConfig.model_validate(
|
||||
{
|
||||
"inputs": {"primary": str(primary), "reference": str(reference)},
|
||||
"analysis": {"kind": "dtw", "method": "dtw_all"},
|
||||
"output": {"report": str(tmp_path / "r.json")},
|
||||
}
|
||||
)
|
||||
report = run_analysis(config)
|
||||
assert isinstance(report.results, DtwResults)
|
||||
assert report.results.distances[0] > 0.0
|
||||
assert "mean" in report.results.summary
|
||||
|
||||
|
||||
class TestRunAnalysisDtwSegmented:
|
||||
def test_dtw_with_gait_cycles_produces_per_segment_distances(self, tmp_path: Path) -> None:
|
||||
from neuropose.analyzer.segment import JOINT_INDEX
|
||||
|
||||
primary = _write_heel_trial(tmp_path, "a.json", joint=JOINT_INDEX["rhee"], num_cycles=4)
|
||||
reference = _write_heel_trial(tmp_path, "b.json", joint=JOINT_INDEX["rhee"], num_cycles=4)
|
||||
config = AnalysisConfig.model_validate(
|
||||
{
|
||||
"inputs": {"primary": str(primary), "reference": str(reference)},
|
||||
"segmentation": {"kind": "gait_cycles", "joint": "rhee"},
|
||||
"analysis": {"kind": "dtw", "method": "dtw_all"},
|
||||
"output": {"report": str(tmp_path / "r.json")},
|
||||
}
|
||||
)
|
||||
report = run_analysis(config)
|
||||
assert isinstance(report.results, DtwResults)
|
||||
# 4 cycles detected on both → 4 paired distances.
|
||||
assert len(report.results.distances) == 4
|
||||
assert all(label.startswith("rhee_cycles[") for label in report.results.segment_labels)
|
||||
|
||||
def test_dtw_bilateral_produces_distances_per_side(self, tmp_path: Path) -> None:
|
||||
from neuropose.analyzer.segment import JOINT_INDEX
|
||||
|
||||
# Put the heel trace on both lhee and rhee.
|
||||
rng_signal = _heel_signal(num_cycles=3, frames_per_cycle=30)
|
||||
frames = {}
|
||||
for i, value in enumerate(rng_signal):
|
||||
poses = [[[0.0, 0.0, 0.0] for _ in range(NUM_JOINTS)]]
|
||||
poses[0][JOINT_INDEX["lhee"]][1] = float(value)
|
||||
poses[0][JOINT_INDEX["rhee"]][1] = float(value)
|
||||
frames[f"frame_{i:06d}"] = {
|
||||
"boxes": [[0.0, 0.0, 1.0, 1.0, 0.9]],
|
||||
"poses3d": poses,
|
||||
"poses2d": [[[0.0, 0.0]] * NUM_JOINTS],
|
||||
}
|
||||
preds = VideoPredictions.model_validate(
|
||||
{
|
||||
"metadata": {
|
||||
"frame_count": len(rng_signal),
|
||||
"fps": 30.0,
|
||||
"width": 640,
|
||||
"height": 480,
|
||||
},
|
||||
"frames": frames,
|
||||
}
|
||||
)
|
||||
primary = tmp_path / "a.json"
|
||||
reference = tmp_path / "b.json"
|
||||
save_video_predictions(primary, preds)
|
||||
save_video_predictions(reference, preds)
|
||||
config = AnalysisConfig.model_validate(
|
||||
{
|
||||
"inputs": {"primary": str(primary), "reference": str(reference)},
|
||||
"segmentation": {"kind": "gait_cycles_bilateral"},
|
||||
"analysis": {"kind": "dtw", "method": "dtw_all"},
|
||||
"output": {"report": str(tmp_path / "r.json")},
|
||||
}
|
||||
)
|
||||
report = run_analysis(config)
|
||||
assert isinstance(report.results, DtwResults)
|
||||
# 3 cycles * 2 sides.
|
||||
assert len(report.results.distances) == 6
|
||||
left = [lbl for lbl in report.results.segment_labels if lbl.startswith("left_heel_strikes")]
|
||||
right = [
|
||||
lbl for lbl in report.results.segment_labels if lbl.startswith("right_heel_strikes")
|
||||
]
|
||||
assert len(left) == 3
|
||||
assert len(right) == 3
|
||||
# Identical primary and reference → all distances zero.
|
||||
for d in report.results.distances:
|
||||
assert d == pytest.approx(0.0, abs=1e-9)
|
||||
|
||||
def test_dtw_per_joint_populates_per_joint_distances(self, tmp_path: Path) -> None:
|
||||
from neuropose.analyzer.segment import JOINT_INDEX
|
||||
|
||||
primary = _write_heel_trial(tmp_path, "a.json", joint=JOINT_INDEX["rhee"])
|
||||
reference = _write_heel_trial(tmp_path, "b.json", joint=JOINT_INDEX["rhee"])
|
||||
config = AnalysisConfig.model_validate(
|
||||
{
|
||||
"inputs": {"primary": str(primary), "reference": str(reference)},
|
||||
"analysis": {"kind": "dtw", "method": "dtw_per_joint"},
|
||||
"output": {"report": str(tmp_path / "r.json")},
|
||||
}
|
||||
)
|
||||
report = run_analysis(config)
|
||||
assert isinstance(report.results, DtwResults)
|
||||
assert report.results.per_joint_distances is not None
|
||||
assert len(report.results.per_joint_distances) == 1 # unsegmented → one pair
|
||||
assert len(report.results.per_joint_distances[0]) == NUM_JOINTS
|
||||
|
||||
|
||||
class TestRunAnalysisStats:
|
||||
def test_stats_unsegmented_single_block(self, tmp_path: Path) -> None:
|
||||
from neuropose.analyzer.segment import JOINT_INDEX
|
||||
|
||||
primary = _write_heel_trial(tmp_path, "a.json", joint=JOINT_INDEX["rhee"])
|
||||
config = AnalysisConfig.model_validate(
|
||||
{
|
||||
"inputs": {"primary": str(primary)},
|
||||
"analysis": {
|
||||
"kind": "stats",
|
||||
"extractor": {
|
||||
"kind": "joint_axis",
|
||||
"joint": JOINT_INDEX["rhee"],
|
||||
"axis": 1,
|
||||
"invert": False,
|
||||
},
|
||||
},
|
||||
"output": {"report": str(tmp_path / "r.json")},
|
||||
}
|
||||
)
|
||||
report = run_analysis(config)
|
||||
assert isinstance(report.results, StatsResults)
|
||||
assert report.results.segment_labels == ["full_trial"]
|
||||
assert len(report.results.statistics) == 1
|
||||
stat = report.results.statistics[0]
|
||||
assert isinstance(stat, FeatureSummary)
|
||||
assert stat.max > stat.min # Signal oscillates.
|
||||
|
||||
def test_stats_with_segmentation_emits_per_segment(self, tmp_path: Path) -> None:
|
||||
from neuropose.analyzer.segment import JOINT_INDEX
|
||||
|
||||
primary = _write_heel_trial(tmp_path, "a.json", joint=JOINT_INDEX["rhee"], num_cycles=3)
|
||||
config = AnalysisConfig.model_validate(
|
||||
{
|
||||
"inputs": {"primary": str(primary)},
|
||||
"segmentation": {"kind": "gait_cycles", "joint": "rhee"},
|
||||
"analysis": {
|
||||
"kind": "stats",
|
||||
"extractor": {
|
||||
"kind": "joint_axis",
|
||||
"joint": JOINT_INDEX["rhee"],
|
||||
"axis": 1,
|
||||
"invert": False,
|
||||
},
|
||||
},
|
||||
"output": {"report": str(tmp_path / "r.json")},
|
||||
}
|
||||
)
|
||||
report = run_analysis(config)
|
||||
assert isinstance(report.results, StatsResults)
|
||||
assert len(report.results.statistics) == 3
|
||||
|
||||
|
||||
class TestRunAnalysisNone:
|
||||
def test_none_analysis_returns_no_results(self, tmp_path: Path) -> None:
|
||||
from neuropose.analyzer.segment import JOINT_INDEX
|
||||
|
||||
primary = _write_heel_trial(tmp_path, "a.json", joint=JOINT_INDEX["rhee"])
|
||||
config = AnalysisConfig.model_validate(
|
||||
{
|
||||
"inputs": {"primary": str(primary)},
|
||||
"analysis": {"kind": "none"},
|
||||
"output": {"report": str(tmp_path / "r.json")},
|
||||
}
|
||||
)
|
||||
report = run_analysis(config)
|
||||
assert isinstance(report.results, NoResults)
|
||||
|
||||
def test_none_with_segmentation_still_emits_segmentations(self, tmp_path: Path) -> None:
|
||||
from neuropose.analyzer.segment import JOINT_INDEX
|
||||
|
||||
primary = _write_heel_trial(tmp_path, "a.json", joint=JOINT_INDEX["rhee"])
|
||||
config = AnalysisConfig.model_validate(
|
||||
{
|
||||
"inputs": {"primary": str(primary)},
|
||||
"segmentation": {"kind": "gait_cycles", "joint": "rhee"},
|
||||
"analysis": {"kind": "none"},
|
||||
"output": {"report": str(tmp_path / "r.json")},
|
||||
}
|
||||
)
|
||||
report = run_analysis(config)
|
||||
assert isinstance(report.results, NoResults)
|
||||
assert "rhee_cycles" in report.segmentations
|
||||
assert len(report.segmentations["rhee_cycles"].segments) > 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Provenance inheritance
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunAnalysisProvenance:
|
||||
def test_inherits_primary_provenance_and_stamps_config(self, tmp_path: Path) -> None:
|
||||
from neuropose.analyzer.segment import JOINT_INDEX
|
||||
|
||||
provenance = _fake_provenance()
|
||||
primary = _write_heel_trial(
|
||||
tmp_path, "a.json", joint=JOINT_INDEX["rhee"], provenance=provenance
|
||||
)
|
||||
reference = _write_heel_trial(tmp_path, "b.json", joint=JOINT_INDEX["rhee"])
|
||||
config = AnalysisConfig.model_validate(
|
||||
{
|
||||
"inputs": {"primary": str(primary), "reference": str(reference)},
|
||||
"analysis": {"kind": "dtw", "method": "dtw_all"},
|
||||
"output": {"report": str(tmp_path / "r.json")},
|
||||
}
|
||||
)
|
||||
report = run_analysis(config)
|
||||
assert report.provenance is not None
|
||||
# Model SHA inherited from primary.
|
||||
assert report.provenance.model_sha256 == provenance.model_sha256
|
||||
# analysis_config populated with the serialised config.
|
||||
assert report.provenance.analysis_config is not None
|
||||
assert report.provenance.analysis_config["config_version"] == 1
|
||||
|
||||
def test_no_primary_provenance_yields_none_report_provenance(self, tmp_path: Path) -> None:
|
||||
from neuropose.analyzer.segment import JOINT_INDEX
|
||||
|
||||
primary = _write_heel_trial(tmp_path, "a.json", joint=JOINT_INDEX["rhee"])
|
||||
reference = _write_heel_trial(tmp_path, "b.json", joint=JOINT_INDEX["rhee"])
|
||||
config = AnalysisConfig.model_validate(
|
||||
{
|
||||
"inputs": {"primary": str(primary), "reference": str(reference)},
|
||||
"analysis": {"kind": "dtw", "method": "dtw_all"},
|
||||
"output": {"report": str(tmp_path / "r.json")},
|
||||
}
|
||||
)
|
||||
report = run_analysis(config)
|
||||
assert report.provenance is None
|
||||
|
||||
def test_input_summaries_track_paths_and_metadata(self, tmp_path: Path) -> None:
|
||||
from neuropose.analyzer.segment import JOINT_INDEX
|
||||
|
||||
primary = _write_heel_trial(tmp_path, "a.json", joint=JOINT_INDEX["rhee"], num_cycles=5)
|
||||
reference = _write_heel_trial(tmp_path, "b.json", joint=JOINT_INDEX["rhee"], num_cycles=3)
|
||||
config = AnalysisConfig.model_validate(
|
||||
{
|
||||
"inputs": {"primary": str(primary), "reference": str(reference)},
|
||||
"analysis": {"kind": "dtw", "method": "dtw_all"},
|
||||
"output": {"report": str(tmp_path / "r.json")},
|
||||
}
|
||||
)
|
||||
report = run_analysis(config)
|
||||
assert report.primary.path == primary
|
||||
assert report.primary.frame_count == 5 * 30
|
||||
assert report.reference is not None
|
||||
assert report.reference.frame_count == 3 * 30
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# load_config / save_report / load_report
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLoadSave:
|
||||
def test_load_config_parses_yaml(self, tmp_path: Path) -> None:
|
||||
config_dict = {
|
||||
"inputs": {
|
||||
"primary": str(tmp_path / "a.json"),
|
||||
"reference": str(tmp_path / "b.json"),
|
||||
},
|
||||
"analysis": {"kind": "dtw", "method": "dtw_all"},
|
||||
"output": {"report": str(tmp_path / "r.json")},
|
||||
}
|
||||
yaml_path = tmp_path / "exp.yaml"
|
||||
yaml_path.write_text(yaml.safe_dump(config_dict))
|
||||
loaded = load_config(yaml_path)
|
||||
assert isinstance(loaded.analysis, DtwAnalysis)
|
||||
|
||||
def test_load_config_empty_file_fails_cleanly(self, tmp_path: Path) -> None:
|
||||
empty = tmp_path / "empty.yaml"
|
||||
empty.write_text("")
|
||||
with pytest.raises(ValidationError):
|
||||
load_config(empty)
|
||||
|
||||
def test_load_config_rejects_malformed_yaml(self, tmp_path: Path) -> None:
|
||||
bad = tmp_path / "bad.yaml"
|
||||
# Unclosed flow-style mapping — yaml.safe_load raises here.
|
||||
bad.write_text("inputs: {primary: foo\n")
|
||||
with pytest.raises(yaml.YAMLError):
|
||||
load_config(bad)
|
||||
|
||||
def test_save_report_round_trip(self, tmp_path: Path) -> None:
|
||||
from neuropose.analyzer.segment import JOINT_INDEX
|
||||
|
||||
primary = _write_heel_trial(tmp_path, "a.json", joint=JOINT_INDEX["rhee"])
|
||||
reference = _write_heel_trial(tmp_path, "b.json", joint=JOINT_INDEX["rhee"])
|
||||
config = AnalysisConfig.model_validate(
|
||||
{
|
||||
"inputs": {"primary": str(primary), "reference": str(reference)},
|
||||
"analysis": {"kind": "dtw", "method": "dtw_all"},
|
||||
"output": {"report": str(tmp_path / "report.json")},
|
||||
}
|
||||
)
|
||||
report = run_analysis(config)
|
||||
report_path = tmp_path / "report.json"
|
||||
save_report(report_path, report)
|
||||
assert report_path.exists()
|
||||
|
||||
restored = load_report(report_path)
|
||||
assert restored == report
|
||||
|
||||
def test_save_report_is_atomic(self, tmp_path: Path) -> None:
|
||||
"""The saver writes via a sibling .tmp path and renames."""
|
||||
report = _make_report(tmp_path)
|
||||
report_path = tmp_path / "subdir" / "report.json"
|
||||
save_report(report_path, report)
|
||||
# Parent directory was created.
|
||||
assert report_path.exists()
|
||||
# No .tmp sibling left behind.
|
||||
assert not (report_path.with_suffix(report_path.suffix + ".tmp")).exists()
|
||||
|
||||
def test_load_report_rejects_future_schema(self, tmp_path: Path) -> None:
|
||||
"""Future schema_version surfaces as a migration error."""
|
||||
from neuropose.migrations import FutureSchemaError
|
||||
|
||||
future = {"schema_version": CURRENT_VERSION + 1}
|
||||
path = tmp_path / "future.json"
|
||||
path.write_text(json.dumps(future))
|
||||
with pytest.raises(FutureSchemaError):
|
||||
load_report(path)
|
||||
|
|
@ -33,8 +33,6 @@ from neuropose.analyzer.segment import (
|
|||
joint_pair_distance,
|
||||
joint_speed,
|
||||
segment_by_peaks,
|
||||
segment_gait_cycles,
|
||||
segment_gait_cycles_bilateral,
|
||||
segment_predictions,
|
||||
slice_predictions,
|
||||
)
|
||||
|
|
@ -431,139 +429,3 @@ class TestSlicePredictions:
|
|||
sliced = slice_predictions(preds, segments)[0]
|
||||
# frame_000000 of the slice must equal frame_000050 of the source
|
||||
assert sliced["frame_000000"].poses3d == preds["frame_000050"].poses3d
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# segment_gait_cycles / segment_gait_cycles_bilateral
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _heel_signal(num_cycles: int, frames_per_cycle: int) -> np.ndarray:
|
||||
"""A clean sinusoid standing in for a heel's vertical trace."""
|
||||
total = num_cycles * frames_per_cycle
|
||||
t = np.linspace(0.0, num_cycles * 2.0 * math.pi, total, endpoint=False)
|
||||
# Amplitude chosen so min_prominence tests have a non-trivial range.
|
||||
return (np.sin(t) * 100.0 + 100.0).astype(float)
|
||||
|
||||
|
||||
class TestSegmentGaitCycles:
|
||||
def test_detects_expected_number_of_cycles(self) -> None:
|
||||
# 5 cycles at 30 fps, 30 frames per cycle = 1.0 s per stride.
|
||||
# Well inside the default min_cycle_seconds=0.4 gate.
|
||||
signal = _heel_signal(num_cycles=5, frames_per_cycle=30)
|
||||
preds = _make_predictions(signal, joint=JOINT_INDEX["rhee"])
|
||||
seg = segment_gait_cycles(preds, joint="rhee", axis="y")
|
||||
assert len(seg.segments) == 5
|
||||
|
||||
def test_config_records_inputs(self) -> None:
|
||||
signal = _heel_signal(num_cycles=3, frames_per_cycle=30)
|
||||
preds = _make_predictions(signal, joint=JOINT_INDEX["rhee"])
|
||||
seg = segment_gait_cycles(
|
||||
preds,
|
||||
joint="rhee",
|
||||
axis="y",
|
||||
min_cycle_seconds=0.5,
|
||||
min_prominence=10.0,
|
||||
)
|
||||
assert isinstance(seg, Segmentation)
|
||||
assert isinstance(seg.config.extractor, JointAxisExtractor)
|
||||
assert seg.config.extractor.joint == JOINT_INDEX["rhee"]
|
||||
assert seg.config.extractor.axis == 1 # "y" → 1
|
||||
assert seg.config.extractor.invert is False
|
||||
assert seg.config.min_distance_seconds == 0.5
|
||||
assert seg.config.min_prominence == 10.0
|
||||
|
||||
def test_axis_selection(self) -> None:
|
||||
# Put the signal on the X axis instead of Y.
|
||||
signal = _heel_signal(num_cycles=4, frames_per_cycle=30)
|
||||
preds = _make_predictions(signal, joint=JOINT_INDEX["rhee"], axis=0)
|
||||
seg_y = segment_gait_cycles(preds, joint="rhee", axis="y")
|
||||
seg_x = segment_gait_cycles(preds, joint="rhee", axis="x")
|
||||
# Y is all-zeros (flat → no peaks), X carries the signal.
|
||||
assert len(seg_y.segments) == 0
|
||||
assert len(seg_x.segments) == 4
|
||||
|
||||
def test_invert_flips_peaks_and_valleys(self) -> None:
|
||||
# Invert the heel trace; with invert=True, the original valleys
|
||||
# become the peaks detected as heel-strikes.
|
||||
signal = _heel_signal(num_cycles=4, frames_per_cycle=30)
|
||||
preds = _make_predictions(signal, joint=JOINT_INDEX["rhee"])
|
||||
seg_plain = segment_gait_cycles(preds, joint="rhee", axis="y", invert=False)
|
||||
seg_inverted = segment_gait_cycles(preds, joint="rhee", axis="y", invert=True)
|
||||
# Both detect four distinct events (peaks in either the signal
|
||||
# or its negation). Peaks differ by roughly half a cycle.
|
||||
assert len(seg_plain.segments) == 4
|
||||
assert len(seg_inverted.segments) == 4
|
||||
plain_peaks = [s.peak for s in seg_plain.segments]
|
||||
inverted_peaks = [s.peak for s in seg_inverted.segments]
|
||||
assert plain_peaks != inverted_peaks
|
||||
|
||||
def test_pathological_flat_signal_returns_empty(self) -> None:
|
||||
# A subject whose heel never leaves the ground — no peaks.
|
||||
signal = np.zeros(120)
|
||||
preds = _make_predictions(signal, joint=JOINT_INDEX["rhee"])
|
||||
seg = segment_gait_cycles(preds, joint="rhee", axis="y")
|
||||
assert seg.segments == []
|
||||
|
||||
def test_min_cycle_seconds_rejects_close_peaks(self) -> None:
|
||||
# 10 cycles in 60 frames @ 30 fps = 0.2 s per cycle.
|
||||
# min_cycle_seconds=0.4 should reject all but every-other peak.
|
||||
signal = _heel_signal(num_cycles=10, frames_per_cycle=6)
|
||||
preds = _make_predictions(signal, joint=JOINT_INDEX["rhee"])
|
||||
seg_permissive = segment_gait_cycles(preds, joint="rhee", min_cycle_seconds=0.0)
|
||||
seg_strict = segment_gait_cycles(preds, joint="rhee", min_cycle_seconds=0.4)
|
||||
# Strict mode drops peaks that are too close together.
|
||||
assert len(seg_strict.segments) < len(seg_permissive.segments)
|
||||
|
||||
def test_unknown_joint_raises_key_error(self) -> None:
|
||||
signal = _heel_signal(num_cycles=3, frames_per_cycle=30)
|
||||
preds = _make_predictions(signal, joint=JOINT_INDEX["rhee"])
|
||||
with pytest.raises(KeyError, match="unknown joint"):
|
||||
segment_gait_cycles(preds, joint="left_heel") # wrong name
|
||||
|
||||
def test_invalid_axis_raises_value_error(self) -> None:
|
||||
signal = _heel_signal(num_cycles=3, frames_per_cycle=30)
|
||||
preds = _make_predictions(signal, joint=JOINT_INDEX["rhee"])
|
||||
with pytest.raises(ValueError, match="axis must be one of"):
|
||||
segment_gait_cycles(preds, joint="rhee", axis="w") # type: ignore[arg-type]
|
||||
|
||||
|
||||
class TestSegmentGaitCyclesBilateral:
|
||||
def test_returns_both_keys(self) -> None:
|
||||
signal = _heel_signal(num_cycles=3, frames_per_cycle=30)
|
||||
# Put the same signal on both heels so both sides find cycles.
|
||||
preds = _make_predictions(signal, joint=JOINT_INDEX["rhee"])
|
||||
# Rebuild predictions with lhee populated too.
|
||||
frames = {}
|
||||
for i, value in enumerate(signal):
|
||||
poses = [[[0.0, 0.0, 0.0] for _ in range(NUM_JOINTS)]]
|
||||
poses[0][JOINT_INDEX["lhee"]][1] = float(value)
|
||||
poses[0][JOINT_INDEX["rhee"]][1] = float(value)
|
||||
frames[f"frame_{i:06d}"] = {
|
||||
"boxes": [[0.0, 0.0, 1.0, 1.0, 0.9]],
|
||||
"poses3d": poses,
|
||||
"poses2d": [[[0.0, 0.0]] * NUM_JOINTS],
|
||||
}
|
||||
preds = VideoPredictions.model_validate(
|
||||
{
|
||||
"metadata": {
|
||||
"frame_count": len(signal),
|
||||
"fps": 30.0,
|
||||
"width": 640,
|
||||
"height": 480,
|
||||
},
|
||||
"frames": frames,
|
||||
}
|
||||
)
|
||||
result = segment_gait_cycles_bilateral(preds)
|
||||
assert set(result.keys()) == {"left_heel_strikes", "right_heel_strikes"}
|
||||
assert len(result["left_heel_strikes"].segments) == 3
|
||||
assert len(result["right_heel_strikes"].segments) == 3
|
||||
|
||||
def test_pathological_one_side_returns_empty_for_that_side(self) -> None:
|
||||
# Only the right heel carries a signal; left heel is flat.
|
||||
signal = _heel_signal(num_cycles=3, frames_per_cycle=30)
|
||||
preds = _make_predictions(signal, joint=JOINT_INDEX["rhee"])
|
||||
result = segment_gait_cycles_bilateral(preds)
|
||||
assert len(result["right_heel_strikes"].segments) == 3
|
||||
assert result["left_heel_strikes"].segments == []
|
||||
|
|
|
|||
|
|
@ -683,15 +683,9 @@ def stub_estimator_with_metrics(monkeypatch: pytest.MonkeyPatch):
|
|||
"poses2d": np.array([[[0.0, 0.0], [1.0, 1.0]]]),
|
||||
}
|
||||
|
||||
from neuropose._model import LoadedModel
|
||||
|
||||
def fake_loader(cache_dir: Path | None = None) -> LoadedModel:
|
||||
def fake_loader(cache_dir: Path | None = None) -> object:
|
||||
del cache_dir
|
||||
return LoadedModel(
|
||||
model=RecordingFake(),
|
||||
sha256="smoke_sha",
|
||||
filename="metrabs_smoke.tar.gz",
|
||||
)
|
||||
return RecordingFake()
|
||||
|
||||
monkeypatch.setattr("neuropose.estimator.load_metrabs_model", fake_loader)
|
||||
|
||||
|
|
@ -776,142 +770,17 @@ class TestBenchmarkSubcommand:
|
|||
|
||||
|
||||
class TestAnalyze:
|
||||
"""Covers the ``neuropose analyze --config <yaml>`` subcommand.
|
||||
|
||||
Execution happy path is exercised in detail in
|
||||
:mod:`tests.unit.test_analyzer_pipeline` — this file focuses on
|
||||
the CLI wiring: argument parsing, config-loading error modes, and
|
||||
end-to-end smoke.
|
||||
"""
|
||||
|
||||
def _make_predictions_file(self, tmp_path: Path, name: str, num_frames: int = 30) -> Path:
|
||||
"""Write a trivial VideoPredictions file to disk for the CLI to load."""
|
||||
import math
|
||||
|
||||
from neuropose.io import VideoPredictions, save_video_predictions
|
||||
|
||||
num_joints = 43
|
||||
frames = {}
|
||||
for i in range(num_frames):
|
||||
poses = [[[0.0, 0.0, 0.0] for _ in range(num_joints)]]
|
||||
poses[0][41][1] = float(math.sin(i * 0.3)) * 100.0 # rhee Y
|
||||
frames[f"frame_{i:06d}"] = {
|
||||
"boxes": [[0.0, 0.0, 1.0, 1.0, 0.9]],
|
||||
"poses3d": poses,
|
||||
"poses2d": [[[0.0, 0.0]] * num_joints],
|
||||
}
|
||||
preds = VideoPredictions.model_validate(
|
||||
{
|
||||
"metadata": {
|
||||
"frame_count": num_frames,
|
||||
"fps": 30.0,
|
||||
"width": 640,
|
||||
"height": 480,
|
||||
},
|
||||
"frames": frames,
|
||||
}
|
||||
)
|
||||
path = tmp_path / name
|
||||
save_video_predictions(path, preds)
|
||||
return path
|
||||
|
||||
def _write_dtw_config(
|
||||
self,
|
||||
tmp_path: Path,
|
||||
*,
|
||||
primary: Path,
|
||||
reference: Path,
|
||||
report: Path,
|
||||
) -> Path:
|
||||
import yaml as _yaml
|
||||
|
||||
config_path = tmp_path / "config.yaml"
|
||||
config_path.write_text(
|
||||
_yaml.safe_dump(
|
||||
{
|
||||
"inputs": {"primary": str(primary), "reference": str(reference)},
|
||||
"analysis": {"kind": "dtw", "method": "dtw_all"},
|
||||
"output": {"report": str(report)},
|
||||
}
|
||||
)
|
||||
)
|
||||
return config_path
|
||||
|
||||
def test_missing_config_is_usage_error(self, runner: CliRunner, tmp_path: Path) -> None:
|
||||
result = runner.invoke(app, ["analyze", "--config", str(tmp_path / "nope.yaml")])
|
||||
assert result.exit_code == EXIT_USAGE
|
||||
assert "config file not found" in result.output
|
||||
|
||||
def test_missing_config_flag_is_usage_error(self, runner: CliRunner) -> None:
|
||||
result = runner.invoke(app, ["analyze"])
|
||||
assert result.exit_code == EXIT_USAGE
|
||||
|
||||
def test_invalid_yaml_is_usage_error(self, runner: CliRunner, tmp_path: Path) -> None:
|
||||
bad = tmp_path / "bad.yaml"
|
||||
bad.write_text("inputs: {primary: foo\n") # unclosed flow mapping
|
||||
result = runner.invoke(app, ["analyze", "--config", str(bad)])
|
||||
assert result.exit_code == EXIT_USAGE
|
||||
assert "could not parse YAML" in result.output
|
||||
|
||||
def test_schema_violation_is_usage_error(self, runner: CliRunner, tmp_path: Path) -> None:
|
||||
import yaml as _yaml
|
||||
|
||||
bad = tmp_path / "schema.yaml"
|
||||
bad.write_text(
|
||||
_yaml.safe_dump(
|
||||
{
|
||||
"inputs": {"primary": str(tmp_path / "a.json")},
|
||||
# dtw without reference — violates cross-field invariant.
|
||||
"analysis": {"kind": "dtw", "method": "dtw_all"},
|
||||
"output": {"report": str(tmp_path / "r.json")},
|
||||
}
|
||||
)
|
||||
)
|
||||
result = runner.invoke(app, ["analyze", "--config", str(bad)])
|
||||
assert result.exit_code == EXIT_USAGE
|
||||
assert "invalid config" in result.output
|
||||
|
||||
def test_happy_path_writes_report(self, runner: CliRunner, tmp_path: Path) -> None:
|
||||
primary = self._make_predictions_file(tmp_path, "a.json")
|
||||
reference = self._make_predictions_file(tmp_path, "b.json")
|
||||
report_path = tmp_path / "report.json"
|
||||
config = self._write_dtw_config(
|
||||
tmp_path, primary=primary, reference=reference, report=report_path
|
||||
)
|
||||
result = runner.invoke(app, ["analyze", "--config", str(config)])
|
||||
assert result.exit_code == EXIT_OK, result.output
|
||||
assert report_path.exists()
|
||||
assert "wrote analysis report" in result.output
|
||||
assert "analysis kind: dtw" in result.output
|
||||
|
||||
def test_output_option_overrides_config_path(self, runner: CliRunner, tmp_path: Path) -> None:
|
||||
primary = self._make_predictions_file(tmp_path, "a.json")
|
||||
reference = self._make_predictions_file(tmp_path, "b.json")
|
||||
# Config points at one report path ...
|
||||
config = self._write_dtw_config(
|
||||
tmp_path,
|
||||
primary=primary,
|
||||
reference=reference,
|
||||
report=tmp_path / "declared.json",
|
||||
)
|
||||
# ... but --output overrides.
|
||||
override = tmp_path / "override.json"
|
||||
result = runner.invoke(app, ["analyze", "--config", str(config), "--output", str(override)])
|
||||
assert result.exit_code == EXIT_OK, result.output
|
||||
assert override.exists()
|
||||
assert not (tmp_path / "declared.json").exists()
|
||||
|
||||
def test_missing_predictions_file_is_usage_error(
|
||||
def test_analyze_stub_exits_with_pending_message(
|
||||
self, runner: CliRunner, tmp_path: Path
|
||||
) -> None:
|
||||
# Config points at a primary that does not exist.
|
||||
config = self._write_dtw_config(
|
||||
tmp_path,
|
||||
primary=tmp_path / "missing_primary.json",
|
||||
reference=tmp_path / "missing_reference.json",
|
||||
report=tmp_path / "report.json",
|
||||
)
|
||||
result = runner.invoke(app, ["analyze", "--config", str(config)])
|
||||
results_path = tmp_path / "results.json"
|
||||
results_path.write_text("{}")
|
||||
result = runner.invoke(app, ["analyze", str(results_path)])
|
||||
assert result.exit_code == EXIT_PENDING
|
||||
assert "commit 10" in result.output
|
||||
|
||||
def test_analyze_requires_an_argument(self, runner: CliRunner) -> None:
|
||||
result = runner.invoke(app, ["analyze"])
|
||||
assert result.exit_code == EXIT_USAGE
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -70,21 +70,17 @@ class TestModelGuard:
|
|||
network: the loader is monkeypatched to return a sentinel, and we
|
||||
assert it ends up as the estimator's model.
|
||||
"""
|
||||
from neuropose._model import LoadedModel
|
||||
|
||||
sentinel = object()
|
||||
called_with: list[Path | None] = []
|
||||
|
||||
def fake_loader(cache_dir: Path | None = None) -> LoadedModel:
|
||||
def fake_loader(cache_dir: Path | None = None) -> object:
|
||||
called_with.append(cache_dir)
|
||||
return LoadedModel(model=sentinel, sha256="deadbeef", filename="fake.tar.gz")
|
||||
return sentinel
|
||||
|
||||
monkeypatch.setattr("neuropose.estimator.load_metrabs_model", fake_loader)
|
||||
estimator = Estimator()
|
||||
estimator.load_model(cache_dir=Path("/tmp/fake-cache"))
|
||||
assert estimator.model is sentinel
|
||||
assert estimator.model_sha256 == "deadbeef"
|
||||
assert estimator.model_filename == "fake.tar.gz"
|
||||
assert called_with == [Path("/tmp/fake-cache")]
|
||||
|
||||
def test_load_model_is_idempotent_when_already_loaded(
|
||||
|
|
@ -282,15 +278,9 @@ class TestPerformanceMetrics:
|
|||
"poses2d": np.array([[[0.0, 0.0]]]),
|
||||
}
|
||||
|
||||
from neuropose._model import LoadedModel
|
||||
|
||||
def fake_loader(cache_dir: Path | None = None) -> LoadedModel:
|
||||
def fake_loader(cache_dir: Path | None = None) -> object:
|
||||
del cache_dir
|
||||
return LoadedModel(
|
||||
model=Recorder(),
|
||||
sha256="fake_sha",
|
||||
filename="metrabs_fake.tar.gz",
|
||||
)
|
||||
return Recorder()
|
||||
|
||||
monkeypatch.setattr("neuropose.estimator.load_metrabs_model", fake_loader)
|
||||
estimator = Estimator()
|
||||
|
|
@ -322,88 +312,6 @@ class TestPerformanceMetrics:
|
|||
assert result.metrics.tensorflow_version not in {"", "unknown"}
|
||||
|
||||
|
||||
class TestProvenance:
|
||||
"""Provenance attachment to VideoPredictions.
|
||||
|
||||
Covers the two relevant paths: the injected-model path (no SHA
|
||||
known → ``provenance=None`` on output) and the ``load_model`` path
|
||||
(SHA is known → full ``Provenance`` populated and attached).
|
||||
"""
|
||||
|
||||
def test_injected_model_produces_no_provenance(
|
||||
self,
|
||||
synthetic_video: Path,
|
||||
fake_metrabs_model,
|
||||
) -> None:
|
||||
estimator = Estimator(model=fake_metrabs_model)
|
||||
result = estimator.process_video(synthetic_video)
|
||||
assert result.predictions.provenance is None
|
||||
assert estimator.model_sha256 is None
|
||||
assert estimator.model_filename is None
|
||||
|
||||
def test_loaded_model_populates_provenance(
|
||||
self,
|
||||
synthetic_video: Path,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
import numpy as np
|
||||
|
||||
from neuropose._model import LoadedModel
|
||||
|
||||
class Recorder:
|
||||
def detect_poses(self, image, **kwargs):
|
||||
del image, kwargs
|
||||
return {
|
||||
"boxes": np.array([[0.0, 0.0, 1.0, 1.0, 0.9]]),
|
||||
"poses3d": np.array([[[0.0, 0.0, 0.0]]]),
|
||||
"poses2d": np.array([[[0.0, 0.0]]]),
|
||||
}
|
||||
|
||||
def fake_loader(cache_dir: Path | None = None) -> LoadedModel:
|
||||
del cache_dir
|
||||
return LoadedModel(
|
||||
model=Recorder(),
|
||||
sha256="e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
|
||||
filename="metrabs_stub.tar.gz",
|
||||
)
|
||||
|
||||
monkeypatch.setattr("neuropose.estimator.load_metrabs_model", fake_loader)
|
||||
estimator = Estimator()
|
||||
estimator.load_model()
|
||||
result = estimator.process_video(synthetic_video)
|
||||
|
||||
prov = result.predictions.provenance
|
||||
assert prov is not None
|
||||
assert prov.model_sha256.startswith("e3b0c44")
|
||||
assert prov.model_filename == "metrabs_stub.tar.gz"
|
||||
assert prov.numpy_version == np.__version__
|
||||
assert prov.python_version.count(".") == 2 # MAJOR.MINOR.MICRO
|
||||
# neuropose_version should match the package's __version__
|
||||
from neuropose import __version__ as pkg_version
|
||||
|
||||
assert prov.neuropose_version == pkg_version
|
||||
# tensorflow_version should also be real (TF is in dev deps).
|
||||
assert prov.tensorflow_version not in {"", "unknown"}
|
||||
|
||||
def test_model_sha256_and_filename_properties_after_load(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
from neuropose._model import LoadedModel
|
||||
|
||||
def fake_loader(cache_dir: Path | None = None) -> LoadedModel:
|
||||
del cache_dir
|
||||
return LoadedModel(model=object(), sha256="abcd", filename="x.tar.gz")
|
||||
|
||||
monkeypatch.setattr("neuropose.estimator.load_metrabs_model", fake_loader)
|
||||
estimator = Estimator()
|
||||
assert estimator.model_sha256 is None
|
||||
assert estimator.model_filename is None
|
||||
estimator.load_model()
|
||||
assert estimator.model_sha256 == "abcd"
|
||||
assert estimator.model_filename == "x.tar.gz"
|
||||
|
||||
|
||||
class TestErrors:
|
||||
def test_missing_video(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -22,7 +22,6 @@ from neuropose.io import (
|
|||
JointPairDistanceExtractor,
|
||||
JointSpeedExtractor,
|
||||
PerformanceMetrics,
|
||||
Provenance,
|
||||
Segment,
|
||||
Segmentation,
|
||||
SegmentationConfig,
|
||||
|
|
@ -279,102 +278,6 @@ class TestPerformanceMetricsModel:
|
|||
m.total_seconds = 2.0
|
||||
|
||||
|
||||
def _minimal_provenance() -> Provenance:
|
||||
return Provenance(
|
||||
model_sha256="a" * 64,
|
||||
model_filename="metrabs_fake.tar.gz",
|
||||
tensorflow_version="2.18.1",
|
||||
numpy_version="2.0.2",
|
||||
neuropose_version="0.1.0.dev0",
|
||||
python_version="3.11.14",
|
||||
)
|
||||
|
||||
|
||||
class TestProvenanceModel:
|
||||
"""Schema-level behaviour of :class:`neuropose.io.Provenance`."""
|
||||
|
||||
def test_roundtrip_through_json(self) -> None:
|
||||
p = Provenance(
|
||||
model_sha256="a" * 64,
|
||||
model_filename="metrabs_fake.tar.gz",
|
||||
tensorflow_version="2.18.1",
|
||||
tensorflow_metal_version="1.2.0",
|
||||
numpy_version="2.0.2",
|
||||
neuropose_version="0.1.0.dev0",
|
||||
python_version="3.11.14",
|
||||
seed=42,
|
||||
deterministic=True,
|
||||
analysis_config={"step": "dtw", "nan_policy": "propagate"},
|
||||
)
|
||||
rehydrated = Provenance.model_validate(p.model_dump(mode="json"))
|
||||
assert rehydrated == p
|
||||
|
||||
def test_optional_fields_default_to_none_and_false(self) -> None:
|
||||
p = _minimal_provenance()
|
||||
assert p.tensorflow_metal_version is None
|
||||
assert p.seed is None
|
||||
assert p.deterministic is False
|
||||
assert p.analysis_config is None
|
||||
|
||||
def test_is_frozen(self) -> None:
|
||||
p = _minimal_provenance()
|
||||
with pytest.raises(ValidationError):
|
||||
p.model_sha256 = "different"
|
||||
|
||||
def test_extra_fields_forbidden(self) -> None:
|
||||
# Construct via model_validate so pyright doesn't have to prove the
|
||||
# keyword doesn't exist on the class at static-type time.
|
||||
with pytest.raises(ValidationError):
|
||||
Provenance.model_validate(
|
||||
{
|
||||
"model_sha256": "x" * 64,
|
||||
"model_filename": "x.tar.gz",
|
||||
"tensorflow_version": "2.18",
|
||||
"numpy_version": "2.0",
|
||||
"neuropose_version": "0.1",
|
||||
"python_version": "3.11.14",
|
||||
"unknown_field": "bogus",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class TestVideoPredictionsProvenance:
|
||||
"""``provenance`` field on :class:`VideoPredictions` round-trips."""
|
||||
|
||||
def test_default_is_none(self) -> None:
|
||||
vp = VideoPredictions(
|
||||
metadata=VideoMetadata(frame_count=0, fps=30.0, width=32, height=32),
|
||||
frames={},
|
||||
)
|
||||
assert vp.provenance is None
|
||||
|
||||
def test_roundtrip_with_provenance(self, tmp_path: Path) -> None:
|
||||
prov = Provenance(
|
||||
model_sha256="f" * 64,
|
||||
model_filename="metrabs.tar.gz",
|
||||
tensorflow_version="2.18.1",
|
||||
numpy_version="2.0.2",
|
||||
neuropose_version="0.1.0.dev0",
|
||||
python_version="3.11.14",
|
||||
)
|
||||
vp = VideoPredictions(
|
||||
metadata=VideoMetadata(frame_count=1, fps=30.0, width=32, height=32),
|
||||
frames={
|
||||
"frame_000000": FramePrediction(
|
||||
boxes=[[0.0, 0.0, 32.0, 32.0, 0.9]],
|
||||
poses3d=[[[1.0, 2.0, 3.0]]],
|
||||
poses2d=[[[10.0, 20.0]]],
|
||||
)
|
||||
},
|
||||
provenance=prov,
|
||||
)
|
||||
path = tmp_path / "vp.json"
|
||||
save_video_predictions(path, vp)
|
||||
loaded = load_video_predictions(path)
|
||||
assert loaded == vp
|
||||
assert loaded.provenance == prov
|
||||
|
||||
|
||||
class TestBenchmarkResultPersistence:
|
||||
def test_roundtrip_to_disk(self, tmp_path: Path) -> None:
|
||||
result = BenchmarkResult(
|
||||
|
|
|
|||
|
|
@ -1,488 +0,0 @@
|
|||
"""Tests for :mod:`neuropose.migrations`.
|
||||
|
||||
Covers both the low-level migration driver (version walking, future/missing
|
||||
errors, INFO logging) and its integration through the
|
||||
:mod:`neuropose.io` load helpers (legacy payloads round-trip; future
|
||||
payloads fail with a clear message).
|
||||
|
||||
The migration driver is tested by monkey-patching ``CURRENT_VERSION`` and
|
||||
the per-schema migration registries, so the tests exercise the full
|
||||
chain-walking machinery without needing the codebase to actually be on a
|
||||
non-initial schema version.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from neuropose import migrations
|
||||
from neuropose.io import (
|
||||
BenchmarkResult,
|
||||
FramePrediction,
|
||||
JobResults,
|
||||
VideoMetadata,
|
||||
VideoPredictions,
|
||||
load_benchmark_result,
|
||||
load_job_results,
|
||||
load_video_predictions,
|
||||
save_benchmark_result,
|
||||
save_job_results,
|
||||
save_video_predictions,
|
||||
)
|
||||
from neuropose.migrations import (
|
||||
CURRENT_VERSION,
|
||||
FutureSchemaError,
|
||||
MigrationError,
|
||||
MigrationNotFoundError,
|
||||
migrate_analysis_report,
|
||||
migrate_benchmark_result,
|
||||
migrate_job_results,
|
||||
migrate_video_predictions,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _minimal_video_predictions_payload() -> dict:
|
||||
"""A valid VideoPredictions payload at the current schema version."""
|
||||
return {
|
||||
"schema_version": CURRENT_VERSION,
|
||||
"metadata": {
|
||||
"frame_count": 1,
|
||||
"fps": 30.0,
|
||||
"width": 32,
|
||||
"height": 32,
|
||||
},
|
||||
"frames": {
|
||||
"frame_000000": {
|
||||
"boxes": [[0.0, 0.0, 32.0, 32.0, 0.95]],
|
||||
"poses3d": [[[1.0, 2.0, 3.0]]],
|
||||
"poses2d": [[[10.0, 20.0]]],
|
||||
}
|
||||
},
|
||||
"segmentations": {},
|
||||
}
|
||||
|
||||
|
||||
def _minimal_video_predictions_object() -> VideoPredictions:
|
||||
"""Same payload, as a validated pydantic object."""
|
||||
return VideoPredictions(
|
||||
metadata=VideoMetadata(frame_count=1, fps=30.0, width=32, height=32),
|
||||
frames={
|
||||
"frame_000000": FramePrediction(
|
||||
boxes=[[0.0, 0.0, 32.0, 32.0, 0.95]],
|
||||
poses3d=[[[1.0, 2.0, 3.0]]],
|
||||
poses2d=[[[10.0, 20.0]]],
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_two_version_chain(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Patch the module to look like CURRENT_VERSION=2 with a v1->v2 migration.
|
||||
|
||||
Lets the tests exercise the full migration loop even though the real
|
||||
codebase is still at CURRENT_VERSION=1.
|
||||
"""
|
||||
monkeypatch.setattr(migrations, "CURRENT_VERSION", 2)
|
||||
|
||||
def _v1_to_v2(payload: dict) -> dict:
|
||||
payload = dict(payload)
|
||||
payload["schema_version"] = 2
|
||||
payload["added_in_v2"] = "hello"
|
||||
return payload
|
||||
|
||||
monkeypatch.setattr(migrations, "_VIDEO_PREDICTIONS_MIGRATIONS", {1: _v1_to_v2})
|
||||
monkeypatch.setattr(migrations, "_BENCHMARK_RESULT_MIGRATIONS", {1: _v1_to_v2})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_three_version_chain(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Patch the module to look like CURRENT_VERSION=3 with v1->v2 and v2->v3.
|
||||
|
||||
Exercises multi-step migration chaining.
|
||||
"""
|
||||
monkeypatch.setattr(migrations, "CURRENT_VERSION", 3)
|
||||
|
||||
def _v1_to_v2(payload: dict) -> dict:
|
||||
payload = dict(payload)
|
||||
payload["schema_version"] = 2
|
||||
payload["added_in_v2"] = "alpha"
|
||||
return payload
|
||||
|
||||
def _v2_to_v3(payload: dict) -> dict:
|
||||
payload = dict(payload)
|
||||
payload["schema_version"] = 3
|
||||
payload["added_in_v3"] = "beta"
|
||||
return payload
|
||||
|
||||
monkeypatch.setattr(
|
||||
migrations,
|
||||
"_VIDEO_PREDICTIONS_MIGRATIONS",
|
||||
{1: _v1_to_v2, 2: _v2_to_v3},
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# migrate_video_predictions — driver behavior
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMigrateVideoPredictions:
|
||||
def test_current_version_payload_is_noop(self) -> None:
|
||||
payload = {"schema_version": CURRENT_VERSION, "hello": "world"}
|
||||
result = migrate_video_predictions(payload)
|
||||
assert result == payload
|
||||
|
||||
def test_missing_version_key_treated_as_v1(self) -> None:
|
||||
"""A payload with no schema_version is treated as version 1.
|
||||
|
||||
With CURRENT_VERSION == 2, the legacy payload is run through
|
||||
the registered v1 → v2 migration (which stamps ``provenance =
|
||||
None``) on the way to the current version.
|
||||
"""
|
||||
payload = {"hello": "world"}
|
||||
result = migrate_video_predictions(payload)
|
||||
assert result["hello"] == "world"
|
||||
assert result["schema_version"] == CURRENT_VERSION
|
||||
assert result["provenance"] is None
|
||||
|
||||
def test_future_version_raises(self) -> None:
|
||||
payload = {"schema_version": CURRENT_VERSION + 99}
|
||||
with pytest.raises(FutureSchemaError, match="newer than"):
|
||||
migrate_video_predictions(payload)
|
||||
|
||||
def test_non_integer_version_raises(self) -> None:
|
||||
payload = {"schema_version": "1.0"}
|
||||
with pytest.raises(MigrationError, match="invalid schema_version"):
|
||||
migrate_video_predictions(payload)
|
||||
|
||||
def test_zero_version_raises(self) -> None:
|
||||
payload = {"schema_version": 0}
|
||||
with pytest.raises(MigrationError, match="invalid schema_version"):
|
||||
migrate_video_predictions(payload)
|
||||
|
||||
def test_single_step_migration(self, fake_two_version_chain: None) -> None:
|
||||
del fake_two_version_chain
|
||||
payload = {"schema_version": 1, "original_field": "keep_me"}
|
||||
result = migrate_video_predictions(payload)
|
||||
assert result == {
|
||||
"schema_version": 2,
|
||||
"original_field": "keep_me",
|
||||
"added_in_v2": "hello",
|
||||
}
|
||||
|
||||
def test_missing_version_under_patched_chain_migrates_from_v1(
|
||||
self, fake_two_version_chain: None
|
||||
) -> None:
|
||||
del fake_two_version_chain
|
||||
# Legacy file with no version stamp: should be treated as v1 and
|
||||
# upgraded to v2.
|
||||
payload = {"legacy": True}
|
||||
result = migrate_video_predictions(payload)
|
||||
assert result["schema_version"] == 2
|
||||
assert result["added_in_v2"] == "hello"
|
||||
assert result["legacy"] is True
|
||||
|
||||
def test_multi_step_migration_chains(self, fake_three_version_chain: None) -> None:
|
||||
del fake_three_version_chain
|
||||
payload = {"schema_version": 1, "original": "yes"}
|
||||
result = migrate_video_predictions(payload)
|
||||
assert result == {
|
||||
"schema_version": 3,
|
||||
"original": "yes",
|
||||
"added_in_v2": "alpha",
|
||||
"added_in_v3": "beta",
|
||||
}
|
||||
|
||||
def test_missing_intermediate_migration_raises(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""If CURRENT advances past a version with no registered migration, fail loud."""
|
||||
monkeypatch.setattr(migrations, "CURRENT_VERSION", 3)
|
||||
# Only v1 -> v2 registered; v2 -> v3 is the missing link.
|
||||
monkeypatch.setattr(
|
||||
migrations,
|
||||
"_VIDEO_PREDICTIONS_MIGRATIONS",
|
||||
{1: lambda p: {**p, "schema_version": 2}},
|
||||
)
|
||||
with pytest.raises(MigrationNotFoundError, match="from schema_version 2"):
|
||||
migrate_video_predictions({"schema_version": 1})
|
||||
|
||||
def test_logs_at_info_on_migration(
|
||||
self,
|
||||
fake_two_version_chain: None,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
del fake_two_version_chain
|
||||
caplog.set_level(logging.INFO, logger="neuropose.migrations")
|
||||
migrate_video_predictions({"schema_version": 1})
|
||||
assert any("Migrating VideoPredictions" in record.message for record in caplog.records)
|
||||
|
||||
def test_starting_from_current_logs_nothing(
|
||||
self,
|
||||
fake_two_version_chain: None,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
del fake_two_version_chain
|
||||
caplog.set_level(logging.INFO, logger="neuropose.migrations")
|
||||
migrate_video_predictions({"schema_version": 2})
|
||||
assert not any("Migrating" in record.message for record in caplog.records)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# migrate_benchmark_result — same driver, sibling registry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMigrateBenchmarkResult:
|
||||
def test_uses_benchmark_registry_not_video_registry(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Each schema has its own migration registry; they must not cross-pollinate."""
|
||||
monkeypatch.setattr(migrations, "CURRENT_VERSION", 2)
|
||||
# Register video migration but NOT benchmark migration.
|
||||
monkeypatch.setattr(
|
||||
migrations,
|
||||
"_VIDEO_PREDICTIONS_MIGRATIONS",
|
||||
{1: lambda p: {**p, "schema_version": 2, "from_video_registry": True}},
|
||||
)
|
||||
monkeypatch.setattr(migrations, "_BENCHMARK_RESULT_MIGRATIONS", {})
|
||||
# Video migration works:
|
||||
assert migrate_video_predictions({"schema_version": 1})["from_video_registry"] is True
|
||||
# Benchmark migration should fail — no entry in its registry.
|
||||
with pytest.raises(MigrationNotFoundError):
|
||||
migrate_benchmark_result({"schema_version": 1})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# migrate_job_results — per-entry dispatch
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMigrateJobResults:
|
||||
def test_empty_dict_is_noop(self) -> None:
|
||||
assert migrate_job_results({}) == {}
|
||||
|
||||
def test_each_video_is_migrated(self, fake_two_version_chain: None) -> None:
|
||||
del fake_two_version_chain
|
||||
payload = {
|
||||
"video_a.mp4": {"schema_version": 1, "content_a": True},
|
||||
"video_b.mp4": {"schema_version": 1, "content_b": True},
|
||||
}
|
||||
result = migrate_job_results(payload)
|
||||
assert result["video_a.mp4"]["schema_version"] == 2
|
||||
assert result["video_a.mp4"]["content_a"] is True
|
||||
assert result["video_a.mp4"]["added_in_v2"] == "hello"
|
||||
assert result["video_b.mp4"]["schema_version"] == 2
|
||||
assert result["video_b.mp4"]["content_b"] is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# migrate_analysis_report
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMigrateAnalysisReport:
|
||||
def test_current_version_is_noop(self) -> None:
|
||||
"""A payload already at CURRENT_VERSION passes through unchanged."""
|
||||
payload = {"schema_version": CURRENT_VERSION, "foo": "bar"}
|
||||
assert migrate_analysis_report(payload) == payload
|
||||
|
||||
def test_missing_version_defaults_to_v1_and_fails(self) -> None:
|
||||
"""AnalysisReport first shipped at CURRENT_VERSION, so a payload
|
||||
without schema_version (defaulting to 1) would require a
|
||||
non-existent v1→v2 migration and fail with a clear error."""
|
||||
with pytest.raises(MigrationNotFoundError, match="AnalysisReport"):
|
||||
migrate_analysis_report({})
|
||||
|
||||
def test_future_version_rejected(self) -> None:
|
||||
with pytest.raises(FutureSchemaError, match="AnalysisReport"):
|
||||
migrate_analysis_report({"schema_version": CURRENT_VERSION + 5})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# register_video_predictions_migration / register_benchmark_result_migration
|
||||
# / register_analysis_report_migration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRegistration:
|
||||
def test_duplicate_registration_raises(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(migrations, "_VIDEO_PREDICTIONS_MIGRATIONS", {})
|
||||
|
||||
@migrations.register_video_predictions_migration(from_version=1)
|
||||
def _first(p: dict) -> dict:
|
||||
return p
|
||||
|
||||
with pytest.raises(RuntimeError, match="already registered"):
|
||||
|
||||
@migrations.register_video_predictions_migration(from_version=1)
|
||||
def _second(p: dict) -> dict:
|
||||
return p
|
||||
|
||||
def test_decorator_returns_callable_unchanged(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""The decorator must not wrap or rename the function."""
|
||||
monkeypatch.setattr(migrations, "_VIDEO_PREDICTIONS_MIGRATIONS", {})
|
||||
|
||||
@migrations.register_video_predictions_migration(from_version=1)
|
||||
def _fn(p: dict) -> dict:
|
||||
return p
|
||||
|
||||
assert _fn.__name__ == "_fn"
|
||||
assert _fn({"x": 1}) == {"x": 1}
|
||||
|
||||
def test_analysis_report_duplicate_registration_raises(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
monkeypatch.setattr(migrations, "_ANALYSIS_REPORT_MIGRATIONS", {})
|
||||
|
||||
@migrations.register_analysis_report_migration(from_version=2)
|
||||
def _first(p: dict) -> dict:
|
||||
return p
|
||||
|
||||
with pytest.raises(RuntimeError, match="already registered"):
|
||||
|
||||
@migrations.register_analysis_report_migration(from_version=2)
|
||||
def _second(p: dict) -> dict:
|
||||
return p
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration: load_* functions run migrations before validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLoadIntegration:
|
||||
def test_load_video_predictions_accepts_legacy_payload(self, tmp_path: Path) -> None:
|
||||
"""A VideoPredictions JSON written before schema_version existed loads cleanly."""
|
||||
legacy = _minimal_video_predictions_payload()
|
||||
del legacy["schema_version"] # Pretend this file predates versioning.
|
||||
path = tmp_path / "legacy.json"
|
||||
path.write_text(json.dumps(legacy))
|
||||
|
||||
loaded = load_video_predictions(path)
|
||||
assert loaded.schema_version == CURRENT_VERSION
|
||||
|
||||
def test_load_video_predictions_rejects_future_version(self, tmp_path: Path) -> None:
|
||||
payload = _minimal_video_predictions_payload()
|
||||
payload["schema_version"] = CURRENT_VERSION + 42
|
||||
path = tmp_path / "future.json"
|
||||
path.write_text(json.dumps(payload))
|
||||
|
||||
with pytest.raises(FutureSchemaError):
|
||||
load_video_predictions(path)
|
||||
|
||||
def test_save_then_load_roundtrips(self, tmp_path: Path) -> None:
|
||||
obj = _minimal_video_predictions_object()
|
||||
path = tmp_path / "out.json"
|
||||
save_video_predictions(path, obj)
|
||||
loaded = load_video_predictions(path)
|
||||
assert loaded == obj
|
||||
assert loaded.schema_version == CURRENT_VERSION
|
||||
|
||||
def test_load_job_results_migrates_each_video(self, tmp_path: Path) -> None:
|
||||
video_a = _minimal_video_predictions_payload()
|
||||
video_b = _minimal_video_predictions_payload()
|
||||
# Strip schema_version from both to simulate legacy file.
|
||||
del video_a["schema_version"]
|
||||
del video_b["schema_version"]
|
||||
payload = {"a.mp4": video_a, "b.mp4": video_b}
|
||||
path = tmp_path / "job.json"
|
||||
path.write_text(json.dumps(payload))
|
||||
|
||||
loaded = load_job_results(path)
|
||||
assert len(loaded) == 2
|
||||
for video in ("a.mp4", "b.mp4"):
|
||||
assert loaded[video].schema_version == CURRENT_VERSION
|
||||
|
||||
def test_save_then_load_job_results_roundtrips(self, tmp_path: Path) -> None:
|
||||
obj = JobResults(root={"video_a.mp4": _minimal_video_predictions_object()})
|
||||
path = tmp_path / "job.json"
|
||||
save_job_results(path, obj)
|
||||
loaded = load_job_results(path)
|
||||
assert loaded == obj
|
||||
|
||||
def test_load_benchmark_result_roundtrips(self, tmp_path: Path) -> None:
|
||||
"""Save → load round-trip for a realistic benchmark result."""
|
||||
from neuropose.io import BenchmarkAggregate, PerformanceMetrics
|
||||
|
||||
metrics = PerformanceMetrics(
|
||||
total_seconds=1.0,
|
||||
per_frame_latencies_ms=[10.0, 11.0],
|
||||
peak_rss_mb=100.0,
|
||||
active_device="/CPU:0",
|
||||
tensorflow_version="2.18.0",
|
||||
)
|
||||
aggregate = BenchmarkAggregate(
|
||||
repeats_measured=1,
|
||||
warmup_frames_per_pass=0,
|
||||
mean_frame_latency_ms=10.5,
|
||||
p50_frame_latency_ms=10.5,
|
||||
p95_frame_latency_ms=11.0,
|
||||
p99_frame_latency_ms=11.0,
|
||||
stddev_frame_latency_ms=0.5,
|
||||
mean_throughput_fps=95.0,
|
||||
peak_rss_mb_max=100.0,
|
||||
active_device="/CPU:0",
|
||||
tensorflow_version="2.18.0",
|
||||
)
|
||||
result = BenchmarkResult(
|
||||
video_name="test.mp4",
|
||||
repeats=2,
|
||||
warmup_frames=0,
|
||||
warmup_pass=metrics,
|
||||
measured_passes=[metrics],
|
||||
aggregate=aggregate,
|
||||
)
|
||||
path = tmp_path / "bench.json"
|
||||
save_benchmark_result(path, result)
|
||||
loaded = load_benchmark_result(path)
|
||||
assert loaded == result
|
||||
assert loaded.schema_version == CURRENT_VERSION
|
||||
|
||||
def test_load_benchmark_result_rejects_future_version(self, tmp_path: Path) -> None:
|
||||
"""Future-versioned benchmark file should raise with a clear message."""
|
||||
from neuropose.io import BenchmarkAggregate, PerformanceMetrics
|
||||
|
||||
metrics = PerformanceMetrics(
|
||||
total_seconds=1.0,
|
||||
per_frame_latencies_ms=[10.0],
|
||||
peak_rss_mb=100.0,
|
||||
active_device="/CPU:0",
|
||||
tensorflow_version="2.18.0",
|
||||
)
|
||||
aggregate = BenchmarkAggregate(
|
||||
repeats_measured=1,
|
||||
warmup_frames_per_pass=0,
|
||||
mean_frame_latency_ms=10.0,
|
||||
p50_frame_latency_ms=10.0,
|
||||
p95_frame_latency_ms=10.0,
|
||||
p99_frame_latency_ms=10.0,
|
||||
stddev_frame_latency_ms=0.0,
|
||||
mean_throughput_fps=100.0,
|
||||
peak_rss_mb_max=100.0,
|
||||
active_device="/CPU:0",
|
||||
tensorflow_version="2.18.0",
|
||||
)
|
||||
result = BenchmarkResult(
|
||||
video_name="x.mp4",
|
||||
repeats=1,
|
||||
warmup_frames=0,
|
||||
warmup_pass=metrics,
|
||||
measured_passes=[metrics],
|
||||
aggregate=aggregate,
|
||||
)
|
||||
# Serialize then hand-edit to inject a future version.
|
||||
payload = result.model_dump(mode="json")
|
||||
payload["schema_version"] = CURRENT_VERSION + 1
|
||||
path = tmp_path / "bench_future.json"
|
||||
path.write_text(json.dumps(payload))
|
||||
|
||||
with pytest.raises(FutureSchemaError):
|
||||
load_benchmark_result(path)
|
||||
|
|
@ -1,451 +0,0 @@
|
|||
"""Tests for :mod:`neuropose.reset` and the ``neuropose reset`` CLI command.
|
||||
|
||||
Coverage:
|
||||
|
||||
- :func:`find_neuropose_processes` filters by cmdline marker, classifies
|
||||
daemon vs monitor, and excludes the calling process.
|
||||
- :func:`terminate_processes` sends SIGINT, escalates to SIGKILL when
|
||||
asked, and reports survivors.
|
||||
- :func:`wipe_state` removes contents of in/, out/, failed/, the lock
|
||||
file, and ``.ingest_*`` staging dirs; honors ``keep_failed`` and
|
||||
``dry_run``.
|
||||
- :func:`reset_pipeline` skips the wipe phase when termination leaves
|
||||
survivors.
|
||||
- The ``neuropose reset`` CLI command renders previews, honors
|
||||
``--dry-run`` and ``--yes``, and exits non-zero on survivors.
|
||||
|
||||
The process-killing tests use monkeypatched ``psutil.process_iter``
|
||||
and ``os.kill`` so the suite never touches the real process table or
|
||||
sends real signals.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import signal
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from typer.testing import CliRunner
|
||||
|
||||
from neuropose.cli import EXIT_OK, EXIT_USAGE, app
|
||||
from neuropose.config import Settings
|
||||
from neuropose.interfacer import LOCK_FILENAME
|
||||
from neuropose.reset import (
|
||||
DEFAULT_GRACE_SECONDS,
|
||||
RunningProcess,
|
||||
TerminationReport,
|
||||
WipeReport,
|
||||
find_neuropose_processes,
|
||||
reset_pipeline,
|
||||
terminate_processes,
|
||||
wipe_state,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _FakeProc:
|
||||
"""Minimal stand-in for ``psutil.Process`` for the discovery tests."""
|
||||
|
||||
def __init__(self, pid: int, cmdline: list[str]) -> None:
|
||||
self.info = {"pid": pid, "cmdline": cmdline}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def settings(tmp_path: Path) -> Settings:
|
||||
"""A Settings pointing at an isolated temp data dir, with subdirs created."""
|
||||
s = Settings(data_dir=tmp_path / "jobs", model_cache_dir=tmp_path / "models")
|
||||
s.ensure_dirs()
|
||||
return s
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# find_neuropose_processes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFindNeuroposeProcesses:
|
||||
def test_classifies_watch_as_daemon(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
procs = [_FakeProc(1234, ["python", "-m", "neuropose", "watch"])]
|
||||
monkeypatch.setattr("psutil.process_iter", lambda attrs: iter(procs))
|
||||
found = find_neuropose_processes()
|
||||
assert len(found) == 1
|
||||
assert found[0].pid == 1234
|
||||
assert found[0].role == "daemon"
|
||||
|
||||
def test_classifies_serve_as_monitor(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
procs = [_FakeProc(5678, ["uv", "run", "neuropose", "serve", "--port", "8765"])]
|
||||
monkeypatch.setattr("psutil.process_iter", lambda attrs: iter(procs))
|
||||
found = find_neuropose_processes()
|
||||
assert len(found) == 1
|
||||
assert found[0].role == "monitor"
|
||||
|
||||
def test_ignores_unrelated_processes(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
procs = [
|
||||
_FakeProc(1, ["bash"]),
|
||||
_FakeProc(2, ["python", "-m", "pip", "install", "neuropose"]),
|
||||
_FakeProc(3, ["neuropose", "--help"]),
|
||||
]
|
||||
monkeypatch.setattr("psutil.process_iter", lambda attrs: iter(procs))
|
||||
assert find_neuropose_processes() == []
|
||||
|
||||
def test_excludes_self(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
import os
|
||||
|
||||
self_pid = os.getpid()
|
||||
procs = [
|
||||
_FakeProc(self_pid, ["python", "-m", "neuropose", "watch"]),
|
||||
_FakeProc(9999, ["python", "-m", "neuropose", "watch"]),
|
||||
]
|
||||
monkeypatch.setattr("psutil.process_iter", lambda attrs: iter(procs))
|
||||
found = find_neuropose_processes()
|
||||
assert [rp.pid for rp in found] == [9999]
|
||||
|
||||
def test_includes_self_when_disabled(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
import os
|
||||
|
||||
self_pid = os.getpid()
|
||||
procs = [_FakeProc(self_pid, ["python", "-m", "neuropose", "watch"])]
|
||||
monkeypatch.setattr("psutil.process_iter", lambda attrs: iter(procs))
|
||||
found = find_neuropose_processes(exclude_self=False)
|
||||
assert [rp.pid for rp in found] == [self_pid]
|
||||
|
||||
def test_handles_dead_processes(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""A NoSuchProcess raised mid-iteration must not crash the scan."""
|
||||
import psutil
|
||||
|
||||
class _RaisingProc:
|
||||
@property
|
||||
def info(self) -> dict[str, Any]:
|
||||
raise psutil.NoSuchProcess(pid=1)
|
||||
|
||||
procs = [
|
||||
_RaisingProc(),
|
||||
_FakeProc(2, ["python", "-m", "neuropose", "watch"]),
|
||||
]
|
||||
monkeypatch.setattr("psutil.process_iter", lambda attrs: iter(procs))
|
||||
found = find_neuropose_processes()
|
||||
assert [rp.pid for rp in found] == [2]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# terminate_processes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTerminateProcesses:
|
||||
def test_empty_list_is_noop(self) -> None:
|
||||
report = terminate_processes([])
|
||||
assert report.stopped == []
|
||||
assert report.survivors == []
|
||||
assert report.force_killed == []
|
||||
|
||||
def test_sigint_to_each_process(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
sent: list[tuple[int, int]] = []
|
||||
monkeypatch.setattr("os.kill", lambda pid, sig: sent.append((pid, sig)))
|
||||
# Mark all processes immediately dead so the wait loop exits fast.
|
||||
monkeypatch.setattr("neuropose.reset._is_alive", lambda pid: False)
|
||||
|
||||
rps = [
|
||||
RunningProcess(pid=10, role="daemon", cmdline="x"),
|
||||
RunningProcess(pid=20, role="monitor", cmdline="y"),
|
||||
]
|
||||
report = terminate_processes(rps, grace_seconds=0.0)
|
||||
assert sent == [(10, signal.SIGINT), (20, signal.SIGINT)]
|
||||
assert {p.pid for p in report.stopped} == {10, 20}
|
||||
assert report.survivors == []
|
||||
|
||||
def test_survivors_reported_when_force_kill_off(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr("os.kill", lambda pid, sig: None)
|
||||
# Process 10 always alive; process 20 dies after SIGINT.
|
||||
alive = {10}
|
||||
monkeypatch.setattr("neuropose.reset._is_alive", lambda pid: pid in alive)
|
||||
|
||||
rps = [
|
||||
RunningProcess(pid=10, role="daemon", cmdline="x"),
|
||||
RunningProcess(pid=20, role="monitor", cmdline="y"),
|
||||
]
|
||||
# Drop pid 20 so it appears dead immediately.
|
||||
alive.discard(20)
|
||||
report = terminate_processes(rps, grace_seconds=0.0, force_kill=False)
|
||||
assert {p.pid for p in report.stopped} == {20}
|
||||
assert {p.pid for p in report.survivors} == {10}
|
||||
assert report.force_killed == []
|
||||
|
||||
def test_force_kill_escalates_to_sigkill(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
sent: list[tuple[int, int]] = []
|
||||
monkeypatch.setattr("os.kill", lambda pid, sig: sent.append((pid, sig)))
|
||||
|
||||
# Start alive; SIGKILL "kills" by toggling the flag from inside _is_alive.
|
||||
alive = {10}
|
||||
|
||||
def fake_is_alive(pid: int) -> bool:
|
||||
if (pid, signal.SIGKILL) in sent:
|
||||
return False
|
||||
return pid in alive
|
||||
|
||||
monkeypatch.setattr("neuropose.reset._is_alive", fake_is_alive)
|
||||
|
||||
rp = RunningProcess(pid=10, role="daemon", cmdline="x")
|
||||
report = terminate_processes([rp], grace_seconds=0.0, force_kill=True)
|
||||
assert (10, signal.SIGINT) in sent
|
||||
assert (10, signal.SIGKILL) in sent
|
||||
assert {p.pid for p in report.force_killed} == {10}
|
||||
assert report.survivors == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# wipe_state
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestWipeState:
|
||||
def test_no_op_on_empty_dirs(self, settings: Settings) -> None:
|
||||
report = wipe_state(settings)
|
||||
assert report.removed_paths == []
|
||||
assert report.bytes_freed == 0
|
||||
|
||||
def test_removes_in_out_failed_contents(self, settings: Settings) -> None:
|
||||
(settings.input_dir / "job_a").mkdir()
|
||||
(settings.input_dir / "job_a" / "video.mp4").write_bytes(b"x" * 100)
|
||||
(settings.output_dir / "status.json").write_text("{}")
|
||||
(settings.failed_dir / "job_b").mkdir()
|
||||
|
||||
report = wipe_state(settings)
|
||||
names = {p.name for p in report.removed_paths}
|
||||
assert names == {"job_a", "status.json", "job_b"}
|
||||
# Containers themselves preserved.
|
||||
assert settings.input_dir.exists()
|
||||
assert settings.output_dir.exists()
|
||||
assert settings.failed_dir.exists()
|
||||
|
||||
def test_keep_failed_preserves_failed_contents(self, settings: Settings) -> None:
|
||||
(settings.input_dir / "job_a").mkdir()
|
||||
(settings.failed_dir / "job_b").mkdir()
|
||||
(settings.failed_dir / "job_b" / "evidence.log").write_text("crash")
|
||||
|
||||
report = wipe_state(settings, keep_failed=True)
|
||||
names = {p.name for p in report.removed_paths}
|
||||
assert "job_a" in names
|
||||
assert "job_b" not in names
|
||||
assert (settings.failed_dir / "job_b" / "evidence.log").exists()
|
||||
|
||||
def test_removes_lock_file(self, settings: Settings) -> None:
|
||||
(settings.data_dir / LOCK_FILENAME).write_text("12345\n")
|
||||
report = wipe_state(settings)
|
||||
assert (settings.data_dir / LOCK_FILENAME) in report.removed_paths
|
||||
assert not (settings.data_dir / LOCK_FILENAME).exists()
|
||||
|
||||
def test_removes_ingest_staging_dirs(self, settings: Settings) -> None:
|
||||
staging_a = settings.data_dir / ".ingest_abc123"
|
||||
staging_b = settings.data_dir / ".ingest_def456"
|
||||
staging_a.mkdir()
|
||||
staging_b.mkdir()
|
||||
(staging_a / "leftover.mp4").write_bytes(b"y" * 50)
|
||||
|
||||
report = wipe_state(settings)
|
||||
assert staging_a in report.removed_paths
|
||||
assert staging_b in report.removed_paths
|
||||
assert not staging_a.exists()
|
||||
assert not staging_b.exists()
|
||||
|
||||
def test_dry_run_reports_without_removing(self, settings: Settings) -> None:
|
||||
(settings.input_dir / "job_a").mkdir()
|
||||
(settings.input_dir / "job_a" / "video.mp4").write_bytes(b"z" * 200)
|
||||
|
||||
report = wipe_state(settings, dry_run=True)
|
||||
assert len(report.removed_paths) == 1
|
||||
assert report.bytes_freed == 200
|
||||
# Nothing actually deleted.
|
||||
assert (settings.input_dir / "job_a" / "video.mp4").exists()
|
||||
|
||||
def test_bytes_freed_recurses_into_subdirs(self, settings: Settings) -> None:
|
||||
job = settings.input_dir / "job_a"
|
||||
job.mkdir()
|
||||
(job / "a.mp4").write_bytes(b"a" * 100)
|
||||
(job / "nested").mkdir()
|
||||
(job / "nested" / "b.mp4").write_bytes(b"b" * 250)
|
||||
|
||||
report = wipe_state(settings, dry_run=True)
|
||||
assert report.bytes_freed == 350
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# reset_pipeline
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResetPipeline:
|
||||
def test_dry_run_skips_termination(
|
||||
self,
|
||||
settings: Settings,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
rp = RunningProcess(pid=10, role="daemon", cmdline="x")
|
||||
monkeypatch.setattr("neuropose.reset.find_neuropose_processes", lambda: [rp])
|
||||
|
||||
# Sentinel to detect termination calls.
|
||||
def _should_not_be_called(*args: Any, **kwargs: Any) -> TerminationReport:
|
||||
raise AssertionError("dry_run must not invoke terminate_processes")
|
||||
|
||||
monkeypatch.setattr("neuropose.reset.terminate_processes", _should_not_be_called)
|
||||
|
||||
report = reset_pipeline(settings, dry_run=True)
|
||||
assert report.dry_run is True
|
||||
assert report.discovered == [rp]
|
||||
assert report.termination.stopped == []
|
||||
|
||||
def test_skips_wipe_when_survivors_remain(
|
||||
self,
|
||||
settings: Settings,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
rp = RunningProcess(pid=10, role="daemon", cmdline="x")
|
||||
monkeypatch.setattr("neuropose.reset.find_neuropose_processes", lambda: [rp])
|
||||
monkeypatch.setattr(
|
||||
"neuropose.reset.terminate_processes",
|
||||
lambda procs, **_: TerminationReport(survivors=list(procs)),
|
||||
)
|
||||
|
||||
# Seed something to wipe so we can confirm it's untouched.
|
||||
(settings.input_dir / "job_a").mkdir()
|
||||
|
||||
report = reset_pipeline(settings, dry_run=False)
|
||||
assert report.wipe_skipped_due_to_survivors is True
|
||||
assert report.wipe.removed_paths == []
|
||||
assert (settings.input_dir / "job_a").exists()
|
||||
|
||||
def test_wipes_when_all_processes_stopped(
|
||||
self,
|
||||
settings: Settings,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
rp = RunningProcess(pid=10, role="daemon", cmdline="x")
|
||||
monkeypatch.setattr("neuropose.reset.find_neuropose_processes", lambda: [rp])
|
||||
monkeypatch.setattr(
|
||||
"neuropose.reset.terminate_processes",
|
||||
lambda procs, **_: TerminationReport(stopped=list(procs)),
|
||||
)
|
||||
|
||||
(settings.input_dir / "job_a").mkdir()
|
||||
|
||||
report = reset_pipeline(settings, dry_run=False)
|
||||
assert report.wipe_skipped_due_to_survivors is False
|
||||
assert any(p.name == "job_a" for p in report.wipe.removed_paths)
|
||||
assert not (settings.input_dir / "job_a").exists()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI: neuropose reset
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner() -> CliRunner:
|
||||
return CliRunner()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def env_data_dir(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> Path:
|
||||
"""Point NEUROPOSE_DATA_DIR at an isolated temp dir for CLI tests."""
|
||||
data_dir = tmp_path / "jobs"
|
||||
data_dir.mkdir()
|
||||
(data_dir / "in").mkdir()
|
||||
(data_dir / "out").mkdir()
|
||||
(data_dir / "failed").mkdir()
|
||||
monkeypatch.setenv("NEUROPOSE_DATA_DIR", str(data_dir))
|
||||
return data_dir
|
||||
|
||||
|
||||
class TestResetCli:
|
||||
def test_reset_dry_run_does_not_modify(
|
||||
self,
|
||||
runner: CliRunner,
|
||||
env_data_dir: Path,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
(env_data_dir / "in" / "job_a").mkdir()
|
||||
monkeypatch.setattr("neuropose.reset.find_neuropose_processes", list)
|
||||
|
||||
result = runner.invoke(app, ["reset", "--dry-run"])
|
||||
assert result.exit_code == EXIT_OK, result.output
|
||||
assert "would remove" in result.output
|
||||
assert "(dry-run; no changes made)" in result.output
|
||||
assert (env_data_dir / "in" / "job_a").exists()
|
||||
|
||||
def test_reset_yes_skips_confirmation_and_wipes(
|
||||
self,
|
||||
runner: CliRunner,
|
||||
env_data_dir: Path,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
(env_data_dir / "in" / "job_a").mkdir()
|
||||
(env_data_dir / "in" / "job_a" / "video.mp4").write_bytes(b"x" * 100)
|
||||
monkeypatch.setattr("neuropose.reset.find_neuropose_processes", list)
|
||||
|
||||
result = runner.invoke(app, ["reset", "--yes"])
|
||||
assert result.exit_code == EXIT_OK, result.output
|
||||
assert "removed" in result.output
|
||||
assert not (env_data_dir / "in" / "job_a").exists()
|
||||
|
||||
def test_reset_aborts_on_no_confirmation(
|
||||
self,
|
||||
runner: CliRunner,
|
||||
env_data_dir: Path,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
(env_data_dir / "in" / "job_a").mkdir()
|
||||
monkeypatch.setattr("neuropose.reset.find_neuropose_processes", list)
|
||||
|
||||
# typer.confirm reads from stdin; "n\n" declines.
|
||||
result = runner.invoke(app, ["reset"], input="n\n")
|
||||
assert result.exit_code == EXIT_USAGE, result.output
|
||||
assert "aborted" in result.output
|
||||
assert (env_data_dir / "in" / "job_a").exists()
|
||||
|
||||
def test_reset_clean_state_is_noop(
|
||||
self,
|
||||
runner: CliRunner,
|
||||
env_data_dir: Path,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
del env_data_dir
|
||||
monkeypatch.setattr("neuropose.reset.find_neuropose_processes", list)
|
||||
|
||||
result = runner.invoke(app, ["reset"])
|
||||
assert result.exit_code == EXIT_OK, result.output
|
||||
assert "nothing to do" in result.output
|
||||
|
||||
def test_reset_reports_survivors_with_nonzero_exit(
|
||||
self,
|
||||
runner: CliRunner,
|
||||
env_data_dir: Path,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
del env_data_dir
|
||||
rp = RunningProcess(pid=42, role="daemon", cmdline="neuropose watch")
|
||||
monkeypatch.setattr("neuropose.reset.find_neuropose_processes", lambda: [rp])
|
||||
monkeypatch.setattr(
|
||||
"neuropose.reset.terminate_processes",
|
||||
lambda procs, **_: TerminationReport(survivors=list(procs)),
|
||||
)
|
||||
|
||||
result = runner.invoke(app, ["reset", "--yes"])
|
||||
assert result.exit_code == EXIT_USAGE, result.output
|
||||
assert "did not exit" in result.output
|
||||
assert "pid 42" in result.output
|
||||
assert "--force-kill" in result.output
|
||||
|
||||
|
||||
def test_default_grace_seconds_constant_is_reasonable() -> None:
|
||||
"""Lock the default grace period so a refactor cannot silently lower it."""
|
||||
assert 5.0 <= DEFAULT_GRACE_SECONDS <= 60.0
|
||||
|
||||
|
||||
def test_wipe_report_default_construction() -> None:
|
||||
r = WipeReport()
|
||||
assert r.removed_paths == []
|
||||
assert r.bytes_freed == 0
|
||||
Loading…
Reference in New Issue