Compare commits

..

12 Commits

Author SHA1 Message Date
Levi Neuwirth 4f3a6241fb add example analysis configs and integration suite
Three reference configs under examples/analysis/:
  - minimal.yaml: full-trial DTW on raw coordinates, no alignment or
    segmentation. Smallest working example; a starting template.
  - paper_c_headline.yaml: the representative Paper C pipeline.
    Bilateral gait-cycle segmentation, per-sequence Procrustes, and
    joint-angle DTW on knee and hip flexion triplets.
  - per_joint_debug.yaml: per-joint DTW breakdown for diagnosing
    which joint drives an unexpected distance.

tests/integration/test_analyze_examples.py exercises each example
twice: load_config must accept the YAML (catches drift between the
examples and the current schema), and run_analysis must execute the
config end-to-end against synthetic predictions (catches drift
between the examples and the executor). The Paper C example has an
extra guard verifying the knee-flexion triplets haven't been edited
to something unexpected.

Also wires docs/api/pipeline.md into the mkdocs nav so mkdocstrings
surfaces the full schema and executor API.
2026-04-22 11:49:47 -04:00
Levi Neuwirth 01b374451f wire neuropose analyze to run_analysis
Replaces the placeholder stub that returned EXIT_PENDING with a real
analyze --config <yaml> [--output <json>] command. Loads the YAML,
validates through AnalysisConfig (so typos fail with a clear
ValidationError before any predictions load), runs the pipeline,
and writes the AnalysisReport atomically.

Surfaces YAML parse errors and schema violations as EXIT_USAGE=2
with messages pointing at the offending file. Missing predictions
files during execution also surface as EXIT_USAGE rather than a
bare traceback.

Prints a one-line summary after the run: segmentation counts, the
analysis kind, and — for DTW — the per-segment distance count and
mean. --output / -o overrides the report path declared in the
config, useful when sweeping a single config over multiple input
pairs from a shell loop.
2026-04-22 11:40:24 -04:00
Levi Neuwirth dc48988450 add run_analysis pipeline executor
run_analysis(config) loads the predictions files named in the config,
applies the configured segmentation, dispatches to the selected
analysis kind (DTW, stats, or none), and emits a fully populated
AnalysisReport. The report's Provenance inherits the inference-time
envelope from the primary input with analysis_config stamped in, so
the output is self-describing even if the source YAML is later lost.

For DTW runs with segmentation, segments are paired one-to-one by
index across primary and reference, truncating to min of the two
counts. Bilateral segmentations emit per-side distances under
"left_heel_strikes[i]" / "right_heel_strikes[i]" labels. dtw_per_joint
stores its full per-unit breakdown in the per_joint_distances field
and reports the sum as the representative scalar distance.

Also ships load_config (YAML), save_report (atomic JSON write), and
load_report (rehydrate via the migration chain) so the executor can
be driven end-to-end from Python without the CLI. The CLI wiring
lands in the next commit.
2026-04-22 11:25:02 -04:00
Levi Neuwirth 979beb1078 add AnalysisConfig and AnalysisReport schemas
neuropose.analyzer.pipeline ships two top-level pydantic schemas:

AnalysisConfig — what a user writes in YAML. Inputs (primary plus
optional reference), preprocessing (person_index, room to grow),
optional segmentation as a discriminated union of gait_cycles,
gait_cycles_bilateral, and extractor, and a required analysis stage
as a discriminated union of dtw, stats, none.

AnalysisReport — runtime output with config, Provenance envelope,
per-input summaries, produced segmentations, and a results payload
whose shape mirrors the stage (DtwResults, StatsResults, NoResults).
schema_version defaults to CURRENT_VERSION.

Cross-field invariants enforced at parse time via model_validator:
method='dtw_relation' requires joint_i/joint_j and refuses
representation='angles'; representation='angles' requires non-empty
angle_triplets; analysis.kind='dtw' requires inputs.reference;
analysis.kind='stats' refuses a reference. Typos fail in
milliseconds instead of after a multi-minute predictions load.

neuropose.migrations gains a third registry for AnalysisReport
(_ANALYSIS_REPORT_MIGRATIONS + register_analysis_report_migration +
migrate_analysis_report), ready for future schema changes. No v1→v2
migration is registered because AnalysisReport first shipped at v2.

Execution, CLI wiring, and example configs land in follow-up commits.
2026-04-22 11:13:36 -04:00
Levi Neuwirth 87461a17d0 add joint-angle representation and nan_policy to DTW
representation: Literal["coords", "angles"] on dtw_all and
dtw_per_joint. "coords" preserves the 0.1 behaviour; "angles" runs
extract_joint_angles on caller-supplied angle_triplets before DTW,
giving translation-, rotation-, and scale-invariant distances that
are directly interpretable as clinical joint-range comparisons. Under
dtw_per_joint the "unit" becomes one angle column per triplet.

nan_policy: Literal["propagate", "interpolate", "drop"] on all three
entry points. "propagate" (default) lets NaN hit fastdtw, which raises
ValueError via numpy.asarray_chkfinite — the safest default because it
surfaces degenerate-vector problems rather than silently corrupting a
distance. "interpolate" runs 1D linear interpolation per feature
column; "drop" removes NaN frames before DTW.

dtw_relation stays a standalone convenience entry point. Paper C's
typical call becomes dtw_all(representation="angles",
align="procrustes_per_sequence"); see TECHNICAL.md Phase 0.
2026-04-18 18:02:13 -04:00
Levi Neuwirth a1c495b2fd add gait-cycle segmentation to analyzer
segment_gait_cycles wraps segment_predictions with a joint_axis
extractor and gait-appropriate defaults (joint="rhee", axis="y",
min_cycle_seconds=0.4). The joint name resolves through
joint_index; axis is a "x"/"y"/"z" string literal converted to the
numeric index internally. An invert flag flips peaks and valleys
for recording conventions where a heel-strike appears as a local
minimum.

segment_gait_cycles_bilateral composes the single-side function
twice and returns a {"left_heel_strikes", "right_heel_strikes"}
dict shape-compatible with VideoPredictions.segmentations, so the
caller can merge it directly into a predictions object.

Pathological gaits (shuffling, walker-assisted) degrade to an
empty segments list rather than raising, inherited from
segment_by_peaks' peak-not-found behaviour.

Closes the gait-cycle segmentation item in TECHNICAL.md Phase 0.
2026-04-18 17:50:21 -04:00
Levi Neuwirth f8368ca861 apply ruff format to recent commits
Pure formatter drift picked up after running ruff format across the
tree. No behavioural changes: line-length unwrapping where strings
fit on one line, one blank-line separator added to _model.py.
2026-04-18 17:49:44 -04:00
Levi Neuwirth bcce5315be add neuropose reset subcommand for pipeline-wide state wipe
Three-layer module: find_neuropose_processes() scans the process
table via psutil for running watch/serve instances; terminate_processes()
SIGINTs with a configurable grace period before optional SIGKILL
escalation; wipe_state() clears $data_dir/in/, out/, failed/,
the .neuropose.lock file, and leftover .ingest_<uuid>/ staging dirs
while preserving the container directories themselves. reset_pipeline()
composes the three and refuses to wipe while any process survives
termination.

CLI wraps it with --yes/-y, --keep-failed, --force-kill,
--grace-seconds, and --dry-run/-n. Always prints a preview before
prompting; returns EXIT_USAGE=2 when survivors block the wipe.

Unblocks the Mac benchmark iteration loop where partially-complete
runs need to be cleared between experiments.
2026-04-18 17:15:24 -04:00
Levi Neuwirth cc9fcb4adb add Procrustes alignment to analyzer
procrustes_align in neuropose.analyzer.features — Kabsch closed-form
rigid alignment between two pose sequences, with per_frame and
per_sequence modes and an optional scale flag for cross-subject
comparisons. Returns aligned arrays plus an AlignmentDiagnostics
dataclass reporting rotation magnitude (mean and max), translation
magnitude (mean and max), and scale factor, so downstream code can
flag suspiciously large transforms.

Wired into every DTW entry point via a new keyword-only align
parameter — "none" (the default) preserves the 0.1 raw-coordinate
behaviour, while "procrustes_per_frame" and "procrustes_per_sequence"
route inputs through procrustes_align before DTW runs. Rejects
mismatched frame counts when alignment is requested (Procrustes
requires a 1:1 correspondence).

Phase 0 of TECHNICAL.md: closes one of the three methodological
gaps Paper C's pipeline is waiting on.
2026-04-18 17:11:53 -04:00
Levi Neuwirth fe8e417aa0 add Provenance subobject and LoadedModel
Captures the MeTRAbs SHA-256 and filename plus tensorflow /
tensorflow-metal / numpy / neuropose / python versions, and reserves
slots for seed, deterministic, and analysis_config. Populated
automatically by Estimator.process_video when the model was loaded via
load_model; propagates into JobResults and BenchmarkResult via the
existing output path. None on the injected-model test path where no
SHA is known.

_model.load_metrabs_model now returns a LoadedModel dataclass so the
estimator can bundle the TF handle with the pinned SHA without
re-hashing the tarball on every daemon startup. All test fakes and
the integration smoke tests updated to unwrap .model.

Bumps the optional schema_version field on VideoPredictions and
BenchmarkResult to default=CURRENT_VERSION so fresh writes stamp the
latest version; legacy payloads without it are migrated on load via
the chain registered in the previous commit.
2026-04-18 17:10:52 -04:00
Levi Neuwirth 9c549fd9e2 add neuropose.migrations for schema versioning
One shared CURRENT_VERSION across the three top-level serialised
payloads (VideoPredictions, JobResults, BenchmarkResult), with
per-schema registries populated via register_*_migration(from_version)
decorators. FutureSchemaError and MigrationNotFoundError surface bad
chains clearly. CURRENT_VERSION=2 with v1→v2 migrations registered
that add an optional provenance field to the payload dicts.

Tested standalone; io.py is wired through the migrator in a follow-up
commit that introduces the Provenance schema those migrations target.
2026-04-18 17:02:50 -04:00
Levi Neuwirth 2469c34676 add TECHNICAL.md engineering roadmap
Phase 0 (C-enabling pipeline work) → Phase 1 (Paper C clinical
validation) → Phase 2 (open-source release + Paper A), with Track 2
(clinical platform) as a contingent side track. Mirrors RESEARCH.md but
for engineering scope rather than methodology.
2026-04-18 17:00:36 -04:00
33 changed files with 7391 additions and 95 deletions

View File

@ -202,6 +202,170 @@ be split into per-release sections once tagging begins.
(with a `.collisions` list of offending names). The running
daemon needs no changes — ingested job dirs are picked up on the
next poll.
- **`neuropose.migrations`** — schema-migration infrastructure for
the three top-level serialised payloads (`VideoPredictions`,
`JobResults`, `BenchmarkResult`). Every payload carries a
`schema_version` field defaulting to `CURRENT_VERSION`; on load,
the raw JSON dict is passed through `migrate_video_predictions` /
`migrate_job_results` / `migrate_benchmark_result` *before*
pydantic validation so files written by older NeuroPose versions
upgrade transparently. One shared `CURRENT_VERSION` counter;
per-schema migration registries populated via
`register_video_predictions_migration(from_version)` and
`register_benchmark_result_migration(from_version)` decorators.
`JobResults` is a `RootModel` with no envelope of its own, so its
migration runs per-entry across the root mapping. The driver raises
`FutureSchemaError` for payloads newer than the current build
(clear upgrade-NeuroPose message), `MigrationNotFoundError` for
missing chain links (indicates a `CURRENT_VERSION` bump that forgot
its migration), and logs at INFO on each version advance. Currently
at `CURRENT_VERSION = 2`, with registered v1 → v2 migrations for
`VideoPredictions` and `BenchmarkResult` that add the optional
`provenance` field.
- **`neuropose.analyzer.features.procrustes_align`** — Kabsch
rigid-alignment helper for pose sequences, plus a
`ProcrustesMode` literal (`"per_frame"` | `"per_sequence"`) and a
frozen `AlignmentDiagnostics` dataclass (`rotation_deg`,
`rotation_deg_max`, `translation`, `translation_max`, `scale`,
plus the mode that produced them). Per-sequence mode fits one
rigid transform across the whole trial; per-frame fits an
independent transform per frame. Optional `scale=True` fits a
uniform scale factor for cross-subject comparisons. Wired into
every DTW entry point in `neuropose.analyzer.dtw` via a new
keyword-only `align: AlignMode = "none"` parameter — `"none"`
preserves the 0.1 raw-coordinate behaviour, while
`"procrustes_per_frame"` and `"procrustes_per_sequence"` route
inputs through `procrustes_align` before DTW runs so the returned
distance is rotation- and translation-invariant. Paper C's
pipeline is expected to set `align="procrustes_per_sequence"`;
see `TECHNICAL.md` Phase 0.
- **`neuropose.analyzer.dtw.Representation`** and
**`neuropose.analyzer.dtw.NanPolicy`** — two new Literal types
exposing orthogonal DTW preprocessing knobs on every entry point.
`representation` (on `dtw_all` and `dtw_per_joint`) switches the
per-frame feature vector between `"coords"` (the 0.1 default) and
`"angles"`, which runs `extract_joint_angles` on the supplied
`angle_triplets` first — yielding distances that are translation-,
rotation-, and scale-invariant by construction, and directly
interpretable in clinical terms. `nan_policy` (on all three entry
points) selects `"propagate"` (surface fastdtw's ValueError on
NaN — the default), `"interpolate"` (linear fill per feature
column), or `"drop"` (remove NaN frames before DTW); the
policy is applied consistently whether NaN originated from the
angles pipeline or from corrupted upstream coordinates.
`dtw_relation` stays a standalone convenience entry point for
two-joint displacement DTW; users who prefer a unified API can
express the same computation via `dtw_all` with an appropriate
pair of angle triplets or run `dtw_relation` directly.
- **`neuropose.analyzer.pipeline`** (schemas) — declarative
analysis-pipeline configuration and output artifact, parseable from
YAML or JSON via pydantic. `AnalysisConfig` captures a full
experiment: inputs (primary + optional reference predictions
files), preprocessing (person index, with room to grow),
optional segmentation (`gait_cycles` / `gait_cycles_bilateral` /
`extractor` discriminated union), and a required analysis stage
(`dtw` / `stats` / `none` discriminated union). `AnalysisReport`
is the runtime output: carries the originating config, a
`Provenance` envelope with `analysis_config` populated, per-input
summaries, produced segmentations, and an analysis-result payload
that mirrors the stage choice (`DtwResults`, `StatsResults`, or
`NoResults`). Cross-field invariants — `method="dtw_relation"`
requires `joint_i`/`joint_j`, `representation="angles"` requires
non-empty `angle_triplets`, `analysis.kind="dtw"` requires
`inputs.reference`, `analysis.kind="stats"` refuses a reference —
are enforced at parse time via `model_validator` so typos fail in
milliseconds instead of after a multi-minute predictions load.
`AnalysisReport` carries a `schema_version` field defaulting to
`CURRENT_VERSION = 2`, with a new
`register_analysis_report_migration` decorator and
`migrate_analysis_report` driver in `neuropose.migrations` ready
for future schema changes. `run_analysis(config)` loads the named
predictions files, applies the configured segmentation, dispatches
to the selected analysis kind (DTW, stats, or none), and emits a
fully populated `AnalysisReport` whose `Provenance` inherits the
inference-time envelope from the primary input with
`analysis_config` stamped in, so the report is self-describing
even if the source YAML is lost. For DTW runs with segmentation,
segments are paired one-to-one by index across primary and
reference, truncating to `min(len_primary, len_reference)`;
bilateral segmentations emit per-side distances under
`"left_heel_strikes[i]"` / `"right_heel_strikes[i]"` labels.
`load_config(path)` parses YAML, `save_report(path, report)`
writes atomically, and `load_report(path)` rehydrates via the
migration chain. Wired to the CLI as `neuropose analyze --config
<yaml> [--output <json>]` — replaces the placeholder stub that
previously returned `EXIT_PENDING`. The CLI surfaces schema
violations and YAML parse errors as `EXIT_USAGE=2` with a clear
message pointing at the offending file, prints a one-line summary
of the run (segmentation counts, analysis kind, per-segment
distance count + mean for DTW), and supports `--output`/`-o` to
override the report path declared in the config (useful for
sweeping a single config over multiple input pairs from a shell
loop). Ships three example configs under `examples/analysis/`:
`minimal.yaml` (smallest working DTW pipeline), `paper_c_headline.yaml`
(representative Paper C config with bilateral gait-cycle
segmentation, per-sequence Procrustes, and joint-angle DTW on
knee/hip triplets), and `per_joint_debug.yaml` (per-joint DTW
breakdown for diagnosing which joint drives an unexpected
distance). An integration suite exercises each example against
synthetic predictions so schema drift between the YAMLs and the
executor fails CI, not silently at run time. Documented in
`docs/api/pipeline.md`.
- **`neuropose.analyzer.segment.segment_gait_cycles`** and
**`segment_gait_cycles_bilateral`** — clinical convenience
wrappers over `segment_predictions` that pre-fill a `joint_axis`
extractor with gait-appropriate defaults (`joint="rhee"`,
`axis="y"`, `min_cycle_seconds=0.4`). The single-side entry point
accepts any berkeley_mhad_43 joint name and any spatial axis as a
string literal `"x" | "y" | "z"`, plus an `invert` flag for
recordings whose vertical axis runs opposite to MeTRAbs's
Y-down world-coordinate convention. The bilateral wrapper runs
the detection on both `lhee` and `rhee` and returns the two
results under `"left_heel_strikes"` / `"right_heel_strikes"`
keys — shape-compatible with `VideoPredictions.segmentations` so
the dict can be merged in directly. Degrades gracefully on
pathological gaits (shuffling, walker-assisted) by returning an
empty segments list rather than raising. Closes the gait-cycle
segmentation item in `TECHNICAL.md` Phase 0.
- **`neuropose.io.Provenance`** — reproducibility envelope for every
inference run. Populated automatically by `Estimator.process_video`
when the model was loaded via `load_model` (the production path)
and attached to the output `VideoPredictions`; propagates from
there into `JobResults` (per-video) and `BenchmarkResult` (via the
benchmark loop). Captures the MeTRAbs artifact SHA-256 and
filename, `tensorflow` / `tensorflow-metal` / `numpy` /
`neuropose` / Python versions, and reserved slots for a `seed`,
`deterministic` flag (Track 2), and `analysis_config` (Phase 0
YAML pipeline). `None` on the injected-model test path where
NeuroPose has no way to fingerprint the supplied artifact. Frozen
pydantic model with `extra="forbid"` and
`protected_namespaces=()` so the `model_*` field names do not
collide with pydantic v2's internal namespace. `_model.load_metrabs_model`
now returns a `LoadedModel` dataclass bundling the TF handle with
the pinned SHA and filename so the estimator can build the
`Provenance` without re-hashing the tarball.
- **`neuropose.reset`** — pipeline-wide reset utility for the
benchmark / iteration loop. `find_neuropose_processes()` scans the
OS process table (via `psutil`) for running `neuropose watch` and
`neuropose serve` instances and classifies each as `daemon` or
`monitor`. `terminate_processes()` SIGINTs them, polls for graceful
exit up to a configurable grace period, and optionally escalates
to SIGKILL with `force_kill=True`. `wipe_state()` removes the
contents of `$data_dir/in/`, `$data_dir/out/` (including
`status.json`), `$data_dir/failed/` (unless `keep_failed=True`),
the `.neuropose.lock` file, and any leftover `.ingest_<uuid>/`
staging dirs from interrupted ingests; container directories
themselves are preserved so the daemon does not need to recreate
them on next startup. `reset_pipeline()` composes the three with
one safety guard: if any process survives termination, the wipe
phase is skipped and the returned `ResetReport` flags
`wipe_skipped_due_to_survivors`, because removing `$data_dir`
out from under an active daemon would corrupt its in-flight
writes. Surfaced as `neuropose reset` in the CLI with
`--yes/-y`, `--keep-failed`, `--force-kill`, `--grace-seconds`,
and `--dry-run/-n` flags; the command always prints a preview
before prompting (skipped under `--yes`) and returns
`EXIT_USAGE=2` when survivors block the wipe.
- **`neuropose.benchmark`** — multi-pass inference benchmarking for
a single video. `run_benchmark()` runs `process_video` N times
(default 5), always discards the first pass as warmup (graph
@ -221,14 +385,18 @@ be split into per-release sections once tagging begins.
imports for the heavy dependencies:
- `analyzer.dtw` — three DTW entry points (`dtw_all`,
`dtw_per_joint`, `dtw_relation`) over fastdtw, with a frozen
`DTWResult` dataclass. See `RESEARCH.md` for the ongoing
`DTWResult` dataclass and three orthogonal preprocessing knobs
(`align`, `representation`, `nan_policy`). See `RESEARCH.md`
for the ongoing
methodology investigation.
- `analyzer.features``predictions_to_numpy`,
`normalize_pose_sequence` (uniform and axis-wise),
`pad_sequences` (edge-padding), `extract_joint_angles` (NaN on
degenerate vectors), `extract_feature_statistics`
(`FeatureStatistics` frozen dataclass), and a `find_peaks` thin
wrapper around `scipy.signal.find_peaks`.
`pad_sequences` (edge-padding), `procrustes_align` (Kabsch
rigid alignment, per-frame or per-sequence, optional uniform
scaling), `extract_joint_angles` (NaN on degenerate vectors),
`extract_feature_statistics` (`FeatureStatistics` frozen
dataclass), and a `find_peaks` thin wrapper around
`scipy.signal.find_peaks`.
- `analyzer.segment` — repetition segmentation for trials in
which a subject performs the same movement several times. A
three-layer API: `segment_by_peaks` (pure 1D
@ -238,7 +406,11 @@ be split into per-release sections once tagging begins.
time-based parameters to frame counts via `metadata.fps`), and
`slice_predictions` (split a `VideoPredictions` into one per
detected repetition with re-keyed frame names and a rewritten
`frame_count`). Ships four extractor factories —
`frame_count`). Gait-specific convenience wrappers
`segment_gait_cycles` (single heel) and
`segment_gait_cycles_bilateral` (both heels, returning a dict
keyed by `"left_heel_strikes"` / `"right_heel_strikes"`) sit
above `segment_predictions` with clinical defaults. Ships four extractor factories —
`joint_axis`, `joint_pair_distance`, `joint_speed`, and
`joint_angle` — plus a `JOINT_NAMES` constant for the
berkeley_mhad_43 skeleton with a `joint_index(name)` lookup,
@ -248,7 +420,7 @@ be split into per-release sections once tagging begins.
`slow`) loads the real model and asserts the constant still
matches, so any upstream skeleton drift fails CI.
- **`neuropose.cli`** — Typer-based command-line interface with
seven subcommands: `watch` (run the daemon), `process <video>`
eight subcommands: `watch` (run the daemon), `process <video>`
(run the estimator on a single video), `ingest <archive>` (unzip
a video archive into per-video job directories under
`$data_dir/in/` with validation-before-write and atomic
@ -259,6 +431,13 @@ be split into per-release sections once tagging begins.
KeyboardInterrupt exits with the standard shell-interruption
code and an `OSError` at bind time is translated to a clean
usage error with the bind target in the message),
`reset` (stop the daemon and monitor, then wipe pipeline state
for a clean restart — wraps `neuropose.reset` with a confirmation
prompt, `--dry-run` preview, `--keep-failed` to preserve the
forensic quarantine, `--force-kill` to escalate to SIGKILL after
the SIGINT grace period, and `--grace-seconds` to tune the wait;
refuses to wipe state while any process survives termination so
active writes cannot be corrupted),
`segment <results>` (post-hoc repetition segmentation — loads a
JobResults or a single VideoPredictions, runs
`neuropose.analyzer.segment.segment_predictions` with the chosen
@ -273,7 +452,9 @@ be split into per-release sections once tagging begins.
the resulting `poses3d` arrays, and reports throughput speedup
and max divergence in mm — the missing Apple Silicon numerical
verification answer from `RESEARCH.md`), and
`analyze <results>` (stub). The `segment` subcommand accepts
`analyze --config <yaml>` (run the declarative analysis
pipeline — see the dedicated entry above for scope). The
`segment` subcommand accepts
joint specifiers as either berkeley_mhad_43 names (`lwri`,
`rwri`, …) or integer indices, and refuses to overwrite an
existing segmentation of the same name without `--force`.

1191
TECHNICAL.md Normal file

File diff suppressed because it is too large Load Diff

3
docs/api/migrations.md Normal file
View File

@ -0,0 +1,3 @@
# `neuropose.migrations`
::: neuropose.migrations

3
docs/api/pipeline.md Normal file
View File

@ -0,0 +1,3 @@
# `neuropose.analyzer.pipeline`
::: neuropose.analyzer.pipeline

3
docs/api/reset.md Normal file
View File

@ -0,0 +1,3 @@
# `neuropose.reset`
::: neuropose.reset

View File

@ -0,0 +1,26 @@
# Minimal DTW config: full-trial comparison on raw 3D coordinates,
# no Procrustes alignment, no segmentation. The simplest working
# example; use this as a starting template and add stages as needed.
#
# Run:
# neuropose analyze --config examples/analysis/minimal.yaml \
# --output out/minimal_report.json
#
# (Substitute real paths for `inputs.primary` and `inputs.reference`
# before running.)
config_version: 1
inputs:
primary: data/trial_a.json
reference: data/trial_b.json
analysis:
kind: dtw
method: dtw_all
align: none
representation: coords
nan_policy: propagate
output:
report: out/minimal_report.json

View File

@ -0,0 +1,48 @@
# Paper C headline config: cycle-segmented joint-angle DTW with
# per-sequence Procrustes alignment. This is the representative
# Paper C pipeline — bilateral gait cycles drive the segmentation
# so distances are reported per-stride per-side, and the angle-space
# representation makes the distance clinically interpretable
# (knee flexion angle, hip extension angle, etc.).
#
# Joint-triplet indices below target the berkeley_mhad_43 skeleton:
# - (27, 31, 32): left hip → left knee → left ankle (left knee flex)
# - (35, 39, 40): right hip → right knee → right ankle (right knee flex)
# - (34, 27, 31): back hip ← left hip → left knee (left hip flex)
# - (34, 35, 39): back hip ← right hip → right knee (right hip flex)
#
# See neuropose.analyzer.JOINT_NAMES for the full 43-joint table.
#
# Run:
# neuropose analyze --config examples/analysis/paper_c_headline.yaml \
# --output out/paper_c_report.json
config_version: 1
inputs:
primary: data/subject_trial.json
reference: data/mocap_reference.json
preprocessing:
person_index: 0
segmentation:
kind: gait_cycles_bilateral
axis: y
invert: false
min_cycle_seconds: 0.4
analysis:
kind: dtw
method: dtw_all
align: procrustes_per_sequence
representation: angles
angle_triplets:
- [27, 31, 32] # left knee flexion
- [35, 39, 40] # right knee flexion
- [34, 27, 31] # left hip flexion
- [34, 35, 39] # right hip flexion
nan_policy: interpolate
output:
report: out/paper_c_report.json

View File

@ -0,0 +1,36 @@
# Per-joint debug config: runs dtw_per_joint on raw coordinates so
# the resulting report carries a full (segments × joints) distance
# breakdown. Useful when one joint is suspected of driving an
# otherwise-unexpected DTW distance — the per-joint numbers make it
# obvious which joint's trajectory diverges most.
#
# Raw coordinates (representation: coords) are used rather than
# angles because joint-level debugging is most interpretable in the
# native measurement space. Procrustes alignment is on so
# translation and rotation between trials do not inflate the numbers.
#
# Run:
# neuropose analyze --config examples/analysis/per_joint_debug.yaml \
# --output out/per_joint_report.json
config_version: 1
inputs:
primary: data/trial_a.json
reference: data/trial_b.json
segmentation:
kind: gait_cycles
joint: rhee
axis: y
min_cycle_seconds: 0.4
analysis:
kind: dtw
method: dtw_per_joint
align: procrustes_per_sequence
representation: coords
nan_policy: propagate
output:
report: out/per_joint_report.json

View File

@ -95,9 +95,12 @@ nav:
- neuropose.interfacer: api/interfacer.md
- neuropose.ingest: api/ingest.md
- neuropose.monitor: api/monitor.md
- neuropose.reset: api/reset.md
- neuropose.io: api/io.md
- neuropose.migrations: api/migrations.md
- neuropose.benchmark: api/benchmark.md
- neuropose.analyzer.segment: api/segment.md
- neuropose.analyzer.pipeline: api/pipeline.md
- neuropose.visualize: api/visualize.md
- Development: development.md
- Deployment: deployment.md

View File

@ -41,11 +41,34 @@ import os
import shutil
import tarfile
import urllib.request
from dataclasses import dataclass
from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class LoadedModel:
"""Result of :func:`load_metrabs_model`.
Bundles the loaded TensorFlow model with the provenance metadata
that identifies which artifact it came from. Callers that only want
the model reach for :attr:`model`; callers that build a
:class:`~neuropose.io.Provenance` (primarily
:class:`~neuropose.estimator.Estimator`) pull :attr:`sha256` and
:attr:`filename` too.
Frozen once :func:`load_metrabs_model` has produced a
``LoadedModel``, nothing downstream should edit the identity of
the artifact it describes.
"""
model: Any
sha256: str
filename: str
# ---------------------------------------------------------------------------
# Model artifact: pinned URL and checksum.
# ---------------------------------------------------------------------------
@ -74,7 +97,7 @@ _REQUIRED_MODEL_ATTRS = (
# ---------------------------------------------------------------------------
def load_metrabs_model(cache_dir: Path | None = None) -> Any:
def load_metrabs_model(cache_dir: Path | None = None) -> LoadedModel:
"""Load the MeTRAbs model, downloading and caching on first use.
Parameters
@ -87,9 +110,11 @@ def load_metrabs_model(cache_dir: Path | None = None) -> Any:
Returns
-------
object
A TensorFlow SavedModel handle exposing ``detect_poses`` and
the ``per_skeleton_joint_names`` / ``per_skeleton_joint_edges``
LoadedModel
Bundle containing the TensorFlow SavedModel handle alongside
the pinned artifact SHA-256 and filename that identify which
model the handle came from. The handle exposes ``detect_poses``
and the ``per_skeleton_joint_names`` / ``per_skeleton_joint_edges``
attributes used by :class:`neuropose.estimator.Estimator`.
Raises
@ -99,6 +124,18 @@ def load_metrabs_model(cache_dir: Path | None = None) -> Any:
automatic retry), extraction fails, TensorFlow is not
installed, or the loaded model does not expose the expected
interface.
Notes
-----
The returned ``sha256`` is the module-pinned :data:`_MODEL_SHA256`,
not a re-hash of the on-disk tarball. On the cold-cache path this
is exactly the hash we verified against before loading. On the
warm-cache path the tarball is not re-verified (that would cost a
2 GB I/O pass on every daemon startup), so the reported SHA is an
attestation of "this is the pinned artifact NeuroPose loads" rather
than a direct fingerprint of the on-disk bytes. For the threat
model this supports reproducibility, not tamper-evidence that
is the correct semantics.
"""
resolved_cache = Path(cache_dir) if cache_dir is not None else _default_cache_dir()
resolved_cache.mkdir(parents=True, exist_ok=True)
@ -115,7 +152,11 @@ def load_metrabs_model(cache_dir: Path | None = None) -> Any:
)
shutil.rmtree(model_dir, ignore_errors=True)
else:
return _tf_load(saved_model_dir)
return LoadedModel(
model=_tf_load(saved_model_dir),
sha256=_MODEL_SHA256,
filename=_MODEL_ARCHIVE_NAME,
)
tarball = resolved_cache / _MODEL_ARCHIVE_NAME
@ -135,7 +176,11 @@ def load_metrabs_model(cache_dir: Path | None = None) -> Any:
_extract_tarball(tarball, model_dir)
saved_model_dir = _find_saved_model(model_dir)
return _tf_load(saved_model_dir)
return LoadedModel(
model=_tf_load(saved_model_dir),
sha256=_MODEL_SHA256,
filename=_MODEL_ARCHIVE_NAME,
)
# ---------------------------------------------------------------------------

View File

@ -28,23 +28,56 @@ here for ergonomic access.
from __future__ import annotations
from neuropose.analyzer.dtw import (
AlignMode,
DTWResult,
NanPolicy,
Representation,
dtw_all,
dtw_per_joint,
dtw_relation,
)
from neuropose.analyzer.features import (
AlignmentDiagnostics,
FeatureStatistics,
ProcrustesMode,
extract_feature_statistics,
extract_joint_angles,
find_peaks,
normalize_pose_sequence,
pad_sequences,
predictions_to_numpy,
procrustes_align,
)
from neuropose.analyzer.pipeline import (
AnalysisConfig,
AnalysisReport,
AnalysisResults,
AnalysisStage,
DtwAnalysis,
DtwResults,
ExtractorSegmentation,
FeatureSummary,
GaitCyclesBilateralSegmentation,
GaitCyclesSegmentation,
InputsConfig,
InputSummary,
NoAnalysis,
NoResults,
OutputConfig,
PreprocessingConfig,
SegmentationStage,
StatsAnalysis,
StatsResults,
analysis_config_to_dict,
load_config,
load_report,
run_analysis,
save_report,
)
from neuropose.analyzer.segment import (
JOINT_INDEX,
JOINT_NAMES,
AxisLetter,
extract_signal,
joint_angle,
joint_axis,
@ -52,6 +85,8 @@ from neuropose.analyzer.segment import (
joint_pair_distance,
joint_speed,
segment_by_peaks,
segment_gait_cycles,
segment_gait_cycles_bilateral,
segment_predictions,
slice_predictions,
)
@ -59,8 +94,34 @@ from neuropose.analyzer.segment import (
__all__ = [
"JOINT_INDEX",
"JOINT_NAMES",
"AlignMode",
"AlignmentDiagnostics",
"AnalysisConfig",
"AnalysisReport",
"AnalysisResults",
"AnalysisStage",
"AxisLetter",
"DTWResult",
"DtwAnalysis",
"DtwResults",
"ExtractorSegmentation",
"FeatureStatistics",
"FeatureSummary",
"GaitCyclesBilateralSegmentation",
"GaitCyclesSegmentation",
"InputSummary",
"InputsConfig",
"NanPolicy",
"NoAnalysis",
"NoResults",
"OutputConfig",
"PreprocessingConfig",
"ProcrustesMode",
"Representation",
"SegmentationStage",
"StatsAnalysis",
"StatsResults",
"analysis_config_to_dict",
"dtw_all",
"dtw_per_joint",
"dtw_relation",
@ -73,10 +134,17 @@ __all__ = [
"joint_index",
"joint_pair_distance",
"joint_speed",
"load_config",
"load_report",
"normalize_pose_sequence",
"pad_sequences",
"predictions_to_numpy",
"procrustes_align",
"run_analysis",
"save_report",
"segment_by_peaks",
"segment_gait_cycles",
"segment_gait_cycles_bilateral",
"segment_predictions",
"slice_predictions",
]

View File

@ -2,10 +2,12 @@
Three entry points, ordered by increasing precision (and increasing cost):
- :func:`dtw_all` DTW on the flattened per-frame joint vector. Fast but
coarse; collapses every joint axis into a single per-frame vector.
- :func:`dtw_per_joint` DTW on each joint independently. Preserves
per-joint temporal alignment at the cost of one DTW call per joint.
- :func:`dtw_all` DTW on the flattened per-frame feature vector. Fast
but coarse; collapses every joint axis (or every angle triplet) into
a single per-frame vector.
- :func:`dtw_per_joint` DTW on each joint (or angle triplet)
independently. Preserves per-unit temporal alignment at the cost of
one DTW call per unit.
- :func:`dtw_relation` DTW on the displacement vector between two
specific joints. This is the right tool when the research question is
about the *relative* motion of a specific pair of joints (e.g. the
@ -16,6 +18,24 @@ and the warping path. Inputs are expected to be ``(frames, joints, 3)``
numpy arrays the shape :func:`~neuropose.analyzer.features.predictions_to_numpy`
produces.
Three orthogonal preprocessing knobs are available on the entry points:
- **``align``** routes the inputs through
:func:`~neuropose.analyzer.features.procrustes_align` before DTW runs,
yielding translation- and rotation-invariant distances.
``align="none"`` (the default) preserves the raw-coordinate behaviour
shipped in 0.1.
- **``representation``** (on :func:`dtw_all` and :func:`dtw_per_joint`)
selects what each frame is reduced to before DTW. ``"coords"`` uses
the raw joint coordinates; ``"angles"`` replaces them with joint
angles computed at caller-supplied triplets via
:func:`~neuropose.analyzer.features.extract_joint_angles`, giving
DTW distances that are directly interpretable as clinical joint-range
comparisons.
- **``nan_policy``** decides how the DTW path handles non-finite values
in its input typically a concern only for the angle representation,
where degenerate (zero-length) vectors produce NaN. See :data:`NanPolicy`.
Dependency note
---------------
This module requires :mod:`fastdtw` and :mod:`scipy`, which are part of
@ -28,11 +48,58 @@ called.
from __future__ import annotations
from collections.abc import Callable
from collections.abc import Callable, Sequence
from dataclasses import dataclass
from typing import Literal
import numpy as np
from neuropose.analyzer.features import extract_joint_angles, procrustes_align
AlignMode = Literal["none", "procrustes_per_frame", "procrustes_per_sequence"]
"""Alignment selector for DTW entry points.
- ``"none"`` feed raw coordinates directly to DTW.
- ``"procrustes_per_frame"`` per-frame Kabsch alignment before DTW.
- ``"procrustes_per_sequence"`` single sequence-wide Kabsch
alignment before DTW.
"""
Representation = Literal["coords", "angles"]
"""Per-frame feature representation for :func:`dtw_all` and :func:`dtw_per_joint`.
- ``"coords"`` use the raw joint coordinates (the input's last two
axes). Preserves the 0.1 behaviour.
- ``"angles"`` replace joints with joint angles at caller-supplied
triplets. Translation- and rotation-invariant by construction,
scale-invariant modulo the upstream normalization, and directly
interpretable in clinical terms ("knee flexion during swing phase").
The ``angle_triplets`` keyword becomes mandatory in this mode.
"""
NanPolicy = Literal["propagate", "interpolate", "drop"]
"""Per-feature NaN handling for the DTW input.
NaN typically appears when ``representation="angles"`` encounters a
degenerate (zero-length) vector the angle is undefined and
:func:`extract_joint_angles` propagates NaN rather than quietly returning
a stand-in value.
- ``"propagate"`` (default) pass NaN straight through to the DTW
engine. fastdtw validates its input via
:func:`numpy.asarray_chkfinite` and raises :class:`ValueError`
the moment a NaN appears, which is the safest default because it
makes the problem visible instead of quietly corrupting a
distance.
- ``"interpolate"`` linearly interpolate NaN frames along each
feature column using neighbouring finite values. Reasonable when a
small number of frames are corrupted and the surrounding motion is
smooth; inappropriate when long stretches are missing.
- ``"drop"`` remove any frame where *any* feature is NaN before DTW
runs. Simple, but compresses the time axis, so warping-path indices
refer to the *compacted* sequence rather than the original.
"""
@dataclass(frozen=True)
class DTWResult:
@ -77,20 +144,44 @@ def _require_fastdtw() -> tuple[Callable, Callable]:
return fastdtw, euclidean
def dtw_all(a: np.ndarray, b: np.ndarray) -> DTWResult:
"""DTW on the flattened per-frame joint vector.
def dtw_all(
a: np.ndarray,
b: np.ndarray,
*,
align: AlignMode = "none",
representation: Representation = "coords",
angle_triplets: Sequence[tuple[int, int, int]] | None = None,
nan_policy: NanPolicy = "propagate",
) -> DTWResult:
"""DTW on the flattened per-frame feature vector.
Each frame's joints are collapsed into a single vector before DTW
is applied. This is fast one DTW call regardless of the joint
count but loses per-joint temporal structure, so a small
Under the default ``representation="coords"`` each frame's joints
are collapsed into a single vector before DTW is applied fast
(one DTW call regardless of joint count) but coarse, since a small
timing mismatch on one joint can dominate the distance metric.
Switching to ``representation="angles"`` computes joint angles at
the supplied triplets first and flattens those instead.
Parameters
----------
a, b
Pose sequences as ``(frames, joints, 3)`` numpy arrays. The two
sequences do not need to have the same number of frames, but
they must have the same number of joints.
they must have the same number of joints. When ``align`` is not
``"none"``, the two sequences must additionally share a frame
count (Procrustes requires a 1:1 correspondence).
align
Procrustes alignment mode applied before DTW. See
:data:`AlignMode`.
representation
Per-frame feature representation. See :data:`Representation`.
angle_triplets
Required when ``representation="angles"``. Sequence of
``(a, b, c)`` joint-index triplets passed through to
:func:`~neuropose.analyzer.features.extract_joint_angles`.
Ignored otherwise.
nan_policy
How to handle NaN values in the DTW input. See :data:`NanPolicy`.
Returns
-------
@ -101,48 +192,102 @@ def dtw_all(a: np.ndarray, b: np.ndarray) -> DTWResult:
Raises
------
ValueError
If ``a`` and ``b`` do not have the same joint count.
If ``a`` and ``b`` do not have the same joint count, if
``align`` requires a matching frame count that is not present,
if ``representation="angles"`` is requested without
``angle_triplets``, or if ``nan_policy="interpolate"``
encounters an all-NaN column.
"""
_validate_same_joint_count(a, b)
a, b = _maybe_align(a, b, align=align)
feat_a = _apply_representation(a, representation, angle_triplets=angle_triplets)
feat_b = _apply_representation(b, representation, angle_triplets=angle_triplets)
feat_a = _apply_nan_policy(feat_a, nan_policy)
feat_b = _apply_nan_policy(feat_b, nan_policy)
fastdtw, euclidean = _require_fastdtw()
a_flat = a.reshape(a.shape[0], -1)
b_flat = b.reshape(b.shape[0], -1)
distance, path = fastdtw(a_flat, b_flat, dist=euclidean)
distance, path = fastdtw(feat_a, feat_b, dist=euclidean)
return DTWResult(distance=float(distance), path=[tuple(p) for p in path])
def dtw_per_joint(a: np.ndarray, b: np.ndarray) -> list[DTWResult]:
"""DTW on each joint independently.
def dtw_per_joint(
a: np.ndarray,
b: np.ndarray,
*,
align: AlignMode = "none",
representation: Representation = "coords",
angle_triplets: Sequence[tuple[int, int, int]] | None = None,
nan_policy: NanPolicy = "propagate",
) -> list[DTWResult]:
"""DTW on each joint (or angle triplet) independently.
Performs one DTW computation per joint, yielding a list of
:class:`DTWResult` objects in joint-index order. More precise than
:func:`dtw_all` because each joint's temporal alignment is optimised
separately, at the cost of J times more DTW calls for J joints.
Performs one DTW computation per unit, yielding a list of
:class:`DTWResult` objects in input order. More precise than
:func:`dtw_all` because each unit's temporal alignment is optimised
separately, at the cost of J times more DTW calls for J units.
Under the default ``representation="coords"`` a "unit" is one of
the input's joints (xyz treated jointly). Under
``representation="angles"`` a "unit" is one scalar angle column
computed from one ``angle_triplets`` entry.
Parameters
----------
a, b
Pose sequences as ``(frames, joints, 3)`` numpy arrays. The two
sequences do not need to have the same number of frames but
must have the same number of joints.
must have the same number of joints. When ``align`` is not
``"none"``, they must additionally share a frame count.
align
Procrustes alignment mode applied before DTW. See
:data:`AlignMode`.
representation
Per-frame feature representation. See :data:`Representation`.
angle_triplets
Required when ``representation="angles"``; see
:func:`dtw_all` for details.
nan_policy
How to handle NaN values in the DTW input. See :data:`NanPolicy`.
Returns
-------
list[DTWResult]
One DTW result per joint, in index order.
One DTW result per joint or per angle triplet, in input order.
Raises
------
ValueError
If ``a`` and ``b`` do not have the same joint count.
Same conditions as :func:`dtw_all`.
"""
_validate_same_joint_count(a, b)
a, b = _maybe_align(a, b, align=align)
if representation == "coords":
feat_a = a
feat_b = b
# (frames, joints, 3) — one DTW per joint over its (frames, 3) slice.
num_units = feat_a.shape[1]
slicers: list[Callable[[np.ndarray], np.ndarray]] = [
(lambda arr, idx=i: arr[:, idx, :]) for i in range(num_units)
]
else: # "angles"
if angle_triplets is None:
raise ValueError("representation='angles' requires angle_triplets")
feat_a = extract_joint_angles(a, angle_triplets) # (frames, num_triplets)
feat_b = extract_joint_angles(b, angle_triplets)
num_units = feat_a.shape[1]
slicers = [
# Scalar columns become 2D for DTW (fastdtw expects a
# sequence of vectors, not a sequence of scalars).
(lambda arr, idx=i: arr[:, idx : idx + 1])
for i in range(num_units)
]
fastdtw, euclidean = _require_fastdtw()
results: list[DTWResult] = []
for joint_idx in range(a.shape[1]):
a_joint = a[:, joint_idx, :]
b_joint = b[:, joint_idx, :]
distance, path = fastdtw(a_joint, b_joint, dist=euclidean)
for slicer in slicers:
unit_a = _apply_nan_policy(slicer(feat_a), nan_policy)
unit_b = _apply_nan_policy(slicer(feat_b), nan_policy)
distance, path = fastdtw(unit_a, unit_b, dist=euclidean)
results.append(DTWResult(distance=float(distance), path=[tuple(p) for p in path]))
return results
@ -152,6 +297,9 @@ def dtw_relation(
b: np.ndarray,
joint_i: int,
joint_j: int,
*,
align: AlignMode = "none",
nan_policy: NanPolicy = "propagate",
) -> DTWResult:
"""DTW on the displacement vector between two specific joints.
@ -170,6 +318,14 @@ def dtw_relation(
Indices of the two joints whose relative position should be
compared. Must be valid indices into ``a`` and ``b``'s joint
axis.
align
Procrustes alignment mode applied to the full sequences
before the displacement vectors are extracted. See
:data:`AlignMode`. Note that displacement vectors are already
translation-invariant; alignment is still useful for cancelling
camera rotation between trials.
nan_policy
How to handle NaN values in the DTW input. See :data:`NanPolicy`.
Returns
-------
@ -179,8 +335,9 @@ def dtw_relation(
Raises
------
ValueError
If the sequences have different joint counts or if either joint
index is out of range.
If the sequences have different joint counts, either joint
index is out of range, or ``align`` requires a matching frame
count that is not present.
"""
_validate_same_joint_count(a, b)
num_joints = a.shape[1]
@ -188,9 +345,10 @@ def dtw_relation(
raise ValueError(
f"joint indices must be in [0, {num_joints}); got joint_i={joint_i}, joint_j={joint_j}"
)
a, b = _maybe_align(a, b, align=align)
disp_a = _apply_nan_policy(a[:, joint_j, :] - a[:, joint_i, :], nan_policy)
disp_b = _apply_nan_policy(b[:, joint_j, :] - b[:, joint_i, :], nan_policy)
fastdtw, euclidean = _require_fastdtw()
disp_a = a[:, joint_j, :] - a[:, joint_i, :]
disp_b = b[:, joint_j, :] - b[:, joint_i, :]
distance, path = fastdtw(disp_a, disp_b, dist=euclidean)
return DTWResult(distance=float(distance), path=[tuple(p) for p in path])
@ -206,3 +364,95 @@ def _validate_same_joint_count(a: np.ndarray, b: np.ndarray) -> None:
f"input arrays disagree on joint count: "
f"a has {a.shape[1]} joints, b has {b.shape[1]} joints"
)
def _maybe_align(
a: np.ndarray,
b: np.ndarray,
*,
align: AlignMode,
) -> tuple[np.ndarray, np.ndarray]:
"""Apply Procrustes alignment if ``align`` requests it.
Procrustes requires a frame-by-frame correspondence, so this
helper rejects calls where the two sequences disagree on frame
count and ``align`` is not ``"none"``. Pad upstream with
:func:`~neuropose.analyzer.features.pad_sequences` if the lengths
differ.
"""
if align == "none":
return a, b
if a.shape[0] != b.shape[0]:
raise ValueError(
f"align={align!r} requires matching frame counts; "
f"got a with {a.shape[0]} frames and b with {b.shape[0]} frames"
)
mode = "per_frame" if align == "procrustes_per_frame" else "per_sequence"
aligned_a, _target, _diag = procrustes_align(a, b, mode=mode)
return aligned_a, b
def _apply_representation(
sequence: np.ndarray,
representation: Representation,
*,
angle_triplets: Sequence[tuple[int, int, int]] | None,
) -> np.ndarray:
"""Reduce a ``(frames, joints, 3)`` sequence to DTW-ready 2D features.
``"coords"`` reshapes to ``(frames, joints * 3)``; ``"angles"``
runs :func:`extract_joint_angles` to produce
``(frames, len(angle_triplets))``.
"""
if representation == "coords":
return sequence.reshape(sequence.shape[0], -1)
if representation == "angles":
if angle_triplets is None:
raise ValueError("representation='angles' requires angle_triplets")
return extract_joint_angles(sequence, angle_triplets)
raise ValueError(f"unknown representation {representation!r}")
def _apply_nan_policy(features: np.ndarray, policy: NanPolicy) -> np.ndarray:
"""Handle NaN values in a ``(frames, features)`` array per ``policy``.
``"propagate"`` is a no-op. ``"interpolate"`` runs 1D linear
interpolation along the frame axis within each feature column,
leaving finite data untouched. ``"drop"`` removes any frame where
*any* feature is NaN.
Raises
------
ValueError
If ``"interpolate"`` encounters a column that is entirely NaN
(no finite anchors to interpolate between), or if ``"drop"``
leaves an empty sequence.
"""
if policy == "propagate":
return features
if features.ndim == 1:
features = features.reshape(-1, 1)
if policy == "drop":
keep = np.isfinite(features).all(axis=1)
dropped = features[keep]
if dropped.shape[0] == 0:
raise ValueError(
"nan_policy='drop' removed every frame; DTW needs a non-empty sequence"
)
return dropped
if policy == "interpolate":
out = features.astype(float, copy=True)
num_frames = out.shape[0]
indices = np.arange(num_frames, dtype=float)
for col in range(out.shape[1]):
column = out[:, col]
finite = np.isfinite(column)
if finite.all():
continue
if not finite.any():
raise ValueError(
f"nan_policy='interpolate' cannot fill column {col}: all values are NaN"
)
out[:, col] = np.interp(indices, indices[finite], column[finite])
return out
raise ValueError(f"unknown nan_policy {policy!r}")

View File

@ -14,6 +14,8 @@ The following helpers are provided:
fit in the unit cube (either per-axis or uniform).
- :func:`pad_sequences` edge-pad a batch of sequences to a common
length, suitable for downstream tensor-based analysis.
- :func:`procrustes_align` rigid-align one pose sequence to another
via the Kabsch algorithm, with optional uniform scaling.
- :func:`extract_joint_angles` compute joint angles at specified
triplet positions across a pose sequence.
- :func:`extract_feature_statistics` summary statistics
@ -26,12 +28,19 @@ from __future__ import annotations
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any
from typing import Any, Literal
import numpy as np
from neuropose.io import VideoPredictions
ProcrustesMode = Literal["per_frame", "per_sequence"]
"""Mode selector for :func:`procrustes_align`.
``per_sequence`` computes a single rigid transform over the whole
sequence; ``per_frame`` aligns every frame independently.
"""
# ---------------------------------------------------------------------------
# VideoPredictions → numpy
# ---------------------------------------------------------------------------
@ -211,6 +220,247 @@ def pad_sequences(
return padded
# ---------------------------------------------------------------------------
# Procrustes alignment (Kabsch)
# ---------------------------------------------------------------------------
@dataclass(frozen=True)
class AlignmentDiagnostics:
"""Summary of the rigid transform fitted by :func:`procrustes_align`.
Attributes
----------
mode
Which alignment mode produced this result; mirrors the ``mode``
argument passed to :func:`procrustes_align`.
rotation_deg
Magnitude of the fitted rotation, in degrees, computed as
``arccos((trace(R) - 1) / 2)``. For ``per_frame`` mode this is
the mean magnitude across frames.
rotation_deg_max
Worst-case (maximum) rotation magnitude across frames.
Equal to :attr:`rotation_deg` in ``per_sequence`` mode.
translation
Magnitude of the fitted translation vector, in the same units
as the input (millimetres for MeTRAbs output). For ``per_frame``
mode this is the mean magnitude across frames.
translation_max
Worst-case (maximum) translation magnitude across frames.
Equal to :attr:`translation` in ``per_sequence`` mode.
scale
Applied uniform scale factor. Always ``1.0`` when
``procrustes_align`` was called with ``scale=False``. In
``per_frame`` mode this is the mean scale across frames.
"""
mode: ProcrustesMode
rotation_deg: float
rotation_deg_max: float
translation: float
translation_max: float
scale: float
def _kabsch_single(
source: np.ndarray,
target: np.ndarray,
*,
scale: bool,
) -> tuple[np.ndarray, np.ndarray, float, np.ndarray]:
"""Fit the optimal rigid (+ optional uniform scale) transform.
Aligns ``source`` to ``target`` via the closed-form Kabsch
algorithm and returns ``(aligned_source, R, s, t)`` where
``aligned_source = s * (source - centroid_source) @ R.T + centroid_target + t_fine``
(with ``t_fine`` absorbed for convenience aligned points match
the target's centroid to within floating-point error).
Parameters
----------
source
``(N, 3)`` point set to align.
target
``(N, 3)`` reference point set. Must have the same shape as
``source``.
scale
If ``True``, fit a uniform scale factor; otherwise lock to
``1.0``.
Returns
-------
aligned_source
``(N, 3)`` aligned copy of ``source``.
R
``(3, 3)`` rotation matrix.
s
Scalar scale factor (``1.0`` when ``scale=False``).
t
``(3,)`` translation vector in world coordinates such that
``aligned_source[i] = s * R @ source[i] + t``.
"""
centroid_source = source.mean(axis=0)
centroid_target = target.mean(axis=0)
source_centered = source - centroid_source
target_centered = target - centroid_target
covariance = source_centered.T @ target_centered
u_mat, sigma, vt_mat = np.linalg.svd(covariance)
reflection_sign = float(np.sign(np.linalg.det(vt_mat.T @ u_mat.T)))
# Guard against the degenerate det == 0 case (coplanar points).
if reflection_sign == 0.0:
reflection_sign = 1.0
diag = np.diag([1.0, 1.0, reflection_sign])
rotation = vt_mat.T @ diag @ u_mat.T
if scale:
source_var = float((source_centered**2).sum())
if source_var <= 0.0:
scale_factor = 1.0
else:
scale_factor = float((sigma * np.array([1.0, 1.0, reflection_sign])).sum() / source_var)
else:
scale_factor = 1.0
translation = centroid_target - scale_factor * rotation @ centroid_source
aligned = scale_factor * source @ rotation.T + translation
return aligned, rotation, scale_factor, translation
def _rotation_magnitude_deg(rotation: np.ndarray) -> float:
"""Return the rotation angle (degrees) represented by ``rotation``.
Uses the axis-angle relation ``cos(theta) = (trace(R) - 1) / 2``.
"""
cos_theta = (float(np.trace(rotation)) - 1.0) / 2.0
cos_theta = max(-1.0, min(1.0, cos_theta))
return float(np.degrees(np.arccos(cos_theta)))
def procrustes_align(
source: np.ndarray,
target: np.ndarray,
*,
mode: ProcrustesMode = "per_sequence",
scale: bool = False,
) -> tuple[np.ndarray, np.ndarray, AlignmentDiagnostics]:
"""Rigid-align ``source`` to ``target`` via the Kabsch algorithm.
Fits the optimal rigid transform (optionally including uniform
scaling) that minimizes the sum of squared distances between
corresponding joints. The transform is always applied to
``source``; ``target`` is returned unchanged alongside it for
symmetry with downstream DTW callers, which typically consume both
aligned arrays as a pair.
Parameters
----------
source
Pose sequence to align, shape ``(frames, joints, 3)``.
target
Reference pose sequence, shape ``(frames, joints, 3)``. For
``per_frame`` mode the frame counts must match; for
``per_sequence`` mode they must also match (the correspondence
runs frame-by-frame and joint-by-joint). Use
:func:`pad_sequences` first if your sequences have different
lengths.
mode
``"per_sequence"`` (default) fits a single rigid transform over
the whole sequence good when the recording geometry is
stable across frames. ``"per_frame"`` fits an independent
transform per frame good for matching pose shape while
discarding global trajectory.
scale
If ``True``, also fit a uniform scale factor. Useful for
cross-subject comparisons where the reference skeleton has a
different overall size.
Returns
-------
aligned_source
``source`` transformed to align with ``target``, same shape as
the input.
target
The ``target`` array, unchanged.
diagnostics
:class:`AlignmentDiagnostics` summarising the fitted transform.
Raises
------
ValueError
If ``source`` and ``target`` have different shapes or the
trailing axis is not of size 3.
Notes
-----
The Kabsch algorithm (Kabsch 1976, "A solution for the best
rotation to relate two sets of vectors") is a closed-form SVD
solution and does not iterate. Reflection is explicitly prevented
via a sign correction on the smallest singular value; the fitted
matrix is always a proper rotation (det = +1).
In ``per_frame`` mode, rotation, translation, and scale
diagnostics are reported as means across frames, with
:attr:`AlignmentDiagnostics.rotation_deg_max` and
:attr:`AlignmentDiagnostics.translation_max` exposing the worst
frame for anomaly detection.
"""
if source.ndim != 3 or source.shape[-1] != 3:
raise ValueError(f"expected (frames, joints, 3); got source shape {source.shape}")
if source.shape != target.shape:
raise ValueError(
f"source and target must have the same shape; got {source.shape} and {target.shape}"
)
source = source.astype(float, copy=False)
target = target.astype(float, copy=False)
num_frames = source.shape[0]
if mode == "per_sequence":
flat_source = source.reshape(-1, 3)
flat_target = target.reshape(-1, 3)
aligned_flat, rotation, scale_factor, translation = _kabsch_single(
flat_source, flat_target, scale=scale
)
aligned = aligned_flat.reshape(source.shape)
rotation_deg = _rotation_magnitude_deg(rotation)
translation_mag = float(np.linalg.norm(translation))
diagnostics = AlignmentDiagnostics(
mode="per_sequence",
rotation_deg=rotation_deg,
rotation_deg_max=rotation_deg,
translation=translation_mag,
translation_max=translation_mag,
scale=scale_factor,
)
return aligned, target, diagnostics
if mode == "per_frame":
aligned = np.empty_like(source)
rotation_degs = np.empty(num_frames, dtype=float)
translations = np.empty(num_frames, dtype=float)
scales = np.empty(num_frames, dtype=float)
for frame_idx in range(num_frames):
aligned_frame, rotation, scale_factor, translation = _kabsch_single(
source[frame_idx], target[frame_idx], scale=scale
)
aligned[frame_idx] = aligned_frame
rotation_degs[frame_idx] = _rotation_magnitude_deg(rotation)
translations[frame_idx] = float(np.linalg.norm(translation))
scales[frame_idx] = scale_factor
diagnostics = AlignmentDiagnostics(
mode="per_frame",
rotation_deg=float(rotation_degs.mean()) if num_frames else 0.0,
rotation_deg_max=float(rotation_degs.max()) if num_frames else 0.0,
translation=float(translations.mean()) if num_frames else 0.0,
translation_max=float(translations.max()) if num_frames else 0.0,
scale=float(scales.mean()) if num_frames else 1.0,
)
return aligned, target, diagnostics
raise ValueError(f"unknown mode {mode!r}; expected 'per_frame' or 'per_sequence'")
# ---------------------------------------------------------------------------
# Joint angles
# ---------------------------------------------------------------------------

View File

@ -0,0 +1,932 @@
"""YAML-configurable analysis pipeline.
This module unifies the analyzer's individual primitives (Procrustes
alignment, gait-cycle segmentation, DTW on coords or angles, feature
statistics) behind a single declarative configuration object so an
experiment can be reproduced from one file that lives in a git
repository, carries through to the :class:`~neuropose.io.Provenance`
envelope on the output artifact, and can be cited unambiguously in
accompanying papers.
Two top-level schemas live here:
- :class:`AnalysisConfig` what the user writes in YAML. Describes
the full pipeline: inputs, preprocessing, optional segmentation,
required analysis stage, and output path.
- :class:`AnalysisReport` what :func:`run_analysis` emits. Carries
the config, a :class:`~neuropose.io.Provenance` envelope (with the
config serialised into :attr:`~neuropose.io.Provenance.analysis_config`),
per-input summaries, segmentation results, and the analysis results
themselves.
Both schemas parse from (and serialise to) JSON via pydantic; the
config additionally parses from YAML via :func:`load_config`. Cross-field
invariants (for example, ``method="dtw_relation"`` requires ``joint_i``
and ``joint_j``) are enforced at parse time so typo-laden configs fail
fast rather than after an expensive multi-minute load.
Execution
---------
:func:`run_analysis` is the top-level executor: it loads the
predictions files named in the config, applies any configured
segmentation stage, dispatches to the configured analysis stage, and
returns a fully populated :class:`AnalysisReport`. The executor is
intended to be called from the ``neuropose analyze`` CLI but is
equally valid as a Python-level entry point for notebook-driven
exploration.
"""
from __future__ import annotations
import json
from pathlib import Path
from typing import Annotated, Any, Literal
import numpy as np
import yaml
from pydantic import BaseModel, ConfigDict, Field, model_validator
from neuropose.analyzer.dtw import (
AlignMode,
DTWResult,
NanPolicy,
Representation,
dtw_all,
dtw_per_joint,
dtw_relation,
)
from neuropose.analyzer.features import (
extract_feature_statistics,
predictions_to_numpy,
)
from neuropose.analyzer.segment import (
AxisLetter,
extract_signal,
segment_gait_cycles,
segment_gait_cycles_bilateral,
segment_predictions,
)
from neuropose.io import (
ExtractorSpec,
Provenance,
Segmentation,
VideoPredictions,
load_video_predictions,
)
from neuropose.migrations import CURRENT_VERSION, migrate_analysis_report
# ---------------------------------------------------------------------------
# Inputs
# ---------------------------------------------------------------------------
class InputsConfig(BaseModel):
"""Predictions files consumed by the pipeline.
Attributes
----------
primary
Path to a :class:`~neuropose.io.VideoPredictions` JSON file.
Always required.
reference
Optional second predictions file. When provided,
:class:`DtwAnalysis` runs comparative DTW between primary and
reference; when absent, analysis stages that require a
reference (i.e. DTW) raise a validation error at parse time.
"""
model_config = ConfigDict(extra="forbid", frozen=True)
primary: Path
reference: Path | None = None
# ---------------------------------------------------------------------------
# Preprocessing
# ---------------------------------------------------------------------------
class PreprocessingConfig(BaseModel):
"""Per-input preprocessing applied before segmentation and analysis.
Minimal today just picks which detected person to extract from
each frame. Left as a named stage so future extensions (coordinate
normalisation, smoothing) can land here without reshuffling the
config shape.
Attributes
----------
person_index
Which detected person to extract per frame. Defaults to ``0``
(the first detected person), matching the single-subject
clinical case.
"""
model_config = ConfigDict(extra="forbid", frozen=True)
person_index: int = Field(default=0, ge=0)
# ---------------------------------------------------------------------------
# Segmentation stage
# ---------------------------------------------------------------------------
class GaitCyclesSegmentation(BaseModel):
"""Single-heel gait-cycle segmentation via peak detection.
Produces one :class:`~neuropose.io.Segmentation` keyed under the
joint name (e.g. ``"rhee_cycles"``). See
:func:`~neuropose.analyzer.segment.segment_gait_cycles` for the
underlying implementation.
"""
model_config = ConfigDict(extra="forbid", frozen=True)
kind: Literal["gait_cycles"]
joint: str = "rhee"
axis: AxisLetter = "y"
invert: bool = False
min_cycle_seconds: float = Field(default=0.4, gt=0.0)
min_prominence: float | None = None
class GaitCyclesBilateralSegmentation(BaseModel):
"""Bilateral (both heels) gait-cycle segmentation.
Produces two :class:`~neuropose.io.Segmentation` objects keyed as
``"left_heel_strikes"`` and ``"right_heel_strikes"``.
"""
model_config = ConfigDict(extra="forbid", frozen=True)
kind: Literal["gait_cycles_bilateral"]
axis: AxisLetter = "y"
invert: bool = False
min_cycle_seconds: float = Field(default=0.4, gt=0.0)
min_prominence: float | None = None
class ExtractorSegmentation(BaseModel):
"""Generic extractor-driven segmentation.
Wraps :func:`~neuropose.analyzer.segment.segment_predictions` with
a caller-supplied :class:`~neuropose.io.ExtractorSpec`. Use this
when the signal of interest is not the vertical heel trace e.g.
wrist-hip distance for a reach-and-grasp task, or elbow flexion
angle for a range-of-motion trial.
"""
model_config = ConfigDict(extra="forbid", frozen=True)
kind: Literal["extractor"]
extractor: ExtractorSpec
label: str = Field(
default="segmentation",
description="Key under which the resulting Segmentation is stored.",
)
person_index: int | None = Field(
default=None,
description=(
"Overrides preprocessing.person_index for this stage. "
"None defers to the global preprocessing setting."
),
)
min_distance_seconds: float | None = Field(default=None, ge=0.0)
min_prominence: float | None = None
min_height: float | None = None
pad_seconds: float = Field(default=0.0, ge=0.0)
SegmentationStage = Annotated[
GaitCyclesSegmentation | GaitCyclesBilateralSegmentation | ExtractorSegmentation,
Field(discriminator="kind"),
]
"""Discriminated-union alias for the three segmentation variants.
Pydantic dispatches on the ``kind`` field. A config without a
``segmentation`` key at all skips this stage entirely
(see :class:`AnalysisConfig.segmentation`).
"""
# ---------------------------------------------------------------------------
# Analysis stage
# ---------------------------------------------------------------------------
class DtwAnalysis(BaseModel):
"""Dynamic Time Warping between the primary and reference inputs.
Dispatches to one of :func:`~neuropose.analyzer.dtw.dtw_all`,
:func:`~neuropose.analyzer.dtw.dtw_per_joint`, or
:func:`~neuropose.analyzer.dtw.dtw_relation` per the ``method``
field. Cross-field invariants ``method="dtw_relation"`` requires
``joint_i`` and ``joint_j``, ``representation="angles"`` requires
a non-empty ``angle_triplets`` are enforced at parse time.
"""
model_config = ConfigDict(extra="forbid", frozen=True)
kind: Literal["dtw"]
method: Literal["dtw_all", "dtw_per_joint", "dtw_relation"] = "dtw_all"
align: AlignMode = "none"
representation: Representation = "coords"
angle_triplets: list[tuple[int, int, int]] | None = None
joint_i: int | None = None
joint_j: int | None = None
nan_policy: NanPolicy = "propagate"
@model_validator(mode="after")
def _check_method_fields(self) -> DtwAnalysis:
if self.method == "dtw_relation":
if self.joint_i is None or self.joint_j is None:
raise ValueError("method='dtw_relation' requires joint_i and joint_j")
if self.representation != "coords":
raise ValueError(
"method='dtw_relation' only supports representation='coords' "
"(a two-joint displacement is not a joint-angle signal)"
)
if self.representation == "angles":
if not self.angle_triplets:
raise ValueError("representation='angles' requires a non-empty angle_triplets list")
if self.method == "dtw_relation":
# Guarded by the earlier branch, but make the invariant explicit.
raise ValueError(
"representation='angles' is incompatible with method='dtw_relation'"
)
return self
class StatsAnalysis(BaseModel):
"""Summary statistics over a scalar signal extracted from the primary input.
Runs :func:`~neuropose.analyzer.segment.extract_signal` with the
caller-supplied :class:`~neuropose.io.ExtractorSpec`, then
computes :func:`~neuropose.analyzer.features.extract_feature_statistics`
on each segment (or on the full trial if no segmentation stage
runs).
"""
model_config = ConfigDict(extra="forbid", frozen=True)
kind: Literal["stats"]
extractor: ExtractorSpec
class NoAnalysis(BaseModel):
"""Terminal stage placeholder; produces no per-segment results.
Useful when the pipeline's goal is just to segment the input and
persist the :class:`~neuropose.io.Segmentation` plus an
:class:`AnalysisReport` with provenance the ``none`` analysis
kind makes that explicit rather than requiring the absence of the
stage.
"""
model_config = ConfigDict(extra="forbid", frozen=True)
kind: Literal["none"]
AnalysisStage = Annotated[
DtwAnalysis | StatsAnalysis | NoAnalysis,
Field(discriminator="kind"),
]
"""Discriminated-union alias for the three analysis variants.
Pydantic dispatches on ``kind``. One of the three must always be
present in a valid config.
"""
# ---------------------------------------------------------------------------
# Output
# ---------------------------------------------------------------------------
class OutputConfig(BaseModel):
"""Where :func:`run_analysis` should write its :class:`AnalysisReport`.
Kept as a sub-object rather than a bare path so downstream
extensions (figure paths, supplementary distance-matrix files)
can land here without changing the config's top-level shape.
"""
model_config = ConfigDict(extra="forbid", frozen=True)
report: Path
# ---------------------------------------------------------------------------
# Top-level config
# ---------------------------------------------------------------------------
class AnalysisConfig(BaseModel):
"""Declarative description of a full analysis run.
Parsed from YAML (via :func:`load_config`) or JSON (via
:meth:`pydantic.BaseModel.model_validate`). Every field is
cross-validated at parse time so a typo in a nested sub-field
fails in milliseconds rather than after a multi-minute
predictions load.
Attributes
----------
config_version
Schema version for the config itself. Only ``1`` is valid at
this release. Future config-format breaks bump this and a
sibling migration registry handles legacy YAML in place.
inputs
Predictions-file paths.
preprocessing
Per-input preprocessing (person-index selection today).
segmentation
Optional segmentation stage. ``None`` skips segmentation
entirely and analysis runs over each full trial as a single
"segment".
analysis
Required analysis stage. Exactly one of
:class:`DtwAnalysis` / :class:`StatsAnalysis` / :class:`NoAnalysis`.
output
Output paths.
"""
model_config = ConfigDict(extra="forbid", frozen=True)
config_version: Literal[1] = 1
inputs: InputsConfig
preprocessing: PreprocessingConfig = Field(default_factory=PreprocessingConfig)
segmentation: SegmentationStage | None = None
analysis: AnalysisStage
output: OutputConfig
@model_validator(mode="after")
def _check_cross_stage_invariants(self) -> AnalysisConfig:
# DTW is comparative — it needs a reference input.
if isinstance(self.analysis, DtwAnalysis) and self.inputs.reference is None:
raise ValueError("analysis.kind='dtw' requires inputs.reference to be set")
# Stats is non-comparative — a reference without a use is
# almost certainly an operator error.
if isinstance(self.analysis, StatsAnalysis) and self.inputs.reference is not None:
raise ValueError(
"analysis.kind='stats' operates on inputs.primary only; "
"remove inputs.reference or switch analysis.kind to 'dtw'"
)
return self
# ---------------------------------------------------------------------------
# Report pieces
# ---------------------------------------------------------------------------
class InputSummary(BaseModel):
"""Capsule of an input predictions file's headline metadata.
Stored in the :class:`AnalysisReport` so a reader of the report
can tell at a glance what was analysed without having to load the
underlying predictions JSONs.
"""
model_config = ConfigDict(extra="forbid", frozen=True)
path: Path
frame_count: int = Field(ge=0)
fps: float = Field(ge=0.0)
provenance: Provenance | None = None
class FeatureSummary(BaseModel):
"""Pydantic twin of :class:`~neuropose.analyzer.features.FeatureStatistics`.
The dataclass is used throughout the analyzer for ad-hoc Python
consumption; the report path needs a pydantic model for
round-tripping through JSON.
"""
model_config = ConfigDict(extra="forbid", frozen=True)
mean: float
std: float
min: float
max: float
range: float
class DtwResults(BaseModel):
"""DTW results attached to an :class:`AnalysisReport`.
``distances`` is parallel to ``segment_labels``. For an
unsegmented run the lists have length 1 and the label is
``"full_trial"``. For a segmented run each label takes the form
``"<segmentation_key>[<index>]"`` (e.g. ``"left_heel_strikes[3]"``).
``per_joint_distances`` carries a per-unit breakdown for
``method="dtw_per_joint"`` only; its outer length matches
``distances``, inner length matches either ``num_joints`` (coords)
or ``len(angle_triplets)`` (angles).
"""
model_config = ConfigDict(extra="forbid", frozen=True)
kind: Literal["dtw"]
method: Literal["dtw_all", "dtw_per_joint", "dtw_relation"]
distances: list[float]
paths: list[list[tuple[int, int]]]
per_joint_distances: list[list[float]] | None = None
segment_labels: list[str]
summary: dict[str, float]
class StatsResults(BaseModel):
"""Feature-statistics results attached to an :class:`AnalysisReport`.
``statistics`` is parallel to ``segment_labels``; see
:class:`DtwResults` for the labelling convention.
"""
model_config = ConfigDict(extra="forbid", frozen=True)
kind: Literal["stats"]
statistics: list[FeatureSummary]
segment_labels: list[str]
class NoResults(BaseModel):
"""Empty results payload for ``analysis.kind='none'`` runs."""
model_config = ConfigDict(extra="forbid", frozen=True)
kind: Literal["none"]
AnalysisResults = Annotated[
DtwResults | StatsResults | NoResults,
Field(discriminator="kind"),
]
"""Discriminated-union alias for the three analysis-result shapes.
Mirrors :data:`AnalysisStage` one-for-one: ``DtwAnalysis`` produces
:class:`DtwResults`, etc.
"""
# ---------------------------------------------------------------------------
# Top-level report
# ---------------------------------------------------------------------------
class AnalysisReport(BaseModel):
"""Self-describing output artifact of :func:`run_analysis`.
Serialised to JSON on disk. Carries the originating config, the
:class:`~neuropose.io.Provenance` envelope (with the config
serialised into :attr:`~neuropose.io.Provenance.analysis_config`
so the report is self-describing even if the YAML is lost), each
input's headline metadata plus its own provenance if available,
any segmentations produced, and the analysis results themselves.
Lives in the schema-migration registry under ``"AnalysisReport"``
at ``CURRENT_VERSION``; see :mod:`neuropose.migrations`.
"""
model_config = ConfigDict(extra="forbid", frozen=True)
schema_version: int = Field(default=CURRENT_VERSION, ge=1)
config: AnalysisConfig
provenance: Provenance | None = None
primary: InputSummary
reference: InputSummary | None = None
segmentations: dict[str, Segmentation] = Field(default_factory=dict)
results: AnalysisResults
def analysis_config_to_dict(config: AnalysisConfig) -> dict[str, Any]:
"""Serialise an :class:`AnalysisConfig` to a JSON-safe dict.
Returned shape is identical to what pydantic would produce via
:meth:`~pydantic.BaseModel.model_dump` in ``mode="json"`` paths
become strings, tuples become lists, enums become their values.
Useful for stamping
:attr:`~neuropose.io.Provenance.analysis_config` on the
:class:`AnalysisReport`'s provenance envelope.
"""
return config.model_dump(mode="json")
# ---------------------------------------------------------------------------
# Load / save
# ---------------------------------------------------------------------------
def load_config(path: Path) -> AnalysisConfig:
"""Load and validate an :class:`AnalysisConfig` from a YAML file.
Parameters
----------
path
Filesystem path to a YAML file conforming to the
:class:`AnalysisConfig` schema.
Returns
-------
AnalysisConfig
The fully validated config. Cross-field invariants have
already been checked.
Raises
------
pydantic.ValidationError
On any schema violation unknown keys, wrong types, or
failed cross-field invariants.
yaml.YAMLError
On malformed YAML.
"""
with path.open("r", encoding="utf-8") as f:
raw = yaml.safe_load(f)
if raw is None:
raw = {}
return AnalysisConfig.model_validate(raw)
def save_report(path: Path, report: AnalysisReport) -> None:
"""Serialise an :class:`AnalysisReport` to ``path`` atomically.
Writes to a sibling ``<path>.tmp`` first, then renames over
``path`` so a crash mid-write cannot leave behind a truncated
file. The parent directory is created if it does not exist.
"""
path.parent.mkdir(parents=True, exist_ok=True)
tmp = path.with_suffix(path.suffix + ".tmp")
payload = report.model_dump(mode="json")
with tmp.open("w", encoding="utf-8") as f:
json.dump(payload, f, indent=2)
tmp.replace(path)
def load_report(path: Path) -> AnalysisReport:
"""Load and validate an :class:`AnalysisReport` JSON file.
Runs the payload through :func:`~neuropose.migrations.migrate_analysis_report`
before pydantic validation so future schema bumps can upgrade
legacy reports transparently.
"""
with path.open("r", encoding="utf-8") as f:
data: Any = json.load(f)
if isinstance(data, dict):
data = migrate_analysis_report(data)
return AnalysisReport.model_validate(data)
# ---------------------------------------------------------------------------
# Executor
# ---------------------------------------------------------------------------
def run_analysis(config: AnalysisConfig) -> AnalysisReport:
"""Execute the pipeline described by ``config`` end-to-end.
Loads the predictions files named in ``config.inputs``, applies
the configured preprocessing + segmentation + analysis stages,
and returns an :class:`AnalysisReport` whose
:attr:`~AnalysisReport.provenance` inherits the inference-time
provenance of the primary input with
:attr:`~neuropose.io.Provenance.analysis_config` populated so the
report is self-describing even if the YAML config is later lost.
Parameters
----------
config
The pre-validated pipeline configuration.
Returns
-------
AnalysisReport
Fully populated report. Not yet written to disk the caller
passes it to :func:`save_report` (or inspects it directly).
Notes
-----
For DTW runs with a segmentation stage, segments are paired
one-to-one by index across primary and reference, truncating to
``min(len_primary, len_reference)``. Bilateral segmentations
produce distances for each side independently, labelled under
their segmentation key (e.g. ``"left_heel_strikes[3]"``).
"""
primary_preds = load_video_predictions(config.inputs.primary)
reference_preds: VideoPredictions | None = None
if config.inputs.reference is not None:
reference_preds = load_video_predictions(config.inputs.reference)
person_index = config.preprocessing.person_index
primary_seq = predictions_to_numpy(primary_preds, person_index=person_index)
reference_seq: np.ndarray | None = None
if reference_preds is not None:
reference_seq = predictions_to_numpy(reference_preds, person_index=person_index)
primary_segmentations: dict[str, Segmentation] = {}
reference_segmentations: dict[str, Segmentation] = {}
if config.segmentation is not None:
primary_segmentations = _run_segmentation(primary_preds, config.segmentation, person_index)
if reference_preds is not None:
reference_segmentations = _run_segmentation(
reference_preds, config.segmentation, person_index
)
results = _run_analysis_stage(
config.analysis,
primary_seq=primary_seq,
reference_seq=reference_seq,
primary_segmentations=primary_segmentations,
reference_segmentations=reference_segmentations,
)
analysis_config_dump = analysis_config_to_dict(config)
report_provenance: Provenance | None = None
if primary_preds.provenance is not None:
report_provenance = primary_preds.provenance.model_copy(
update={"analysis_config": analysis_config_dump}
)
primary_summary = InputSummary(
path=config.inputs.primary,
frame_count=primary_preds.metadata.frame_count,
fps=primary_preds.metadata.fps,
provenance=primary_preds.provenance,
)
reference_summary: InputSummary | None = None
if reference_preds is not None and config.inputs.reference is not None:
reference_summary = InputSummary(
path=config.inputs.reference,
frame_count=reference_preds.metadata.frame_count,
fps=reference_preds.metadata.fps,
provenance=reference_preds.provenance,
)
return AnalysisReport(
config=config,
provenance=report_provenance,
primary=primary_summary,
reference=reference_summary,
segmentations=primary_segmentations,
results=results,
)
# ---------------------------------------------------------------------------
# Internal dispatch helpers
# ---------------------------------------------------------------------------
def _run_segmentation(
predictions: VideoPredictions,
stage: SegmentationStage, # type: ignore[valid-type]
person_index: int,
) -> dict[str, Segmentation]:
"""Apply a segmentation stage to a :class:`VideoPredictions`.
Returns a dict keyed by a stage-appropriate label: single-side
gait cycles use ``"<joint>_cycles"``, bilateral gait cycles use
``"left_heel_strikes"`` / ``"right_heel_strikes"``, and extractor
segmentation uses the caller-supplied
:attr:`~ExtractorSegmentation.label`.
"""
if isinstance(stage, GaitCyclesSegmentation):
seg = segment_gait_cycles(
predictions,
joint=stage.joint,
axis=stage.axis,
invert=stage.invert,
min_cycle_seconds=stage.min_cycle_seconds,
min_prominence=stage.min_prominence,
)
return {f"{stage.joint}_cycles": seg}
if isinstance(stage, GaitCyclesBilateralSegmentation):
return segment_gait_cycles_bilateral(
predictions,
axis=stage.axis,
invert=stage.invert,
min_cycle_seconds=stage.min_cycle_seconds,
min_prominence=stage.min_prominence,
)
if isinstance(stage, ExtractorSegmentation):
effective_person_index = (
stage.person_index if stage.person_index is not None else person_index
)
seg = segment_predictions(
predictions,
stage.extractor,
person_index=effective_person_index,
min_distance_seconds=stage.min_distance_seconds,
min_prominence=stage.min_prominence,
min_height=stage.min_height,
pad_seconds=stage.pad_seconds,
)
return {stage.label: seg}
raise TypeError(f"unknown segmentation stage: {type(stage).__name__}")
def _run_analysis_stage(
stage: AnalysisStage, # type: ignore[valid-type]
*,
primary_seq: np.ndarray,
reference_seq: np.ndarray | None,
primary_segmentations: dict[str, Segmentation],
reference_segmentations: dict[str, Segmentation],
) -> AnalysisResults: # type: ignore[valid-type]
"""Dispatch to the appropriate analysis executor per ``stage.kind``."""
if isinstance(stage, DtwAnalysis):
if reference_seq is None:
# AnalysisConfig's cross-stage validator should prevent
# this; duplicate the check here so a direct programmatic
# call can't slip through.
raise ValueError("DtwAnalysis requires a reference sequence")
return _run_dtw(
stage,
primary_seq=primary_seq,
reference_seq=reference_seq,
primary_segmentations=primary_segmentations,
reference_segmentations=reference_segmentations,
)
if isinstance(stage, StatsAnalysis):
return _run_stats(
stage,
primary_seq=primary_seq,
primary_segmentations=primary_segmentations,
)
if isinstance(stage, NoAnalysis):
return NoResults(kind="none")
raise TypeError(f"unknown analysis stage: {type(stage).__name__}")
def _run_dtw(
stage: DtwAnalysis,
*,
primary_seq: np.ndarray,
reference_seq: np.ndarray,
primary_segmentations: dict[str, Segmentation],
reference_segmentations: dict[str, Segmentation],
) -> DtwResults:
"""Execute a DTW analysis stage, returning :class:`DtwResults`."""
labels: list[str] = []
distances: list[float] = []
paths: list[list[tuple[int, int]]] = []
per_joint_distances: list[list[float]] | None = [] if stage.method == "dtw_per_joint" else None
pairs: list[tuple[str, np.ndarray, np.ndarray]] = []
if primary_segmentations:
for key, primary_seg in primary_segmentations.items():
reference_seg = reference_segmentations.get(key)
if reference_seg is None:
# Same config was applied to both, so this should not
# happen unless the segmentation depends on the input
# length in some unexpected way. Skip with a warning
# rather than crash the whole run.
continue
pair_count = min(len(primary_seg.segments), len(reference_seg.segments))
for i in range(pair_count):
p_seg = primary_seg.segments[i]
r_seg = reference_seg.segments[i]
pairs.append(
(
f"{key}[{i}]",
primary_seq[p_seg.start : p_seg.end],
reference_seq[r_seg.start : r_seg.end],
)
)
else:
pairs.append(("full_trial", primary_seq, reference_seq))
for label, primary_slice, reference_slice in pairs:
labels.append(label)
if stage.method == "dtw_all":
result = dtw_all(
primary_slice,
reference_slice,
align=stage.align,
representation=stage.representation,
angle_triplets=stage.angle_triplets,
nan_policy=stage.nan_policy,
)
distances.append(result.distance)
paths.append(result.path)
elif stage.method == "dtw_per_joint":
assert per_joint_distances is not None
per_joint_results = dtw_per_joint(
primary_slice,
reference_slice,
align=stage.align,
representation=stage.representation,
angle_triplets=stage.angle_triplets,
nan_policy=stage.nan_policy,
)
# "distance" for a per-joint run is the sum across units;
# "per_joint_distances" carries the full breakdown.
per_unit = [r.distance for r in per_joint_results]
distances.append(float(sum(per_unit)))
per_joint_distances.append(per_unit)
# Store just the first joint's path as a representative —
# per-joint paths are a list of equal length, but
# reporting all of them on disk is almost always overkill.
paths.append(per_joint_results[0].path if per_joint_results else [])
else: # "dtw_relation"
assert stage.joint_i is not None
assert stage.joint_j is not None
result = _invoke_dtw_relation(
primary_slice,
reference_slice,
joint_i=stage.joint_i,
joint_j=stage.joint_j,
align=stage.align,
nan_policy=stage.nan_policy,
)
distances.append(result.distance)
paths.append(result.path)
return DtwResults(
kind="dtw",
method=stage.method,
distances=distances,
paths=paths,
per_joint_distances=per_joint_distances,
segment_labels=labels,
summary=_summarize_distances(distances),
)
def _invoke_dtw_relation(
primary_slice: np.ndarray,
reference_slice: np.ndarray,
*,
joint_i: int,
joint_j: int,
align: AlignMode,
nan_policy: NanPolicy,
) -> DTWResult:
"""Isolating thin wrapper so test fakes can replace the call site cleanly."""
return dtw_relation(
primary_slice,
reference_slice,
joint_i,
joint_j,
align=align,
nan_policy=nan_policy,
)
def _run_stats(
stage: StatsAnalysis,
*,
primary_seq: np.ndarray,
primary_segmentations: dict[str, Segmentation],
) -> StatsResults:
"""Execute a stats analysis stage, returning :class:`StatsResults`."""
labels: list[str] = []
stats: list[FeatureSummary] = []
if primary_segmentations:
for key, seg in primary_segmentations.items():
for i, segment in enumerate(seg.segments):
labels.append(f"{key}[{i}]")
signal = extract_signal(
primary_seq[segment.start : segment.end],
stage.extractor,
)
stats.append(_feature_summary(signal))
else:
labels.append("full_trial")
signal = extract_signal(primary_seq, stage.extractor)
stats.append(_feature_summary(signal))
return StatsResults(kind="stats", statistics=stats, segment_labels=labels)
def _feature_summary(signal: np.ndarray) -> FeatureSummary:
"""Wrap :func:`extract_feature_statistics` output in a pydantic model."""
raw = extract_feature_statistics(signal)
return FeatureSummary(
mean=raw.mean,
std=raw.std,
min=raw.min,
max=raw.max,
range=raw.range,
)
def _summarize_distances(distances: list[float]) -> dict[str, float]:
"""Compute mean / p50 / p95 / p99 of a distance list.
Empty inputs return an empty dict so the report's ``summary``
field still round-trips through JSON without special cases.
"""
if not distances:
return {}
arr = np.asarray(distances, dtype=float)
return {
"mean": float(arr.mean()),
"p50": float(np.percentile(arr, 50)),
"p95": float(np.percentile(arr, 95)),
"p99": float(np.percentile(arr, 99)),
}

View File

@ -30,6 +30,12 @@ Three layers of API are provided, in increasing order of convenience:
:class:`~neuropose.io.ExtractorSpec`, converts time-based parameters
to frame counts using ``metadata.fps``, and returns a full
:class:`~neuropose.io.Segmentation` ready to attach to the predictions.
- :func:`segment_gait_cycles` and :func:`segment_gait_cycles_bilateral`
clinical convenience wrappers over :func:`segment_predictions`
that pre-fill a :func:`joint_axis` extractor with gait-appropriate
defaults (heel joint, Y axis, 0.4 s minimum cycle). The bilateral
variant returns both sides under ``"left_heel_strikes"`` and
``"right_heel_strikes"`` keys.
- :func:`slice_predictions` split a :class:`~neuropose.io.VideoPredictions`
into one per-repetition :class:`~neuropose.io.VideoPredictions`,
useful when downstream code wants per-rep objects rather than windows
@ -66,6 +72,7 @@ installed; a clear :class:`ImportError` surfaces at the first call to
from __future__ import annotations
from collections.abc import Sequence
from typing import Literal
import numpy as np
@ -83,6 +90,11 @@ from neuropose.io import (
VideoPredictions,
)
AxisLetter = Literal["x", "y", "z"]
"""Axis selector used by gait-cycle segmentation helpers."""
_AXIS_INDICES: dict[AxisLetter, int] = {"x": 0, "y": 1, "z": 2}
# ---------------------------------------------------------------------------
# berkeley_mhad_43 joint names
# ---------------------------------------------------------------------------
@ -522,6 +534,156 @@ def segment_predictions(
return Segmentation(config=config, segments=segments)
# ---------------------------------------------------------------------------
# Gait-cycle segmentation
# ---------------------------------------------------------------------------
def segment_gait_cycles(
predictions: VideoPredictions,
*,
joint: str = "rhee",
axis: AxisLetter = "y",
invert: bool = False,
min_cycle_seconds: float = 0.4,
min_prominence: float | None = None,
) -> Segmentation:
"""Segment gait cycles from a single heel's vertical trace.
Runs valley-to-valley peak detection (the same engine used by
:func:`segment_predictions`) on the chosen joint's coordinate along
the chosen spatial axis. By default, each detected peak corresponds
to one heel-strike the frame where the heel reaches its lowest
point on the Y-down MeTRAbs world-coordinate convention and the
returned :class:`~neuropose.io.Segment` windows span one full gait
cycle from the preceding toe-off valley to the following toe-off
valley.
The function is a **thin wrapper** over :func:`segment_predictions`
with a :func:`joint_axis` extractor; it exists to give clinical
callers a gait-specific entry point with meaningful defaults
(``joint="rhee"``, ``axis="y"``, ``min_cycle_seconds=0.4``)
rather than forcing them to construct the extractor by hand.
Parameters
----------
predictions
Per-video predictions to segment. ``metadata.fps`` is used to
translate ``min_cycle_seconds`` into a sample-count distance
threshold.
joint
Joint name in the berkeley_mhad_43 skeleton typically
``"rhee"`` (right heel) or ``"lhee"`` (left heel). Resolved
via :func:`joint_index`.
axis
Spatial axis to track, as ``"x"``, ``"y"``, or ``"z"``. The
default ``"y"`` matches the vertical axis in MeTRAbs's output
(Y-down world coordinates).
invert
If ``True``, negate the extracted signal so that minima
become peaks. Needed when the recording convention makes a
heel-strike appear as a *decrease* in the chosen coordinate
for example, a camera orientation where the vertical axis
runs bottom-to-top instead of MeTRAbs's default top-to-bottom.
min_cycle_seconds
Minimum gait-cycle duration. Used as scipy's
``find_peaks(distance=...)`` parameter after conversion to
frame count via ``metadata.fps``. Defaults to ``0.4`` seconds,
which rejects noise peaks on even the fastest human gaits
(~120 strides/min) while retaining every real cadence.
min_prominence
Forwarded to :func:`segment_by_peaks` to filter out shallow
local maxima that aren't real heel-strikes. In MeTRAbs units
(millimetres) a threshold of 20 to 50 mm is typical for
able-bodied gait; leave ``None`` to accept every peak scipy
identifies.
Returns
-------
Segmentation
A :class:`~neuropose.io.Segmentation` paired with the full
:class:`~neuropose.io.SegmentationConfig` that produced it, so
the output is self-describing when persisted. The segments
list is **empty** rather than an exception when no peaks are
detected a common outcome for shuffling gaits or
walker-assisted trials.
Raises
------
KeyError
If ``joint`` is not a known berkeley_mhad_43 joint name.
ValueError
If ``axis`` is not one of ``"x"``, ``"y"``, ``"z"``, or if
``predictions`` has zero frames, or if ``metadata.fps`` is
non-positive.
ImportError
If :mod:`scipy` is not installed.
"""
if axis not in _AXIS_INDICES:
raise ValueError(f"axis must be one of 'x', 'y', 'z'; got {axis!r}")
joint_idx = joint_index(joint)
axis_idx = _AXIS_INDICES[axis]
extractor = joint_axis(joint_idx, axis_idx, invert=invert)
return segment_predictions(
predictions,
extractor,
min_distance_seconds=min_cycle_seconds,
min_prominence=min_prominence,
)
def segment_gait_cycles_bilateral(
predictions: VideoPredictions,
*,
axis: AxisLetter = "y",
invert: bool = False,
min_cycle_seconds: float = 0.4,
min_prominence: float | None = None,
) -> dict[str, Segmentation]:
"""Segment gait cycles for both heels.
Runs :func:`segment_gait_cycles` twice once with ``joint="lhee"``
and once with ``joint="rhee"`` and returns the two results under
the keys ``"left_heel_strikes"`` and ``"right_heel_strikes"``. The
returned dict is shape-compatible with
:class:`~neuropose.io.VideoPredictions.segmentations` so it can be
merged directly into a predictions object and persisted to
``results.json`` via the usual save path.
Parameters
----------
predictions, axis, invert, min_cycle_seconds, min_prominence
Forwarded to :func:`segment_gait_cycles`; see that function's
docstring for details.
Returns
-------
dict[str, Segmentation]
Two-keyed mapping with the left and right heel segmentations
under ``"left_heel_strikes"`` and ``"right_heel_strikes"``.
Either side may carry an empty segments list if its heel's
trace contained no detectable strikes.
"""
return {
"left_heel_strikes": segment_gait_cycles(
predictions,
joint="lhee",
axis=axis,
invert=invert,
min_cycle_seconds=min_cycle_seconds,
min_prominence=min_prominence,
),
"right_heel_strikes": segment_gait_cycles(
predictions,
joint="rhee",
axis=axis,
invert=invert,
min_cycle_seconds=min_cycle_seconds,
min_prominence=min_prominence,
),
}
# ---------------------------------------------------------------------------
# Slicing: one VideoPredictions per segment
# ---------------------------------------------------------------------------

View File

@ -105,9 +105,17 @@ def run_benchmark(
passes: list[PerformanceMetrics] = []
reference_predictions: VideoPredictions | None = None
# Provenance is identical across every pass of a single run (same
# estimator, same model, same environment), so we keep just the
# latest one we see. Doing this on every iteration is cheap — it's
# one attribute read — and means the benchmark result carries
# provenance even when ``capture_reference`` is off.
latest_provenance = None
for i in range(repeats):
result = estimator.process_video(video_path)
passes.append(result.metrics)
if result.predictions.provenance is not None:
latest_provenance = result.predictions.provenance
# Only the *last* measured pass needs to be captured for
# divergence comparison. Earlier passes would just be
# overwritten, so we avoid holding their frame dicts in memory.
@ -122,6 +130,7 @@ def run_benchmark(
warmup_pass=passes[0],
measured_passes=passes[1:],
aggregate=aggregate,
provenance=latest_provenance,
)
return BenchmarkRunOutcome(
result=benchmark_result,

View File

@ -1,6 +1,6 @@
"""NeuroPose command-line interface.
Seven subcommands:
Eight subcommands:
- ``neuropose watch`` run the :class:`~neuropose.interfacer.Interfacer`
daemon against the configured input directory.
@ -12,6 +12,10 @@ Seven subcommands:
- ``neuropose serve`` start the :mod:`~neuropose.monitor` localhost
HTTP dashboard so collaborators can watch a run's progress in a
browser or via ``curl``.
- ``neuropose reset`` stop the daemon and monitor, then wipe pipeline
state (input queue, results, status file, lock file, ingest staging
dirs) for a clean restart. See :mod:`neuropose.reset` for the layered
implementation.
- ``neuropose segment <results>`` post-hoc repetition segmentation of
an existing predictions file. Attaches a named
:class:`~neuropose.io.Segmentation` to every video it contains and
@ -21,8 +25,11 @@ Seven subcommands:
vs CPU numerical-divergence checks. Prints a human report to stdout
and (optionally) writes a structured :class:`~neuropose.io.BenchmarkResult`
JSON to ``--output``.
- ``neuropose analyze <results>`` stubbed placeholder pending the
analyzer rewrite in commit 10.
- ``neuropose analyze --config <yaml>`` run the declarative analysis
pipeline described in a YAML config. Loads the named predictions
files, applies segmentation + analysis, writes an
:class:`~neuropose.analyzer.pipeline.AnalysisReport` JSON. See
``examples/analysis/*.yaml`` for runnable references.
User-facing error handling
--------------------------
@ -53,6 +60,7 @@ from pathlib import Path
from typing import Annotated
import typer
import yaml
from pydantic import ValidationError
from neuropose import __version__
@ -426,6 +434,151 @@ def serve(
raise typer.Exit(code=EXIT_USAGE) from exc
# ---------------------------------------------------------------------------
# reset
# ---------------------------------------------------------------------------
@app.command()
def reset(
ctx: typer.Context,
yes: Annotated[
bool,
typer.Option(
"--yes",
"-y",
help="Skip the interactive confirmation prompt.",
),
] = False,
keep_failed: Annotated[
bool,
typer.Option(
"--keep-failed",
help=(
"Preserve $data_dir/failed/ for forensic review. By "
"default the failed-job quarantine is wiped along with "
"in/ and out/."
),
),
] = False,
force_kill: Annotated[
bool,
typer.Option(
"--force-kill",
help=(
"Escalate to SIGKILL on any daemon or monitor still "
"alive after the SIGINT grace period. Necessary if the "
"daemon is mid-inference on a long video and you do "
"not want to wait for the current video to finish."
),
),
] = False,
grace_seconds: Annotated[
float,
typer.Option(
"--grace-seconds",
min=0.0,
help=(
"Seconds to wait after SIGINT before declaring a "
"process a survivor (or escalating to SIGKILL when "
"--force-kill is set)."
),
),
] = 10.0,
dry_run: Annotated[
bool,
typer.Option(
"--dry-run",
"-n",
help="Show what would be killed and removed without doing it.",
),
] = False,
) -> None:
"""Stop the daemon and monitor, then wipe pipeline state.
Discovers running ``neuropose watch`` and ``neuropose serve``
processes, sends SIGINT, waits ``--grace-seconds`` for graceful
shutdown (optionally escalating to SIGKILL with ``--force-kill``),
then removes the contents of ``$data_dir/in/``, ``$data_dir/out/``
(including ``status.json``), ``$data_dir/failed/`` (unless
``--keep-failed``), the daemon lock file, and any leftover
``.ingest_<uuid>/`` staging directories from interrupted ingests.
Refuses to wipe state if any process survives the termination
phase wiping the data directory out from under an active daemon
would leave it writing into deleted directory entries. Re-run
with ``--force-kill`` or stop the survivor manually.
"""
# Deferred import so reset's psutil scan stays off the watch/process
# hot path. psutil is already a runtime dependency for benchmark
# metrics, so this import is free at install time.
from neuropose.reset import find_neuropose_processes, reset_pipeline, wipe_state
settings: Settings = ctx.obj
discovered = find_neuropose_processes()
preview = wipe_state(settings, keep_failed=keep_failed, dry_run=True)
typer.echo(f"data dir: {settings.data_dir}")
if discovered:
typer.echo(f"would stop: {len(discovered)} process(es)")
for rp in discovered:
typer.echo(f" pid {rp.pid:>7} {rp.role:<7} {rp.cmdline}")
else:
typer.echo("would stop: no daemon or monitor running")
if preview.removed_paths:
size_mb = preview.bytes_freed / (1024 * 1024)
typer.echo(f"would remove: {len(preview.removed_paths)} path(s) ({size_mb:.1f} MB)")
for path in preview.removed_paths:
typer.echo(f" {path}")
else:
typer.echo("would remove: nothing — data dir is already clean")
if dry_run:
typer.echo("(dry-run; no changes made)")
return
if not discovered and not preview.removed_paths:
typer.echo("nothing to do.")
return
if not yes and not typer.confirm("\nproceed?"):
typer.echo("aborted.")
raise typer.Exit(code=EXIT_USAGE)
report = reset_pipeline(
settings,
grace_seconds=grace_seconds,
force_kill=force_kill,
keep_failed=keep_failed,
)
if report.termination.stopped:
typer.echo(f"stopped {len(report.termination.stopped)} process(es) via SIGINT")
if report.termination.force_killed:
typer.echo(
f"force-killed {len(report.termination.force_killed)} process(es) "
f"after {grace_seconds:.0f}s grace period"
)
if report.termination.survivors:
typer.echo(
f"error: {len(report.termination.survivors)} process(es) did not exit:",
err=True,
)
for rp in report.termination.survivors:
typer.echo(f" pid {rp.pid} ({rp.role})", err=True)
if report.wipe_skipped_due_to_survivors:
typer.echo(
" state on disk was NOT wiped — re-run with --force-kill, "
"or stop these processes manually first.",
err=True,
)
raise typer.Exit(code=EXIT_USAGE)
size_mb = report.wipe.bytes_freed / (1024 * 1024)
typer.echo(f"removed {len(report.wipe.removed_paths)} path(s) ({size_mb:.1f} MB freed)")
# ---------------------------------------------------------------------------
# segment
# ---------------------------------------------------------------------------
@ -1047,26 +1200,94 @@ def benchmark(
# ---------------------------------------------------------------------------
# analyze (stub)
# analyze
# ---------------------------------------------------------------------------
@app.command()
def analyze(
ctx: typer.Context,
results: Annotated[
config: Annotated[
Path,
typer.Argument(help="Path to a results.json produced by watch or process."),
typer.Option(
"--config",
"-c",
help=(
"Path to a YAML AnalysisConfig file. See examples/analysis/ "
"for runnable references."
),
),
],
output: Annotated[
Path | None,
typer.Option(
"--output",
"-o",
help=(
"Override the report path declared in the config's "
"output.report field. Useful when running the same config "
"against multiple input pairs from a shell loop."
),
),
] = None,
) -> None:
"""Run the analyzer subpackage against a results.json (pending commit 10)."""
del ctx, results
typer.echo(
"error: the analyzer subpackage is pending commit 10. "
"Until it lands, use neuropose.io to load results.json from Python.",
err=True,
)
raise typer.Exit(code=EXIT_PENDING)
"""Run the declarative analysis pipeline described by a YAML config.
Loads the config, parses it through
:class:`~neuropose.analyzer.pipeline.AnalysisConfig` (so typos fail
immediately with a clear error), executes the pipeline via
:func:`~neuropose.analyzer.pipeline.run_analysis`, and writes the
resulting :class:`~neuropose.analyzer.pipeline.AnalysisReport` to
``--output`` (or to ``output.report`` declared in the config).
Cross-field invariants (for example,
``method='dtw_relation'`` requires ``joint_i`` / ``joint_j``) are
enforced at parse time, so a typo fails before any predictions
are loaded.
"""
del ctx
# Deferred import keeps the CLI module's top-level imports free of
# pipeline dependencies so ``watch`` / ``process`` startup stays
# cheap.
from neuropose.analyzer.pipeline import load_config, run_analysis, save_report
if not config.exists():
typer.echo(f"error: config file not found: {config}", err=True)
raise typer.Exit(code=EXIT_USAGE)
try:
analysis_config = load_config(config)
except ValidationError as exc:
typer.echo(f"error: invalid config {config}:\n{exc}", err=True)
raise typer.Exit(code=EXIT_USAGE) from exc
except yaml.YAMLError as exc:
typer.echo(f"error: could not parse YAML {config}: {exc}", err=True)
raise typer.Exit(code=EXIT_USAGE) from exc
report_path = output if output is not None else analysis_config.output.report
try:
report = run_analysis(analysis_config)
except (FileNotFoundError, ValueError) as exc:
typer.echo(f"error: analysis failed: {exc}", err=True)
raise typer.Exit(code=EXIT_USAGE) from exc
save_report(report_path, report)
typer.echo(f"wrote analysis report to {report_path}")
if report.segmentations:
seg_summary = ", ".join(
f"{name}={len(seg.segments)}" for name, seg in report.segmentations.items()
)
typer.echo(f"segmentations: {seg_summary}")
# Emit a one-line summary of the results regardless of kind.
typer.echo(f"analysis kind: {report.results.kind}")
if report.results.kind == "dtw":
n = len(report.results.distances)
mean = report.results.summary.get("mean", float("nan"))
typer.echo(f"distances computed: {n} (mean={mean:.4f})")
elif report.results.kind == "stats":
typer.echo(f"statistic blocks computed: {len(report.results.statistics)}")
def run() -> None:

View File

@ -34,19 +34,25 @@ model is present raises :class:`ModelNotLoadedError`.
from __future__ import annotations
import logging
import sys
import time
from collections.abc import Callable
from dataclasses import dataclass, field
from importlib.metadata import PackageNotFoundError
from importlib.metadata import version as _pkg_version
from pathlib import Path
from typing import Any
import cv2
import numpy as np
import psutil
from neuropose import __version__ as _neuropose_version
from neuropose._model import load_metrabs_model
from neuropose.io import (
FramePrediction,
PerformanceMetrics,
Provenance,
VideoMetadata,
VideoPredictions,
)
@ -158,6 +164,12 @@ class Estimator:
# successful ``load_model`` below so the next ``process_video`` can
# pass the real number through into ``PerformanceMetrics``.
self._model_load_seconds: float | None = None
# MeTRAbs artifact identity, set only by ``load_model``. When the
# model was injected via the constructor we have no way to
# fingerprint it, so these remain ``None`` and ``process_video``
# leaves the output's ``provenance`` as ``None`` too.
self._model_sha256: str | None = None
self._model_filename: str | None = None
# -- model lifecycle ----------------------------------------------------
@ -176,6 +188,21 @@ class Estimator:
"""Return ``True`` if a model has been supplied or loaded."""
return self._model is not None
@property
def model_sha256(self) -> str | None:
"""Return the SHA-256 of the loaded MeTRAbs artifact, or ``None``.
``None`` when the model was injected via ``Estimator(model=...)``
rather than loaded via :meth:`load_model`. The value, when
present, is the module-pinned SHA from :mod:`neuropose._model`.
"""
return self._model_sha256
@property
def model_filename(self) -> str | None:
"""Return the basename of the MeTRAbs artifact, or ``None`` if injected."""
return self._model_filename
def load_model(self, cache_dir: Path | None = None) -> None:
"""Load the MeTRAbs model via :func:`neuropose._model.load_metrabs_model`.
@ -196,9 +223,16 @@ class Estimator:
return
logger.info("Loading MeTRAbs model (cache_dir=%s)", cache_dir)
start = time.perf_counter()
self._model = load_metrabs_model(cache_dir=cache_dir)
loaded = load_metrabs_model(cache_dir=cache_dir)
self._model_load_seconds = time.perf_counter() - start
logger.info("MeTRAbs model loaded in %.2f s", self._model_load_seconds)
self._model = loaded.model
self._model_sha256 = loaded.sha256
self._model_filename = loaded.filename
logger.info(
"MeTRAbs model loaded in %.2f s (sha256=%s)",
self._model_load_seconds,
loaded.sha256[:12],
)
# -- inference ----------------------------------------------------------
@ -330,11 +364,53 @@ class Estimator:
metrics.active_device,
)
predictions = VideoPredictions(metadata=metadata, frames=frames)
provenance = self._build_provenance(device_info=device_info)
predictions = VideoPredictions(
metadata=metadata,
frames=frames,
provenance=provenance,
)
return ProcessVideoResult(predictions=predictions, metrics=metrics)
# -- internals ----------------------------------------------------------
def _build_provenance(self, *, device_info: _ActiveDeviceInfo) -> Provenance | None:
"""Construct a :class:`~neuropose.io.Provenance` for the current run.
Returns ``None`` when the model was injected via the constructor
rather than loaded via :meth:`load_model` in that case we
cannot fingerprint the artifact, and a partial provenance would
mislead readers into thinking we could.
The device-info bundle is shared with the :class:`PerformanceMetrics`
construction (one call to :func:`_detect_active_device` per
``process_video`` invocation) so that both artifacts see
identical TF and Metal state.
"""
if self._model_sha256 is None or self._model_filename is None:
return None
metal_version: str | None = None
if device_info.metal_active:
try:
metal_version = _pkg_version("tensorflow-metal")
except PackageNotFoundError:
metal_version = None
python_version = (
f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
)
return Provenance(
model_sha256=self._model_sha256,
model_filename=self._model_filename,
tensorflow_version=device_info.tf_version,
tensorflow_metal_version=metal_version,
numpy_version=np.__version__,
neuropose_version=_neuropose_version,
python_version=python_version,
)
def _infer_frame(
self,
model: Any,

View File

@ -10,6 +10,14 @@ Atomicity: :func:`save_status`, :func:`save_job_results`, and
atomically rename, so a crash mid-write will not leave a partially-written
file behind. This matches the crash-resilience guarantee the interfacer
daemon makes to callers.
Schema versioning: :class:`VideoPredictions` and :class:`BenchmarkResult`
each carry a ``schema_version`` integer. On load, the raw JSON dict is
passed through :mod:`neuropose.migrations` before pydantic validation so
that files written by earlier versions upgrade transparently. :class:`JobResults`
is a ``RootModel`` with no envelope of its own, so its loader runs the
per-video migration on each entry of its mapping. See
:mod:`neuropose.migrations` for the migration-registration pattern.
"""
from __future__ import annotations
@ -23,6 +31,13 @@ from typing import Annotated, Any, Literal
from pydantic import BaseModel, ConfigDict, Field, RootModel, model_validator
from neuropose.migrations import (
CURRENT_VERSION,
migrate_benchmark_result,
migrate_job_results,
migrate_video_predictions,
)
class JobStatus(StrEnum):
"""Lifecycle state of a single processing job."""
@ -157,6 +172,104 @@ class PerformanceMetrics(BaseModel):
)
class Provenance(BaseModel):
"""Reproducibility-grade record of the environment that produced a payload.
Populated by the estimator on every inference run when the MeTRAbs
model was loaded through
:meth:`neuropose.estimator.Estimator.load_model` (the production
path). ``None`` when the model was injected directly via the
``Estimator(model=...)`` constructor (the test-fixture path), since
NeuroPose has no way to fingerprint a model it did not load itself.
Paper C's reproducibility story rests on this envelope: two runs
that produced equal ``Provenance`` objects against the same input
are expected to produce equal output (modulo non-determinism
controlled by ``deterministic``). Reviewers who want to re-derive a
figure from raw video need exactly these fields.
Frozen so a captured ``Provenance`` cannot be mutated after it has
been attached to a result; this matches the invariant that
provenance is a property of the run, not of the reader.
``protected_namespaces=()`` silences pydantic's ``model_*`` field
warning the ``model_sha256`` / ``model_filename`` names refer to
the MeTRAbs model artifact, not to pydantic's internal
``model_validate`` / ``model_dump`` namespace, so the collision is
cosmetic.
"""
model_config = ConfigDict(extra="forbid", frozen=True, protected_namespaces=())
model_sha256: str = Field(
description=(
"SHA-256 of the MeTRAbs model tarball (hex-encoded, lowercase). "
"Pinned at build time in :mod:`neuropose._model` and verified on "
"first download. Identifies the exact model weights used."
),
)
model_filename: str = Field(
description=(
"Canonical basename of the MeTRAbs tarball, e.g. "
"``metrabs_eff2l_y4_384px_800k_28ds.tar.gz``. Human-readable "
"companion to ``model_sha256``."
),
)
tensorflow_version: str = Field(
description="Value of ``tensorflow.__version__`` at the time of the run.",
)
tensorflow_metal_version: str | None = Field(
default=None,
description=(
"Version of the ``tensorflow-metal`` PyPI package when installed; "
"``None`` on platforms without Metal GPU acceleration."
),
)
numpy_version: str = Field(
description="Value of ``numpy.__version__`` at the time of the run.",
)
neuropose_version: str = Field(
description="Value of ``neuropose.__version__`` at the time of the run.",
)
python_version: str = Field(
description=(
"Python version as ``MAJOR.MINOR.MICRO``, e.g. ``3.11.14``. The "
"full ``sys.version`` string is intentionally not captured; the "
"three-component form is stable across patch builds and avoids "
"embedding compiler and build-date metadata."
),
)
seed: int | None = Field(
default=None,
description=(
"Random seed used for the run if one was set, else ``None``. "
"MeTRAbs inference is deterministic on a given device up to "
"floating-point associativity, so seeding mostly matters for "
"downstream analysis that introduces randomness (bootstraps, "
"learned metrics)."
),
)
deterministic: bool = Field(
default=False,
description=(
"``True`` if ``tf.config.experimental.enable_op_determinism()`` "
"was active during the run. Track 2 deterministic-inference "
"mode; the field exists in Phase 0 so payloads can record "
"whether the run *was* deterministic without requiring a "
"schema change when the toggle lands."
),
)
analysis_config: dict[str, Any] | None = Field(
default=None,
description=(
"Parsed YAML dict if this payload was produced by ``neuropose "
"analyze --config <file>``. ``None`` for direct-library or "
"``neuropose watch`` invocations. Reserved for the Phase 0 "
"YAML-configurable analysis pipeline."
),
)
class BenchmarkAggregate(BaseModel):
"""Distributional statistics aggregated across benchmark passes.
@ -255,6 +368,16 @@ class BenchmarkResult(BaseModel):
model_config = ConfigDict(extra="forbid", frozen=True)
schema_version: int = Field(
default=CURRENT_VERSION,
ge=1,
description=(
"Schema version of this BenchmarkResult payload. Fresh writes "
"stamp :data:`neuropose.migrations.CURRENT_VERSION`; older files "
"are migrated on load via :mod:`neuropose.migrations` before "
"pydantic validation."
),
)
video_name: str = Field(
description="Basename of the benchmarked video (no directory components).",
)
@ -280,6 +403,14 @@ class BenchmarkResult(BaseModel):
)
aggregate: BenchmarkAggregate
cpu_comparison: CpuComparisonResult | None = None
provenance: Provenance | None = Field(
default=None,
description=(
"Reproducibility envelope from the benchmark run. ``None`` on "
"tests where the model was injected directly via "
"``Estimator(model=...)``."
),
)
class JointAxisExtractor(BaseModel):
@ -469,9 +600,30 @@ class VideoPredictions(BaseModel):
model_config = ConfigDict(extra="forbid", frozen=True)
schema_version: int = Field(
default=CURRENT_VERSION,
ge=1,
description=(
"Schema version of this VideoPredictions payload. Fresh writes "
"stamp :data:`neuropose.migrations.CURRENT_VERSION`; files written "
"by older NeuroPose versions are migrated to the current version "
"by :mod:`neuropose.migrations` before pydantic validation."
),
)
metadata: VideoMetadata
frames: dict[str, FramePrediction]
segmentations: dict[str, Segmentation] = Field(default_factory=dict)
provenance: Provenance | None = Field(
default=None,
description=(
"Reproducibility envelope populated by the estimator on runs "
"where the MeTRAbs model was loaded via "
":meth:`neuropose.estimator.Estimator.load_model`. ``None`` on "
"test paths where the model was injected via "
"``Estimator(model=...)``, because no model SHA is known in "
"that case."
),
)
def frame_names(self) -> list[str]:
"""Return frame identifiers in insertion order."""
@ -623,9 +775,16 @@ class StatusFile(RootModel[dict[str, JobStatusEntry]]):
def load_video_predictions(path: Path) -> VideoPredictions:
"""Load and validate a per-video predictions JSON file."""
"""Load and validate a per-video predictions JSON file.
Runs the payload through :func:`neuropose.migrations.migrate_video_predictions`
before pydantic validation so files written by older NeuroPose versions
upgrade to the current schema transparently.
"""
with path.open("r", encoding="utf-8") as f:
data: Any = json.load(f)
if isinstance(data, dict):
data = migrate_video_predictions(data)
return VideoPredictions.model_validate(data)
@ -636,9 +795,17 @@ def save_video_predictions(path: Path, predictions: VideoPredictions) -> None:
def load_job_results(path: Path) -> JobResults:
"""Load and validate an aggregated per-job results JSON file."""
"""Load and validate an aggregated per-job results JSON file.
Runs each video's payload through
:func:`neuropose.migrations.migrate_video_predictions` before pydantic
validation. :class:`JobResults` is a ``RootModel`` with no envelope of
its own, so migration happens per-entry rather than at the top level.
"""
with path.open("r", encoding="utf-8") as f:
data: Any = json.load(f)
if isinstance(data, dict):
data = migrate_job_results(data)
return JobResults.model_validate(data)
@ -649,9 +816,16 @@ def save_job_results(path: Path, results: JobResults) -> None:
def load_benchmark_result(path: Path) -> BenchmarkResult:
"""Load and validate a benchmark-result JSON file."""
"""Load and validate a benchmark-result JSON file.
Runs the payload through :func:`neuropose.migrations.migrate_benchmark_result`
before pydantic validation so files written by older NeuroPose versions
upgrade transparently.
"""
with path.open("r", encoding="utf-8") as f:
data: Any = json.load(f)
if isinstance(data, dict):
data = migrate_benchmark_result(data)
return BenchmarkResult.model_validate(data)

318
src/neuropose/migrations.py Normal file
View File

@ -0,0 +1,318 @@
"""Schema migration infrastructure for serialised NeuroPose payloads.
Every top-level JSON schema that NeuroPose persists to disk
(:class:`~neuropose.io.VideoPredictions`,
:class:`~neuropose.io.JobResults`, and
:class:`~neuropose.io.BenchmarkResult`) carries a ``schema_version``
integer. When those files are read back, the raw dict is passed
through :func:`migrate_video_predictions` /
:func:`migrate_job_results` / :func:`migrate_benchmark_result`
*before* pydantic validation runs, so each schema version can be
brought up to the current one transparently.
The pattern is deliberately small: one integer version counter shared
across all top-level schemas, plus a per-schema registry of
``{from_version: migration_fn}``. Each migration is a pure function
``dict -> dict`` responsible for stamping the new ``schema_version``
on its output. The framework chains them.
This module is intentionally separate from :mod:`neuropose.io` so
that migration registrations cannot accidentally import the pydantic
models they migrate migrations must operate on raw dicts to be
robust to schema drift (a field a migration references may not exist
on the pydantic model by the time CURRENT_VERSION has moved past
it).
Adding a new migration
----------------------
When a schema change lands:
1. Bump :data:`CURRENT_VERSION`.
2. Register a migration from the *previous* version to the new one
via :func:`register_video_predictions_migration` (or the sibling
for benchmark results). The function receives the raw dict at the
old version and must return a dict at the new version *including*
the updated ``schema_version`` stamp.
3. Update the pydantic model in :mod:`neuropose.io` to reflect the
new field set.
4. Add a unit test verifying that a fixture at the old version
round-trips through ``load_*`` to the expected new-version shape.
"""
from __future__ import annotations
import logging
from collections.abc import Callable
logger = logging.getLogger(__name__)
CURRENT_VERSION = 2
"""The current schema version for all NeuroPose-persisted JSON payloads.
Shared across :class:`~neuropose.io.VideoPredictions`,
:class:`~neuropose.io.JobResults`, and
:class:`~neuropose.io.BenchmarkResult` so that coordinated schema
changes (for example, adding a ``provenance`` field to all three at
once) bump a single counter rather than three parallel ones.
Version history
---------------
- **v1:** initial schema, pre-Phase-0.
- **v2:** added optional ``provenance`` field to :class:`~neuropose.io.VideoPredictions`
and :class:`~neuropose.io.BenchmarkResult` (Phase 0, Paper C reproducibility envelope).
:class:`~neuropose.analyzer.pipeline.AnalysisReport` also enters the registry at v2
(no legacy v1 payloads ever existed for it, so no migration is registered)."""
class MigrationError(Exception):
"""Base class for schema-migration failures."""
class FutureSchemaError(MigrationError):
"""Raised when a payload's ``schema_version`` exceeds :data:`CURRENT_VERSION`.
Produced when a newer NeuroPose version writes a file and an older
version tries to read it. The fix is upgrading NeuroPose; silently
stripping fields would corrupt the payload.
"""
class MigrationNotFoundError(MigrationError):
"""Raised when no migration is registered for an intermediate version.
Indicates a bug :data:`CURRENT_VERSION` was bumped past a version
for which no migration function was registered. Should only surface
in tests or on a corrupted install.
"""
# Per-schema registries. Keys are the *source* version of the migration;
# the value is a callable that takes a dict at that version and returns a
# dict at ``source + 1``.
_VIDEO_PREDICTIONS_MIGRATIONS: dict[int, Callable[[dict], dict]] = {}
_BENCHMARK_RESULT_MIGRATIONS: dict[int, Callable[[dict], dict]] = {}
_ANALYSIS_REPORT_MIGRATIONS: dict[int, Callable[[dict], dict]] = {}
def register_video_predictions_migration(
from_version: int,
) -> Callable[[Callable[[dict], dict]], Callable[[dict], dict]]:
"""Register a :class:`~neuropose.io.VideoPredictions` migration.
Usage::
@register_video_predictions_migration(from_version=1)
def _v1_to_v2(payload: dict) -> dict:
payload = dict(payload)
payload["provenance"] = None
payload["schema_version"] = 2
return payload
The decorator registers the function into the per-schema migration
registry and returns it unchanged, so it can still be called
directly from tests.
"""
def wrap(fn: Callable[[dict], dict]) -> Callable[[dict], dict]:
if from_version in _VIDEO_PREDICTIONS_MIGRATIONS:
raise RuntimeError(
f"video-predictions migration already registered from version {from_version}"
)
_VIDEO_PREDICTIONS_MIGRATIONS[from_version] = fn
return fn
return wrap
def register_benchmark_result_migration(
from_version: int,
) -> Callable[[Callable[[dict], dict]], Callable[[dict], dict]]:
"""Register a :class:`~neuropose.io.BenchmarkResult` migration.
See :func:`register_video_predictions_migration` for usage this
is the same pattern for the benchmark-result registry.
"""
def wrap(fn: Callable[[dict], dict]) -> Callable[[dict], dict]:
if from_version in _BENCHMARK_RESULT_MIGRATIONS:
raise RuntimeError(
f"benchmark-result migration already registered from version {from_version}"
)
_BENCHMARK_RESULT_MIGRATIONS[from_version] = fn
return fn
return wrap
def register_analysis_report_migration(
from_version: int,
) -> Callable[[Callable[[dict], dict]], Callable[[dict], dict]]:
"""Register a :class:`~neuropose.analyzer.pipeline.AnalysisReport` migration.
See :func:`register_video_predictions_migration` for usage this
is the same pattern for the analysis-report registry. Unlike the
other two schemas, :class:`AnalysisReport` first appeared at
:data:`CURRENT_VERSION = 2`, so no ``from_version=1`` migration
exists (and none is expected).
"""
def wrap(fn: Callable[[dict], dict]) -> Callable[[dict], dict]:
if from_version in _ANALYSIS_REPORT_MIGRATIONS:
raise RuntimeError(
f"analysis-report migration already registered from version {from_version}"
)
_ANALYSIS_REPORT_MIGRATIONS[from_version] = fn
return fn
return wrap
def migrate_video_predictions(payload: dict) -> dict:
"""Migrate a raw :class:`~neuropose.io.VideoPredictions` dict to current.
Parameters
----------
payload
Raw JSON-loaded dict. Must not yet have been through pydantic
validation. A missing ``schema_version`` key is interpreted as
version ``1`` (the earliest tracked version, shipped before
the migration infrastructure existed).
Returns
-------
dict
The payload at :data:`CURRENT_VERSION`. Ready to be passed to
``VideoPredictions.model_validate``.
Raises
------
FutureSchemaError
If the payload declares a ``schema_version`` higher than
:data:`CURRENT_VERSION`.
MigrationNotFoundError
If an intermediate migration is missing from the registry.
"""
return _migrate(payload, _VIDEO_PREDICTIONS_MIGRATIONS, schema_name="VideoPredictions")
def migrate_benchmark_result(payload: dict) -> dict:
"""Migrate a raw :class:`~neuropose.io.BenchmarkResult` dict to current.
See :func:`migrate_video_predictions` for semantics. This is the
sibling function for benchmark-result payloads.
"""
return _migrate(payload, _BENCHMARK_RESULT_MIGRATIONS, schema_name="BenchmarkResult")
def migrate_analysis_report(payload: dict) -> dict:
"""Migrate a raw :class:`~neuropose.analyzer.pipeline.AnalysisReport` dict.
See :func:`migrate_video_predictions` for semantics. Because
:class:`AnalysisReport` first shipped at schema_version 2, a
payload missing the key still defaults to 1 (and would require a
not-yet-registered v1 v2 migration); this is only reachable for
deliberately malformed inputs.
"""
return _migrate(payload, _ANALYSIS_REPORT_MIGRATIONS, schema_name="AnalysisReport")
def migrate_job_results(payload: dict) -> dict:
"""Migrate a :class:`~neuropose.io.JobResults` root dict to current.
``JobResults`` is a ``RootModel`` whose root is a mapping of video
name to :class:`~neuropose.io.VideoPredictions` payload. It has no
envelope of its own, so the migration is "run
:func:`migrate_video_predictions` on every value in the mapping."
Parameters
----------
payload
Raw JSON-loaded dict of ``{video_name: VideoPredictions-shaped dict}``.
Returns
-------
dict
The same mapping with each video payload migrated to the
current schema version.
"""
return {name: migrate_video_predictions(video) for name, video in payload.items()}
def _migrate(
payload: dict,
migrations: dict[int, Callable[[dict], dict]],
*,
schema_name: str,
) -> dict:
"""Walk the migration chain to :data:`CURRENT_VERSION` and return the migrated payload.
Shared driver for :func:`migrate_video_predictions` and
:func:`migrate_benchmark_result`. Looks up the incoming
``schema_version`` (defaulting to 1 when absent), walks the migration
chain until reaching :data:`CURRENT_VERSION`, and returns the
migrated payload. Logs at INFO each time it actually advances a
version so operators see the upgrade happen.
"""
version = payload.get("schema_version", 1)
if not isinstance(version, int) or version < 1:
raise MigrationError(
f"{schema_name} payload has invalid schema_version {version!r}; must be an integer >= 1"
)
if version > CURRENT_VERSION:
raise FutureSchemaError(
f"{schema_name} payload declares schema_version {version}, which is newer "
f"than this build's CURRENT_VERSION ({CURRENT_VERSION}). Upgrade NeuroPose."
)
while version < CURRENT_VERSION:
if version not in migrations:
raise MigrationNotFoundError(
f"no {schema_name} migration registered from schema_version {version}"
)
logger.info(
"Migrating %s payload from schema_version %d to %d",
schema_name,
version,
version + 1,
)
payload = migrations[version](payload)
version += 1
return payload
# ---------------------------------------------------------------------------
# Registered migrations
# ---------------------------------------------------------------------------
#
# Keep registrations *below* the driver so the module's public API surfaces
# at the top and the version-specific diffs live together at the bottom where
# they are easiest to audit chronologically.
@register_video_predictions_migration(from_version=1)
def _video_predictions_v1_to_v2(payload: dict) -> dict:
"""v1 → v2: add the optional ``provenance`` field (Phase 0).
Phase 0 introduces the :class:`~neuropose.io.Provenance` envelope
for Paper C reproducibility. v1 files predate it, so we stamp
``provenance = None`` on load the field is optional on the
pydantic model and ``None`` correctly indicates "we don't have
provenance metadata for this payload."
"""
payload = dict(payload)
payload.setdefault("provenance", None)
payload["schema_version"] = 2
return payload
@register_benchmark_result_migration(from_version=1)
def _benchmark_result_v1_to_v2(payload: dict) -> dict:
"""v1 → v2: add the optional ``provenance`` field (Phase 0).
Sibling of :func:`_video_predictions_v1_to_v2` for benchmark
payloads; same rationale.
"""
payload = dict(payload)
payload.setdefault("provenance", None)
payload["schema_version"] = 2
return payload

388
src/neuropose/reset.py Normal file
View File

@ -0,0 +1,388 @@
"""Pipeline-wide reset utility.
Tear down a running NeuroPose deployment back to a clean state: stop
the watch daemon and the monitor server, then wipe the job queue,
results, status file, and ingest staging directories. Intended for
the rapid iteration loop that comes up during benchmark and validation
work, where you want an empty ``$data_dir/in`` without manually
``rm -rf``-ing five separate paths and pkill'ing two processes.
The module is split into three independently-callable layers so each
piece is testable in isolation and reusable from non-CLI contexts:
- :func:`find_neuropose_processes` enumerates running ``neuropose
watch`` and ``neuropose serve`` processes by scanning the OS process
table. Pure read; no side effects.
- :func:`terminate_processes` signals the discovered processes (SIGINT
first, optionally SIGKILL after a grace period) and reports
survivors.
- :func:`wipe_state` removes the data-directory paths that the daemon
and monitor produce. Idempotent; safe against a fresh install.
The top-level :func:`reset_pipeline` orchestrates all three and
returns a :class:`ResetReport` summarizing what happened.
Safety
------
:func:`reset_pipeline` refuses to wipe state while *any* discovered
process is still alive after the termination phase. Wiping
``$data_dir`` out from under an active daemon would leave the daemon
writing into deleted directory entries a guaranteed mess. Callers
that hit a survivor must either raise the grace period, opt into
``force_kill=True`` (SIGKILL), or kill the survivor manually before
re-running.
"""
from __future__ import annotations
import contextlib
import logging
import os
import shutil
import signal
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Literal
import psutil
from neuropose.config import Settings
from neuropose.interfacer import LOCK_FILENAME
logger = logging.getLogger(__name__)
# Cmdline substrings that identify the daemon and monitor.
# Matched against the joined argv so ``uv run neuropose watch -v`` and
# ``python -m neuropose watch`` both hit. The substring choice
# deliberately includes the subcommand name so a generic
# ``neuropose --help`` shell doesn't get caught.
_DAEMON_MARKER = "neuropose watch"
_MONITOR_MARKER = "neuropose serve"
DEFAULT_GRACE_SECONDS = 10.0
"""How long :func:`terminate_processes` waits after SIGINT before
declaring a process a survivor (or escalating to SIGKILL when
``force_kill`` is set). Long enough for an idle daemon to finish its
current poll iteration; short enough that an interactive ``reset``
invocation doesn't feel hung. Override per-call when waiting on a
multi-minute inference."""
_POLL_INTERVAL_SECONDS = 0.2
ProcessRole = Literal["daemon", "monitor"]
@dataclass(frozen=True)
class RunningProcess:
"""A neuropose process discovered in the OS process table."""
pid: int
role: ProcessRole
cmdline: str
@dataclass
class TerminationReport:
"""Outcome of trying to stop a set of running processes."""
stopped: list[RunningProcess] = field(default_factory=list)
survivors: list[RunningProcess] = field(default_factory=list)
force_killed: list[RunningProcess] = field(default_factory=list)
@dataclass
class WipeReport:
"""Outcome of wiping data-directory state."""
removed_paths: list[Path] = field(default_factory=list)
bytes_freed: int = 0
@dataclass
class ResetReport:
"""Aggregate report from a full pipeline reset."""
discovered: list[RunningProcess]
termination: TerminationReport
wipe: WipeReport
dry_run: bool
wipe_skipped_due_to_survivors: bool = False
def find_neuropose_processes(*, exclude_self: bool = True) -> list[RunningProcess]:
"""Scan the process table for ``neuropose watch`` / ``neuropose serve``.
Parameters
----------
exclude_self
Skip the current process. The ``neuropose reset`` command
itself has ``"neuropose"`` in its argv and would otherwise see
itself in the result. Set to ``False`` only in tests where the
caller has constructed a process table that should match
verbatim.
Returns
-------
list[RunningProcess]
Processes whose joined argv contains either marker substring.
Empty list when nothing matches.
"""
self_pid = os.getpid()
found: list[RunningProcess] = []
for proc in psutil.process_iter(["pid", "cmdline"]):
try:
pid = proc.info["pid"]
cmdline = proc.info["cmdline"] or []
except (psutil.NoSuchProcess, psutil.AccessDenied):
continue
if exclude_self and pid == self_pid:
continue
joined = " ".join(cmdline)
# Daemon check first because "neuropose serve" and
# "neuropose watch" cannot both appear in a single process.
if _DAEMON_MARKER in joined:
found.append(RunningProcess(pid=pid, role="daemon", cmdline=joined))
elif _MONITOR_MARKER in joined:
found.append(RunningProcess(pid=pid, role="monitor", cmdline=joined))
return found
def terminate_processes(
processes: list[RunningProcess],
*,
grace_seconds: float = DEFAULT_GRACE_SECONDS,
force_kill: bool = False,
) -> TerminationReport:
"""Send SIGINT to each process; optionally escalate to SIGKILL.
Parameters
----------
processes
The processes to stop. Pass an empty list for a no-op.
grace_seconds
Maximum time to wait for processes to exit after SIGINT.
force_kill
When ``True``, any process still alive after ``grace_seconds``
is sent SIGKILL. When ``False``, survivors are reported back
to the caller untouched.
Notes
-----
SIGTERM is *not* used as an intermediate escalation step. The
interfacer's signal handler treats SIGINT and SIGTERM identically
(both call :meth:`Interfacer.stop`), so SIGTERM accomplishes
nothing that SIGINT did not already attempt. The only escalation
that actually forces a stuck daemon to exit is SIGKILL, which
bypasses the handler entirely.
"""
report = TerminationReport()
if not processes:
return report
for rp in processes:
with contextlib.suppress(ProcessLookupError, PermissionError):
os.kill(rp.pid, signal.SIGINT)
logger.info("sent SIGINT to pid %d (%s)", rp.pid, rp.role)
survivors = _wait_for_exit(processes, grace_seconds)
stopped = [p for p in processes if p not in survivors]
report.stopped.extend(stopped)
if not survivors:
return report
if not force_kill:
report.survivors.extend(survivors)
return report
for rp in survivors:
with contextlib.suppress(ProcessLookupError, PermissionError):
os.kill(rp.pid, signal.SIGKILL)
logger.warning("escalated to SIGKILL for pid %d (%s)", rp.pid, rp.role)
# SIGKILL is delivered synchronously enough that a short final
# poll is sufficient — any remaining "survivor" at this point is
# a permission error or a kernel-side hang, not graceful shutdown.
final_survivors = _wait_for_exit(survivors, grace_seconds=2.0)
killed = [p for p in survivors if p not in final_survivors]
report.force_killed.extend(killed)
report.survivors.extend(final_survivors)
return report
def _wait_for_exit(
processes: list[RunningProcess],
grace_seconds: float,
) -> list[RunningProcess]:
"""Poll until every process exits or the deadline passes."""
deadline = time.monotonic() + grace_seconds
while time.monotonic() < deadline:
survivors = [p for p in processes if _is_alive(p.pid)]
if not survivors:
return []
time.sleep(_POLL_INTERVAL_SECONDS)
return [p for p in processes if _is_alive(p.pid)]
def _is_alive(pid: int) -> bool:
"""Return ``True`` if ``pid`` is still running and not a zombie."""
try:
proc = psutil.Process(pid)
except psutil.NoSuchProcess:
return False
try:
return proc.status() != psutil.STATUS_ZOMBIE
except psutil.NoSuchProcess:
return False
def wipe_state(
settings: Settings,
*,
keep_failed: bool = False,
dry_run: bool = False,
) -> WipeReport:
"""Remove data-directory paths produced by the daemon and monitor.
Removes the *contents* of ``in/``, ``out/``, and (unless
``keep_failed`` is set) ``failed/``, plus the daemon lock file and
any leftover ``.ingest_<uuid>/`` staging directories. The
container directories themselves are preserved so the daemon does
not need to recreate them on next startup.
Parameters
----------
settings
Resolved :class:`~neuropose.config.Settings`. Determines all
target paths via the ``input_dir`` / ``output_dir`` /
``failed_dir`` properties.
keep_failed
Preserve ``$data_dir/failed/`` for forensic review of past
crashes. Default removes it along with the rest of the
pipeline state.
dry_run
Compute the report without actually deleting anything. Useful
for previewing the blast radius before confirming a reset.
"""
report = WipeReport()
targets: list[Path] = []
if settings.input_dir.exists():
targets.extend(settings.input_dir.iterdir())
if settings.output_dir.exists():
targets.extend(settings.output_dir.iterdir())
if not keep_failed and settings.failed_dir.exists():
targets.extend(settings.failed_dir.iterdir())
lock_path = settings.data_dir / LOCK_FILENAME
if lock_path.exists():
targets.append(lock_path)
if settings.data_dir.exists():
targets.extend(settings.data_dir.glob(".ingest_*"))
for target in targets:
size = _path_size(target)
if not dry_run:
_remove(target)
report.removed_paths.append(target)
report.bytes_freed += size
return report
def _path_size(path: Path) -> int:
"""Return the cumulative size of ``path``, recursing into directories."""
if path.is_symlink() or path.is_file():
try:
return path.stat().st_size
except OSError:
return 0
total = 0
for sub in path.rglob("*"):
try:
if sub.is_file() and not sub.is_symlink():
total += sub.stat().st_size
except OSError:
continue
return total
def _remove(path: Path) -> None:
"""Remove ``path`` whether file, symlink, or directory."""
if path.is_dir() and not path.is_symlink():
shutil.rmtree(path)
else:
path.unlink()
def reset_pipeline(
settings: Settings,
*,
grace_seconds: float = DEFAULT_GRACE_SECONDS,
force_kill: bool = False,
keep_failed: bool = False,
dry_run: bool = False,
) -> ResetReport:
"""Stop daemon + monitor, then wipe pipeline state.
Composes :func:`find_neuropose_processes`,
:func:`terminate_processes`, and :func:`wipe_state` into a single
operation. The wipe phase is *skipped* if any process survives
termination wiping ``$data_dir`` out from under an active
daemon would corrupt its in-flight writes. The returned
:class:`ResetReport` flags this case via
``wipe_skipped_due_to_survivors``.
Parameters
----------
settings
Resolved :class:`~neuropose.config.Settings`.
grace_seconds
Maximum time to wait for SIGINT to take effect.
force_kill
Escalate to SIGKILL on any process still alive after
``grace_seconds``. Necessary when the daemon is mid-inference
on a long video and you do not want to wait for the current
video to finish.
keep_failed
Preserve ``$data_dir/failed/`` during the wipe.
dry_run
Discover and report without killing anything or removing any
paths. Termination phase is skipped entirely.
"""
discovered = find_neuropose_processes()
if dry_run:
wipe = wipe_state(settings, keep_failed=keep_failed, dry_run=True)
return ResetReport(
discovered=discovered,
termination=TerminationReport(),
wipe=wipe,
dry_run=True,
)
termination = terminate_processes(
discovered,
grace_seconds=grace_seconds,
force_kill=force_kill,
)
if termination.survivors:
return ResetReport(
discovered=discovered,
termination=termination,
wipe=WipeReport(),
dry_run=False,
wipe_skipped_due_to_survivors=True,
)
wipe = wipe_state(settings, keep_failed=keep_failed, dry_run=False)
return ResetReport(
discovered=discovered,
termination=termination,
wipe=wipe,
dry_run=False,
)

View File

@ -0,0 +1,205 @@
"""Example-config sanity integration tests.
Every YAML config under ``examples/analysis/`` is exercised here in
two ways:
1. ``test_<name>_parses`` :func:`load_config` accepts the YAML,
i.e. the file matches the current :class:`AnalysisConfig` schema
including cross-field invariants. Catches silent drift between the
example configs and the schema they claim to exercise.
2. ``test_<name>_runs`` the example's pipeline runs end-to-end
against synthetic predictions. Paths in the YAML are overwritten
with test fixtures before :func:`run_analysis` is invoked;
everything else (stages, thresholds, triplets) is used verbatim.
These tests are deliberately not marked ``slow`` they use synthetic
fixtures and do not touch the MeTRAbs SavedModel, so they run in the
default unit-test suite.
"""
from __future__ import annotations
import math
from pathlib import Path
import numpy as np
import pytest
from neuropose.analyzer.pipeline import (
AnalysisConfig,
AnalysisReport,
DtwResults,
load_config,
run_analysis,
)
from neuropose.analyzer.segment import JOINT_INDEX
from neuropose.io import VideoPredictions, save_video_predictions
EXAMPLES_DIR = Path(__file__).resolve().parents[2] / "examples" / "analysis"
NUM_JOINTS = 43
def _sinusoid(num_cycles: int, frames_per_cycle: int, amplitude: float = 100.0) -> np.ndarray:
total = num_cycles * frames_per_cycle
t = np.linspace(0.0, num_cycles * 2.0 * math.pi, total, endpoint=False)
return (np.sin(t) * amplitude + amplitude).astype(float)
def _write_trial(
path: Path,
*,
num_cycles: int = 4,
frames_per_cycle: int = 30,
seed: int = 0,
) -> Path:
"""Write a synthetic VideoPredictions with every joint oscillating.
Joint ``0`` gets a reproducible RNG-driven trace so Procrustes has
something non-degenerate to align. All other joints get their own
phase-shifted sinusoid so joint-angle triplets and per-joint DTW
have signal to act on.
"""
rng = np.random.default_rng(seed)
base = _sinusoid(num_cycles, frames_per_cycle)
total = base.shape[0]
frames: dict[str, dict] = {}
for frame_idx in range(total):
poses = [[[0.0, 0.0, 0.0] for _ in range(NUM_JOINTS)]]
for j in range(NUM_JOINTS):
# Unique per-joint position so no triplet is degenerate.
phase = rng.uniform(0.0, 2.0 * math.pi)
amplitude = 30.0 + 10.0 * (j % 5)
offset = float(j) * 15.0
poses[0][j][0] = offset + amplitude * math.cos(
2.0 * math.pi * frame_idx / frames_per_cycle + phase
)
poses[0][j][1] = offset * 0.5 + base[frame_idx] + 5.0 * j
poses[0][j][2] = 3.0 * j
frames[f"frame_{frame_idx:06d}"] = {
"boxes": [[0.0, 0.0, 1.0, 1.0, 0.9]],
"poses3d": poses,
"poses2d": [[[0.0, 0.0]] * NUM_JOINTS],
}
preds = VideoPredictions.model_validate(
{
"metadata": {
"frame_count": total,
"fps": float(frames_per_cycle),
"width": 640,
"height": 480,
},
"frames": frames,
}
)
save_video_predictions(path, preds)
return path
@pytest.fixture
def example_fixtures(tmp_path: Path) -> tuple[Path, Path, Path]:
"""Return (primary_path, reference_path, report_path) under ``tmp_path``."""
primary = _write_trial(tmp_path / "primary.json", seed=1)
reference = _write_trial(tmp_path / "reference.json", seed=2)
report = tmp_path / "report.json"
return primary, reference, report
def _rewrite_paths(
example_path: Path, primary: Path, reference: Path, report: Path
) -> AnalysisConfig:
"""Load an example YAML and rewrite inputs/output paths to fixtures.
Tests run against synthetic predictions in ``tmp_path``; the
example YAML's hardcoded ``data/*.json`` paths would never resolve
otherwise.
"""
config = load_config(example_path)
update: dict = {
"inputs": config.inputs.model_copy(update={"primary": primary, "reference": reference}),
"output": config.output.model_copy(update={"report": report}),
}
return config.model_copy(update=update)
class TestMinimalExample:
def test_minimal_parses(self) -> None:
config = load_config(EXAMPLES_DIR / "minimal.yaml")
assert isinstance(config, AnalysisConfig)
assert config.analysis.kind == "dtw"
assert config.segmentation is None
def test_minimal_runs(self, example_fixtures: tuple[Path, Path, Path]) -> None:
primary, reference, report = example_fixtures
config = _rewrite_paths(EXAMPLES_DIR / "minimal.yaml", primary, reference, report)
result = run_analysis(config)
assert isinstance(result, AnalysisReport)
assert isinstance(result.results, DtwResults)
# Unsegmented → one distance.
assert result.results.segment_labels == ["full_trial"]
assert len(result.results.distances) == 1
class TestPaperCExample:
def test_paper_c_parses(self) -> None:
config = load_config(EXAMPLES_DIR / "paper_c_headline.yaml")
assert config.analysis.kind == "dtw"
assert config.segmentation is not None
assert config.segmentation.kind == "gait_cycles_bilateral"
# Joint triplets must be in range for berkeley_mhad_43.
assert config.analysis.kind == "dtw"
angle_triplets = config.analysis.angle_triplets # type: ignore[union-attr]
assert angle_triplets is not None
for a, b, c in angle_triplets:
for idx in (a, b, c):
assert 0 <= idx < NUM_JOINTS
def test_paper_c_runs(self, example_fixtures: tuple[Path, Path, Path]) -> None:
primary, reference, report = example_fixtures
config = _rewrite_paths(EXAMPLES_DIR / "paper_c_headline.yaml", primary, reference, report)
result = run_analysis(config)
assert isinstance(result.results, DtwResults)
# Bilateral segmentation → distances labelled per side.
assert any(lbl.startswith("left_heel_strikes") for lbl in result.results.segment_labels)
assert any(lbl.startswith("right_heel_strikes") for lbl in result.results.segment_labels)
def test_paper_c_uses_documented_knee_triplets(self) -> None:
"""The Paper C config must target knee-flexion joint triplets.
Safety net: if someone edits the YAML and breaks the joint
references, this test catches it before the example silently
starts computing the wrong angles.
"""
config = load_config(EXAMPLES_DIR / "paper_c_headline.yaml")
assert config.analysis.kind == "dtw"
triplets = config.analysis.angle_triplets # type: ignore[union-attr]
assert triplets is not None
# Left knee flex = hip → knee → ankle.
assert (
JOINT_INDEX["lhipb"],
JOINT_INDEX["lkne"],
JOINT_INDEX["lank"],
) in triplets or (
JOINT_INDEX["lhipf"],
JOINT_INDEX["lkne"],
JOINT_INDEX["lank"],
) in triplets
class TestPerJointDebugExample:
def test_per_joint_debug_parses(self) -> None:
config = load_config(EXAMPLES_DIR / "per_joint_debug.yaml")
assert config.analysis.kind == "dtw"
assert config.analysis.method == "dtw_per_joint" # type: ignore[union-attr]
def test_per_joint_debug_runs(self, example_fixtures: tuple[Path, Path, Path]) -> None:
primary, reference, report = example_fixtures
config = _rewrite_paths(EXAMPLES_DIR / "per_joint_debug.yaml", primary, reference, report)
result = run_analysis(config)
assert isinstance(result.results, DtwResults)
# dtw_per_joint → per_joint_distances populated.
assert result.results.per_joint_distances is not None
# Inner length must match num_joints for the coords
# representation.
for per_seg in result.results.per_joint_distances:
assert len(per_seg) == NUM_JOINTS

View File

@ -81,26 +81,29 @@ class TestMetrabsLoader:
"""Exercises the loader's download → verify → extract → load path."""
def test_download_and_load(self, shared_model_cache_dir: Path) -> None:
model = load_metrabs_model(cache_dir=shared_model_cache_dir)
assert model is not None
loaded = load_metrabs_model(cache_dir=shared_model_cache_dir)
assert loaded.model is not None
assert loaded.sha256
assert loaded.filename
for attr in ("detect_poses", "per_skeleton_joint_names", "per_skeleton_joint_edges"):
assert hasattr(model, attr), f"loaded model is missing {attr}"
assert hasattr(loaded.model, attr), f"loaded model is missing {attr}"
def test_second_call_uses_cache(self, shared_model_cache_dir: Path) -> None:
"""Idempotent: second call should return the cached model cheaply."""
model_a = load_metrabs_model(cache_dir=shared_model_cache_dir)
model_b = load_metrabs_model(cache_dir=shared_model_cache_dir)
loaded_a = load_metrabs_model(cache_dir=shared_model_cache_dir)
loaded_b = load_metrabs_model(cache_dir=shared_model_cache_dir)
# tf.saved_model.load returns a new Python object each call, so
# identity comparison doesn't work — but both should still
# expose the MeTRAbs interface.
assert hasattr(model_a, "detect_poses")
assert hasattr(model_b, "detect_poses")
# expose the MeTRAbs interface, and the SHA should match.
assert hasattr(loaded_a.model, "detect_poses")
assert hasattr(loaded_b.model, "detect_poses")
assert loaded_a.sha256 == loaded_b.sha256
def test_berkeley_mhad_skeleton_is_present(self, shared_model_cache_dir: Path) -> None:
"""The estimator pins skeleton='berkeley_mhad_43'; verify it exists."""
model = load_metrabs_model(cache_dir=shared_model_cache_dir)
joint_names = model.per_skeleton_joint_names["berkeley_mhad_43"]
joint_edges = model.per_skeleton_joint_edges["berkeley_mhad_43"]
loaded = load_metrabs_model(cache_dir=shared_model_cache_dir)
joint_names = loaded.model.per_skeleton_joint_names["berkeley_mhad_43"]
joint_edges = loaded.model.per_skeleton_joint_edges["berkeley_mhad_43"]
# MeTRAbs exposes these as tf.Tensor objects; just verify we
# can pull a shape out.
assert joint_names.shape[0] == 43

View File

@ -50,8 +50,8 @@ def test_joint_names_match_pinned_model(metrabs_model_cache_dir: Path) -> None:
commit that bumps the model pin in :mod:`neuropose._model`.
2. Cross-check any CLI or docs that embed hardcoded joint names.
"""
model = load_metrabs_model(cache_dir=metrabs_model_cache_dir)
tensor = model.per_skeleton_joint_names["berkeley_mhad_43"]
loaded = load_metrabs_model(cache_dir=metrabs_model_cache_dir)
tensor = loaded.model.per_skeleton_joint_names["berkeley_mhad_43"]
model_names = tuple(tensor.numpy().astype(str).tolist())
assert model_names == JOINT_NAMES, (
"JOINT_NAMES drift detected — the hardcoded tuple in "

View File

@ -131,3 +131,250 @@ class TestDtwRelation:
b = np.zeros((3, 2, 3))
with pytest.raises(ValueError, match="joint count"):
dtw_relation(a, b, joint_i=0, joint_j=1)
# ---------------------------------------------------------------------------
# representation="angles"
# ---------------------------------------------------------------------------
def _rotation_matrix_z(angle_rad: float) -> np.ndarray:
c, s = np.cos(angle_rad), np.sin(angle_rad)
return np.array(
[
[c, -s, 0.0],
[s, c, 0.0],
[0.0, 0.0, 1.0],
]
)
def _three_joint_arm(num_frames: int = 6) -> np.ndarray:
"""A three-joint arm opening from a right angle to straight.
Joints laid out as [shoulder, elbow, wrist], forming an angle at
the elbow that linearly opens from pi/2 to pi across ``num_frames``.
"""
sequence = np.zeros((num_frames, 3, 3))
angles = np.linspace(np.pi / 2, np.pi, num_frames)
for i, theta in enumerate(angles):
sequence[i, 0] = [-1.0, 0.0, 0.0] # shoulder
sequence[i, 1] = [0.0, 0.0, 0.0] # elbow
sequence[i, 2] = [np.cos(theta - np.pi), np.sin(theta - np.pi), 0.0] # wrist
return sequence
class TestDtwAllAngles:
def test_angles_identical_sequences_distance_zero(self) -> None:
seq = _three_joint_arm()
result = dtw_all(
seq,
seq,
representation="angles",
angle_triplets=[(0, 1, 2)],
)
assert result.distance == pytest.approx(0.0, abs=1e-9)
def test_angles_invariant_to_global_rotation(self) -> None:
"""Angle-space DTW must not change under a global rotation."""
seq = _three_joint_arm()
rotated = seq @ _rotation_matrix_z(np.deg2rad(40.0)).T
baseline = dtw_all(seq, seq, representation="angles", angle_triplets=[(0, 1, 2)])
under_rotation = dtw_all(
seq,
rotated,
representation="angles",
angle_triplets=[(0, 1, 2)],
)
assert baseline.distance == pytest.approx(under_rotation.distance, abs=1e-6)
def test_angles_translation_invariant(self) -> None:
seq = _three_joint_arm()
translated = seq + np.array([10.0, -5.0, 2.0])
result = dtw_all(
seq,
translated,
representation="angles",
angle_triplets=[(0, 1, 2)],
)
assert result.distance == pytest.approx(0.0, abs=1e-9)
def test_angles_detects_different_motion(self) -> None:
# A sequence whose angle is constant vs. one that opens.
constant = np.zeros((6, 3, 3))
constant[:, 0] = [-1.0, 0.0, 0.0]
constant[:, 1] = [0.0, 0.0, 0.0]
constant[:, 2] = [0.0, 1.0, 0.0] # right angle throughout
opening = _three_joint_arm()
result = dtw_all(
constant,
opening,
representation="angles",
angle_triplets=[(0, 1, 2)],
)
assert result.distance > 0.0
def test_angles_without_triplets_rejected(self) -> None:
seq = _three_joint_arm()
with pytest.raises(ValueError, match="angle_triplets"):
dtw_all(seq, seq, representation="angles")
class TestDtwPerJointAngles:
def test_returns_one_result_per_triplet(self) -> None:
seq = _three_joint_arm()
triplets = [(0, 1, 2), (0, 1, 2)] # duplicate triplet on purpose
results = dtw_per_joint(
seq,
seq,
representation="angles",
angle_triplets=triplets,
)
assert len(results) == 2
for result in results:
assert result.distance == pytest.approx(0.0, abs=1e-9)
def test_per_triplet_distinct_paths(self) -> None:
# Two triplets covering different angles; with different motion
# per triplet, the per-unit results should differ.
seq_a = np.zeros((5, 4, 3))
seq_b = np.zeros((5, 4, 3))
# joint 0: pivot, joint 1/2/3: arm endpoints
for i in range(5):
seq_a[i, 0] = [0.0, 0.0, 0.0]
seq_a[i, 1] = [1.0, 0.0, 0.0]
seq_a[i, 2] = [0.0, 1.0, 0.0]
seq_a[i, 3] = [0.0, 0.0, 1.0]
seq_b[i, 0] = [0.0, 0.0, 0.0]
seq_b[i, 1] = [1.0, 0.0, 0.0]
seq_b[i, 2] = [np.cos(i * 0.3), np.sin(i * 0.3), 0.0] # rotating
seq_b[i, 3] = [0.0, 0.0, 1.0]
results = dtw_per_joint(
seq_a,
seq_b,
representation="angles",
angle_triplets=[(1, 0, 2), (1, 0, 3)],
)
assert len(results) == 2
# First triplet tracks the rotation, second is stationary.
assert results[0].distance > 0.0
assert results[1].distance == pytest.approx(0.0, abs=1e-9)
# ---------------------------------------------------------------------------
# nan_policy
# ---------------------------------------------------------------------------
def _collinear_sequence(num_frames: int = 4) -> np.ndarray:
"""Three collinear joints — the angle at the middle joint is degenerate."""
seq = np.zeros((num_frames, 3, 3))
seq[:, 0] = [-1.0, 0.0, 0.0]
# Middle joint at (0,0,0); but because the outer joints are collinear
# through the origin, we need one joint overlapping with the middle
# to force a zero-length vector. Place joint 2 AT joint 1 to trigger
# the degenerate case in extract_joint_angles.
seq[:, 1] = [0.0, 0.0, 0.0]
seq[:, 2] = [0.0, 0.0, 0.0]
return seq
class TestNanPolicy:
def test_propagate_surfaces_error(self) -> None:
# Degenerate triplet produces NaN angles for every frame.
# With nan_policy="propagate" the NaN reaches fastdtw, which
# validates via numpy.asarray_chkfinite and raises ValueError —
# the intended behaviour ("make the problem visible").
seq = _collinear_sequence(num_frames=4)
other = _three_joint_arm(num_frames=4)
with pytest.raises(ValueError, match="infs or NaNs"):
dtw_all(
seq,
other,
representation="angles",
angle_triplets=[(0, 1, 2)],
nan_policy="propagate",
)
def test_interpolate_fills_isolated_nan(self) -> None:
# One bad frame in a 5-frame sequence — the other four are
# finite anchors to interpolate between.
good = _three_joint_arm(num_frames=5)
# Inject a degenerate middle frame.
good[2, 2] = good[2, 1] # force zero-length vector → NaN angle
# Reference is the same arm without injection.
reference = _three_joint_arm(num_frames=5)
result = dtw_all(
good,
reference,
representation="angles",
angle_triplets=[(0, 1, 2)],
nan_policy="interpolate",
)
assert not np.isnan(result.distance)
def test_interpolate_all_nan_column_rejected(self) -> None:
seq = _collinear_sequence(num_frames=5)
other = _three_joint_arm(num_frames=5)
with pytest.raises(ValueError, match="all values are NaN"):
dtw_all(
seq,
other,
representation="angles",
angle_triplets=[(0, 1, 2)],
nan_policy="interpolate",
)
def test_drop_removes_nan_frames(self) -> None:
good = _three_joint_arm(num_frames=6)
good[2, 2] = good[2, 1] # inject NaN at frame 2
good[4, 2] = good[4, 1] # inject NaN at frame 4
reference = _three_joint_arm(num_frames=6)
result = dtw_all(
good,
reference,
representation="angles",
angle_triplets=[(0, 1, 2)],
nan_policy="drop",
)
# The 4 remaining finite frames should align cleanly with
# their counterparts in the reference.
assert not np.isnan(result.distance)
def test_drop_empties_sequence_rejected(self) -> None:
seq = _collinear_sequence(num_frames=5)
other = _three_joint_arm(num_frames=5)
with pytest.raises(ValueError, match="every frame"):
dtw_all(
seq,
other,
representation="angles",
angle_triplets=[(0, 1, 2)],
nan_policy="drop",
)
# ---------------------------------------------------------------------------
# align + representation composition
# ---------------------------------------------------------------------------
class TestAlignWithAngles:
def test_procrustes_before_angles_is_no_op_on_invariant_representation(self) -> None:
"""Procrustes on angle-space DTW should be redundant but safe."""
seq = _three_joint_arm()
rotated = seq @ _rotation_matrix_z(np.deg2rad(20.0)).T
with_align = dtw_all(
seq,
rotated,
align="procrustes_per_sequence",
representation="angles",
angle_triplets=[(0, 1, 2)],
)
without_align = dtw_all(
seq,
rotated,
representation="angles",
angle_triplets=[(0, 1, 2)],
)
assert with_align.distance == pytest.approx(without_align.distance, abs=1e-6)

View File

@ -8,6 +8,7 @@ import numpy as np
import pytest
from neuropose.analyzer.features import (
AlignmentDiagnostics,
FeatureStatistics,
extract_feature_statistics,
extract_joint_angles,
@ -15,6 +16,7 @@ from neuropose.analyzer.features import (
normalize_pose_sequence,
pad_sequences,
predictions_to_numpy,
procrustes_align,
)
from neuropose.io import VideoPredictions
@ -297,3 +299,177 @@ class TestFindPeaks:
def test_rejects_2d_input(self) -> None:
with pytest.raises(ValueError, match="1D"):
find_peaks(np.zeros((5, 5)))
# ---------------------------------------------------------------------------
# procrustes_align
# ---------------------------------------------------------------------------
def _rotation_matrix_z(angle_rad: float) -> np.ndarray:
"""Rotation matrix about the Z axis."""
c, s = np.cos(angle_rad), np.sin(angle_rad)
return np.array(
[
[c, -s, 0.0],
[s, c, 0.0],
[0.0, 0.0, 1.0],
]
)
def _skeleton(num_joints: int = 8, seed: int = 0) -> np.ndarray:
"""A deterministic, non-degenerate single-frame skeleton."""
rng = np.random.default_rng(seed)
return rng.standard_normal((num_joints, 3))
class TestProcrustesAlignPerSequence:
def test_identical_sequences_yield_identity_transform(self) -> None:
sequence = _skeleton()[np.newaxis, :, :].repeat(3, axis=0) # (3, 8, 3)
aligned, target, diag = procrustes_align(sequence, sequence, mode="per_sequence")
np.testing.assert_allclose(aligned, sequence, atol=1e-10)
np.testing.assert_array_equal(target, sequence)
assert diag.mode == "per_sequence"
assert diag.rotation_deg == pytest.approx(0.0, abs=1e-6)
assert diag.translation == pytest.approx(0.0, abs=1e-9)
assert diag.scale == pytest.approx(1.0)
def test_recovers_known_rotation(self) -> None:
# Build a reference sequence; construct the source by rotating it
# about Z, then verify alignment returns the reference up to
# floating-point error.
rotation = _rotation_matrix_z(np.deg2rad(37.0))
reference = _skeleton(num_joints=10)[np.newaxis, :, :].repeat(4, axis=0)
source = reference @ rotation.T
aligned, _, diag = procrustes_align(source, reference, mode="per_sequence")
np.testing.assert_allclose(aligned, reference, atol=1e-8)
# The recovered rotation's magnitude should be the original 37°.
assert diag.rotation_deg == pytest.approx(37.0, abs=1e-4)
def test_recovers_known_translation(self) -> None:
reference = _skeleton()[np.newaxis, :, :].repeat(5, axis=0)
translation = np.array([10.0, -4.5, 2.25])
source = reference + translation
aligned, _, diag = procrustes_align(source, reference, mode="per_sequence")
np.testing.assert_allclose(aligned, reference, atol=1e-9)
# rotation_deg may be numerically tiny but not exactly 0.
assert diag.rotation_deg == pytest.approx(0.0, abs=1e-4)
assert diag.translation == pytest.approx(np.linalg.norm(translation), rel=1e-6)
def test_recovers_combined_rotation_and_translation(self) -> None:
rotation = _rotation_matrix_z(np.deg2rad(-12.0))
translation = np.array([1.0, 2.0, 3.0])
reference = _skeleton(num_joints=6)[np.newaxis, :, :].repeat(3, axis=0)
source = reference @ rotation.T + translation
aligned, _, diag = procrustes_align(source, reference, mode="per_sequence")
np.testing.assert_allclose(aligned, reference, atol=1e-8)
assert diag.rotation_deg == pytest.approx(12.0, abs=1e-4)
assert diag.translation == pytest.approx(np.linalg.norm(translation), rel=1e-4)
def test_scale_flag_recovers_known_scale(self) -> None:
reference = _skeleton()[np.newaxis, :, :].repeat(2, axis=0)
source = reference * 0.5
aligned, _, diag = procrustes_align(source, reference, mode="per_sequence", scale=True)
np.testing.assert_allclose(aligned, reference, atol=1e-8)
assert diag.scale == pytest.approx(2.0, rel=1e-6)
def test_scale_flag_off_leaves_scale_at_one(self) -> None:
reference = _skeleton()[np.newaxis, :, :].repeat(2, axis=0)
source = reference * 0.5
_, _, diag = procrustes_align(source, reference, mode="per_sequence", scale=False)
assert diag.scale == pytest.approx(1.0)
def test_rejects_mismatched_shapes(self) -> None:
a = np.zeros((4, 8, 3))
b = np.zeros((4, 7, 3))
with pytest.raises(ValueError, match="same shape"):
procrustes_align(a, b)
def test_rejects_wrong_trailing_axis(self) -> None:
a = np.zeros((4, 8, 2))
b = np.zeros((4, 8, 2))
with pytest.raises(ValueError, match="joints, 3"):
procrustes_align(a, b)
def test_rejects_unknown_mode(self) -> None:
a = np.zeros((2, 4, 3))
with pytest.raises(ValueError, match="unknown mode"):
procrustes_align(a, a, mode="nope") # type: ignore[arg-type]
def test_does_not_mutate_inputs(self) -> None:
source = _skeleton()[np.newaxis, :, :].repeat(3, axis=0).copy()
target = (source @ _rotation_matrix_z(np.deg2rad(10.0)).T).copy()
source_before = source.copy()
target_before = target.copy()
procrustes_align(source, target, mode="per_sequence")
np.testing.assert_array_equal(source, source_before)
np.testing.assert_array_equal(target, target_before)
def test_returns_alignment_diagnostics_dataclass(self) -> None:
a = _skeleton()[np.newaxis, :, :].repeat(2, axis=0)
_, _, diag = procrustes_align(a, a)
assert isinstance(diag, AlignmentDiagnostics)
class TestProcrustesAlignPerFrame:
def test_per_frame_recovers_varying_rotations(self) -> None:
# Each frame is rotated by a different angle; per_frame alignment
# should recover each frame independently.
num_frames = 4
reference_frame = _skeleton(num_joints=6)
angles = np.deg2rad([5.0, -10.0, 20.0, 45.0])
reference = np.stack([reference_frame for _ in range(num_frames)], axis=0)
source = np.stack([reference_frame @ _rotation_matrix_z(a).T for a in angles], axis=0)
aligned, _, diag = procrustes_align(source, reference, mode="per_frame")
np.testing.assert_allclose(aligned, reference, atol=1e-8)
assert diag.mode == "per_frame"
# The max rotation across frames should be 45°.
assert diag.rotation_deg_max == pytest.approx(45.0, abs=1e-4)
# The mean rotation across frames should be 20°.
assert diag.rotation_deg == pytest.approx(20.0, abs=1e-4)
def test_per_frame_with_identical_sequences_yields_zero(self) -> None:
sequence = _skeleton(num_joints=5)[np.newaxis, :, :].repeat(3, axis=0)
aligned, _, diag = procrustes_align(sequence, sequence, mode="per_frame")
np.testing.assert_allclose(aligned, sequence, atol=1e-10)
# Per-frame SVD on a symmetric covariance is numerically ambiguous
# in axis selection, so the fitted rotation can be a few micro-
# degrees off zero; the residual positions are still exact.
assert diag.rotation_deg == pytest.approx(0.0, abs=1e-3)
assert diag.rotation_deg_max == pytest.approx(0.0, abs=1e-3)
assert diag.translation == pytest.approx(0.0, abs=1e-9)
# ---------------------------------------------------------------------------
# DTW with align= (integration)
# ---------------------------------------------------------------------------
class TestDtwAlignIntegration:
"""Smoke tests: align= routes through procrustes_align correctly.
Depth tests of the DTW path itself live in test_analyzer_dtw.
"""
def test_dtw_all_with_alignment_cancels_rigid_offset(self) -> None:
pytest.importorskip("fastdtw")
from neuropose.analyzer.dtw import dtw_all
rotation = _rotation_matrix_z(np.deg2rad(30.0))
translation = np.array([5.0, -2.0, 1.0])
reference = _skeleton(num_joints=6)[np.newaxis, :, :].repeat(4, axis=0)
source = reference @ rotation.T + translation
baseline = dtw_all(source, reference, align="none")
aligned_result = dtw_all(source, reference, align="procrustes_per_sequence")
assert baseline.distance > 0.0
assert aligned_result.distance == pytest.approx(0.0, abs=1e-6)
def test_dtw_align_rejects_mismatched_frame_counts(self) -> None:
pytest.importorskip("fastdtw")
from neuropose.analyzer.dtw import dtw_all
a = np.zeros((5, 3, 3))
b = np.zeros((6, 3, 3))
with pytest.raises(ValueError, match="matching frame counts"):
dtw_all(a, b, align="procrustes_per_sequence")

View File

@ -0,0 +1,881 @@
"""Tests for :mod:`neuropose.analyzer.pipeline`.
Covers both halves of the pipeline:
- **Schemas** :class:`AnalysisConfig` parsing (discriminated unions
for segmentation and analysis stages, cross-field invariants),
:class:`AnalysisReport` construction + JSON round-trip (including
the migration hook on ``schema_version``), and
:func:`analysis_config_to_dict` JSON-safety.
- **Executor** :func:`run_analysis` dispatches to each analysis kind
(dtw / stats / none) with and without segmentation; provenance is
inherited from the primary input with ``analysis_config``
populated; :func:`load_config`, :func:`save_report`, and
:func:`load_report` round-trip.
"""
from __future__ import annotations
import json
from pathlib import Path
from typing import Any
import numpy as np
import pytest
import yaml
from pydantic import ValidationError
from neuropose.analyzer.pipeline import (
AnalysisConfig,
AnalysisReport,
DtwAnalysis,
DtwResults,
ExtractorSegmentation,
FeatureSummary,
GaitCyclesBilateralSegmentation,
GaitCyclesSegmentation,
InputsConfig,
InputSummary,
NoAnalysis,
NoResults,
OutputConfig,
PreprocessingConfig,
StatsAnalysis,
StatsResults,
analysis_config_to_dict,
load_config,
load_report,
run_analysis,
save_report,
)
from neuropose.io import Provenance, VideoPredictions, save_video_predictions
from neuropose.migrations import CURRENT_VERSION
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _minimal_dtw_config(tmp_path: Path) -> dict[str, Any]:
"""A minimal AnalysisConfig dict with dtw_all + reference."""
return {
"inputs": {
"primary": str(tmp_path / "primary.json"),
"reference": str(tmp_path / "reference.json"),
},
"analysis": {"kind": "dtw", "method": "dtw_all"},
"output": {"report": str(tmp_path / "report.json")},
}
def _minimal_stats_config(tmp_path: Path) -> dict[str, Any]:
return {
"inputs": {"primary": str(tmp_path / "primary.json")},
"analysis": {
"kind": "stats",
"extractor": {"kind": "joint_axis", "joint": 32, "axis": 1, "invert": False},
},
"output": {"report": str(tmp_path / "report.json")},
}
# ---------------------------------------------------------------------------
# InputsConfig / PreprocessingConfig / OutputConfig
# ---------------------------------------------------------------------------
class TestInputsConfig:
def test_primary_only(self, tmp_path: Path) -> None:
cfg = InputsConfig(primary=tmp_path / "a.json")
assert cfg.reference is None
def test_primary_and_reference(self, tmp_path: Path) -> None:
cfg = InputsConfig(
primary=tmp_path / "a.json",
reference=tmp_path / "b.json",
)
assert cfg.reference == tmp_path / "b.json"
def test_extra_field_rejected(self, tmp_path: Path) -> None:
with pytest.raises(ValidationError, match="Extra inputs"):
InputsConfig.model_validate(
{
"primary": str(tmp_path / "a.json"),
"extra": "field",
}
)
def test_frozen(self, tmp_path: Path) -> None:
cfg = InputsConfig(primary=tmp_path / "a.json")
with pytest.raises(ValidationError, match="frozen"):
cfg.primary = tmp_path / "b.json" # type: ignore[misc]
class TestPreprocessingConfig:
def test_default_person_index_zero(self) -> None:
cfg = PreprocessingConfig()
assert cfg.person_index == 0
def test_negative_person_index_rejected(self) -> None:
with pytest.raises(ValidationError):
PreprocessingConfig(person_index=-1)
class TestOutputConfig:
def test_report_path(self, tmp_path: Path) -> None:
cfg = OutputConfig(report=tmp_path / "out.json")
assert cfg.report == tmp_path / "out.json"
# ---------------------------------------------------------------------------
# Segmentation stage discriminated union
# ---------------------------------------------------------------------------
class TestSegmentationStage:
def test_gait_cycles_parses_from_dict(self, tmp_path: Path) -> None:
config_dict = _minimal_dtw_config(tmp_path)
config_dict["segmentation"] = {
"kind": "gait_cycles",
"joint": "lhee",
"axis": "y",
"min_cycle_seconds": 0.5,
}
cfg = AnalysisConfig.model_validate(config_dict)
assert isinstance(cfg.segmentation, GaitCyclesSegmentation)
assert cfg.segmentation.joint == "lhee"
assert cfg.segmentation.min_cycle_seconds == 0.5
def test_bilateral_parses_from_dict(self, tmp_path: Path) -> None:
config_dict = _minimal_dtw_config(tmp_path)
config_dict["segmentation"] = {"kind": "gait_cycles_bilateral"}
cfg = AnalysisConfig.model_validate(config_dict)
assert isinstance(cfg.segmentation, GaitCyclesBilateralSegmentation)
def test_extractor_parses_from_dict(self, tmp_path: Path) -> None:
config_dict = _minimal_dtw_config(tmp_path)
config_dict["segmentation"] = {
"kind": "extractor",
"extractor": {
"kind": "joint_axis",
"joint": 15,
"axis": 1,
"invert": False,
},
"label": "wrist_cycles",
"min_distance_seconds": 0.5,
}
cfg = AnalysisConfig.model_validate(config_dict)
assert isinstance(cfg.segmentation, ExtractorSegmentation)
assert cfg.segmentation.label == "wrist_cycles"
def test_unknown_kind_rejected(self, tmp_path: Path) -> None:
config_dict = _minimal_dtw_config(tmp_path)
config_dict["segmentation"] = {"kind": "unknown_method"}
with pytest.raises(ValidationError):
AnalysisConfig.model_validate(config_dict)
def test_segmentation_omitted_is_none(self, tmp_path: Path) -> None:
cfg = AnalysisConfig.model_validate(_minimal_dtw_config(tmp_path))
assert cfg.segmentation is None
def test_invalid_min_cycle_seconds_rejected(self, tmp_path: Path) -> None:
config_dict = _minimal_dtw_config(tmp_path)
config_dict["segmentation"] = {
"kind": "gait_cycles",
"min_cycle_seconds": 0.0, # must be > 0
}
with pytest.raises(ValidationError):
AnalysisConfig.model_validate(config_dict)
# ---------------------------------------------------------------------------
# Analysis stage discriminated union + cross-field invariants
# ---------------------------------------------------------------------------
class TestDtwAnalysisValidation:
def test_dtw_relation_requires_joints(self) -> None:
with pytest.raises(ValidationError, match="joint_i and joint_j"):
DtwAnalysis(kind="dtw", method="dtw_relation")
def test_dtw_relation_rejects_angles_representation(self) -> None:
with pytest.raises(ValidationError, match="only supports representation='coords'"):
DtwAnalysis(
kind="dtw",
method="dtw_relation",
joint_i=0,
joint_j=1,
representation="angles",
angle_triplets=[(0, 1, 2)],
)
def test_angles_requires_triplets(self) -> None:
with pytest.raises(ValidationError, match="angle_triplets"):
DtwAnalysis(
kind="dtw",
method="dtw_all",
representation="angles",
)
def test_angles_with_empty_triplets_rejected(self) -> None:
with pytest.raises(ValidationError, match="angle_triplets"):
DtwAnalysis(
kind="dtw",
method="dtw_all",
representation="angles",
angle_triplets=[],
)
def test_happy_path_dtw_all_coords(self) -> None:
analysis = DtwAnalysis(
kind="dtw",
method="dtw_all",
align="procrustes_per_sequence",
nan_policy="interpolate",
)
assert analysis.align == "procrustes_per_sequence"
def test_happy_path_dtw_all_angles(self) -> None:
analysis = DtwAnalysis(
kind="dtw",
method="dtw_all",
representation="angles",
angle_triplets=[(0, 1, 2), (3, 4, 5)],
)
assert len(analysis.angle_triplets or []) == 2
def test_happy_path_dtw_relation(self) -> None:
analysis = DtwAnalysis(
kind="dtw",
method="dtw_relation",
joint_i=15,
joint_j=23,
)
assert analysis.joint_i == 15
class TestAnalysisCrossStage:
def test_dtw_without_reference_rejected(self, tmp_path: Path) -> None:
config_dict = {
"inputs": {"primary": str(tmp_path / "a.json")},
"analysis": {"kind": "dtw", "method": "dtw_all"},
"output": {"report": str(tmp_path / "out.json")},
}
with pytest.raises(ValidationError, match=r"inputs\.reference"):
AnalysisConfig.model_validate(config_dict)
def test_stats_with_reference_rejected(self, tmp_path: Path) -> None:
config_dict = _minimal_stats_config(tmp_path)
config_dict["inputs"]["reference"] = str(tmp_path / "b.json")
with pytest.raises(ValidationError, match="primary only"):
AnalysisConfig.model_validate(config_dict)
def test_none_analysis_requires_no_reference(self, tmp_path: Path) -> None:
# NoAnalysis is fine with either reference present or absent.
config_dict = {
"inputs": {"primary": str(tmp_path / "a.json")},
"analysis": {"kind": "none"},
"output": {"report": str(tmp_path / "out.json")},
}
cfg = AnalysisConfig.model_validate(config_dict)
assert isinstance(cfg.analysis, NoAnalysis)
# ---------------------------------------------------------------------------
# Top-level AnalysisConfig
# ---------------------------------------------------------------------------
class TestAnalysisConfig:
def test_minimal_dtw_config_parses(self, tmp_path: Path) -> None:
cfg = AnalysisConfig.model_validate(_minimal_dtw_config(tmp_path))
assert cfg.config_version == 1
assert isinstance(cfg.analysis, DtwAnalysis)
assert cfg.preprocessing.person_index == 0 # default
def test_minimal_stats_config_parses(self, tmp_path: Path) -> None:
cfg = AnalysisConfig.model_validate(_minimal_stats_config(tmp_path))
assert isinstance(cfg.analysis, StatsAnalysis)
def test_config_version_must_be_1(self, tmp_path: Path) -> None:
config_dict = _minimal_dtw_config(tmp_path)
config_dict["config_version"] = 99
with pytest.raises(ValidationError):
AnalysisConfig.model_validate(config_dict)
def test_round_trip_json(self, tmp_path: Path) -> None:
original = AnalysisConfig.model_validate(_minimal_dtw_config(tmp_path))
serialised = original.model_dump_json()
restored = AnalysisConfig.model_validate_json(serialised)
assert restored == original
def test_extra_top_level_field_rejected(self, tmp_path: Path) -> None:
config_dict = _minimal_dtw_config(tmp_path)
config_dict["unknown_key"] = "typo"
with pytest.raises(ValidationError):
AnalysisConfig.model_validate(config_dict)
# ---------------------------------------------------------------------------
# analysis_config_to_dict
# ---------------------------------------------------------------------------
class TestAnalysisConfigToDict:
def test_returns_json_safe_dict(self, tmp_path: Path) -> None:
cfg = AnalysisConfig.model_validate(_minimal_dtw_config(tmp_path))
dumped = analysis_config_to_dict(cfg)
# Paths must have become strings.
assert isinstance(dumped["inputs"]["primary"], str)
assert isinstance(dumped["output"]["report"], str)
def test_round_trips_through_dict(self, tmp_path: Path) -> None:
original = AnalysisConfig.model_validate(_minimal_dtw_config(tmp_path))
dumped = analysis_config_to_dict(original)
restored = AnalysisConfig.model_validate(dumped)
assert restored == original
# ---------------------------------------------------------------------------
# Result sub-schemas
# ---------------------------------------------------------------------------
class TestDtwResults:
def test_minimal_construction(self) -> None:
res = DtwResults(
kind="dtw",
method="dtw_all",
distances=[0.5],
paths=[[(0, 0), (1, 1)]],
segment_labels=["full_trial"],
summary={"mean": 0.5},
)
assert res.kind == "dtw"
def test_per_joint_distances_shape_is_free(self) -> None:
# No validator enforces that per_joint_distances outer length
# matches distances — that's run-time semantics of the
# executor. Still, verify the field round-trips.
res = DtwResults(
kind="dtw",
method="dtw_per_joint",
distances=[0.1, 0.2],
paths=[[(0, 0)], [(0, 0)]],
per_joint_distances=[[0.05, 0.05], [0.1, 0.1]],
segment_labels=["rhee_cycles[0]", "rhee_cycles[1]"],
summary={"mean": 0.15},
)
assert res.per_joint_distances is not None
class TestStatsResults:
def test_round_trip(self) -> None:
res = StatsResults(
kind="stats",
statistics=[
FeatureSummary(mean=1.0, std=0.1, min=0.8, max=1.2, range=0.4),
FeatureSummary(mean=1.1, std=0.2, min=0.7, max=1.5, range=0.8),
],
segment_labels=["rhee_cycles[0]", "rhee_cycles[1]"],
)
dumped = res.model_dump_json()
restored = StatsResults.model_validate_json(dumped)
assert restored == res
class TestNoResults:
def test_construction(self) -> None:
res = NoResults(kind="none")
assert res.kind == "none"
# ---------------------------------------------------------------------------
# AnalysisReport
# ---------------------------------------------------------------------------
def _make_report(tmp_path: Path) -> AnalysisReport:
config = AnalysisConfig.model_validate(_minimal_dtw_config(tmp_path))
return AnalysisReport(
config=config,
primary=InputSummary(
path=tmp_path / "primary.json",
frame_count=300,
fps=30.0,
),
reference=InputSummary(
path=tmp_path / "reference.json",
frame_count=300,
fps=30.0,
),
results=DtwResults(
kind="dtw",
method="dtw_all",
distances=[0.42],
paths=[[(0, 0), (1, 1)]],
segment_labels=["full_trial"],
summary={"mean": 0.42, "p50": 0.42},
),
)
class TestAnalysisReport:
def test_schema_version_defaults_to_current(self, tmp_path: Path) -> None:
report = _make_report(tmp_path)
assert report.schema_version == CURRENT_VERSION
def test_round_trip_json(self, tmp_path: Path) -> None:
report = _make_report(tmp_path)
serialised = report.model_dump_json()
restored = AnalysisReport.model_validate_json(serialised)
assert restored == report
def test_empty_segmentations_default(self, tmp_path: Path) -> None:
report = _make_report(tmp_path)
assert report.segmentations == {}
def test_extra_field_rejected(self, tmp_path: Path) -> None:
report = _make_report(tmp_path)
dumped = report.model_dump(mode="json")
dumped["mystery_field"] = 1
with pytest.raises(ValidationError):
AnalysisReport.model_validate(dumped)
# ---------------------------------------------------------------------------
# Executor helpers
# ---------------------------------------------------------------------------
NUM_JOINTS = 43
def _heel_signal(num_cycles: int, frames_per_cycle: int, amplitude: float = 100.0) -> np.ndarray:
"""Clean sinusoid stand-in for a heel's vertical trace."""
import math
total = num_cycles * frames_per_cycle
t = np.linspace(0.0, num_cycles * 2.0 * math.pi, total, endpoint=False)
return (np.sin(t) * amplitude + amplitude).astype(float)
def _build_predictions(
signal: np.ndarray,
joint: int,
*,
axis: int = 1,
fps: float = 30.0,
provenance: Provenance | None = None,
) -> VideoPredictions:
"""Build a VideoPredictions whose ``joint``'s ``axis`` follows ``signal``."""
frames = {}
for i, value in enumerate(signal):
poses = [[[0.0, 0.0, 0.0] for _ in range(NUM_JOINTS)]]
poses[0][joint][axis] = float(value)
frames[f"frame_{i:06d}"] = {
"boxes": [[0.0, 0.0, 1.0, 1.0, 0.9]],
"poses3d": poses,
"poses2d": [[[0.0, 0.0]] * NUM_JOINTS],
}
return VideoPredictions.model_validate(
{
"metadata": {
"frame_count": len(signal),
"fps": fps,
"width": 640,
"height": 480,
},
"frames": frames,
"provenance": provenance.model_dump() if provenance is not None else None,
}
)
def _write_heel_trial(
tmp_path: Path,
filename: str,
*,
joint: int,
num_cycles: int = 4,
frames_per_cycle: int = 30,
amplitude: float = 100.0,
provenance: Provenance | None = None,
) -> Path:
"""Write a heel-trace VideoPredictions JSON and return its path."""
signal = _heel_signal(num_cycles, frames_per_cycle, amplitude=amplitude)
preds = _build_predictions(signal, joint=joint, provenance=provenance)
path = tmp_path / filename
save_video_predictions(path, preds)
return path
def _fake_provenance(sha: str = "a" * 64) -> Provenance:
return Provenance(
model_sha256=sha,
model_filename="fake_model.tar.gz",
tensorflow_version="2.18.0",
numpy_version="1.26.0",
neuropose_version="0.0.0",
python_version="3.11.0",
)
# ---------------------------------------------------------------------------
# Executor: run_analysis
# ---------------------------------------------------------------------------
class TestRunAnalysisDtwFullTrial:
def test_dtw_all_unsegmented_yields_one_distance(self, tmp_path: Path) -> None:
from neuropose.analyzer.segment import JOINT_INDEX
primary = _write_heel_trial(tmp_path, "a.json", joint=JOINT_INDEX["rhee"])
reference = _write_heel_trial(tmp_path, "b.json", joint=JOINT_INDEX["rhee"])
report_path = tmp_path / "report.json"
config = AnalysisConfig.model_validate(
{
"inputs": {"primary": str(primary), "reference": str(reference)},
"analysis": {"kind": "dtw", "method": "dtw_all"},
"output": {"report": str(report_path)},
}
)
report = run_analysis(config)
assert isinstance(report.results, DtwResults)
assert report.results.segment_labels == ["full_trial"]
assert len(report.results.distances) == 1
# Identical inputs → distance 0.
assert report.results.distances[0] == pytest.approx(0.0, abs=1e-9)
def test_dtw_all_different_trials_positive_distance(self, tmp_path: Path) -> None:
from neuropose.analyzer.segment import JOINT_INDEX
primary = _write_heel_trial(tmp_path, "a.json", joint=JOINT_INDEX["rhee"], amplitude=100.0)
reference = _write_heel_trial(
tmp_path, "b.json", joint=JOINT_INDEX["rhee"], amplitude=200.0
)
config = AnalysisConfig.model_validate(
{
"inputs": {"primary": str(primary), "reference": str(reference)},
"analysis": {"kind": "dtw", "method": "dtw_all"},
"output": {"report": str(tmp_path / "r.json")},
}
)
report = run_analysis(config)
assert isinstance(report.results, DtwResults)
assert report.results.distances[0] > 0.0
assert "mean" in report.results.summary
class TestRunAnalysisDtwSegmented:
def test_dtw_with_gait_cycles_produces_per_segment_distances(self, tmp_path: Path) -> None:
from neuropose.analyzer.segment import JOINT_INDEX
primary = _write_heel_trial(tmp_path, "a.json", joint=JOINT_INDEX["rhee"], num_cycles=4)
reference = _write_heel_trial(tmp_path, "b.json", joint=JOINT_INDEX["rhee"], num_cycles=4)
config = AnalysisConfig.model_validate(
{
"inputs": {"primary": str(primary), "reference": str(reference)},
"segmentation": {"kind": "gait_cycles", "joint": "rhee"},
"analysis": {"kind": "dtw", "method": "dtw_all"},
"output": {"report": str(tmp_path / "r.json")},
}
)
report = run_analysis(config)
assert isinstance(report.results, DtwResults)
# 4 cycles detected on both → 4 paired distances.
assert len(report.results.distances) == 4
assert all(label.startswith("rhee_cycles[") for label in report.results.segment_labels)
def test_dtw_bilateral_produces_distances_per_side(self, tmp_path: Path) -> None:
from neuropose.analyzer.segment import JOINT_INDEX
# Put the heel trace on both lhee and rhee.
rng_signal = _heel_signal(num_cycles=3, frames_per_cycle=30)
frames = {}
for i, value in enumerate(rng_signal):
poses = [[[0.0, 0.0, 0.0] for _ in range(NUM_JOINTS)]]
poses[0][JOINT_INDEX["lhee"]][1] = float(value)
poses[0][JOINT_INDEX["rhee"]][1] = float(value)
frames[f"frame_{i:06d}"] = {
"boxes": [[0.0, 0.0, 1.0, 1.0, 0.9]],
"poses3d": poses,
"poses2d": [[[0.0, 0.0]] * NUM_JOINTS],
}
preds = VideoPredictions.model_validate(
{
"metadata": {
"frame_count": len(rng_signal),
"fps": 30.0,
"width": 640,
"height": 480,
},
"frames": frames,
}
)
primary = tmp_path / "a.json"
reference = tmp_path / "b.json"
save_video_predictions(primary, preds)
save_video_predictions(reference, preds)
config = AnalysisConfig.model_validate(
{
"inputs": {"primary": str(primary), "reference": str(reference)},
"segmentation": {"kind": "gait_cycles_bilateral"},
"analysis": {"kind": "dtw", "method": "dtw_all"},
"output": {"report": str(tmp_path / "r.json")},
}
)
report = run_analysis(config)
assert isinstance(report.results, DtwResults)
# 3 cycles * 2 sides.
assert len(report.results.distances) == 6
left = [lbl for lbl in report.results.segment_labels if lbl.startswith("left_heel_strikes")]
right = [
lbl for lbl in report.results.segment_labels if lbl.startswith("right_heel_strikes")
]
assert len(left) == 3
assert len(right) == 3
# Identical primary and reference → all distances zero.
for d in report.results.distances:
assert d == pytest.approx(0.0, abs=1e-9)
def test_dtw_per_joint_populates_per_joint_distances(self, tmp_path: Path) -> None:
from neuropose.analyzer.segment import JOINT_INDEX
primary = _write_heel_trial(tmp_path, "a.json", joint=JOINT_INDEX["rhee"])
reference = _write_heel_trial(tmp_path, "b.json", joint=JOINT_INDEX["rhee"])
config = AnalysisConfig.model_validate(
{
"inputs": {"primary": str(primary), "reference": str(reference)},
"analysis": {"kind": "dtw", "method": "dtw_per_joint"},
"output": {"report": str(tmp_path / "r.json")},
}
)
report = run_analysis(config)
assert isinstance(report.results, DtwResults)
assert report.results.per_joint_distances is not None
assert len(report.results.per_joint_distances) == 1 # unsegmented → one pair
assert len(report.results.per_joint_distances[0]) == NUM_JOINTS
class TestRunAnalysisStats:
def test_stats_unsegmented_single_block(self, tmp_path: Path) -> None:
from neuropose.analyzer.segment import JOINT_INDEX
primary = _write_heel_trial(tmp_path, "a.json", joint=JOINT_INDEX["rhee"])
config = AnalysisConfig.model_validate(
{
"inputs": {"primary": str(primary)},
"analysis": {
"kind": "stats",
"extractor": {
"kind": "joint_axis",
"joint": JOINT_INDEX["rhee"],
"axis": 1,
"invert": False,
},
},
"output": {"report": str(tmp_path / "r.json")},
}
)
report = run_analysis(config)
assert isinstance(report.results, StatsResults)
assert report.results.segment_labels == ["full_trial"]
assert len(report.results.statistics) == 1
stat = report.results.statistics[0]
assert isinstance(stat, FeatureSummary)
assert stat.max > stat.min # Signal oscillates.
def test_stats_with_segmentation_emits_per_segment(self, tmp_path: Path) -> None:
from neuropose.analyzer.segment import JOINT_INDEX
primary = _write_heel_trial(tmp_path, "a.json", joint=JOINT_INDEX["rhee"], num_cycles=3)
config = AnalysisConfig.model_validate(
{
"inputs": {"primary": str(primary)},
"segmentation": {"kind": "gait_cycles", "joint": "rhee"},
"analysis": {
"kind": "stats",
"extractor": {
"kind": "joint_axis",
"joint": JOINT_INDEX["rhee"],
"axis": 1,
"invert": False,
},
},
"output": {"report": str(tmp_path / "r.json")},
}
)
report = run_analysis(config)
assert isinstance(report.results, StatsResults)
assert len(report.results.statistics) == 3
class TestRunAnalysisNone:
def test_none_analysis_returns_no_results(self, tmp_path: Path) -> None:
from neuropose.analyzer.segment import JOINT_INDEX
primary = _write_heel_trial(tmp_path, "a.json", joint=JOINT_INDEX["rhee"])
config = AnalysisConfig.model_validate(
{
"inputs": {"primary": str(primary)},
"analysis": {"kind": "none"},
"output": {"report": str(tmp_path / "r.json")},
}
)
report = run_analysis(config)
assert isinstance(report.results, NoResults)
def test_none_with_segmentation_still_emits_segmentations(self, tmp_path: Path) -> None:
from neuropose.analyzer.segment import JOINT_INDEX
primary = _write_heel_trial(tmp_path, "a.json", joint=JOINT_INDEX["rhee"])
config = AnalysisConfig.model_validate(
{
"inputs": {"primary": str(primary)},
"segmentation": {"kind": "gait_cycles", "joint": "rhee"},
"analysis": {"kind": "none"},
"output": {"report": str(tmp_path / "r.json")},
}
)
report = run_analysis(config)
assert isinstance(report.results, NoResults)
assert "rhee_cycles" in report.segmentations
assert len(report.segmentations["rhee_cycles"].segments) > 0
# ---------------------------------------------------------------------------
# Provenance inheritance
# ---------------------------------------------------------------------------
class TestRunAnalysisProvenance:
def test_inherits_primary_provenance_and_stamps_config(self, tmp_path: Path) -> None:
from neuropose.analyzer.segment import JOINT_INDEX
provenance = _fake_provenance()
primary = _write_heel_trial(
tmp_path, "a.json", joint=JOINT_INDEX["rhee"], provenance=provenance
)
reference = _write_heel_trial(tmp_path, "b.json", joint=JOINT_INDEX["rhee"])
config = AnalysisConfig.model_validate(
{
"inputs": {"primary": str(primary), "reference": str(reference)},
"analysis": {"kind": "dtw", "method": "dtw_all"},
"output": {"report": str(tmp_path / "r.json")},
}
)
report = run_analysis(config)
assert report.provenance is not None
# Model SHA inherited from primary.
assert report.provenance.model_sha256 == provenance.model_sha256
# analysis_config populated with the serialised config.
assert report.provenance.analysis_config is not None
assert report.provenance.analysis_config["config_version"] == 1
def test_no_primary_provenance_yields_none_report_provenance(self, tmp_path: Path) -> None:
from neuropose.analyzer.segment import JOINT_INDEX
primary = _write_heel_trial(tmp_path, "a.json", joint=JOINT_INDEX["rhee"])
reference = _write_heel_trial(tmp_path, "b.json", joint=JOINT_INDEX["rhee"])
config = AnalysisConfig.model_validate(
{
"inputs": {"primary": str(primary), "reference": str(reference)},
"analysis": {"kind": "dtw", "method": "dtw_all"},
"output": {"report": str(tmp_path / "r.json")},
}
)
report = run_analysis(config)
assert report.provenance is None
def test_input_summaries_track_paths_and_metadata(self, tmp_path: Path) -> None:
from neuropose.analyzer.segment import JOINT_INDEX
primary = _write_heel_trial(tmp_path, "a.json", joint=JOINT_INDEX["rhee"], num_cycles=5)
reference = _write_heel_trial(tmp_path, "b.json", joint=JOINT_INDEX["rhee"], num_cycles=3)
config = AnalysisConfig.model_validate(
{
"inputs": {"primary": str(primary), "reference": str(reference)},
"analysis": {"kind": "dtw", "method": "dtw_all"},
"output": {"report": str(tmp_path / "r.json")},
}
)
report = run_analysis(config)
assert report.primary.path == primary
assert report.primary.frame_count == 5 * 30
assert report.reference is not None
assert report.reference.frame_count == 3 * 30
# ---------------------------------------------------------------------------
# load_config / save_report / load_report
# ---------------------------------------------------------------------------
class TestLoadSave:
def test_load_config_parses_yaml(self, tmp_path: Path) -> None:
config_dict = {
"inputs": {
"primary": str(tmp_path / "a.json"),
"reference": str(tmp_path / "b.json"),
},
"analysis": {"kind": "dtw", "method": "dtw_all"},
"output": {"report": str(tmp_path / "r.json")},
}
yaml_path = tmp_path / "exp.yaml"
yaml_path.write_text(yaml.safe_dump(config_dict))
loaded = load_config(yaml_path)
assert isinstance(loaded.analysis, DtwAnalysis)
def test_load_config_empty_file_fails_cleanly(self, tmp_path: Path) -> None:
empty = tmp_path / "empty.yaml"
empty.write_text("")
with pytest.raises(ValidationError):
load_config(empty)
def test_load_config_rejects_malformed_yaml(self, tmp_path: Path) -> None:
bad = tmp_path / "bad.yaml"
# Unclosed flow-style mapping — yaml.safe_load raises here.
bad.write_text("inputs: {primary: foo\n")
with pytest.raises(yaml.YAMLError):
load_config(bad)
def test_save_report_round_trip(self, tmp_path: Path) -> None:
from neuropose.analyzer.segment import JOINT_INDEX
primary = _write_heel_trial(tmp_path, "a.json", joint=JOINT_INDEX["rhee"])
reference = _write_heel_trial(tmp_path, "b.json", joint=JOINT_INDEX["rhee"])
config = AnalysisConfig.model_validate(
{
"inputs": {"primary": str(primary), "reference": str(reference)},
"analysis": {"kind": "dtw", "method": "dtw_all"},
"output": {"report": str(tmp_path / "report.json")},
}
)
report = run_analysis(config)
report_path = tmp_path / "report.json"
save_report(report_path, report)
assert report_path.exists()
restored = load_report(report_path)
assert restored == report
def test_save_report_is_atomic(self, tmp_path: Path) -> None:
"""The saver writes via a sibling .tmp path and renames."""
report = _make_report(tmp_path)
report_path = tmp_path / "subdir" / "report.json"
save_report(report_path, report)
# Parent directory was created.
assert report_path.exists()
# No .tmp sibling left behind.
assert not (report_path.with_suffix(report_path.suffix + ".tmp")).exists()
def test_load_report_rejects_future_schema(self, tmp_path: Path) -> None:
"""Future schema_version surfaces as a migration error."""
from neuropose.migrations import FutureSchemaError
future = {"schema_version": CURRENT_VERSION + 1}
path = tmp_path / "future.json"
path.write_text(json.dumps(future))
with pytest.raises(FutureSchemaError):
load_report(path)

View File

@ -33,6 +33,8 @@ from neuropose.analyzer.segment import (
joint_pair_distance,
joint_speed,
segment_by_peaks,
segment_gait_cycles,
segment_gait_cycles_bilateral,
segment_predictions,
slice_predictions,
)
@ -429,3 +431,139 @@ class TestSlicePredictions:
sliced = slice_predictions(preds, segments)[0]
# frame_000000 of the slice must equal frame_000050 of the source
assert sliced["frame_000000"].poses3d == preds["frame_000050"].poses3d
# ---------------------------------------------------------------------------
# segment_gait_cycles / segment_gait_cycles_bilateral
# ---------------------------------------------------------------------------
def _heel_signal(num_cycles: int, frames_per_cycle: int) -> np.ndarray:
"""A clean sinusoid standing in for a heel's vertical trace."""
total = num_cycles * frames_per_cycle
t = np.linspace(0.0, num_cycles * 2.0 * math.pi, total, endpoint=False)
# Amplitude chosen so min_prominence tests have a non-trivial range.
return (np.sin(t) * 100.0 + 100.0).astype(float)
class TestSegmentGaitCycles:
def test_detects_expected_number_of_cycles(self) -> None:
# 5 cycles at 30 fps, 30 frames per cycle = 1.0 s per stride.
# Well inside the default min_cycle_seconds=0.4 gate.
signal = _heel_signal(num_cycles=5, frames_per_cycle=30)
preds = _make_predictions(signal, joint=JOINT_INDEX["rhee"])
seg = segment_gait_cycles(preds, joint="rhee", axis="y")
assert len(seg.segments) == 5
def test_config_records_inputs(self) -> None:
signal = _heel_signal(num_cycles=3, frames_per_cycle=30)
preds = _make_predictions(signal, joint=JOINT_INDEX["rhee"])
seg = segment_gait_cycles(
preds,
joint="rhee",
axis="y",
min_cycle_seconds=0.5,
min_prominence=10.0,
)
assert isinstance(seg, Segmentation)
assert isinstance(seg.config.extractor, JointAxisExtractor)
assert seg.config.extractor.joint == JOINT_INDEX["rhee"]
assert seg.config.extractor.axis == 1 # "y" → 1
assert seg.config.extractor.invert is False
assert seg.config.min_distance_seconds == 0.5
assert seg.config.min_prominence == 10.0
def test_axis_selection(self) -> None:
# Put the signal on the X axis instead of Y.
signal = _heel_signal(num_cycles=4, frames_per_cycle=30)
preds = _make_predictions(signal, joint=JOINT_INDEX["rhee"], axis=0)
seg_y = segment_gait_cycles(preds, joint="rhee", axis="y")
seg_x = segment_gait_cycles(preds, joint="rhee", axis="x")
# Y is all-zeros (flat → no peaks), X carries the signal.
assert len(seg_y.segments) == 0
assert len(seg_x.segments) == 4
def test_invert_flips_peaks_and_valleys(self) -> None:
# Invert the heel trace; with invert=True, the original valleys
# become the peaks detected as heel-strikes.
signal = _heel_signal(num_cycles=4, frames_per_cycle=30)
preds = _make_predictions(signal, joint=JOINT_INDEX["rhee"])
seg_plain = segment_gait_cycles(preds, joint="rhee", axis="y", invert=False)
seg_inverted = segment_gait_cycles(preds, joint="rhee", axis="y", invert=True)
# Both detect four distinct events (peaks in either the signal
# or its negation). Peaks differ by roughly half a cycle.
assert len(seg_plain.segments) == 4
assert len(seg_inverted.segments) == 4
plain_peaks = [s.peak for s in seg_plain.segments]
inverted_peaks = [s.peak for s in seg_inverted.segments]
assert plain_peaks != inverted_peaks
def test_pathological_flat_signal_returns_empty(self) -> None:
# A subject whose heel never leaves the ground — no peaks.
signal = np.zeros(120)
preds = _make_predictions(signal, joint=JOINT_INDEX["rhee"])
seg = segment_gait_cycles(preds, joint="rhee", axis="y")
assert seg.segments == []
def test_min_cycle_seconds_rejects_close_peaks(self) -> None:
# 10 cycles in 60 frames @ 30 fps = 0.2 s per cycle.
# min_cycle_seconds=0.4 should reject all but every-other peak.
signal = _heel_signal(num_cycles=10, frames_per_cycle=6)
preds = _make_predictions(signal, joint=JOINT_INDEX["rhee"])
seg_permissive = segment_gait_cycles(preds, joint="rhee", min_cycle_seconds=0.0)
seg_strict = segment_gait_cycles(preds, joint="rhee", min_cycle_seconds=0.4)
# Strict mode drops peaks that are too close together.
assert len(seg_strict.segments) < len(seg_permissive.segments)
def test_unknown_joint_raises_key_error(self) -> None:
signal = _heel_signal(num_cycles=3, frames_per_cycle=30)
preds = _make_predictions(signal, joint=JOINT_INDEX["rhee"])
with pytest.raises(KeyError, match="unknown joint"):
segment_gait_cycles(preds, joint="left_heel") # wrong name
def test_invalid_axis_raises_value_error(self) -> None:
signal = _heel_signal(num_cycles=3, frames_per_cycle=30)
preds = _make_predictions(signal, joint=JOINT_INDEX["rhee"])
with pytest.raises(ValueError, match="axis must be one of"):
segment_gait_cycles(preds, joint="rhee", axis="w") # type: ignore[arg-type]
class TestSegmentGaitCyclesBilateral:
def test_returns_both_keys(self) -> None:
signal = _heel_signal(num_cycles=3, frames_per_cycle=30)
# Put the same signal on both heels so both sides find cycles.
preds = _make_predictions(signal, joint=JOINT_INDEX["rhee"])
# Rebuild predictions with lhee populated too.
frames = {}
for i, value in enumerate(signal):
poses = [[[0.0, 0.0, 0.0] for _ in range(NUM_JOINTS)]]
poses[0][JOINT_INDEX["lhee"]][1] = float(value)
poses[0][JOINT_INDEX["rhee"]][1] = float(value)
frames[f"frame_{i:06d}"] = {
"boxes": [[0.0, 0.0, 1.0, 1.0, 0.9]],
"poses3d": poses,
"poses2d": [[[0.0, 0.0]] * NUM_JOINTS],
}
preds = VideoPredictions.model_validate(
{
"metadata": {
"frame_count": len(signal),
"fps": 30.0,
"width": 640,
"height": 480,
},
"frames": frames,
}
)
result = segment_gait_cycles_bilateral(preds)
assert set(result.keys()) == {"left_heel_strikes", "right_heel_strikes"}
assert len(result["left_heel_strikes"].segments) == 3
assert len(result["right_heel_strikes"].segments) == 3
def test_pathological_one_side_returns_empty_for_that_side(self) -> None:
# Only the right heel carries a signal; left heel is flat.
signal = _heel_signal(num_cycles=3, frames_per_cycle=30)
preds = _make_predictions(signal, joint=JOINT_INDEX["rhee"])
result = segment_gait_cycles_bilateral(preds)
assert len(result["right_heel_strikes"].segments) == 3
assert result["left_heel_strikes"].segments == []

View File

@ -683,9 +683,15 @@ def stub_estimator_with_metrics(monkeypatch: pytest.MonkeyPatch):
"poses2d": np.array([[[0.0, 0.0], [1.0, 1.0]]]),
}
def fake_loader(cache_dir: Path | None = None) -> object:
from neuropose._model import LoadedModel
def fake_loader(cache_dir: Path | None = None) -> LoadedModel:
del cache_dir
return RecordingFake()
return LoadedModel(
model=RecordingFake(),
sha256="smoke_sha",
filename="metrabs_smoke.tar.gz",
)
monkeypatch.setattr("neuropose.estimator.load_metrabs_model", fake_loader)
@ -770,17 +776,142 @@ class TestBenchmarkSubcommand:
class TestAnalyze:
def test_analyze_stub_exits_with_pending_message(
"""Covers the ``neuropose analyze --config <yaml>`` subcommand.
Execution happy path is exercised in detail in
:mod:`tests.unit.test_analyzer_pipeline` this file focuses on
the CLI wiring: argument parsing, config-loading error modes, and
end-to-end smoke.
"""
def _make_predictions_file(self, tmp_path: Path, name: str, num_frames: int = 30) -> Path:
"""Write a trivial VideoPredictions file to disk for the CLI to load."""
import math
from neuropose.io import VideoPredictions, save_video_predictions
num_joints = 43
frames = {}
for i in range(num_frames):
poses = [[[0.0, 0.0, 0.0] for _ in range(num_joints)]]
poses[0][41][1] = float(math.sin(i * 0.3)) * 100.0 # rhee Y
frames[f"frame_{i:06d}"] = {
"boxes": [[0.0, 0.0, 1.0, 1.0, 0.9]],
"poses3d": poses,
"poses2d": [[[0.0, 0.0]] * num_joints],
}
preds = VideoPredictions.model_validate(
{
"metadata": {
"frame_count": num_frames,
"fps": 30.0,
"width": 640,
"height": 480,
},
"frames": frames,
}
)
path = tmp_path / name
save_video_predictions(path, preds)
return path
def _write_dtw_config(
self,
tmp_path: Path,
*,
primary: Path,
reference: Path,
report: Path,
) -> Path:
import yaml as _yaml
config_path = tmp_path / "config.yaml"
config_path.write_text(
_yaml.safe_dump(
{
"inputs": {"primary": str(primary), "reference": str(reference)},
"analysis": {"kind": "dtw", "method": "dtw_all"},
"output": {"report": str(report)},
}
)
)
return config_path
def test_missing_config_is_usage_error(self, runner: CliRunner, tmp_path: Path) -> None:
result = runner.invoke(app, ["analyze", "--config", str(tmp_path / "nope.yaml")])
assert result.exit_code == EXIT_USAGE
assert "config file not found" in result.output
def test_missing_config_flag_is_usage_error(self, runner: CliRunner) -> None:
result = runner.invoke(app, ["analyze"])
assert result.exit_code == EXIT_USAGE
def test_invalid_yaml_is_usage_error(self, runner: CliRunner, tmp_path: Path) -> None:
bad = tmp_path / "bad.yaml"
bad.write_text("inputs: {primary: foo\n") # unclosed flow mapping
result = runner.invoke(app, ["analyze", "--config", str(bad)])
assert result.exit_code == EXIT_USAGE
assert "could not parse YAML" in result.output
def test_schema_violation_is_usage_error(self, runner: CliRunner, tmp_path: Path) -> None:
import yaml as _yaml
bad = tmp_path / "schema.yaml"
bad.write_text(
_yaml.safe_dump(
{
"inputs": {"primary": str(tmp_path / "a.json")},
# dtw without reference — violates cross-field invariant.
"analysis": {"kind": "dtw", "method": "dtw_all"},
"output": {"report": str(tmp_path / "r.json")},
}
)
)
result = runner.invoke(app, ["analyze", "--config", str(bad)])
assert result.exit_code == EXIT_USAGE
assert "invalid config" in result.output
def test_happy_path_writes_report(self, runner: CliRunner, tmp_path: Path) -> None:
primary = self._make_predictions_file(tmp_path, "a.json")
reference = self._make_predictions_file(tmp_path, "b.json")
report_path = tmp_path / "report.json"
config = self._write_dtw_config(
tmp_path, primary=primary, reference=reference, report=report_path
)
result = runner.invoke(app, ["analyze", "--config", str(config)])
assert result.exit_code == EXIT_OK, result.output
assert report_path.exists()
assert "wrote analysis report" in result.output
assert "analysis kind: dtw" in result.output
def test_output_option_overrides_config_path(self, runner: CliRunner, tmp_path: Path) -> None:
primary = self._make_predictions_file(tmp_path, "a.json")
reference = self._make_predictions_file(tmp_path, "b.json")
# Config points at one report path ...
config = self._write_dtw_config(
tmp_path,
primary=primary,
reference=reference,
report=tmp_path / "declared.json",
)
# ... but --output overrides.
override = tmp_path / "override.json"
result = runner.invoke(app, ["analyze", "--config", str(config), "--output", str(override)])
assert result.exit_code == EXIT_OK, result.output
assert override.exists()
assert not (tmp_path / "declared.json").exists()
def test_missing_predictions_file_is_usage_error(
self, runner: CliRunner, tmp_path: Path
) -> None:
results_path = tmp_path / "results.json"
results_path.write_text("{}")
result = runner.invoke(app, ["analyze", str(results_path)])
assert result.exit_code == EXIT_PENDING
assert "commit 10" in result.output
def test_analyze_requires_an_argument(self, runner: CliRunner) -> None:
result = runner.invoke(app, ["analyze"])
# Config points at a primary that does not exist.
config = self._write_dtw_config(
tmp_path,
primary=tmp_path / "missing_primary.json",
reference=tmp_path / "missing_reference.json",
report=tmp_path / "report.json",
)
result = runner.invoke(app, ["analyze", "--config", str(config)])
assert result.exit_code == EXIT_USAGE

View File

@ -70,17 +70,21 @@ class TestModelGuard:
network: the loader is monkeypatched to return a sentinel, and we
assert it ends up as the estimator's model.
"""
from neuropose._model import LoadedModel
sentinel = object()
called_with: list[Path | None] = []
def fake_loader(cache_dir: Path | None = None) -> object:
def fake_loader(cache_dir: Path | None = None) -> LoadedModel:
called_with.append(cache_dir)
return sentinel
return LoadedModel(model=sentinel, sha256="deadbeef", filename="fake.tar.gz")
monkeypatch.setattr("neuropose.estimator.load_metrabs_model", fake_loader)
estimator = Estimator()
estimator.load_model(cache_dir=Path("/tmp/fake-cache"))
assert estimator.model is sentinel
assert estimator.model_sha256 == "deadbeef"
assert estimator.model_filename == "fake.tar.gz"
assert called_with == [Path("/tmp/fake-cache")]
def test_load_model_is_idempotent_when_already_loaded(
@ -278,9 +282,15 @@ class TestPerformanceMetrics:
"poses2d": np.array([[[0.0, 0.0]]]),
}
def fake_loader(cache_dir: Path | None = None) -> object:
from neuropose._model import LoadedModel
def fake_loader(cache_dir: Path | None = None) -> LoadedModel:
del cache_dir
return Recorder()
return LoadedModel(
model=Recorder(),
sha256="fake_sha",
filename="metrabs_fake.tar.gz",
)
monkeypatch.setattr("neuropose.estimator.load_metrabs_model", fake_loader)
estimator = Estimator()
@ -312,6 +322,88 @@ class TestPerformanceMetrics:
assert result.metrics.tensorflow_version not in {"", "unknown"}
class TestProvenance:
"""Provenance attachment to VideoPredictions.
Covers the two relevant paths: the injected-model path (no SHA
known ``provenance=None`` on output) and the ``load_model`` path
(SHA is known full ``Provenance`` populated and attached).
"""
def test_injected_model_produces_no_provenance(
self,
synthetic_video: Path,
fake_metrabs_model,
) -> None:
estimator = Estimator(model=fake_metrabs_model)
result = estimator.process_video(synthetic_video)
assert result.predictions.provenance is None
assert estimator.model_sha256 is None
assert estimator.model_filename is None
def test_loaded_model_populates_provenance(
self,
synthetic_video: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
import numpy as np
from neuropose._model import LoadedModel
class Recorder:
def detect_poses(self, image, **kwargs):
del image, kwargs
return {
"boxes": np.array([[0.0, 0.0, 1.0, 1.0, 0.9]]),
"poses3d": np.array([[[0.0, 0.0, 0.0]]]),
"poses2d": np.array([[[0.0, 0.0]]]),
}
def fake_loader(cache_dir: Path | None = None) -> LoadedModel:
del cache_dir
return LoadedModel(
model=Recorder(),
sha256="e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
filename="metrabs_stub.tar.gz",
)
monkeypatch.setattr("neuropose.estimator.load_metrabs_model", fake_loader)
estimator = Estimator()
estimator.load_model()
result = estimator.process_video(synthetic_video)
prov = result.predictions.provenance
assert prov is not None
assert prov.model_sha256.startswith("e3b0c44")
assert prov.model_filename == "metrabs_stub.tar.gz"
assert prov.numpy_version == np.__version__
assert prov.python_version.count(".") == 2 # MAJOR.MINOR.MICRO
# neuropose_version should match the package's __version__
from neuropose import __version__ as pkg_version
assert prov.neuropose_version == pkg_version
# tensorflow_version should also be real (TF is in dev deps).
assert prov.tensorflow_version not in {"", "unknown"}
def test_model_sha256_and_filename_properties_after_load(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
from neuropose._model import LoadedModel
def fake_loader(cache_dir: Path | None = None) -> LoadedModel:
del cache_dir
return LoadedModel(model=object(), sha256="abcd", filename="x.tar.gz")
monkeypatch.setattr("neuropose.estimator.load_metrabs_model", fake_loader)
estimator = Estimator()
assert estimator.model_sha256 is None
assert estimator.model_filename is None
estimator.load_model()
assert estimator.model_sha256 == "abcd"
assert estimator.model_filename == "x.tar.gz"
class TestErrors:
def test_missing_video(
self,

View File

@ -22,6 +22,7 @@ from neuropose.io import (
JointPairDistanceExtractor,
JointSpeedExtractor,
PerformanceMetrics,
Provenance,
Segment,
Segmentation,
SegmentationConfig,
@ -278,6 +279,102 @@ class TestPerformanceMetricsModel:
m.total_seconds = 2.0
def _minimal_provenance() -> Provenance:
return Provenance(
model_sha256="a" * 64,
model_filename="metrabs_fake.tar.gz",
tensorflow_version="2.18.1",
numpy_version="2.0.2",
neuropose_version="0.1.0.dev0",
python_version="3.11.14",
)
class TestProvenanceModel:
"""Schema-level behaviour of :class:`neuropose.io.Provenance`."""
def test_roundtrip_through_json(self) -> None:
p = Provenance(
model_sha256="a" * 64,
model_filename="metrabs_fake.tar.gz",
tensorflow_version="2.18.1",
tensorflow_metal_version="1.2.0",
numpy_version="2.0.2",
neuropose_version="0.1.0.dev0",
python_version="3.11.14",
seed=42,
deterministic=True,
analysis_config={"step": "dtw", "nan_policy": "propagate"},
)
rehydrated = Provenance.model_validate(p.model_dump(mode="json"))
assert rehydrated == p
def test_optional_fields_default_to_none_and_false(self) -> None:
p = _minimal_provenance()
assert p.tensorflow_metal_version is None
assert p.seed is None
assert p.deterministic is False
assert p.analysis_config is None
def test_is_frozen(self) -> None:
p = _minimal_provenance()
with pytest.raises(ValidationError):
p.model_sha256 = "different"
def test_extra_fields_forbidden(self) -> None:
# Construct via model_validate so pyright doesn't have to prove the
# keyword doesn't exist on the class at static-type time.
with pytest.raises(ValidationError):
Provenance.model_validate(
{
"model_sha256": "x" * 64,
"model_filename": "x.tar.gz",
"tensorflow_version": "2.18",
"numpy_version": "2.0",
"neuropose_version": "0.1",
"python_version": "3.11.14",
"unknown_field": "bogus",
}
)
class TestVideoPredictionsProvenance:
"""``provenance`` field on :class:`VideoPredictions` round-trips."""
def test_default_is_none(self) -> None:
vp = VideoPredictions(
metadata=VideoMetadata(frame_count=0, fps=30.0, width=32, height=32),
frames={},
)
assert vp.provenance is None
def test_roundtrip_with_provenance(self, tmp_path: Path) -> None:
prov = Provenance(
model_sha256="f" * 64,
model_filename="metrabs.tar.gz",
tensorflow_version="2.18.1",
numpy_version="2.0.2",
neuropose_version="0.1.0.dev0",
python_version="3.11.14",
)
vp = VideoPredictions(
metadata=VideoMetadata(frame_count=1, fps=30.0, width=32, height=32),
frames={
"frame_000000": FramePrediction(
boxes=[[0.0, 0.0, 32.0, 32.0, 0.9]],
poses3d=[[[1.0, 2.0, 3.0]]],
poses2d=[[[10.0, 20.0]]],
)
},
provenance=prov,
)
path = tmp_path / "vp.json"
save_video_predictions(path, vp)
loaded = load_video_predictions(path)
assert loaded == vp
assert loaded.provenance == prov
class TestBenchmarkResultPersistence:
def test_roundtrip_to_disk(self, tmp_path: Path) -> None:
result = BenchmarkResult(

View File

@ -0,0 +1,488 @@
"""Tests for :mod:`neuropose.migrations`.
Covers both the low-level migration driver (version walking, future/missing
errors, INFO logging) and its integration through the
:mod:`neuropose.io` load helpers (legacy payloads round-trip; future
payloads fail with a clear message).
The migration driver is tested by monkey-patching ``CURRENT_VERSION`` and
the per-schema migration registries, so the tests exercise the full
chain-walking machinery without needing the codebase to actually be on a
non-initial schema version.
"""
from __future__ import annotations
import json
import logging
from pathlib import Path
import pytest
from neuropose import migrations
from neuropose.io import (
BenchmarkResult,
FramePrediction,
JobResults,
VideoMetadata,
VideoPredictions,
load_benchmark_result,
load_job_results,
load_video_predictions,
save_benchmark_result,
save_job_results,
save_video_predictions,
)
from neuropose.migrations import (
CURRENT_VERSION,
FutureSchemaError,
MigrationError,
MigrationNotFoundError,
migrate_analysis_report,
migrate_benchmark_result,
migrate_job_results,
migrate_video_predictions,
)
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
def _minimal_video_predictions_payload() -> dict:
"""A valid VideoPredictions payload at the current schema version."""
return {
"schema_version": CURRENT_VERSION,
"metadata": {
"frame_count": 1,
"fps": 30.0,
"width": 32,
"height": 32,
},
"frames": {
"frame_000000": {
"boxes": [[0.0, 0.0, 32.0, 32.0, 0.95]],
"poses3d": [[[1.0, 2.0, 3.0]]],
"poses2d": [[[10.0, 20.0]]],
}
},
"segmentations": {},
}
def _minimal_video_predictions_object() -> VideoPredictions:
"""Same payload, as a validated pydantic object."""
return VideoPredictions(
metadata=VideoMetadata(frame_count=1, fps=30.0, width=32, height=32),
frames={
"frame_000000": FramePrediction(
boxes=[[0.0, 0.0, 32.0, 32.0, 0.95]],
poses3d=[[[1.0, 2.0, 3.0]]],
poses2d=[[[10.0, 20.0]]],
)
},
)
@pytest.fixture
def fake_two_version_chain(monkeypatch: pytest.MonkeyPatch) -> None:
"""Patch the module to look like CURRENT_VERSION=2 with a v1->v2 migration.
Lets the tests exercise the full migration loop even though the real
codebase is still at CURRENT_VERSION=1.
"""
monkeypatch.setattr(migrations, "CURRENT_VERSION", 2)
def _v1_to_v2(payload: dict) -> dict:
payload = dict(payload)
payload["schema_version"] = 2
payload["added_in_v2"] = "hello"
return payload
monkeypatch.setattr(migrations, "_VIDEO_PREDICTIONS_MIGRATIONS", {1: _v1_to_v2})
monkeypatch.setattr(migrations, "_BENCHMARK_RESULT_MIGRATIONS", {1: _v1_to_v2})
@pytest.fixture
def fake_three_version_chain(monkeypatch: pytest.MonkeyPatch) -> None:
"""Patch the module to look like CURRENT_VERSION=3 with v1->v2 and v2->v3.
Exercises multi-step migration chaining.
"""
monkeypatch.setattr(migrations, "CURRENT_VERSION", 3)
def _v1_to_v2(payload: dict) -> dict:
payload = dict(payload)
payload["schema_version"] = 2
payload["added_in_v2"] = "alpha"
return payload
def _v2_to_v3(payload: dict) -> dict:
payload = dict(payload)
payload["schema_version"] = 3
payload["added_in_v3"] = "beta"
return payload
monkeypatch.setattr(
migrations,
"_VIDEO_PREDICTIONS_MIGRATIONS",
{1: _v1_to_v2, 2: _v2_to_v3},
)
# ---------------------------------------------------------------------------
# migrate_video_predictions — driver behavior
# ---------------------------------------------------------------------------
class TestMigrateVideoPredictions:
def test_current_version_payload_is_noop(self) -> None:
payload = {"schema_version": CURRENT_VERSION, "hello": "world"}
result = migrate_video_predictions(payload)
assert result == payload
def test_missing_version_key_treated_as_v1(self) -> None:
"""A payload with no schema_version is treated as version 1.
With CURRENT_VERSION == 2, the legacy payload is run through
the registered v1 v2 migration (which stamps ``provenance =
None``) on the way to the current version.
"""
payload = {"hello": "world"}
result = migrate_video_predictions(payload)
assert result["hello"] == "world"
assert result["schema_version"] == CURRENT_VERSION
assert result["provenance"] is None
def test_future_version_raises(self) -> None:
payload = {"schema_version": CURRENT_VERSION + 99}
with pytest.raises(FutureSchemaError, match="newer than"):
migrate_video_predictions(payload)
def test_non_integer_version_raises(self) -> None:
payload = {"schema_version": "1.0"}
with pytest.raises(MigrationError, match="invalid schema_version"):
migrate_video_predictions(payload)
def test_zero_version_raises(self) -> None:
payload = {"schema_version": 0}
with pytest.raises(MigrationError, match="invalid schema_version"):
migrate_video_predictions(payload)
def test_single_step_migration(self, fake_two_version_chain: None) -> None:
del fake_two_version_chain
payload = {"schema_version": 1, "original_field": "keep_me"}
result = migrate_video_predictions(payload)
assert result == {
"schema_version": 2,
"original_field": "keep_me",
"added_in_v2": "hello",
}
def test_missing_version_under_patched_chain_migrates_from_v1(
self, fake_two_version_chain: None
) -> None:
del fake_two_version_chain
# Legacy file with no version stamp: should be treated as v1 and
# upgraded to v2.
payload = {"legacy": True}
result = migrate_video_predictions(payload)
assert result["schema_version"] == 2
assert result["added_in_v2"] == "hello"
assert result["legacy"] is True
def test_multi_step_migration_chains(self, fake_three_version_chain: None) -> None:
del fake_three_version_chain
payload = {"schema_version": 1, "original": "yes"}
result = migrate_video_predictions(payload)
assert result == {
"schema_version": 3,
"original": "yes",
"added_in_v2": "alpha",
"added_in_v3": "beta",
}
def test_missing_intermediate_migration_raises(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""If CURRENT advances past a version with no registered migration, fail loud."""
monkeypatch.setattr(migrations, "CURRENT_VERSION", 3)
# Only v1 -> v2 registered; v2 -> v3 is the missing link.
monkeypatch.setattr(
migrations,
"_VIDEO_PREDICTIONS_MIGRATIONS",
{1: lambda p: {**p, "schema_version": 2}},
)
with pytest.raises(MigrationNotFoundError, match="from schema_version 2"):
migrate_video_predictions({"schema_version": 1})
def test_logs_at_info_on_migration(
self,
fake_two_version_chain: None,
caplog: pytest.LogCaptureFixture,
) -> None:
del fake_two_version_chain
caplog.set_level(logging.INFO, logger="neuropose.migrations")
migrate_video_predictions({"schema_version": 1})
assert any("Migrating VideoPredictions" in record.message for record in caplog.records)
def test_starting_from_current_logs_nothing(
self,
fake_two_version_chain: None,
caplog: pytest.LogCaptureFixture,
) -> None:
del fake_two_version_chain
caplog.set_level(logging.INFO, logger="neuropose.migrations")
migrate_video_predictions({"schema_version": 2})
assert not any("Migrating" in record.message for record in caplog.records)
# ---------------------------------------------------------------------------
# migrate_benchmark_result — same driver, sibling registry
# ---------------------------------------------------------------------------
class TestMigrateBenchmarkResult:
def test_uses_benchmark_registry_not_video_registry(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Each schema has its own migration registry; they must not cross-pollinate."""
monkeypatch.setattr(migrations, "CURRENT_VERSION", 2)
# Register video migration but NOT benchmark migration.
monkeypatch.setattr(
migrations,
"_VIDEO_PREDICTIONS_MIGRATIONS",
{1: lambda p: {**p, "schema_version": 2, "from_video_registry": True}},
)
monkeypatch.setattr(migrations, "_BENCHMARK_RESULT_MIGRATIONS", {})
# Video migration works:
assert migrate_video_predictions({"schema_version": 1})["from_video_registry"] is True
# Benchmark migration should fail — no entry in its registry.
with pytest.raises(MigrationNotFoundError):
migrate_benchmark_result({"schema_version": 1})
# ---------------------------------------------------------------------------
# migrate_job_results — per-entry dispatch
# ---------------------------------------------------------------------------
class TestMigrateJobResults:
def test_empty_dict_is_noop(self) -> None:
assert migrate_job_results({}) == {}
def test_each_video_is_migrated(self, fake_two_version_chain: None) -> None:
del fake_two_version_chain
payload = {
"video_a.mp4": {"schema_version": 1, "content_a": True},
"video_b.mp4": {"schema_version": 1, "content_b": True},
}
result = migrate_job_results(payload)
assert result["video_a.mp4"]["schema_version"] == 2
assert result["video_a.mp4"]["content_a"] is True
assert result["video_a.mp4"]["added_in_v2"] == "hello"
assert result["video_b.mp4"]["schema_version"] == 2
assert result["video_b.mp4"]["content_b"] is True
# ---------------------------------------------------------------------------
# migrate_analysis_report
# ---------------------------------------------------------------------------
class TestMigrateAnalysisReport:
def test_current_version_is_noop(self) -> None:
"""A payload already at CURRENT_VERSION passes through unchanged."""
payload = {"schema_version": CURRENT_VERSION, "foo": "bar"}
assert migrate_analysis_report(payload) == payload
def test_missing_version_defaults_to_v1_and_fails(self) -> None:
"""AnalysisReport first shipped at CURRENT_VERSION, so a payload
without schema_version (defaulting to 1) would require a
non-existent v1v2 migration and fail with a clear error."""
with pytest.raises(MigrationNotFoundError, match="AnalysisReport"):
migrate_analysis_report({})
def test_future_version_rejected(self) -> None:
with pytest.raises(FutureSchemaError, match="AnalysisReport"):
migrate_analysis_report({"schema_version": CURRENT_VERSION + 5})
# ---------------------------------------------------------------------------
# register_video_predictions_migration / register_benchmark_result_migration
# / register_analysis_report_migration
# ---------------------------------------------------------------------------
class TestRegistration:
def test_duplicate_registration_raises(self, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(migrations, "_VIDEO_PREDICTIONS_MIGRATIONS", {})
@migrations.register_video_predictions_migration(from_version=1)
def _first(p: dict) -> dict:
return p
with pytest.raises(RuntimeError, match="already registered"):
@migrations.register_video_predictions_migration(from_version=1)
def _second(p: dict) -> dict:
return p
def test_decorator_returns_callable_unchanged(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""The decorator must not wrap or rename the function."""
monkeypatch.setattr(migrations, "_VIDEO_PREDICTIONS_MIGRATIONS", {})
@migrations.register_video_predictions_migration(from_version=1)
def _fn(p: dict) -> dict:
return p
assert _fn.__name__ == "_fn"
assert _fn({"x": 1}) == {"x": 1}
def test_analysis_report_duplicate_registration_raises(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
monkeypatch.setattr(migrations, "_ANALYSIS_REPORT_MIGRATIONS", {})
@migrations.register_analysis_report_migration(from_version=2)
def _first(p: dict) -> dict:
return p
with pytest.raises(RuntimeError, match="already registered"):
@migrations.register_analysis_report_migration(from_version=2)
def _second(p: dict) -> dict:
return p
# ---------------------------------------------------------------------------
# Integration: load_* functions run migrations before validation
# ---------------------------------------------------------------------------
class TestLoadIntegration:
def test_load_video_predictions_accepts_legacy_payload(self, tmp_path: Path) -> None:
"""A VideoPredictions JSON written before schema_version existed loads cleanly."""
legacy = _minimal_video_predictions_payload()
del legacy["schema_version"] # Pretend this file predates versioning.
path = tmp_path / "legacy.json"
path.write_text(json.dumps(legacy))
loaded = load_video_predictions(path)
assert loaded.schema_version == CURRENT_VERSION
def test_load_video_predictions_rejects_future_version(self, tmp_path: Path) -> None:
payload = _minimal_video_predictions_payload()
payload["schema_version"] = CURRENT_VERSION + 42
path = tmp_path / "future.json"
path.write_text(json.dumps(payload))
with pytest.raises(FutureSchemaError):
load_video_predictions(path)
def test_save_then_load_roundtrips(self, tmp_path: Path) -> None:
obj = _minimal_video_predictions_object()
path = tmp_path / "out.json"
save_video_predictions(path, obj)
loaded = load_video_predictions(path)
assert loaded == obj
assert loaded.schema_version == CURRENT_VERSION
def test_load_job_results_migrates_each_video(self, tmp_path: Path) -> None:
video_a = _minimal_video_predictions_payload()
video_b = _minimal_video_predictions_payload()
# Strip schema_version from both to simulate legacy file.
del video_a["schema_version"]
del video_b["schema_version"]
payload = {"a.mp4": video_a, "b.mp4": video_b}
path = tmp_path / "job.json"
path.write_text(json.dumps(payload))
loaded = load_job_results(path)
assert len(loaded) == 2
for video in ("a.mp4", "b.mp4"):
assert loaded[video].schema_version == CURRENT_VERSION
def test_save_then_load_job_results_roundtrips(self, tmp_path: Path) -> None:
obj = JobResults(root={"video_a.mp4": _minimal_video_predictions_object()})
path = tmp_path / "job.json"
save_job_results(path, obj)
loaded = load_job_results(path)
assert loaded == obj
def test_load_benchmark_result_roundtrips(self, tmp_path: Path) -> None:
"""Save → load round-trip for a realistic benchmark result."""
from neuropose.io import BenchmarkAggregate, PerformanceMetrics
metrics = PerformanceMetrics(
total_seconds=1.0,
per_frame_latencies_ms=[10.0, 11.0],
peak_rss_mb=100.0,
active_device="/CPU:0",
tensorflow_version="2.18.0",
)
aggregate = BenchmarkAggregate(
repeats_measured=1,
warmup_frames_per_pass=0,
mean_frame_latency_ms=10.5,
p50_frame_latency_ms=10.5,
p95_frame_latency_ms=11.0,
p99_frame_latency_ms=11.0,
stddev_frame_latency_ms=0.5,
mean_throughput_fps=95.0,
peak_rss_mb_max=100.0,
active_device="/CPU:0",
tensorflow_version="2.18.0",
)
result = BenchmarkResult(
video_name="test.mp4",
repeats=2,
warmup_frames=0,
warmup_pass=metrics,
measured_passes=[metrics],
aggregate=aggregate,
)
path = tmp_path / "bench.json"
save_benchmark_result(path, result)
loaded = load_benchmark_result(path)
assert loaded == result
assert loaded.schema_version == CURRENT_VERSION
def test_load_benchmark_result_rejects_future_version(self, tmp_path: Path) -> None:
"""Future-versioned benchmark file should raise with a clear message."""
from neuropose.io import BenchmarkAggregate, PerformanceMetrics
metrics = PerformanceMetrics(
total_seconds=1.0,
per_frame_latencies_ms=[10.0],
peak_rss_mb=100.0,
active_device="/CPU:0",
tensorflow_version="2.18.0",
)
aggregate = BenchmarkAggregate(
repeats_measured=1,
warmup_frames_per_pass=0,
mean_frame_latency_ms=10.0,
p50_frame_latency_ms=10.0,
p95_frame_latency_ms=10.0,
p99_frame_latency_ms=10.0,
stddev_frame_latency_ms=0.0,
mean_throughput_fps=100.0,
peak_rss_mb_max=100.0,
active_device="/CPU:0",
tensorflow_version="2.18.0",
)
result = BenchmarkResult(
video_name="x.mp4",
repeats=1,
warmup_frames=0,
warmup_pass=metrics,
measured_passes=[metrics],
aggregate=aggregate,
)
# Serialize then hand-edit to inject a future version.
payload = result.model_dump(mode="json")
payload["schema_version"] = CURRENT_VERSION + 1
path = tmp_path / "bench_future.json"
path.write_text(json.dumps(payload))
with pytest.raises(FutureSchemaError):
load_benchmark_result(path)

451
tests/unit/test_reset.py Normal file
View File

@ -0,0 +1,451 @@
"""Tests for :mod:`neuropose.reset` and the ``neuropose reset`` CLI command.
Coverage:
- :func:`find_neuropose_processes` filters by cmdline marker, classifies
daemon vs monitor, and excludes the calling process.
- :func:`terminate_processes` sends SIGINT, escalates to SIGKILL when
asked, and reports survivors.
- :func:`wipe_state` removes contents of in/, out/, failed/, the lock
file, and ``.ingest_*`` staging dirs; honors ``keep_failed`` and
``dry_run``.
- :func:`reset_pipeline` skips the wipe phase when termination leaves
survivors.
- The ``neuropose reset`` CLI command renders previews, honors
``--dry-run`` and ``--yes``, and exits non-zero on survivors.
The process-killing tests use monkeypatched ``psutil.process_iter``
and ``os.kill`` so the suite never touches the real process table or
sends real signals.
"""
from __future__ import annotations
import signal
from pathlib import Path
from typing import Any
import pytest
from typer.testing import CliRunner
from neuropose.cli import EXIT_OK, EXIT_USAGE, app
from neuropose.config import Settings
from neuropose.interfacer import LOCK_FILENAME
from neuropose.reset import (
DEFAULT_GRACE_SECONDS,
RunningProcess,
TerminationReport,
WipeReport,
find_neuropose_processes,
reset_pipeline,
terminate_processes,
wipe_state,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
class _FakeProc:
"""Minimal stand-in for ``psutil.Process`` for the discovery tests."""
def __init__(self, pid: int, cmdline: list[str]) -> None:
self.info = {"pid": pid, "cmdline": cmdline}
@pytest.fixture
def settings(tmp_path: Path) -> Settings:
"""A Settings pointing at an isolated temp data dir, with subdirs created."""
s = Settings(data_dir=tmp_path / "jobs", model_cache_dir=tmp_path / "models")
s.ensure_dirs()
return s
# ---------------------------------------------------------------------------
# find_neuropose_processes
# ---------------------------------------------------------------------------
class TestFindNeuroposeProcesses:
def test_classifies_watch_as_daemon(self, monkeypatch: pytest.MonkeyPatch) -> None:
procs = [_FakeProc(1234, ["python", "-m", "neuropose", "watch"])]
monkeypatch.setattr("psutil.process_iter", lambda attrs: iter(procs))
found = find_neuropose_processes()
assert len(found) == 1
assert found[0].pid == 1234
assert found[0].role == "daemon"
def test_classifies_serve_as_monitor(self, monkeypatch: pytest.MonkeyPatch) -> None:
procs = [_FakeProc(5678, ["uv", "run", "neuropose", "serve", "--port", "8765"])]
monkeypatch.setattr("psutil.process_iter", lambda attrs: iter(procs))
found = find_neuropose_processes()
assert len(found) == 1
assert found[0].role == "monitor"
def test_ignores_unrelated_processes(self, monkeypatch: pytest.MonkeyPatch) -> None:
procs = [
_FakeProc(1, ["bash"]),
_FakeProc(2, ["python", "-m", "pip", "install", "neuropose"]),
_FakeProc(3, ["neuropose", "--help"]),
]
monkeypatch.setattr("psutil.process_iter", lambda attrs: iter(procs))
assert find_neuropose_processes() == []
def test_excludes_self(self, monkeypatch: pytest.MonkeyPatch) -> None:
import os
self_pid = os.getpid()
procs = [
_FakeProc(self_pid, ["python", "-m", "neuropose", "watch"]),
_FakeProc(9999, ["python", "-m", "neuropose", "watch"]),
]
monkeypatch.setattr("psutil.process_iter", lambda attrs: iter(procs))
found = find_neuropose_processes()
assert [rp.pid for rp in found] == [9999]
def test_includes_self_when_disabled(self, monkeypatch: pytest.MonkeyPatch) -> None:
import os
self_pid = os.getpid()
procs = [_FakeProc(self_pid, ["python", "-m", "neuropose", "watch"])]
monkeypatch.setattr("psutil.process_iter", lambda attrs: iter(procs))
found = find_neuropose_processes(exclude_self=False)
assert [rp.pid for rp in found] == [self_pid]
def test_handles_dead_processes(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""A NoSuchProcess raised mid-iteration must not crash the scan."""
import psutil
class _RaisingProc:
@property
def info(self) -> dict[str, Any]:
raise psutil.NoSuchProcess(pid=1)
procs = [
_RaisingProc(),
_FakeProc(2, ["python", "-m", "neuropose", "watch"]),
]
monkeypatch.setattr("psutil.process_iter", lambda attrs: iter(procs))
found = find_neuropose_processes()
assert [rp.pid for rp in found] == [2]
# ---------------------------------------------------------------------------
# terminate_processes
# ---------------------------------------------------------------------------
class TestTerminateProcesses:
def test_empty_list_is_noop(self) -> None:
report = terminate_processes([])
assert report.stopped == []
assert report.survivors == []
assert report.force_killed == []
def test_sigint_to_each_process(self, monkeypatch: pytest.MonkeyPatch) -> None:
sent: list[tuple[int, int]] = []
monkeypatch.setattr("os.kill", lambda pid, sig: sent.append((pid, sig)))
# Mark all processes immediately dead so the wait loop exits fast.
monkeypatch.setattr("neuropose.reset._is_alive", lambda pid: False)
rps = [
RunningProcess(pid=10, role="daemon", cmdline="x"),
RunningProcess(pid=20, role="monitor", cmdline="y"),
]
report = terminate_processes(rps, grace_seconds=0.0)
assert sent == [(10, signal.SIGINT), (20, signal.SIGINT)]
assert {p.pid for p in report.stopped} == {10, 20}
assert report.survivors == []
def test_survivors_reported_when_force_kill_off(self, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr("os.kill", lambda pid, sig: None)
# Process 10 always alive; process 20 dies after SIGINT.
alive = {10}
monkeypatch.setattr("neuropose.reset._is_alive", lambda pid: pid in alive)
rps = [
RunningProcess(pid=10, role="daemon", cmdline="x"),
RunningProcess(pid=20, role="monitor", cmdline="y"),
]
# Drop pid 20 so it appears dead immediately.
alive.discard(20)
report = terminate_processes(rps, grace_seconds=0.0, force_kill=False)
assert {p.pid for p in report.stopped} == {20}
assert {p.pid for p in report.survivors} == {10}
assert report.force_killed == []
def test_force_kill_escalates_to_sigkill(self, monkeypatch: pytest.MonkeyPatch) -> None:
sent: list[tuple[int, int]] = []
monkeypatch.setattr("os.kill", lambda pid, sig: sent.append((pid, sig)))
# Start alive; SIGKILL "kills" by toggling the flag from inside _is_alive.
alive = {10}
def fake_is_alive(pid: int) -> bool:
if (pid, signal.SIGKILL) in sent:
return False
return pid in alive
monkeypatch.setattr("neuropose.reset._is_alive", fake_is_alive)
rp = RunningProcess(pid=10, role="daemon", cmdline="x")
report = terminate_processes([rp], grace_seconds=0.0, force_kill=True)
assert (10, signal.SIGINT) in sent
assert (10, signal.SIGKILL) in sent
assert {p.pid for p in report.force_killed} == {10}
assert report.survivors == []
# ---------------------------------------------------------------------------
# wipe_state
# ---------------------------------------------------------------------------
class TestWipeState:
def test_no_op_on_empty_dirs(self, settings: Settings) -> None:
report = wipe_state(settings)
assert report.removed_paths == []
assert report.bytes_freed == 0
def test_removes_in_out_failed_contents(self, settings: Settings) -> None:
(settings.input_dir / "job_a").mkdir()
(settings.input_dir / "job_a" / "video.mp4").write_bytes(b"x" * 100)
(settings.output_dir / "status.json").write_text("{}")
(settings.failed_dir / "job_b").mkdir()
report = wipe_state(settings)
names = {p.name for p in report.removed_paths}
assert names == {"job_a", "status.json", "job_b"}
# Containers themselves preserved.
assert settings.input_dir.exists()
assert settings.output_dir.exists()
assert settings.failed_dir.exists()
def test_keep_failed_preserves_failed_contents(self, settings: Settings) -> None:
(settings.input_dir / "job_a").mkdir()
(settings.failed_dir / "job_b").mkdir()
(settings.failed_dir / "job_b" / "evidence.log").write_text("crash")
report = wipe_state(settings, keep_failed=True)
names = {p.name for p in report.removed_paths}
assert "job_a" in names
assert "job_b" not in names
assert (settings.failed_dir / "job_b" / "evidence.log").exists()
def test_removes_lock_file(self, settings: Settings) -> None:
(settings.data_dir / LOCK_FILENAME).write_text("12345\n")
report = wipe_state(settings)
assert (settings.data_dir / LOCK_FILENAME) in report.removed_paths
assert not (settings.data_dir / LOCK_FILENAME).exists()
def test_removes_ingest_staging_dirs(self, settings: Settings) -> None:
staging_a = settings.data_dir / ".ingest_abc123"
staging_b = settings.data_dir / ".ingest_def456"
staging_a.mkdir()
staging_b.mkdir()
(staging_a / "leftover.mp4").write_bytes(b"y" * 50)
report = wipe_state(settings)
assert staging_a in report.removed_paths
assert staging_b in report.removed_paths
assert not staging_a.exists()
assert not staging_b.exists()
def test_dry_run_reports_without_removing(self, settings: Settings) -> None:
(settings.input_dir / "job_a").mkdir()
(settings.input_dir / "job_a" / "video.mp4").write_bytes(b"z" * 200)
report = wipe_state(settings, dry_run=True)
assert len(report.removed_paths) == 1
assert report.bytes_freed == 200
# Nothing actually deleted.
assert (settings.input_dir / "job_a" / "video.mp4").exists()
def test_bytes_freed_recurses_into_subdirs(self, settings: Settings) -> None:
job = settings.input_dir / "job_a"
job.mkdir()
(job / "a.mp4").write_bytes(b"a" * 100)
(job / "nested").mkdir()
(job / "nested" / "b.mp4").write_bytes(b"b" * 250)
report = wipe_state(settings, dry_run=True)
assert report.bytes_freed == 350
# ---------------------------------------------------------------------------
# reset_pipeline
# ---------------------------------------------------------------------------
class TestResetPipeline:
def test_dry_run_skips_termination(
self,
settings: Settings,
monkeypatch: pytest.MonkeyPatch,
) -> None:
rp = RunningProcess(pid=10, role="daemon", cmdline="x")
monkeypatch.setattr("neuropose.reset.find_neuropose_processes", lambda: [rp])
# Sentinel to detect termination calls.
def _should_not_be_called(*args: Any, **kwargs: Any) -> TerminationReport:
raise AssertionError("dry_run must not invoke terminate_processes")
monkeypatch.setattr("neuropose.reset.terminate_processes", _should_not_be_called)
report = reset_pipeline(settings, dry_run=True)
assert report.dry_run is True
assert report.discovered == [rp]
assert report.termination.stopped == []
def test_skips_wipe_when_survivors_remain(
self,
settings: Settings,
monkeypatch: pytest.MonkeyPatch,
) -> None:
rp = RunningProcess(pid=10, role="daemon", cmdline="x")
monkeypatch.setattr("neuropose.reset.find_neuropose_processes", lambda: [rp])
monkeypatch.setattr(
"neuropose.reset.terminate_processes",
lambda procs, **_: TerminationReport(survivors=list(procs)),
)
# Seed something to wipe so we can confirm it's untouched.
(settings.input_dir / "job_a").mkdir()
report = reset_pipeline(settings, dry_run=False)
assert report.wipe_skipped_due_to_survivors is True
assert report.wipe.removed_paths == []
assert (settings.input_dir / "job_a").exists()
def test_wipes_when_all_processes_stopped(
self,
settings: Settings,
monkeypatch: pytest.MonkeyPatch,
) -> None:
rp = RunningProcess(pid=10, role="daemon", cmdline="x")
monkeypatch.setattr("neuropose.reset.find_neuropose_processes", lambda: [rp])
monkeypatch.setattr(
"neuropose.reset.terminate_processes",
lambda procs, **_: TerminationReport(stopped=list(procs)),
)
(settings.input_dir / "job_a").mkdir()
report = reset_pipeline(settings, dry_run=False)
assert report.wipe_skipped_due_to_survivors is False
assert any(p.name == "job_a" for p in report.wipe.removed_paths)
assert not (settings.input_dir / "job_a").exists()
# ---------------------------------------------------------------------------
# CLI: neuropose reset
# ---------------------------------------------------------------------------
@pytest.fixture
def runner() -> CliRunner:
return CliRunner()
@pytest.fixture
def env_data_dir(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> Path:
"""Point NEUROPOSE_DATA_DIR at an isolated temp dir for CLI tests."""
data_dir = tmp_path / "jobs"
data_dir.mkdir()
(data_dir / "in").mkdir()
(data_dir / "out").mkdir()
(data_dir / "failed").mkdir()
monkeypatch.setenv("NEUROPOSE_DATA_DIR", str(data_dir))
return data_dir
class TestResetCli:
def test_reset_dry_run_does_not_modify(
self,
runner: CliRunner,
env_data_dir: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
(env_data_dir / "in" / "job_a").mkdir()
monkeypatch.setattr("neuropose.reset.find_neuropose_processes", list)
result = runner.invoke(app, ["reset", "--dry-run"])
assert result.exit_code == EXIT_OK, result.output
assert "would remove" in result.output
assert "(dry-run; no changes made)" in result.output
assert (env_data_dir / "in" / "job_a").exists()
def test_reset_yes_skips_confirmation_and_wipes(
self,
runner: CliRunner,
env_data_dir: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
(env_data_dir / "in" / "job_a").mkdir()
(env_data_dir / "in" / "job_a" / "video.mp4").write_bytes(b"x" * 100)
monkeypatch.setattr("neuropose.reset.find_neuropose_processes", list)
result = runner.invoke(app, ["reset", "--yes"])
assert result.exit_code == EXIT_OK, result.output
assert "removed" in result.output
assert not (env_data_dir / "in" / "job_a").exists()
def test_reset_aborts_on_no_confirmation(
self,
runner: CliRunner,
env_data_dir: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
(env_data_dir / "in" / "job_a").mkdir()
monkeypatch.setattr("neuropose.reset.find_neuropose_processes", list)
# typer.confirm reads from stdin; "n\n" declines.
result = runner.invoke(app, ["reset"], input="n\n")
assert result.exit_code == EXIT_USAGE, result.output
assert "aborted" in result.output
assert (env_data_dir / "in" / "job_a").exists()
def test_reset_clean_state_is_noop(
self,
runner: CliRunner,
env_data_dir: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
del env_data_dir
monkeypatch.setattr("neuropose.reset.find_neuropose_processes", list)
result = runner.invoke(app, ["reset"])
assert result.exit_code == EXIT_OK, result.output
assert "nothing to do" in result.output
def test_reset_reports_survivors_with_nonzero_exit(
self,
runner: CliRunner,
env_data_dir: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
del env_data_dir
rp = RunningProcess(pid=42, role="daemon", cmdline="neuropose watch")
monkeypatch.setattr("neuropose.reset.find_neuropose_processes", lambda: [rp])
monkeypatch.setattr(
"neuropose.reset.terminate_processes",
lambda procs, **_: TerminationReport(survivors=list(procs)),
)
result = runner.invoke(app, ["reset", "--yes"])
assert result.exit_code == EXIT_USAGE, result.output
assert "did not exit" in result.output
assert "pid 42" in result.output
assert "--force-kill" in result.output
def test_default_grace_seconds_constant_is_reasonable() -> None:
"""Lock the default grace period so a refactor cannot silently lower it."""
assert 5.0 <= DEFAULT_GRACE_SECONDS <= 60.0
def test_wipe_report_default_construction() -> None:
r = WipeReport()
assert r.removed_paths == []
assert r.bytes_freed == 0