Compare commits
No commits in common. "4f3a6241fbb77bb0d4b8ae0175c6c06d15a9b598" and "01b9ed9475652de70eeb81a0210bc1caf05905a9" have entirely different histories.
4f3a6241fb
...
01b9ed9475
197
CHANGELOG.md
197
CHANGELOG.md
|
|
@ -202,170 +202,6 @@ be split into per-release sections once tagging begins.
|
||||||
(with a `.collisions` list of offending names). The running
|
(with a `.collisions` list of offending names). The running
|
||||||
daemon needs no changes — ingested job dirs are picked up on the
|
daemon needs no changes — ingested job dirs are picked up on the
|
||||||
next poll.
|
next poll.
|
||||||
- **`neuropose.migrations`** — schema-migration infrastructure for
|
|
||||||
the three top-level serialised payloads (`VideoPredictions`,
|
|
||||||
`JobResults`, `BenchmarkResult`). Every payload carries a
|
|
||||||
`schema_version` field defaulting to `CURRENT_VERSION`; on load,
|
|
||||||
the raw JSON dict is passed through `migrate_video_predictions` /
|
|
||||||
`migrate_job_results` / `migrate_benchmark_result` *before*
|
|
||||||
pydantic validation so files written by older NeuroPose versions
|
|
||||||
upgrade transparently. One shared `CURRENT_VERSION` counter;
|
|
||||||
per-schema migration registries populated via
|
|
||||||
`register_video_predictions_migration(from_version)` and
|
|
||||||
`register_benchmark_result_migration(from_version)` decorators.
|
|
||||||
`JobResults` is a `RootModel` with no envelope of its own, so its
|
|
||||||
migration runs per-entry across the root mapping. The driver raises
|
|
||||||
`FutureSchemaError` for payloads newer than the current build
|
|
||||||
(clear upgrade-NeuroPose message), `MigrationNotFoundError` for
|
|
||||||
missing chain links (indicates a `CURRENT_VERSION` bump that forgot
|
|
||||||
its migration), and logs at INFO on each version advance. Currently
|
|
||||||
at `CURRENT_VERSION = 2`, with registered v1 → v2 migrations for
|
|
||||||
`VideoPredictions` and `BenchmarkResult` that add the optional
|
|
||||||
`provenance` field.
|
|
||||||
- **`neuropose.analyzer.features.procrustes_align`** — Kabsch
|
|
||||||
rigid-alignment helper for pose sequences, plus a
|
|
||||||
`ProcrustesMode` literal (`"per_frame"` | `"per_sequence"`) and a
|
|
||||||
frozen `AlignmentDiagnostics` dataclass (`rotation_deg`,
|
|
||||||
`rotation_deg_max`, `translation`, `translation_max`, `scale`,
|
|
||||||
plus the mode that produced them). Per-sequence mode fits one
|
|
||||||
rigid transform across the whole trial; per-frame fits an
|
|
||||||
independent transform per frame. Optional `scale=True` fits a
|
|
||||||
uniform scale factor for cross-subject comparisons. Wired into
|
|
||||||
every DTW entry point in `neuropose.analyzer.dtw` via a new
|
|
||||||
keyword-only `align: AlignMode = "none"` parameter — `"none"`
|
|
||||||
preserves the 0.1 raw-coordinate behaviour, while
|
|
||||||
`"procrustes_per_frame"` and `"procrustes_per_sequence"` route
|
|
||||||
inputs through `procrustes_align` before DTW runs so the returned
|
|
||||||
distance is rotation- and translation-invariant. Paper C's
|
|
||||||
pipeline is expected to set `align="procrustes_per_sequence"`;
|
|
||||||
see `TECHNICAL.md` Phase 0.
|
|
||||||
- **`neuropose.analyzer.dtw.Representation`** and
|
|
||||||
**`neuropose.analyzer.dtw.NanPolicy`** — two new Literal types
|
|
||||||
exposing orthogonal DTW preprocessing knobs on every entry point.
|
|
||||||
`representation` (on `dtw_all` and `dtw_per_joint`) switches the
|
|
||||||
per-frame feature vector between `"coords"` (the 0.1 default) and
|
|
||||||
`"angles"`, which runs `extract_joint_angles` on the supplied
|
|
||||||
`angle_triplets` first — yielding distances that are translation-,
|
|
||||||
rotation-, and scale-invariant by construction, and directly
|
|
||||||
interpretable in clinical terms. `nan_policy` (on all three entry
|
|
||||||
points) selects `"propagate"` (surface fastdtw's ValueError on
|
|
||||||
NaN — the default), `"interpolate"` (linear fill per feature
|
|
||||||
column), or `"drop"` (remove NaN frames before DTW); the
|
|
||||||
policy is applied consistently whether NaN originated from the
|
|
||||||
angles pipeline or from corrupted upstream coordinates.
|
|
||||||
`dtw_relation` stays a standalone convenience entry point for
|
|
||||||
two-joint displacement DTW; users who prefer a unified API can
|
|
||||||
express the same computation via `dtw_all` with an appropriate
|
|
||||||
pair of angle triplets or run `dtw_relation` directly.
|
|
||||||
- **`neuropose.analyzer.pipeline`** (schemas) — declarative
|
|
||||||
analysis-pipeline configuration and output artifact, parseable from
|
|
||||||
YAML or JSON via pydantic. `AnalysisConfig` captures a full
|
|
||||||
experiment: inputs (primary + optional reference predictions
|
|
||||||
files), preprocessing (person index, with room to grow),
|
|
||||||
optional segmentation (`gait_cycles` / `gait_cycles_bilateral` /
|
|
||||||
`extractor` discriminated union), and a required analysis stage
|
|
||||||
(`dtw` / `stats` / `none` discriminated union). `AnalysisReport`
|
|
||||||
is the runtime output: carries the originating config, a
|
|
||||||
`Provenance` envelope with `analysis_config` populated, per-input
|
|
||||||
summaries, produced segmentations, and an analysis-result payload
|
|
||||||
that mirrors the stage choice (`DtwResults`, `StatsResults`, or
|
|
||||||
`NoResults`). Cross-field invariants — `method="dtw_relation"`
|
|
||||||
requires `joint_i`/`joint_j`, `representation="angles"` requires
|
|
||||||
non-empty `angle_triplets`, `analysis.kind="dtw"` requires
|
|
||||||
`inputs.reference`, `analysis.kind="stats"` refuses a reference —
|
|
||||||
are enforced at parse time via `model_validator` so typos fail in
|
|
||||||
milliseconds instead of after a multi-minute predictions load.
|
|
||||||
`AnalysisReport` carries a `schema_version` field defaulting to
|
|
||||||
`CURRENT_VERSION = 2`, with a new
|
|
||||||
`register_analysis_report_migration` decorator and
|
|
||||||
`migrate_analysis_report` driver in `neuropose.migrations` ready
|
|
||||||
for future schema changes. `run_analysis(config)` loads the named
|
|
||||||
predictions files, applies the configured segmentation, dispatches
|
|
||||||
to the selected analysis kind (DTW, stats, or none), and emits a
|
|
||||||
fully populated `AnalysisReport` whose `Provenance` inherits the
|
|
||||||
inference-time envelope from the primary input with
|
|
||||||
`analysis_config` stamped in, so the report is self-describing
|
|
||||||
even if the source YAML is lost. For DTW runs with segmentation,
|
|
||||||
segments are paired one-to-one by index across primary and
|
|
||||||
reference, truncating to `min(len_primary, len_reference)`;
|
|
||||||
bilateral segmentations emit per-side distances under
|
|
||||||
`"left_heel_strikes[i]"` / `"right_heel_strikes[i]"` labels.
|
|
||||||
`load_config(path)` parses YAML, `save_report(path, report)`
|
|
||||||
writes atomically, and `load_report(path)` rehydrates via the
|
|
||||||
migration chain. Wired to the CLI as `neuropose analyze --config
|
|
||||||
<yaml> [--output <json>]` — replaces the placeholder stub that
|
|
||||||
previously returned `EXIT_PENDING`. The CLI surfaces schema
|
|
||||||
violations and YAML parse errors as `EXIT_USAGE=2` with a clear
|
|
||||||
message pointing at the offending file, prints a one-line summary
|
|
||||||
of the run (segmentation counts, analysis kind, per-segment
|
|
||||||
distance count + mean for DTW), and supports `--output`/`-o` to
|
|
||||||
override the report path declared in the config (useful for
|
|
||||||
sweeping a single config over multiple input pairs from a shell
|
|
||||||
loop). Ships three example configs under `examples/analysis/`:
|
|
||||||
`minimal.yaml` (smallest working DTW pipeline), `paper_c_headline.yaml`
|
|
||||||
(representative Paper C config with bilateral gait-cycle
|
|
||||||
segmentation, per-sequence Procrustes, and joint-angle DTW on
|
|
||||||
knee/hip triplets), and `per_joint_debug.yaml` (per-joint DTW
|
|
||||||
breakdown for diagnosing which joint drives an unexpected
|
|
||||||
distance). An integration suite exercises each example against
|
|
||||||
synthetic predictions so schema drift between the YAMLs and the
|
|
||||||
executor fails CI, not silently at run time. Documented in
|
|
||||||
`docs/api/pipeline.md`.
|
|
||||||
- **`neuropose.analyzer.segment.segment_gait_cycles`** and
|
|
||||||
**`segment_gait_cycles_bilateral`** — clinical convenience
|
|
||||||
wrappers over `segment_predictions` that pre-fill a `joint_axis`
|
|
||||||
extractor with gait-appropriate defaults (`joint="rhee"`,
|
|
||||||
`axis="y"`, `min_cycle_seconds=0.4`). The single-side entry point
|
|
||||||
accepts any berkeley_mhad_43 joint name and any spatial axis as a
|
|
||||||
string literal `"x" | "y" | "z"`, plus an `invert` flag for
|
|
||||||
recordings whose vertical axis runs opposite to MeTRAbs's
|
|
||||||
Y-down world-coordinate convention. The bilateral wrapper runs
|
|
||||||
the detection on both `lhee` and `rhee` and returns the two
|
|
||||||
results under `"left_heel_strikes"` / `"right_heel_strikes"`
|
|
||||||
keys — shape-compatible with `VideoPredictions.segmentations` so
|
|
||||||
the dict can be merged in directly. Degrades gracefully on
|
|
||||||
pathological gaits (shuffling, walker-assisted) by returning an
|
|
||||||
empty segments list rather than raising. Closes the gait-cycle
|
|
||||||
segmentation item in `TECHNICAL.md` Phase 0.
|
|
||||||
- **`neuropose.io.Provenance`** — reproducibility envelope for every
|
|
||||||
inference run. Populated automatically by `Estimator.process_video`
|
|
||||||
when the model was loaded via `load_model` (the production path)
|
|
||||||
and attached to the output `VideoPredictions`; propagates from
|
|
||||||
there into `JobResults` (per-video) and `BenchmarkResult` (via the
|
|
||||||
benchmark loop). Captures the MeTRAbs artifact SHA-256 and
|
|
||||||
filename, `tensorflow` / `tensorflow-metal` / `numpy` /
|
|
||||||
`neuropose` / Python versions, and reserved slots for a `seed`,
|
|
||||||
`deterministic` flag (Track 2), and `analysis_config` (Phase 0
|
|
||||||
YAML pipeline). `None` on the injected-model test path where
|
|
||||||
NeuroPose has no way to fingerprint the supplied artifact. Frozen
|
|
||||||
pydantic model with `extra="forbid"` and
|
|
||||||
`protected_namespaces=()` so the `model_*` field names do not
|
|
||||||
collide with pydantic v2's internal namespace. `_model.load_metrabs_model`
|
|
||||||
now returns a `LoadedModel` dataclass bundling the TF handle with
|
|
||||||
the pinned SHA and filename so the estimator can build the
|
|
||||||
`Provenance` without re-hashing the tarball.
|
|
||||||
- **`neuropose.reset`** — pipeline-wide reset utility for the
|
|
||||||
benchmark / iteration loop. `find_neuropose_processes()` scans the
|
|
||||||
OS process table (via `psutil`) for running `neuropose watch` and
|
|
||||||
`neuropose serve` instances and classifies each as `daemon` or
|
|
||||||
`monitor`. `terminate_processes()` SIGINTs them, polls for graceful
|
|
||||||
exit up to a configurable grace period, and optionally escalates
|
|
||||||
to SIGKILL with `force_kill=True`. `wipe_state()` removes the
|
|
||||||
contents of `$data_dir/in/`, `$data_dir/out/` (including
|
|
||||||
`status.json`), `$data_dir/failed/` (unless `keep_failed=True`),
|
|
||||||
the `.neuropose.lock` file, and any leftover `.ingest_<uuid>/`
|
|
||||||
staging dirs from interrupted ingests; container directories
|
|
||||||
themselves are preserved so the daemon does not need to recreate
|
|
||||||
them on next startup. `reset_pipeline()` composes the three with
|
|
||||||
one safety guard: if any process survives termination, the wipe
|
|
||||||
phase is skipped and the returned `ResetReport` flags
|
|
||||||
`wipe_skipped_due_to_survivors`, because removing `$data_dir`
|
|
||||||
out from under an active daemon would corrupt its in-flight
|
|
||||||
writes. Surfaced as `neuropose reset` in the CLI with
|
|
||||||
`--yes/-y`, `--keep-failed`, `--force-kill`, `--grace-seconds`,
|
|
||||||
and `--dry-run/-n` flags; the command always prints a preview
|
|
||||||
before prompting (skipped under `--yes`) and returns
|
|
||||||
`EXIT_USAGE=2` when survivors block the wipe.
|
|
||||||
- **`neuropose.benchmark`** — multi-pass inference benchmarking for
|
- **`neuropose.benchmark`** — multi-pass inference benchmarking for
|
||||||
a single video. `run_benchmark()` runs `process_video` N times
|
a single video. `run_benchmark()` runs `process_video` N times
|
||||||
(default 5), always discards the first pass as warmup (graph
|
(default 5), always discards the first pass as warmup (graph
|
||||||
|
|
@ -385,18 +221,14 @@ be split into per-release sections once tagging begins.
|
||||||
imports for the heavy dependencies:
|
imports for the heavy dependencies:
|
||||||
- `analyzer.dtw` — three DTW entry points (`dtw_all`,
|
- `analyzer.dtw` — three DTW entry points (`dtw_all`,
|
||||||
`dtw_per_joint`, `dtw_relation`) over fastdtw, with a frozen
|
`dtw_per_joint`, `dtw_relation`) over fastdtw, with a frozen
|
||||||
`DTWResult` dataclass and three orthogonal preprocessing knobs
|
`DTWResult` dataclass. See `RESEARCH.md` for the ongoing
|
||||||
(`align`, `representation`, `nan_policy`). See `RESEARCH.md`
|
|
||||||
for the ongoing
|
|
||||||
methodology investigation.
|
methodology investigation.
|
||||||
- `analyzer.features` — `predictions_to_numpy`,
|
- `analyzer.features` — `predictions_to_numpy`,
|
||||||
`normalize_pose_sequence` (uniform and axis-wise),
|
`normalize_pose_sequence` (uniform and axis-wise),
|
||||||
`pad_sequences` (edge-padding), `procrustes_align` (Kabsch
|
`pad_sequences` (edge-padding), `extract_joint_angles` (NaN on
|
||||||
rigid alignment, per-frame or per-sequence, optional uniform
|
degenerate vectors), `extract_feature_statistics`
|
||||||
scaling), `extract_joint_angles` (NaN on degenerate vectors),
|
(`FeatureStatistics` frozen dataclass), and a `find_peaks` thin
|
||||||
`extract_feature_statistics` (`FeatureStatistics` frozen
|
wrapper around `scipy.signal.find_peaks`.
|
||||||
dataclass), and a `find_peaks` thin wrapper around
|
|
||||||
`scipy.signal.find_peaks`.
|
|
||||||
- `analyzer.segment` — repetition segmentation for trials in
|
- `analyzer.segment` — repetition segmentation for trials in
|
||||||
which a subject performs the same movement several times. A
|
which a subject performs the same movement several times. A
|
||||||
three-layer API: `segment_by_peaks` (pure 1D
|
three-layer API: `segment_by_peaks` (pure 1D
|
||||||
|
|
@ -406,11 +238,7 @@ be split into per-release sections once tagging begins.
|
||||||
time-based parameters to frame counts via `metadata.fps`), and
|
time-based parameters to frame counts via `metadata.fps`), and
|
||||||
`slice_predictions` (split a `VideoPredictions` into one per
|
`slice_predictions` (split a `VideoPredictions` into one per
|
||||||
detected repetition with re-keyed frame names and a rewritten
|
detected repetition with re-keyed frame names and a rewritten
|
||||||
`frame_count`). Gait-specific convenience wrappers
|
`frame_count`). Ships four extractor factories —
|
||||||
`segment_gait_cycles` (single heel) and
|
|
||||||
`segment_gait_cycles_bilateral` (both heels, returning a dict
|
|
||||||
keyed by `"left_heel_strikes"` / `"right_heel_strikes"`) sit
|
|
||||||
above `segment_predictions` with clinical defaults. Ships four extractor factories —
|
|
||||||
`joint_axis`, `joint_pair_distance`, `joint_speed`, and
|
`joint_axis`, `joint_pair_distance`, `joint_speed`, and
|
||||||
`joint_angle` — plus a `JOINT_NAMES` constant for the
|
`joint_angle` — plus a `JOINT_NAMES` constant for the
|
||||||
berkeley_mhad_43 skeleton with a `joint_index(name)` lookup,
|
berkeley_mhad_43 skeleton with a `joint_index(name)` lookup,
|
||||||
|
|
@ -420,7 +248,7 @@ be split into per-release sections once tagging begins.
|
||||||
`slow`) loads the real model and asserts the constant still
|
`slow`) loads the real model and asserts the constant still
|
||||||
matches, so any upstream skeleton drift fails CI.
|
matches, so any upstream skeleton drift fails CI.
|
||||||
- **`neuropose.cli`** — Typer-based command-line interface with
|
- **`neuropose.cli`** — Typer-based command-line interface with
|
||||||
eight subcommands: `watch` (run the daemon), `process <video>`
|
seven subcommands: `watch` (run the daemon), `process <video>`
|
||||||
(run the estimator on a single video), `ingest <archive>` (unzip
|
(run the estimator on a single video), `ingest <archive>` (unzip
|
||||||
a video archive into per-video job directories under
|
a video archive into per-video job directories under
|
||||||
`$data_dir/in/` with validation-before-write and atomic
|
`$data_dir/in/` with validation-before-write and atomic
|
||||||
|
|
@ -431,13 +259,6 @@ be split into per-release sections once tagging begins.
|
||||||
KeyboardInterrupt exits with the standard shell-interruption
|
KeyboardInterrupt exits with the standard shell-interruption
|
||||||
code and an `OSError` at bind time is translated to a clean
|
code and an `OSError` at bind time is translated to a clean
|
||||||
usage error with the bind target in the message),
|
usage error with the bind target in the message),
|
||||||
`reset` (stop the daemon and monitor, then wipe pipeline state
|
|
||||||
for a clean restart — wraps `neuropose.reset` with a confirmation
|
|
||||||
prompt, `--dry-run` preview, `--keep-failed` to preserve the
|
|
||||||
forensic quarantine, `--force-kill` to escalate to SIGKILL after
|
|
||||||
the SIGINT grace period, and `--grace-seconds` to tune the wait;
|
|
||||||
refuses to wipe state while any process survives termination so
|
|
||||||
active writes cannot be corrupted),
|
|
||||||
`segment <results>` (post-hoc repetition segmentation — loads a
|
`segment <results>` (post-hoc repetition segmentation — loads a
|
||||||
JobResults or a single VideoPredictions, runs
|
JobResults or a single VideoPredictions, runs
|
||||||
`neuropose.analyzer.segment.segment_predictions` with the chosen
|
`neuropose.analyzer.segment.segment_predictions` with the chosen
|
||||||
|
|
@ -452,9 +273,7 @@ be split into per-release sections once tagging begins.
|
||||||
the resulting `poses3d` arrays, and reports throughput speedup
|
the resulting `poses3d` arrays, and reports throughput speedup
|
||||||
and max divergence in mm — the missing Apple Silicon numerical
|
and max divergence in mm — the missing Apple Silicon numerical
|
||||||
verification answer from `RESEARCH.md`), and
|
verification answer from `RESEARCH.md`), and
|
||||||
`analyze --config <yaml>` (run the declarative analysis
|
`analyze <results>` (stub). The `segment` subcommand accepts
|
||||||
pipeline — see the dedicated entry above for scope). The
|
|
||||||
`segment` subcommand accepts
|
|
||||||
joint specifiers as either berkeley_mhad_43 names (`lwri`,
|
joint specifiers as either berkeley_mhad_43 names (`lwri`,
|
||||||
`rwri`, …) or integer indices, and refuses to overwrite an
|
`rwri`, …) or integer indices, and refuses to overwrite an
|
||||||
existing segmentation of the same name without `--force`.
|
existing segmentation of the same name without `--force`.
|
||||||
|
|
|
||||||
1191
TECHNICAL.md
1191
TECHNICAL.md
File diff suppressed because it is too large
Load Diff
|
|
@ -1,3 +0,0 @@
|
||||||
# `neuropose.migrations`
|
|
||||||
|
|
||||||
::: neuropose.migrations
|
|
||||||
|
|
@ -1,3 +0,0 @@
|
||||||
# `neuropose.analyzer.pipeline`
|
|
||||||
|
|
||||||
::: neuropose.analyzer.pipeline
|
|
||||||
|
|
@ -1,3 +0,0 @@
|
||||||
# `neuropose.reset`
|
|
||||||
|
|
||||||
::: neuropose.reset
|
|
||||||
|
|
@ -1,26 +0,0 @@
|
||||||
# Minimal DTW config: full-trial comparison on raw 3D coordinates,
|
|
||||||
# no Procrustes alignment, no segmentation. The simplest working
|
|
||||||
# example; use this as a starting template and add stages as needed.
|
|
||||||
#
|
|
||||||
# Run:
|
|
||||||
# neuropose analyze --config examples/analysis/minimal.yaml \
|
|
||||||
# --output out/minimal_report.json
|
|
||||||
#
|
|
||||||
# (Substitute real paths for `inputs.primary` and `inputs.reference`
|
|
||||||
# before running.)
|
|
||||||
|
|
||||||
config_version: 1
|
|
||||||
|
|
||||||
inputs:
|
|
||||||
primary: data/trial_a.json
|
|
||||||
reference: data/trial_b.json
|
|
||||||
|
|
||||||
analysis:
|
|
||||||
kind: dtw
|
|
||||||
method: dtw_all
|
|
||||||
align: none
|
|
||||||
representation: coords
|
|
||||||
nan_policy: propagate
|
|
||||||
|
|
||||||
output:
|
|
||||||
report: out/minimal_report.json
|
|
||||||
|
|
@ -1,48 +0,0 @@
|
||||||
# Paper C headline config: cycle-segmented joint-angle DTW with
|
|
||||||
# per-sequence Procrustes alignment. This is the representative
|
|
||||||
# Paper C pipeline — bilateral gait cycles drive the segmentation
|
|
||||||
# so distances are reported per-stride per-side, and the angle-space
|
|
||||||
# representation makes the distance clinically interpretable
|
|
||||||
# (knee flexion angle, hip extension angle, etc.).
|
|
||||||
#
|
|
||||||
# Joint-triplet indices below target the berkeley_mhad_43 skeleton:
|
|
||||||
# - (27, 31, 32): left hip → left knee → left ankle (left knee flex)
|
|
||||||
# - (35, 39, 40): right hip → right knee → right ankle (right knee flex)
|
|
||||||
# - (34, 27, 31): back hip ← left hip → left knee (left hip flex)
|
|
||||||
# - (34, 35, 39): back hip ← right hip → right knee (right hip flex)
|
|
||||||
#
|
|
||||||
# See neuropose.analyzer.JOINT_NAMES for the full 43-joint table.
|
|
||||||
#
|
|
||||||
# Run:
|
|
||||||
# neuropose analyze --config examples/analysis/paper_c_headline.yaml \
|
|
||||||
# --output out/paper_c_report.json
|
|
||||||
|
|
||||||
config_version: 1
|
|
||||||
|
|
||||||
inputs:
|
|
||||||
primary: data/subject_trial.json
|
|
||||||
reference: data/mocap_reference.json
|
|
||||||
|
|
||||||
preprocessing:
|
|
||||||
person_index: 0
|
|
||||||
|
|
||||||
segmentation:
|
|
||||||
kind: gait_cycles_bilateral
|
|
||||||
axis: y
|
|
||||||
invert: false
|
|
||||||
min_cycle_seconds: 0.4
|
|
||||||
|
|
||||||
analysis:
|
|
||||||
kind: dtw
|
|
||||||
method: dtw_all
|
|
||||||
align: procrustes_per_sequence
|
|
||||||
representation: angles
|
|
||||||
angle_triplets:
|
|
||||||
- [27, 31, 32] # left knee flexion
|
|
||||||
- [35, 39, 40] # right knee flexion
|
|
||||||
- [34, 27, 31] # left hip flexion
|
|
||||||
- [34, 35, 39] # right hip flexion
|
|
||||||
nan_policy: interpolate
|
|
||||||
|
|
||||||
output:
|
|
||||||
report: out/paper_c_report.json
|
|
||||||
|
|
@ -1,36 +0,0 @@
|
||||||
# Per-joint debug config: runs dtw_per_joint on raw coordinates so
|
|
||||||
# the resulting report carries a full (segments × joints) distance
|
|
||||||
# breakdown. Useful when one joint is suspected of driving an
|
|
||||||
# otherwise-unexpected DTW distance — the per-joint numbers make it
|
|
||||||
# obvious which joint's trajectory diverges most.
|
|
||||||
#
|
|
||||||
# Raw coordinates (representation: coords) are used rather than
|
|
||||||
# angles because joint-level debugging is most interpretable in the
|
|
||||||
# native measurement space. Procrustes alignment is on so
|
|
||||||
# translation and rotation between trials do not inflate the numbers.
|
|
||||||
#
|
|
||||||
# Run:
|
|
||||||
# neuropose analyze --config examples/analysis/per_joint_debug.yaml \
|
|
||||||
# --output out/per_joint_report.json
|
|
||||||
|
|
||||||
config_version: 1
|
|
||||||
|
|
||||||
inputs:
|
|
||||||
primary: data/trial_a.json
|
|
||||||
reference: data/trial_b.json
|
|
||||||
|
|
||||||
segmentation:
|
|
||||||
kind: gait_cycles
|
|
||||||
joint: rhee
|
|
||||||
axis: y
|
|
||||||
min_cycle_seconds: 0.4
|
|
||||||
|
|
||||||
analysis:
|
|
||||||
kind: dtw
|
|
||||||
method: dtw_per_joint
|
|
||||||
align: procrustes_per_sequence
|
|
||||||
representation: coords
|
|
||||||
nan_policy: propagate
|
|
||||||
|
|
||||||
output:
|
|
||||||
report: out/per_joint_report.json
|
|
||||||
|
|
@ -95,12 +95,9 @@ nav:
|
||||||
- neuropose.interfacer: api/interfacer.md
|
- neuropose.interfacer: api/interfacer.md
|
||||||
- neuropose.ingest: api/ingest.md
|
- neuropose.ingest: api/ingest.md
|
||||||
- neuropose.monitor: api/monitor.md
|
- neuropose.monitor: api/monitor.md
|
||||||
- neuropose.reset: api/reset.md
|
|
||||||
- neuropose.io: api/io.md
|
- neuropose.io: api/io.md
|
||||||
- neuropose.migrations: api/migrations.md
|
|
||||||
- neuropose.benchmark: api/benchmark.md
|
- neuropose.benchmark: api/benchmark.md
|
||||||
- neuropose.analyzer.segment: api/segment.md
|
- neuropose.analyzer.segment: api/segment.md
|
||||||
- neuropose.analyzer.pipeline: api/pipeline.md
|
|
||||||
- neuropose.visualize: api/visualize.md
|
- neuropose.visualize: api/visualize.md
|
||||||
- Development: development.md
|
- Development: development.md
|
||||||
- Deployment: deployment.md
|
- Deployment: deployment.md
|
||||||
|
|
|
||||||
|
|
@ -41,34 +41,11 @@ import os
|
||||||
import shutil
|
import shutil
|
||||||
import tarfile
|
import tarfile
|
||||||
import urllib.request
|
import urllib.request
|
||||||
from dataclasses import dataclass
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class LoadedModel:
|
|
||||||
"""Result of :func:`load_metrabs_model`.
|
|
||||||
|
|
||||||
Bundles the loaded TensorFlow model with the provenance metadata
|
|
||||||
that identifies which artifact it came from. Callers that only want
|
|
||||||
the model reach for :attr:`model`; callers that build a
|
|
||||||
:class:`~neuropose.io.Provenance` (primarily
|
|
||||||
:class:`~neuropose.estimator.Estimator`) pull :attr:`sha256` and
|
|
||||||
:attr:`filename` too.
|
|
||||||
|
|
||||||
Frozen — once :func:`load_metrabs_model` has produced a
|
|
||||||
``LoadedModel``, nothing downstream should edit the identity of
|
|
||||||
the artifact it describes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
model: Any
|
|
||||||
sha256: str
|
|
||||||
filename: str
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Model artifact: pinned URL and checksum.
|
# Model artifact: pinned URL and checksum.
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -97,7 +74,7 @@ _REQUIRED_MODEL_ATTRS = (
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def load_metrabs_model(cache_dir: Path | None = None) -> LoadedModel:
|
def load_metrabs_model(cache_dir: Path | None = None) -> Any:
|
||||||
"""Load the MeTRAbs model, downloading and caching on first use.
|
"""Load the MeTRAbs model, downloading and caching on first use.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
|
|
@ -110,11 +87,9 @@ def load_metrabs_model(cache_dir: Path | None = None) -> LoadedModel:
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
LoadedModel
|
object
|
||||||
Bundle containing the TensorFlow SavedModel handle alongside
|
A TensorFlow SavedModel handle exposing ``detect_poses`` and
|
||||||
the pinned artifact SHA-256 and filename that identify which
|
the ``per_skeleton_joint_names`` / ``per_skeleton_joint_edges``
|
||||||
model the handle came from. The handle exposes ``detect_poses``
|
|
||||||
and the ``per_skeleton_joint_names`` / ``per_skeleton_joint_edges``
|
|
||||||
attributes used by :class:`neuropose.estimator.Estimator`.
|
attributes used by :class:`neuropose.estimator.Estimator`.
|
||||||
|
|
||||||
Raises
|
Raises
|
||||||
|
|
@ -124,18 +99,6 @@ def load_metrabs_model(cache_dir: Path | None = None) -> LoadedModel:
|
||||||
automatic retry), extraction fails, TensorFlow is not
|
automatic retry), extraction fails, TensorFlow is not
|
||||||
installed, or the loaded model does not expose the expected
|
installed, or the loaded model does not expose the expected
|
||||||
interface.
|
interface.
|
||||||
|
|
||||||
Notes
|
|
||||||
-----
|
|
||||||
The returned ``sha256`` is the module-pinned :data:`_MODEL_SHA256`,
|
|
||||||
not a re-hash of the on-disk tarball. On the cold-cache path this
|
|
||||||
is exactly the hash we verified against before loading. On the
|
|
||||||
warm-cache path the tarball is not re-verified (that would cost a
|
|
||||||
2 GB I/O pass on every daemon startup), so the reported SHA is an
|
|
||||||
attestation of "this is the pinned artifact NeuroPose loads" rather
|
|
||||||
than a direct fingerprint of the on-disk bytes. For the threat
|
|
||||||
model this supports — reproducibility, not tamper-evidence — that
|
|
||||||
is the correct semantics.
|
|
||||||
"""
|
"""
|
||||||
resolved_cache = Path(cache_dir) if cache_dir is not None else _default_cache_dir()
|
resolved_cache = Path(cache_dir) if cache_dir is not None else _default_cache_dir()
|
||||||
resolved_cache.mkdir(parents=True, exist_ok=True)
|
resolved_cache.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
@ -152,11 +115,7 @@ def load_metrabs_model(cache_dir: Path | None = None) -> LoadedModel:
|
||||||
)
|
)
|
||||||
shutil.rmtree(model_dir, ignore_errors=True)
|
shutil.rmtree(model_dir, ignore_errors=True)
|
||||||
else:
|
else:
|
||||||
return LoadedModel(
|
return _tf_load(saved_model_dir)
|
||||||
model=_tf_load(saved_model_dir),
|
|
||||||
sha256=_MODEL_SHA256,
|
|
||||||
filename=_MODEL_ARCHIVE_NAME,
|
|
||||||
)
|
|
||||||
|
|
||||||
tarball = resolved_cache / _MODEL_ARCHIVE_NAME
|
tarball = resolved_cache / _MODEL_ARCHIVE_NAME
|
||||||
|
|
||||||
|
|
@ -176,11 +135,7 @@ def load_metrabs_model(cache_dir: Path | None = None) -> LoadedModel:
|
||||||
|
|
||||||
_extract_tarball(tarball, model_dir)
|
_extract_tarball(tarball, model_dir)
|
||||||
saved_model_dir = _find_saved_model(model_dir)
|
saved_model_dir = _find_saved_model(model_dir)
|
||||||
return LoadedModel(
|
return _tf_load(saved_model_dir)
|
||||||
model=_tf_load(saved_model_dir),
|
|
||||||
sha256=_MODEL_SHA256,
|
|
||||||
filename=_MODEL_ARCHIVE_NAME,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
|
||||||
|
|
@ -28,56 +28,23 @@ here for ergonomic access.
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from neuropose.analyzer.dtw import (
|
from neuropose.analyzer.dtw import (
|
||||||
AlignMode,
|
|
||||||
DTWResult,
|
DTWResult,
|
||||||
NanPolicy,
|
|
||||||
Representation,
|
|
||||||
dtw_all,
|
dtw_all,
|
||||||
dtw_per_joint,
|
dtw_per_joint,
|
||||||
dtw_relation,
|
dtw_relation,
|
||||||
)
|
)
|
||||||
from neuropose.analyzer.features import (
|
from neuropose.analyzer.features import (
|
||||||
AlignmentDiagnostics,
|
|
||||||
FeatureStatistics,
|
FeatureStatistics,
|
||||||
ProcrustesMode,
|
|
||||||
extract_feature_statistics,
|
extract_feature_statistics,
|
||||||
extract_joint_angles,
|
extract_joint_angles,
|
||||||
find_peaks,
|
find_peaks,
|
||||||
normalize_pose_sequence,
|
normalize_pose_sequence,
|
||||||
pad_sequences,
|
pad_sequences,
|
||||||
predictions_to_numpy,
|
predictions_to_numpy,
|
||||||
procrustes_align,
|
|
||||||
)
|
|
||||||
from neuropose.analyzer.pipeline import (
|
|
||||||
AnalysisConfig,
|
|
||||||
AnalysisReport,
|
|
||||||
AnalysisResults,
|
|
||||||
AnalysisStage,
|
|
||||||
DtwAnalysis,
|
|
||||||
DtwResults,
|
|
||||||
ExtractorSegmentation,
|
|
||||||
FeatureSummary,
|
|
||||||
GaitCyclesBilateralSegmentation,
|
|
||||||
GaitCyclesSegmentation,
|
|
||||||
InputsConfig,
|
|
||||||
InputSummary,
|
|
||||||
NoAnalysis,
|
|
||||||
NoResults,
|
|
||||||
OutputConfig,
|
|
||||||
PreprocessingConfig,
|
|
||||||
SegmentationStage,
|
|
||||||
StatsAnalysis,
|
|
||||||
StatsResults,
|
|
||||||
analysis_config_to_dict,
|
|
||||||
load_config,
|
|
||||||
load_report,
|
|
||||||
run_analysis,
|
|
||||||
save_report,
|
|
||||||
)
|
)
|
||||||
from neuropose.analyzer.segment import (
|
from neuropose.analyzer.segment import (
|
||||||
JOINT_INDEX,
|
JOINT_INDEX,
|
||||||
JOINT_NAMES,
|
JOINT_NAMES,
|
||||||
AxisLetter,
|
|
||||||
extract_signal,
|
extract_signal,
|
||||||
joint_angle,
|
joint_angle,
|
||||||
joint_axis,
|
joint_axis,
|
||||||
|
|
@ -85,8 +52,6 @@ from neuropose.analyzer.segment import (
|
||||||
joint_pair_distance,
|
joint_pair_distance,
|
||||||
joint_speed,
|
joint_speed,
|
||||||
segment_by_peaks,
|
segment_by_peaks,
|
||||||
segment_gait_cycles,
|
|
||||||
segment_gait_cycles_bilateral,
|
|
||||||
segment_predictions,
|
segment_predictions,
|
||||||
slice_predictions,
|
slice_predictions,
|
||||||
)
|
)
|
||||||
|
|
@ -94,34 +59,8 @@ from neuropose.analyzer.segment import (
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"JOINT_INDEX",
|
"JOINT_INDEX",
|
||||||
"JOINT_NAMES",
|
"JOINT_NAMES",
|
||||||
"AlignMode",
|
|
||||||
"AlignmentDiagnostics",
|
|
||||||
"AnalysisConfig",
|
|
||||||
"AnalysisReport",
|
|
||||||
"AnalysisResults",
|
|
||||||
"AnalysisStage",
|
|
||||||
"AxisLetter",
|
|
||||||
"DTWResult",
|
"DTWResult",
|
||||||
"DtwAnalysis",
|
|
||||||
"DtwResults",
|
|
||||||
"ExtractorSegmentation",
|
|
||||||
"FeatureStatistics",
|
"FeatureStatistics",
|
||||||
"FeatureSummary",
|
|
||||||
"GaitCyclesBilateralSegmentation",
|
|
||||||
"GaitCyclesSegmentation",
|
|
||||||
"InputSummary",
|
|
||||||
"InputsConfig",
|
|
||||||
"NanPolicy",
|
|
||||||
"NoAnalysis",
|
|
||||||
"NoResults",
|
|
||||||
"OutputConfig",
|
|
||||||
"PreprocessingConfig",
|
|
||||||
"ProcrustesMode",
|
|
||||||
"Representation",
|
|
||||||
"SegmentationStage",
|
|
||||||
"StatsAnalysis",
|
|
||||||
"StatsResults",
|
|
||||||
"analysis_config_to_dict",
|
|
||||||
"dtw_all",
|
"dtw_all",
|
||||||
"dtw_per_joint",
|
"dtw_per_joint",
|
||||||
"dtw_relation",
|
"dtw_relation",
|
||||||
|
|
@ -134,17 +73,10 @@ __all__ = [
|
||||||
"joint_index",
|
"joint_index",
|
||||||
"joint_pair_distance",
|
"joint_pair_distance",
|
||||||
"joint_speed",
|
"joint_speed",
|
||||||
"load_config",
|
|
||||||
"load_report",
|
|
||||||
"normalize_pose_sequence",
|
"normalize_pose_sequence",
|
||||||
"pad_sequences",
|
"pad_sequences",
|
||||||
"predictions_to_numpy",
|
"predictions_to_numpy",
|
||||||
"procrustes_align",
|
|
||||||
"run_analysis",
|
|
||||||
"save_report",
|
|
||||||
"segment_by_peaks",
|
"segment_by_peaks",
|
||||||
"segment_gait_cycles",
|
|
||||||
"segment_gait_cycles_bilateral",
|
|
||||||
"segment_predictions",
|
"segment_predictions",
|
||||||
"slice_predictions",
|
"slice_predictions",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -2,12 +2,10 @@
|
||||||
|
|
||||||
Three entry points, ordered by increasing precision (and increasing cost):
|
Three entry points, ordered by increasing precision (and increasing cost):
|
||||||
|
|
||||||
- :func:`dtw_all` — DTW on the flattened per-frame feature vector. Fast
|
- :func:`dtw_all` — DTW on the flattened per-frame joint vector. Fast but
|
||||||
but coarse; collapses every joint axis (or every angle triplet) into
|
coarse; collapses every joint axis into a single per-frame vector.
|
||||||
a single per-frame vector.
|
- :func:`dtw_per_joint` — DTW on each joint independently. Preserves
|
||||||
- :func:`dtw_per_joint` — DTW on each joint (or angle triplet)
|
per-joint temporal alignment at the cost of one DTW call per joint.
|
||||||
independently. Preserves per-unit temporal alignment at the cost of
|
|
||||||
one DTW call per unit.
|
|
||||||
- :func:`dtw_relation` — DTW on the displacement vector between two
|
- :func:`dtw_relation` — DTW on the displacement vector between two
|
||||||
specific joints. This is the right tool when the research question is
|
specific joints. This is the right tool when the research question is
|
||||||
about the *relative* motion of a specific pair of joints (e.g. the
|
about the *relative* motion of a specific pair of joints (e.g. the
|
||||||
|
|
@ -18,24 +16,6 @@ and the warping path. Inputs are expected to be ``(frames, joints, 3)``
|
||||||
numpy arrays — the shape :func:`~neuropose.analyzer.features.predictions_to_numpy`
|
numpy arrays — the shape :func:`~neuropose.analyzer.features.predictions_to_numpy`
|
||||||
produces.
|
produces.
|
||||||
|
|
||||||
Three orthogonal preprocessing knobs are available on the entry points:
|
|
||||||
|
|
||||||
- **``align``** routes the inputs through
|
|
||||||
:func:`~neuropose.analyzer.features.procrustes_align` before DTW runs,
|
|
||||||
yielding translation- and rotation-invariant distances.
|
|
||||||
``align="none"`` (the default) preserves the raw-coordinate behaviour
|
|
||||||
shipped in 0.1.
|
|
||||||
- **``representation``** (on :func:`dtw_all` and :func:`dtw_per_joint`)
|
|
||||||
selects what each frame is reduced to before DTW. ``"coords"`` uses
|
|
||||||
the raw joint coordinates; ``"angles"`` replaces them with joint
|
|
||||||
angles computed at caller-supplied triplets via
|
|
||||||
:func:`~neuropose.analyzer.features.extract_joint_angles`, giving
|
|
||||||
DTW distances that are directly interpretable as clinical joint-range
|
|
||||||
comparisons.
|
|
||||||
- **``nan_policy``** decides how the DTW path handles non-finite values
|
|
||||||
in its input — typically a concern only for the angle representation,
|
|
||||||
where degenerate (zero-length) vectors produce NaN. See :data:`NanPolicy`.
|
|
||||||
|
|
||||||
Dependency note
|
Dependency note
|
||||||
---------------
|
---------------
|
||||||
This module requires :mod:`fastdtw` and :mod:`scipy`, which are part of
|
This module requires :mod:`fastdtw` and :mod:`scipy`, which are part of
|
||||||
|
|
@ -48,58 +28,11 @@ called.
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from neuropose.analyzer.features import extract_joint_angles, procrustes_align
|
|
||||||
|
|
||||||
AlignMode = Literal["none", "procrustes_per_frame", "procrustes_per_sequence"]
|
|
||||||
"""Alignment selector for DTW entry points.
|
|
||||||
|
|
||||||
- ``"none"`` — feed raw coordinates directly to DTW.
|
|
||||||
- ``"procrustes_per_frame"`` — per-frame Kabsch alignment before DTW.
|
|
||||||
- ``"procrustes_per_sequence"`` — single sequence-wide Kabsch
|
|
||||||
alignment before DTW.
|
|
||||||
"""
|
|
||||||
|
|
||||||
Representation = Literal["coords", "angles"]
|
|
||||||
"""Per-frame feature representation for :func:`dtw_all` and :func:`dtw_per_joint`.
|
|
||||||
|
|
||||||
- ``"coords"`` — use the raw joint coordinates (the input's last two
|
|
||||||
axes). Preserves the 0.1 behaviour.
|
|
||||||
- ``"angles"`` — replace joints with joint angles at caller-supplied
|
|
||||||
triplets. Translation- and rotation-invariant by construction,
|
|
||||||
scale-invariant modulo the upstream normalization, and directly
|
|
||||||
interpretable in clinical terms ("knee flexion during swing phase").
|
|
||||||
The ``angle_triplets`` keyword becomes mandatory in this mode.
|
|
||||||
"""
|
|
||||||
|
|
||||||
NanPolicy = Literal["propagate", "interpolate", "drop"]
|
|
||||||
"""Per-feature NaN handling for the DTW input.
|
|
||||||
|
|
||||||
NaN typically appears when ``representation="angles"`` encounters a
|
|
||||||
degenerate (zero-length) vector — the angle is undefined and
|
|
||||||
:func:`extract_joint_angles` propagates NaN rather than quietly returning
|
|
||||||
a stand-in value.
|
|
||||||
|
|
||||||
- ``"propagate"`` (default) — pass NaN straight through to the DTW
|
|
||||||
engine. fastdtw validates its input via
|
|
||||||
:func:`numpy.asarray_chkfinite` and raises :class:`ValueError`
|
|
||||||
the moment a NaN appears, which is the safest default because it
|
|
||||||
makes the problem visible instead of quietly corrupting a
|
|
||||||
distance.
|
|
||||||
- ``"interpolate"`` — linearly interpolate NaN frames along each
|
|
||||||
feature column using neighbouring finite values. Reasonable when a
|
|
||||||
small number of frames are corrupted and the surrounding motion is
|
|
||||||
smooth; inappropriate when long stretches are missing.
|
|
||||||
- ``"drop"`` — remove any frame where *any* feature is NaN before DTW
|
|
||||||
runs. Simple, but compresses the time axis, so warping-path indices
|
|
||||||
refer to the *compacted* sequence rather than the original.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class DTWResult:
|
class DTWResult:
|
||||||
|
|
@ -144,44 +77,20 @@ def _require_fastdtw() -> tuple[Callable, Callable]:
|
||||||
return fastdtw, euclidean
|
return fastdtw, euclidean
|
||||||
|
|
||||||
|
|
||||||
def dtw_all(
|
def dtw_all(a: np.ndarray, b: np.ndarray) -> DTWResult:
|
||||||
a: np.ndarray,
|
"""DTW on the flattened per-frame joint vector.
|
||||||
b: np.ndarray,
|
|
||||||
*,
|
|
||||||
align: AlignMode = "none",
|
|
||||||
representation: Representation = "coords",
|
|
||||||
angle_triplets: Sequence[tuple[int, int, int]] | None = None,
|
|
||||||
nan_policy: NanPolicy = "propagate",
|
|
||||||
) -> DTWResult:
|
|
||||||
"""DTW on the flattened per-frame feature vector.
|
|
||||||
|
|
||||||
Under the default ``representation="coords"`` each frame's joints
|
Each frame's joints are collapsed into a single vector before DTW
|
||||||
are collapsed into a single vector before DTW is applied — fast
|
is applied. This is fast — one DTW call regardless of the joint
|
||||||
(one DTW call regardless of joint count) but coarse, since a small
|
count — but loses per-joint temporal structure, so a small
|
||||||
timing mismatch on one joint can dominate the distance metric.
|
timing mismatch on one joint can dominate the distance metric.
|
||||||
Switching to ``representation="angles"`` computes joint angles at
|
|
||||||
the supplied triplets first and flattens those instead.
|
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
a, b
|
a, b
|
||||||
Pose sequences as ``(frames, joints, 3)`` numpy arrays. The two
|
Pose sequences as ``(frames, joints, 3)`` numpy arrays. The two
|
||||||
sequences do not need to have the same number of frames, but
|
sequences do not need to have the same number of frames, but
|
||||||
they must have the same number of joints. When ``align`` is not
|
they must have the same number of joints.
|
||||||
``"none"``, the two sequences must additionally share a frame
|
|
||||||
count (Procrustes requires a 1:1 correspondence).
|
|
||||||
align
|
|
||||||
Procrustes alignment mode applied before DTW. See
|
|
||||||
:data:`AlignMode`.
|
|
||||||
representation
|
|
||||||
Per-frame feature representation. See :data:`Representation`.
|
|
||||||
angle_triplets
|
|
||||||
Required when ``representation="angles"``. Sequence of
|
|
||||||
``(a, b, c)`` joint-index triplets passed through to
|
|
||||||
:func:`~neuropose.analyzer.features.extract_joint_angles`.
|
|
||||||
Ignored otherwise.
|
|
||||||
nan_policy
|
|
||||||
How to handle NaN values in the DTW input. See :data:`NanPolicy`.
|
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
|
|
@ -192,102 +101,48 @@ def dtw_all(
|
||||||
Raises
|
Raises
|
||||||
------
|
------
|
||||||
ValueError
|
ValueError
|
||||||
If ``a`` and ``b`` do not have the same joint count, if
|
If ``a`` and ``b`` do not have the same joint count.
|
||||||
``align`` requires a matching frame count that is not present,
|
|
||||||
if ``representation="angles"`` is requested without
|
|
||||||
``angle_triplets``, or if ``nan_policy="interpolate"``
|
|
||||||
encounters an all-NaN column.
|
|
||||||
"""
|
"""
|
||||||
_validate_same_joint_count(a, b)
|
_validate_same_joint_count(a, b)
|
||||||
a, b = _maybe_align(a, b, align=align)
|
|
||||||
feat_a = _apply_representation(a, representation, angle_triplets=angle_triplets)
|
|
||||||
feat_b = _apply_representation(b, representation, angle_triplets=angle_triplets)
|
|
||||||
feat_a = _apply_nan_policy(feat_a, nan_policy)
|
|
||||||
feat_b = _apply_nan_policy(feat_b, nan_policy)
|
|
||||||
fastdtw, euclidean = _require_fastdtw()
|
fastdtw, euclidean = _require_fastdtw()
|
||||||
distance, path = fastdtw(feat_a, feat_b, dist=euclidean)
|
a_flat = a.reshape(a.shape[0], -1)
|
||||||
|
b_flat = b.reshape(b.shape[0], -1)
|
||||||
|
distance, path = fastdtw(a_flat, b_flat, dist=euclidean)
|
||||||
return DTWResult(distance=float(distance), path=[tuple(p) for p in path])
|
return DTWResult(distance=float(distance), path=[tuple(p) for p in path])
|
||||||
|
|
||||||
|
|
||||||
def dtw_per_joint(
|
def dtw_per_joint(a: np.ndarray, b: np.ndarray) -> list[DTWResult]:
|
||||||
a: np.ndarray,
|
"""DTW on each joint independently.
|
||||||
b: np.ndarray,
|
|
||||||
*,
|
|
||||||
align: AlignMode = "none",
|
|
||||||
representation: Representation = "coords",
|
|
||||||
angle_triplets: Sequence[tuple[int, int, int]] | None = None,
|
|
||||||
nan_policy: NanPolicy = "propagate",
|
|
||||||
) -> list[DTWResult]:
|
|
||||||
"""DTW on each joint (or angle triplet) independently.
|
|
||||||
|
|
||||||
Performs one DTW computation per unit, yielding a list of
|
Performs one DTW computation per joint, yielding a list of
|
||||||
:class:`DTWResult` objects in input order. More precise than
|
:class:`DTWResult` objects in joint-index order. More precise than
|
||||||
:func:`dtw_all` because each unit's temporal alignment is optimised
|
:func:`dtw_all` because each joint's temporal alignment is optimised
|
||||||
separately, at the cost of J times more DTW calls for J units.
|
separately, at the cost of J times more DTW calls for J joints.
|
||||||
|
|
||||||
Under the default ``representation="coords"`` a "unit" is one of
|
|
||||||
the input's joints (xyz treated jointly). Under
|
|
||||||
``representation="angles"`` a "unit" is one scalar angle column
|
|
||||||
computed from one ``angle_triplets`` entry.
|
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
a, b
|
a, b
|
||||||
Pose sequences as ``(frames, joints, 3)`` numpy arrays. The two
|
Pose sequences as ``(frames, joints, 3)`` numpy arrays. The two
|
||||||
sequences do not need to have the same number of frames but
|
sequences do not need to have the same number of frames but
|
||||||
must have the same number of joints. When ``align`` is not
|
must have the same number of joints.
|
||||||
``"none"``, they must additionally share a frame count.
|
|
||||||
align
|
|
||||||
Procrustes alignment mode applied before DTW. See
|
|
||||||
:data:`AlignMode`.
|
|
||||||
representation
|
|
||||||
Per-frame feature representation. See :data:`Representation`.
|
|
||||||
angle_triplets
|
|
||||||
Required when ``representation="angles"``; see
|
|
||||||
:func:`dtw_all` for details.
|
|
||||||
nan_policy
|
|
||||||
How to handle NaN values in the DTW input. See :data:`NanPolicy`.
|
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
list[DTWResult]
|
list[DTWResult]
|
||||||
One DTW result per joint or per angle triplet, in input order.
|
One DTW result per joint, in index order.
|
||||||
|
|
||||||
Raises
|
Raises
|
||||||
------
|
------
|
||||||
ValueError
|
ValueError
|
||||||
Same conditions as :func:`dtw_all`.
|
If ``a`` and ``b`` do not have the same joint count.
|
||||||
"""
|
"""
|
||||||
_validate_same_joint_count(a, b)
|
_validate_same_joint_count(a, b)
|
||||||
a, b = _maybe_align(a, b, align=align)
|
|
||||||
|
|
||||||
if representation == "coords":
|
|
||||||
feat_a = a
|
|
||||||
feat_b = b
|
|
||||||
# (frames, joints, 3) — one DTW per joint over its (frames, 3) slice.
|
|
||||||
num_units = feat_a.shape[1]
|
|
||||||
slicers: list[Callable[[np.ndarray], np.ndarray]] = [
|
|
||||||
(lambda arr, idx=i: arr[:, idx, :]) for i in range(num_units)
|
|
||||||
]
|
|
||||||
else: # "angles"
|
|
||||||
if angle_triplets is None:
|
|
||||||
raise ValueError("representation='angles' requires angle_triplets")
|
|
||||||
feat_a = extract_joint_angles(a, angle_triplets) # (frames, num_triplets)
|
|
||||||
feat_b = extract_joint_angles(b, angle_triplets)
|
|
||||||
num_units = feat_a.shape[1]
|
|
||||||
slicers = [
|
|
||||||
# Scalar columns become 2D for DTW (fastdtw expects a
|
|
||||||
# sequence of vectors, not a sequence of scalars).
|
|
||||||
(lambda arr, idx=i: arr[:, idx : idx + 1])
|
|
||||||
for i in range(num_units)
|
|
||||||
]
|
|
||||||
|
|
||||||
fastdtw, euclidean = _require_fastdtw()
|
fastdtw, euclidean = _require_fastdtw()
|
||||||
results: list[DTWResult] = []
|
results: list[DTWResult] = []
|
||||||
for slicer in slicers:
|
for joint_idx in range(a.shape[1]):
|
||||||
unit_a = _apply_nan_policy(slicer(feat_a), nan_policy)
|
a_joint = a[:, joint_idx, :]
|
||||||
unit_b = _apply_nan_policy(slicer(feat_b), nan_policy)
|
b_joint = b[:, joint_idx, :]
|
||||||
distance, path = fastdtw(unit_a, unit_b, dist=euclidean)
|
distance, path = fastdtw(a_joint, b_joint, dist=euclidean)
|
||||||
results.append(DTWResult(distance=float(distance), path=[tuple(p) for p in path]))
|
results.append(DTWResult(distance=float(distance), path=[tuple(p) for p in path]))
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
@ -297,9 +152,6 @@ def dtw_relation(
|
||||||
b: np.ndarray,
|
b: np.ndarray,
|
||||||
joint_i: int,
|
joint_i: int,
|
||||||
joint_j: int,
|
joint_j: int,
|
||||||
*,
|
|
||||||
align: AlignMode = "none",
|
|
||||||
nan_policy: NanPolicy = "propagate",
|
|
||||||
) -> DTWResult:
|
) -> DTWResult:
|
||||||
"""DTW on the displacement vector between two specific joints.
|
"""DTW on the displacement vector between two specific joints.
|
||||||
|
|
||||||
|
|
@ -318,14 +170,6 @@ def dtw_relation(
|
||||||
Indices of the two joints whose relative position should be
|
Indices of the two joints whose relative position should be
|
||||||
compared. Must be valid indices into ``a`` and ``b``'s joint
|
compared. Must be valid indices into ``a`` and ``b``'s joint
|
||||||
axis.
|
axis.
|
||||||
align
|
|
||||||
Procrustes alignment mode applied to the full sequences
|
|
||||||
before the displacement vectors are extracted. See
|
|
||||||
:data:`AlignMode`. Note that displacement vectors are already
|
|
||||||
translation-invariant; alignment is still useful for cancelling
|
|
||||||
camera rotation between trials.
|
|
||||||
nan_policy
|
|
||||||
How to handle NaN values in the DTW input. See :data:`NanPolicy`.
|
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
|
|
@ -335,9 +179,8 @@ def dtw_relation(
|
||||||
Raises
|
Raises
|
||||||
------
|
------
|
||||||
ValueError
|
ValueError
|
||||||
If the sequences have different joint counts, either joint
|
If the sequences have different joint counts or if either joint
|
||||||
index is out of range, or ``align`` requires a matching frame
|
index is out of range.
|
||||||
count that is not present.
|
|
||||||
"""
|
"""
|
||||||
_validate_same_joint_count(a, b)
|
_validate_same_joint_count(a, b)
|
||||||
num_joints = a.shape[1]
|
num_joints = a.shape[1]
|
||||||
|
|
@ -345,10 +188,9 @@ def dtw_relation(
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"joint indices must be in [0, {num_joints}); got joint_i={joint_i}, joint_j={joint_j}"
|
f"joint indices must be in [0, {num_joints}); got joint_i={joint_i}, joint_j={joint_j}"
|
||||||
)
|
)
|
||||||
a, b = _maybe_align(a, b, align=align)
|
|
||||||
disp_a = _apply_nan_policy(a[:, joint_j, :] - a[:, joint_i, :], nan_policy)
|
|
||||||
disp_b = _apply_nan_policy(b[:, joint_j, :] - b[:, joint_i, :], nan_policy)
|
|
||||||
fastdtw, euclidean = _require_fastdtw()
|
fastdtw, euclidean = _require_fastdtw()
|
||||||
|
disp_a = a[:, joint_j, :] - a[:, joint_i, :]
|
||||||
|
disp_b = b[:, joint_j, :] - b[:, joint_i, :]
|
||||||
distance, path = fastdtw(disp_a, disp_b, dist=euclidean)
|
distance, path = fastdtw(disp_a, disp_b, dist=euclidean)
|
||||||
return DTWResult(distance=float(distance), path=[tuple(p) for p in path])
|
return DTWResult(distance=float(distance), path=[tuple(p) for p in path])
|
||||||
|
|
||||||
|
|
@ -364,95 +206,3 @@ def _validate_same_joint_count(a: np.ndarray, b: np.ndarray) -> None:
|
||||||
f"input arrays disagree on joint count: "
|
f"input arrays disagree on joint count: "
|
||||||
f"a has {a.shape[1]} joints, b has {b.shape[1]} joints"
|
f"a has {a.shape[1]} joints, b has {b.shape[1]} joints"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _maybe_align(
|
|
||||||
a: np.ndarray,
|
|
||||||
b: np.ndarray,
|
|
||||||
*,
|
|
||||||
align: AlignMode,
|
|
||||||
) -> tuple[np.ndarray, np.ndarray]:
|
|
||||||
"""Apply Procrustes alignment if ``align`` requests it.
|
|
||||||
|
|
||||||
Procrustes requires a frame-by-frame correspondence, so this
|
|
||||||
helper rejects calls where the two sequences disagree on frame
|
|
||||||
count and ``align`` is not ``"none"``. Pad upstream with
|
|
||||||
:func:`~neuropose.analyzer.features.pad_sequences` if the lengths
|
|
||||||
differ.
|
|
||||||
"""
|
|
||||||
if align == "none":
|
|
||||||
return a, b
|
|
||||||
if a.shape[0] != b.shape[0]:
|
|
||||||
raise ValueError(
|
|
||||||
f"align={align!r} requires matching frame counts; "
|
|
||||||
f"got a with {a.shape[0]} frames and b with {b.shape[0]} frames"
|
|
||||||
)
|
|
||||||
mode = "per_frame" if align == "procrustes_per_frame" else "per_sequence"
|
|
||||||
aligned_a, _target, _diag = procrustes_align(a, b, mode=mode)
|
|
||||||
return aligned_a, b
|
|
||||||
|
|
||||||
|
|
||||||
def _apply_representation(
|
|
||||||
sequence: np.ndarray,
|
|
||||||
representation: Representation,
|
|
||||||
*,
|
|
||||||
angle_triplets: Sequence[tuple[int, int, int]] | None,
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""Reduce a ``(frames, joints, 3)`` sequence to DTW-ready 2D features.
|
|
||||||
|
|
||||||
``"coords"`` reshapes to ``(frames, joints * 3)``; ``"angles"``
|
|
||||||
runs :func:`extract_joint_angles` to produce
|
|
||||||
``(frames, len(angle_triplets))``.
|
|
||||||
"""
|
|
||||||
if representation == "coords":
|
|
||||||
return sequence.reshape(sequence.shape[0], -1)
|
|
||||||
if representation == "angles":
|
|
||||||
if angle_triplets is None:
|
|
||||||
raise ValueError("representation='angles' requires angle_triplets")
|
|
||||||
return extract_joint_angles(sequence, angle_triplets)
|
|
||||||
raise ValueError(f"unknown representation {representation!r}")
|
|
||||||
|
|
||||||
|
|
||||||
def _apply_nan_policy(features: np.ndarray, policy: NanPolicy) -> np.ndarray:
|
|
||||||
"""Handle NaN values in a ``(frames, features)`` array per ``policy``.
|
|
||||||
|
|
||||||
``"propagate"`` is a no-op. ``"interpolate"`` runs 1D linear
|
|
||||||
interpolation along the frame axis within each feature column,
|
|
||||||
leaving finite data untouched. ``"drop"`` removes any frame where
|
|
||||||
*any* feature is NaN.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
ValueError
|
|
||||||
If ``"interpolate"`` encounters a column that is entirely NaN
|
|
||||||
(no finite anchors to interpolate between), or if ``"drop"``
|
|
||||||
leaves an empty sequence.
|
|
||||||
"""
|
|
||||||
if policy == "propagate":
|
|
||||||
return features
|
|
||||||
if features.ndim == 1:
|
|
||||||
features = features.reshape(-1, 1)
|
|
||||||
if policy == "drop":
|
|
||||||
keep = np.isfinite(features).all(axis=1)
|
|
||||||
dropped = features[keep]
|
|
||||||
if dropped.shape[0] == 0:
|
|
||||||
raise ValueError(
|
|
||||||
"nan_policy='drop' removed every frame; DTW needs a non-empty sequence"
|
|
||||||
)
|
|
||||||
return dropped
|
|
||||||
if policy == "interpolate":
|
|
||||||
out = features.astype(float, copy=True)
|
|
||||||
num_frames = out.shape[0]
|
|
||||||
indices = np.arange(num_frames, dtype=float)
|
|
||||||
for col in range(out.shape[1]):
|
|
||||||
column = out[:, col]
|
|
||||||
finite = np.isfinite(column)
|
|
||||||
if finite.all():
|
|
||||||
continue
|
|
||||||
if not finite.any():
|
|
||||||
raise ValueError(
|
|
||||||
f"nan_policy='interpolate' cannot fill column {col}: all values are NaN"
|
|
||||||
)
|
|
||||||
out[:, col] = np.interp(indices, indices[finite], column[finite])
|
|
||||||
return out
|
|
||||||
raise ValueError(f"unknown nan_policy {policy!r}")
|
|
||||||
|
|
|
||||||
|
|
@ -14,8 +14,6 @@ The following helpers are provided:
|
||||||
fit in the unit cube (either per-axis or uniform).
|
fit in the unit cube (either per-axis or uniform).
|
||||||
- :func:`pad_sequences` — edge-pad a batch of sequences to a common
|
- :func:`pad_sequences` — edge-pad a batch of sequences to a common
|
||||||
length, suitable for downstream tensor-based analysis.
|
length, suitable for downstream tensor-based analysis.
|
||||||
- :func:`procrustes_align` — rigid-align one pose sequence to another
|
|
||||||
via the Kabsch algorithm, with optional uniform scaling.
|
|
||||||
- :func:`extract_joint_angles` — compute joint angles at specified
|
- :func:`extract_joint_angles` — compute joint angles at specified
|
||||||
triplet positions across a pose sequence.
|
triplet positions across a pose sequence.
|
||||||
- :func:`extract_feature_statistics` — summary statistics
|
- :func:`extract_feature_statistics` — summary statistics
|
||||||
|
|
@ -28,19 +26,12 @@ from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Literal
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from neuropose.io import VideoPredictions
|
from neuropose.io import VideoPredictions
|
||||||
|
|
||||||
ProcrustesMode = Literal["per_frame", "per_sequence"]
|
|
||||||
"""Mode selector for :func:`procrustes_align`.
|
|
||||||
|
|
||||||
``per_sequence`` computes a single rigid transform over the whole
|
|
||||||
sequence; ``per_frame`` aligns every frame independently.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# VideoPredictions → numpy
|
# VideoPredictions → numpy
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -220,247 +211,6 @@ def pad_sequences(
|
||||||
return padded
|
return padded
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Procrustes alignment (Kabsch)
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class AlignmentDiagnostics:
|
|
||||||
"""Summary of the rigid transform fitted by :func:`procrustes_align`.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
mode
|
|
||||||
Which alignment mode produced this result; mirrors the ``mode``
|
|
||||||
argument passed to :func:`procrustes_align`.
|
|
||||||
rotation_deg
|
|
||||||
Magnitude of the fitted rotation, in degrees, computed as
|
|
||||||
``arccos((trace(R) - 1) / 2)``. For ``per_frame`` mode this is
|
|
||||||
the mean magnitude across frames.
|
|
||||||
rotation_deg_max
|
|
||||||
Worst-case (maximum) rotation magnitude across frames.
|
|
||||||
Equal to :attr:`rotation_deg` in ``per_sequence`` mode.
|
|
||||||
translation
|
|
||||||
Magnitude of the fitted translation vector, in the same units
|
|
||||||
as the input (millimetres for MeTRAbs output). For ``per_frame``
|
|
||||||
mode this is the mean magnitude across frames.
|
|
||||||
translation_max
|
|
||||||
Worst-case (maximum) translation magnitude across frames.
|
|
||||||
Equal to :attr:`translation` in ``per_sequence`` mode.
|
|
||||||
scale
|
|
||||||
Applied uniform scale factor. Always ``1.0`` when
|
|
||||||
``procrustes_align`` was called with ``scale=False``. In
|
|
||||||
``per_frame`` mode this is the mean scale across frames.
|
|
||||||
"""
|
|
||||||
|
|
||||||
mode: ProcrustesMode
|
|
||||||
rotation_deg: float
|
|
||||||
rotation_deg_max: float
|
|
||||||
translation: float
|
|
||||||
translation_max: float
|
|
||||||
scale: float
|
|
||||||
|
|
||||||
|
|
||||||
def _kabsch_single(
|
|
||||||
source: np.ndarray,
|
|
||||||
target: np.ndarray,
|
|
||||||
*,
|
|
||||||
scale: bool,
|
|
||||||
) -> tuple[np.ndarray, np.ndarray, float, np.ndarray]:
|
|
||||||
"""Fit the optimal rigid (+ optional uniform scale) transform.
|
|
||||||
|
|
||||||
Aligns ``source`` to ``target`` via the closed-form Kabsch
|
|
||||||
algorithm and returns ``(aligned_source, R, s, t)`` where
|
|
||||||
``aligned_source = s * (source - centroid_source) @ R.T + centroid_target + t_fine``
|
|
||||||
(with ``t_fine`` absorbed for convenience — aligned points match
|
|
||||||
the target's centroid to within floating-point error).
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
source
|
|
||||||
``(N, 3)`` point set to align.
|
|
||||||
target
|
|
||||||
``(N, 3)`` reference point set. Must have the same shape as
|
|
||||||
``source``.
|
|
||||||
scale
|
|
||||||
If ``True``, fit a uniform scale factor; otherwise lock to
|
|
||||||
``1.0``.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
aligned_source
|
|
||||||
``(N, 3)`` aligned copy of ``source``.
|
|
||||||
R
|
|
||||||
``(3, 3)`` rotation matrix.
|
|
||||||
s
|
|
||||||
Scalar scale factor (``1.0`` when ``scale=False``).
|
|
||||||
t
|
|
||||||
``(3,)`` translation vector in world coordinates such that
|
|
||||||
``aligned_source[i] = s * R @ source[i] + t``.
|
|
||||||
"""
|
|
||||||
centroid_source = source.mean(axis=0)
|
|
||||||
centroid_target = target.mean(axis=0)
|
|
||||||
source_centered = source - centroid_source
|
|
||||||
target_centered = target - centroid_target
|
|
||||||
|
|
||||||
covariance = source_centered.T @ target_centered
|
|
||||||
u_mat, sigma, vt_mat = np.linalg.svd(covariance)
|
|
||||||
reflection_sign = float(np.sign(np.linalg.det(vt_mat.T @ u_mat.T)))
|
|
||||||
# Guard against the degenerate det == 0 case (coplanar points).
|
|
||||||
if reflection_sign == 0.0:
|
|
||||||
reflection_sign = 1.0
|
|
||||||
diag = np.diag([1.0, 1.0, reflection_sign])
|
|
||||||
rotation = vt_mat.T @ diag @ u_mat.T
|
|
||||||
|
|
||||||
if scale:
|
|
||||||
source_var = float((source_centered**2).sum())
|
|
||||||
if source_var <= 0.0:
|
|
||||||
scale_factor = 1.0
|
|
||||||
else:
|
|
||||||
scale_factor = float((sigma * np.array([1.0, 1.0, reflection_sign])).sum() / source_var)
|
|
||||||
else:
|
|
||||||
scale_factor = 1.0
|
|
||||||
|
|
||||||
translation = centroid_target - scale_factor * rotation @ centroid_source
|
|
||||||
aligned = scale_factor * source @ rotation.T + translation
|
|
||||||
return aligned, rotation, scale_factor, translation
|
|
||||||
|
|
||||||
|
|
||||||
def _rotation_magnitude_deg(rotation: np.ndarray) -> float:
|
|
||||||
"""Return the rotation angle (degrees) represented by ``rotation``.
|
|
||||||
|
|
||||||
Uses the axis-angle relation ``cos(theta) = (trace(R) - 1) / 2``.
|
|
||||||
"""
|
|
||||||
cos_theta = (float(np.trace(rotation)) - 1.0) / 2.0
|
|
||||||
cos_theta = max(-1.0, min(1.0, cos_theta))
|
|
||||||
return float(np.degrees(np.arccos(cos_theta)))
|
|
||||||
|
|
||||||
|
|
||||||
def procrustes_align(
|
|
||||||
source: np.ndarray,
|
|
||||||
target: np.ndarray,
|
|
||||||
*,
|
|
||||||
mode: ProcrustesMode = "per_sequence",
|
|
||||||
scale: bool = False,
|
|
||||||
) -> tuple[np.ndarray, np.ndarray, AlignmentDiagnostics]:
|
|
||||||
"""Rigid-align ``source`` to ``target`` via the Kabsch algorithm.
|
|
||||||
|
|
||||||
Fits the optimal rigid transform (optionally including uniform
|
|
||||||
scaling) that minimizes the sum of squared distances between
|
|
||||||
corresponding joints. The transform is always applied to
|
|
||||||
``source``; ``target`` is returned unchanged alongside it for
|
|
||||||
symmetry with downstream DTW callers, which typically consume both
|
|
||||||
aligned arrays as a pair.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
source
|
|
||||||
Pose sequence to align, shape ``(frames, joints, 3)``.
|
|
||||||
target
|
|
||||||
Reference pose sequence, shape ``(frames, joints, 3)``. For
|
|
||||||
``per_frame`` mode the frame counts must match; for
|
|
||||||
``per_sequence`` mode they must also match (the correspondence
|
|
||||||
runs frame-by-frame and joint-by-joint). Use
|
|
||||||
:func:`pad_sequences` first if your sequences have different
|
|
||||||
lengths.
|
|
||||||
mode
|
|
||||||
``"per_sequence"`` (default) fits a single rigid transform over
|
|
||||||
the whole sequence — good when the recording geometry is
|
|
||||||
stable across frames. ``"per_frame"`` fits an independent
|
|
||||||
transform per frame — good for matching pose shape while
|
|
||||||
discarding global trajectory.
|
|
||||||
scale
|
|
||||||
If ``True``, also fit a uniform scale factor. Useful for
|
|
||||||
cross-subject comparisons where the reference skeleton has a
|
|
||||||
different overall size.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
aligned_source
|
|
||||||
``source`` transformed to align with ``target``, same shape as
|
|
||||||
the input.
|
|
||||||
target
|
|
||||||
The ``target`` array, unchanged.
|
|
||||||
diagnostics
|
|
||||||
:class:`AlignmentDiagnostics` summarising the fitted transform.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
ValueError
|
|
||||||
If ``source`` and ``target`` have different shapes or the
|
|
||||||
trailing axis is not of size 3.
|
|
||||||
|
|
||||||
Notes
|
|
||||||
-----
|
|
||||||
The Kabsch algorithm (Kabsch 1976, "A solution for the best
|
|
||||||
rotation to relate two sets of vectors") is a closed-form SVD
|
|
||||||
solution and does not iterate. Reflection is explicitly prevented
|
|
||||||
via a sign correction on the smallest singular value; the fitted
|
|
||||||
matrix is always a proper rotation (det = +1).
|
|
||||||
|
|
||||||
In ``per_frame`` mode, rotation, translation, and scale
|
|
||||||
diagnostics are reported as means across frames, with
|
|
||||||
:attr:`AlignmentDiagnostics.rotation_deg_max` and
|
|
||||||
:attr:`AlignmentDiagnostics.translation_max` exposing the worst
|
|
||||||
frame for anomaly detection.
|
|
||||||
"""
|
|
||||||
if source.ndim != 3 or source.shape[-1] != 3:
|
|
||||||
raise ValueError(f"expected (frames, joints, 3); got source shape {source.shape}")
|
|
||||||
if source.shape != target.shape:
|
|
||||||
raise ValueError(
|
|
||||||
f"source and target must have the same shape; got {source.shape} and {target.shape}"
|
|
||||||
)
|
|
||||||
|
|
||||||
source = source.astype(float, copy=False)
|
|
||||||
target = target.astype(float, copy=False)
|
|
||||||
num_frames = source.shape[0]
|
|
||||||
|
|
||||||
if mode == "per_sequence":
|
|
||||||
flat_source = source.reshape(-1, 3)
|
|
||||||
flat_target = target.reshape(-1, 3)
|
|
||||||
aligned_flat, rotation, scale_factor, translation = _kabsch_single(
|
|
||||||
flat_source, flat_target, scale=scale
|
|
||||||
)
|
|
||||||
aligned = aligned_flat.reshape(source.shape)
|
|
||||||
rotation_deg = _rotation_magnitude_deg(rotation)
|
|
||||||
translation_mag = float(np.linalg.norm(translation))
|
|
||||||
diagnostics = AlignmentDiagnostics(
|
|
||||||
mode="per_sequence",
|
|
||||||
rotation_deg=rotation_deg,
|
|
||||||
rotation_deg_max=rotation_deg,
|
|
||||||
translation=translation_mag,
|
|
||||||
translation_max=translation_mag,
|
|
||||||
scale=scale_factor,
|
|
||||||
)
|
|
||||||
return aligned, target, diagnostics
|
|
||||||
|
|
||||||
if mode == "per_frame":
|
|
||||||
aligned = np.empty_like(source)
|
|
||||||
rotation_degs = np.empty(num_frames, dtype=float)
|
|
||||||
translations = np.empty(num_frames, dtype=float)
|
|
||||||
scales = np.empty(num_frames, dtype=float)
|
|
||||||
for frame_idx in range(num_frames):
|
|
||||||
aligned_frame, rotation, scale_factor, translation = _kabsch_single(
|
|
||||||
source[frame_idx], target[frame_idx], scale=scale
|
|
||||||
)
|
|
||||||
aligned[frame_idx] = aligned_frame
|
|
||||||
rotation_degs[frame_idx] = _rotation_magnitude_deg(rotation)
|
|
||||||
translations[frame_idx] = float(np.linalg.norm(translation))
|
|
||||||
scales[frame_idx] = scale_factor
|
|
||||||
diagnostics = AlignmentDiagnostics(
|
|
||||||
mode="per_frame",
|
|
||||||
rotation_deg=float(rotation_degs.mean()) if num_frames else 0.0,
|
|
||||||
rotation_deg_max=float(rotation_degs.max()) if num_frames else 0.0,
|
|
||||||
translation=float(translations.mean()) if num_frames else 0.0,
|
|
||||||
translation_max=float(translations.max()) if num_frames else 0.0,
|
|
||||||
scale=float(scales.mean()) if num_frames else 1.0,
|
|
||||||
)
|
|
||||||
return aligned, target, diagnostics
|
|
||||||
|
|
||||||
raise ValueError(f"unknown mode {mode!r}; expected 'per_frame' or 'per_sequence'")
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Joint angles
|
# Joint angles
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
|
||||||
|
|
@ -1,932 +0,0 @@
|
||||||
"""YAML-configurable analysis pipeline.
|
|
||||||
|
|
||||||
This module unifies the analyzer's individual primitives (Procrustes
|
|
||||||
alignment, gait-cycle segmentation, DTW on coords or angles, feature
|
|
||||||
statistics) behind a single declarative configuration object so an
|
|
||||||
experiment can be reproduced from one file that lives in a git
|
|
||||||
repository, carries through to the :class:`~neuropose.io.Provenance`
|
|
||||||
envelope on the output artifact, and can be cited unambiguously in
|
|
||||||
accompanying papers.
|
|
||||||
|
|
||||||
Two top-level schemas live here:
|
|
||||||
|
|
||||||
- :class:`AnalysisConfig` — what the user writes in YAML. Describes
|
|
||||||
the full pipeline: inputs, preprocessing, optional segmentation,
|
|
||||||
required analysis stage, and output path.
|
|
||||||
- :class:`AnalysisReport` — what :func:`run_analysis` emits. Carries
|
|
||||||
the config, a :class:`~neuropose.io.Provenance` envelope (with the
|
|
||||||
config serialised into :attr:`~neuropose.io.Provenance.analysis_config`),
|
|
||||||
per-input summaries, segmentation results, and the analysis results
|
|
||||||
themselves.
|
|
||||||
|
|
||||||
Both schemas parse from (and serialise to) JSON via pydantic; the
|
|
||||||
config additionally parses from YAML via :func:`load_config`. Cross-field
|
|
||||||
invariants (for example, ``method="dtw_relation"`` requires ``joint_i``
|
|
||||||
and ``joint_j``) are enforced at parse time so typo-laden configs fail
|
|
||||||
fast rather than after an expensive multi-minute load.
|
|
||||||
|
|
||||||
Execution
|
|
||||||
---------
|
|
||||||
:func:`run_analysis` is the top-level executor: it loads the
|
|
||||||
predictions files named in the config, applies any configured
|
|
||||||
segmentation stage, dispatches to the configured analysis stage, and
|
|
||||||
returns a fully populated :class:`AnalysisReport`. The executor is
|
|
||||||
intended to be called from the ``neuropose analyze`` CLI but is
|
|
||||||
equally valid as a Python-level entry point for notebook-driven
|
|
||||||
exploration.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Annotated, Any, Literal
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import yaml
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
|
||||||
|
|
||||||
from neuropose.analyzer.dtw import (
|
|
||||||
AlignMode,
|
|
||||||
DTWResult,
|
|
||||||
NanPolicy,
|
|
||||||
Representation,
|
|
||||||
dtw_all,
|
|
||||||
dtw_per_joint,
|
|
||||||
dtw_relation,
|
|
||||||
)
|
|
||||||
from neuropose.analyzer.features import (
|
|
||||||
extract_feature_statistics,
|
|
||||||
predictions_to_numpy,
|
|
||||||
)
|
|
||||||
from neuropose.analyzer.segment import (
|
|
||||||
AxisLetter,
|
|
||||||
extract_signal,
|
|
||||||
segment_gait_cycles,
|
|
||||||
segment_gait_cycles_bilateral,
|
|
||||||
segment_predictions,
|
|
||||||
)
|
|
||||||
from neuropose.io import (
|
|
||||||
ExtractorSpec,
|
|
||||||
Provenance,
|
|
||||||
Segmentation,
|
|
||||||
VideoPredictions,
|
|
||||||
load_video_predictions,
|
|
||||||
)
|
|
||||||
from neuropose.migrations import CURRENT_VERSION, migrate_analysis_report
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Inputs
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class InputsConfig(BaseModel):
|
|
||||||
"""Predictions files consumed by the pipeline.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
primary
|
|
||||||
Path to a :class:`~neuropose.io.VideoPredictions` JSON file.
|
|
||||||
Always required.
|
|
||||||
reference
|
|
||||||
Optional second predictions file. When provided,
|
|
||||||
:class:`DtwAnalysis` runs comparative DTW between primary and
|
|
||||||
reference; when absent, analysis stages that require a
|
|
||||||
reference (i.e. DTW) raise a validation error at parse time.
|
|
||||||
"""
|
|
||||||
|
|
||||||
model_config = ConfigDict(extra="forbid", frozen=True)
|
|
||||||
|
|
||||||
primary: Path
|
|
||||||
reference: Path | None = None
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Preprocessing
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class PreprocessingConfig(BaseModel):
|
|
||||||
"""Per-input preprocessing applied before segmentation and analysis.
|
|
||||||
|
|
||||||
Minimal today — just picks which detected person to extract from
|
|
||||||
each frame. Left as a named stage so future extensions (coordinate
|
|
||||||
normalisation, smoothing) can land here without reshuffling the
|
|
||||||
config shape.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
person_index
|
|
||||||
Which detected person to extract per frame. Defaults to ``0``
|
|
||||||
(the first detected person), matching the single-subject
|
|
||||||
clinical case.
|
|
||||||
"""
|
|
||||||
|
|
||||||
model_config = ConfigDict(extra="forbid", frozen=True)
|
|
||||||
|
|
||||||
person_index: int = Field(default=0, ge=0)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Segmentation stage
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class GaitCyclesSegmentation(BaseModel):
|
|
||||||
"""Single-heel gait-cycle segmentation via peak detection.
|
|
||||||
|
|
||||||
Produces one :class:`~neuropose.io.Segmentation` keyed under the
|
|
||||||
joint name (e.g. ``"rhee_cycles"``). See
|
|
||||||
:func:`~neuropose.analyzer.segment.segment_gait_cycles` for the
|
|
||||||
underlying implementation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
model_config = ConfigDict(extra="forbid", frozen=True)
|
|
||||||
|
|
||||||
kind: Literal["gait_cycles"]
|
|
||||||
joint: str = "rhee"
|
|
||||||
axis: AxisLetter = "y"
|
|
||||||
invert: bool = False
|
|
||||||
min_cycle_seconds: float = Field(default=0.4, gt=0.0)
|
|
||||||
min_prominence: float | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class GaitCyclesBilateralSegmentation(BaseModel):
|
|
||||||
"""Bilateral (both heels) gait-cycle segmentation.
|
|
||||||
|
|
||||||
Produces two :class:`~neuropose.io.Segmentation` objects keyed as
|
|
||||||
``"left_heel_strikes"`` and ``"right_heel_strikes"``.
|
|
||||||
"""
|
|
||||||
|
|
||||||
model_config = ConfigDict(extra="forbid", frozen=True)
|
|
||||||
|
|
||||||
kind: Literal["gait_cycles_bilateral"]
|
|
||||||
axis: AxisLetter = "y"
|
|
||||||
invert: bool = False
|
|
||||||
min_cycle_seconds: float = Field(default=0.4, gt=0.0)
|
|
||||||
min_prominence: float | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class ExtractorSegmentation(BaseModel):
|
|
||||||
"""Generic extractor-driven segmentation.
|
|
||||||
|
|
||||||
Wraps :func:`~neuropose.analyzer.segment.segment_predictions` with
|
|
||||||
a caller-supplied :class:`~neuropose.io.ExtractorSpec`. Use this
|
|
||||||
when the signal of interest is not the vertical heel trace — e.g.
|
|
||||||
wrist-hip distance for a reach-and-grasp task, or elbow flexion
|
|
||||||
angle for a range-of-motion trial.
|
|
||||||
"""
|
|
||||||
|
|
||||||
model_config = ConfigDict(extra="forbid", frozen=True)
|
|
||||||
|
|
||||||
kind: Literal["extractor"]
|
|
||||||
extractor: ExtractorSpec
|
|
||||||
label: str = Field(
|
|
||||||
default="segmentation",
|
|
||||||
description="Key under which the resulting Segmentation is stored.",
|
|
||||||
)
|
|
||||||
person_index: int | None = Field(
|
|
||||||
default=None,
|
|
||||||
description=(
|
|
||||||
"Overrides preprocessing.person_index for this stage. "
|
|
||||||
"None defers to the global preprocessing setting."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
min_distance_seconds: float | None = Field(default=None, ge=0.0)
|
|
||||||
min_prominence: float | None = None
|
|
||||||
min_height: float | None = None
|
|
||||||
pad_seconds: float = Field(default=0.0, ge=0.0)
|
|
||||||
|
|
||||||
|
|
||||||
SegmentationStage = Annotated[
|
|
||||||
GaitCyclesSegmentation | GaitCyclesBilateralSegmentation | ExtractorSegmentation,
|
|
||||||
Field(discriminator="kind"),
|
|
||||||
]
|
|
||||||
"""Discriminated-union alias for the three segmentation variants.
|
|
||||||
|
|
||||||
Pydantic dispatches on the ``kind`` field. A config without a
|
|
||||||
``segmentation`` key at all skips this stage entirely
|
|
||||||
(see :class:`AnalysisConfig.segmentation`).
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Analysis stage
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class DtwAnalysis(BaseModel):
|
|
||||||
"""Dynamic Time Warping between the primary and reference inputs.
|
|
||||||
|
|
||||||
Dispatches to one of :func:`~neuropose.analyzer.dtw.dtw_all`,
|
|
||||||
:func:`~neuropose.analyzer.dtw.dtw_per_joint`, or
|
|
||||||
:func:`~neuropose.analyzer.dtw.dtw_relation` per the ``method``
|
|
||||||
field. Cross-field invariants — ``method="dtw_relation"`` requires
|
|
||||||
``joint_i`` and ``joint_j``, ``representation="angles"`` requires
|
|
||||||
a non-empty ``angle_triplets`` — are enforced at parse time.
|
|
||||||
"""
|
|
||||||
|
|
||||||
model_config = ConfigDict(extra="forbid", frozen=True)
|
|
||||||
|
|
||||||
kind: Literal["dtw"]
|
|
||||||
method: Literal["dtw_all", "dtw_per_joint", "dtw_relation"] = "dtw_all"
|
|
||||||
align: AlignMode = "none"
|
|
||||||
representation: Representation = "coords"
|
|
||||||
angle_triplets: list[tuple[int, int, int]] | None = None
|
|
||||||
joint_i: int | None = None
|
|
||||||
joint_j: int | None = None
|
|
||||||
nan_policy: NanPolicy = "propagate"
|
|
||||||
|
|
||||||
@model_validator(mode="after")
|
|
||||||
def _check_method_fields(self) -> DtwAnalysis:
|
|
||||||
if self.method == "dtw_relation":
|
|
||||||
if self.joint_i is None or self.joint_j is None:
|
|
||||||
raise ValueError("method='dtw_relation' requires joint_i and joint_j")
|
|
||||||
if self.representation != "coords":
|
|
||||||
raise ValueError(
|
|
||||||
"method='dtw_relation' only supports representation='coords' "
|
|
||||||
"(a two-joint displacement is not a joint-angle signal)"
|
|
||||||
)
|
|
||||||
if self.representation == "angles":
|
|
||||||
if not self.angle_triplets:
|
|
||||||
raise ValueError("representation='angles' requires a non-empty angle_triplets list")
|
|
||||||
if self.method == "dtw_relation":
|
|
||||||
# Guarded by the earlier branch, but make the invariant explicit.
|
|
||||||
raise ValueError(
|
|
||||||
"representation='angles' is incompatible with method='dtw_relation'"
|
|
||||||
)
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
class StatsAnalysis(BaseModel):
|
|
||||||
"""Summary statistics over a scalar signal extracted from the primary input.
|
|
||||||
|
|
||||||
Runs :func:`~neuropose.analyzer.segment.extract_signal` with the
|
|
||||||
caller-supplied :class:`~neuropose.io.ExtractorSpec`, then
|
|
||||||
computes :func:`~neuropose.analyzer.features.extract_feature_statistics`
|
|
||||||
on each segment (or on the full trial if no segmentation stage
|
|
||||||
runs).
|
|
||||||
"""
|
|
||||||
|
|
||||||
model_config = ConfigDict(extra="forbid", frozen=True)
|
|
||||||
|
|
||||||
kind: Literal["stats"]
|
|
||||||
extractor: ExtractorSpec
|
|
||||||
|
|
||||||
|
|
||||||
class NoAnalysis(BaseModel):
|
|
||||||
"""Terminal stage placeholder; produces no per-segment results.
|
|
||||||
|
|
||||||
Useful when the pipeline's goal is just to segment the input and
|
|
||||||
persist the :class:`~neuropose.io.Segmentation` plus an
|
|
||||||
:class:`AnalysisReport` with provenance — the ``none`` analysis
|
|
||||||
kind makes that explicit rather than requiring the absence of the
|
|
||||||
stage.
|
|
||||||
"""
|
|
||||||
|
|
||||||
model_config = ConfigDict(extra="forbid", frozen=True)
|
|
||||||
|
|
||||||
kind: Literal["none"]
|
|
||||||
|
|
||||||
|
|
||||||
AnalysisStage = Annotated[
|
|
||||||
DtwAnalysis | StatsAnalysis | NoAnalysis,
|
|
||||||
Field(discriminator="kind"),
|
|
||||||
]
|
|
||||||
"""Discriminated-union alias for the three analysis variants.
|
|
||||||
|
|
||||||
Pydantic dispatches on ``kind``. One of the three must always be
|
|
||||||
present in a valid config.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Output
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class OutputConfig(BaseModel):
|
|
||||||
"""Where :func:`run_analysis` should write its :class:`AnalysisReport`.
|
|
||||||
|
|
||||||
Kept as a sub-object rather than a bare path so downstream
|
|
||||||
extensions (figure paths, supplementary distance-matrix files)
|
|
||||||
can land here without changing the config's top-level shape.
|
|
||||||
"""
|
|
||||||
|
|
||||||
model_config = ConfigDict(extra="forbid", frozen=True)
|
|
||||||
|
|
||||||
report: Path
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Top-level config
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class AnalysisConfig(BaseModel):
|
|
||||||
"""Declarative description of a full analysis run.
|
|
||||||
|
|
||||||
Parsed from YAML (via :func:`load_config`) or JSON (via
|
|
||||||
:meth:`pydantic.BaseModel.model_validate`). Every field is
|
|
||||||
cross-validated at parse time so a typo in a nested sub-field
|
|
||||||
fails in milliseconds rather than after a multi-minute
|
|
||||||
predictions load.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
config_version
|
|
||||||
Schema version for the config itself. Only ``1`` is valid at
|
|
||||||
this release. Future config-format breaks bump this and a
|
|
||||||
sibling migration registry handles legacy YAML in place.
|
|
||||||
inputs
|
|
||||||
Predictions-file paths.
|
|
||||||
preprocessing
|
|
||||||
Per-input preprocessing (person-index selection today).
|
|
||||||
segmentation
|
|
||||||
Optional segmentation stage. ``None`` skips segmentation
|
|
||||||
entirely and analysis runs over each full trial as a single
|
|
||||||
"segment".
|
|
||||||
analysis
|
|
||||||
Required analysis stage. Exactly one of
|
|
||||||
:class:`DtwAnalysis` / :class:`StatsAnalysis` / :class:`NoAnalysis`.
|
|
||||||
output
|
|
||||||
Output paths.
|
|
||||||
"""
|
|
||||||
|
|
||||||
model_config = ConfigDict(extra="forbid", frozen=True)
|
|
||||||
|
|
||||||
config_version: Literal[1] = 1
|
|
||||||
inputs: InputsConfig
|
|
||||||
preprocessing: PreprocessingConfig = Field(default_factory=PreprocessingConfig)
|
|
||||||
segmentation: SegmentationStage | None = None
|
|
||||||
analysis: AnalysisStage
|
|
||||||
output: OutputConfig
|
|
||||||
|
|
||||||
@model_validator(mode="after")
|
|
||||||
def _check_cross_stage_invariants(self) -> AnalysisConfig:
|
|
||||||
# DTW is comparative — it needs a reference input.
|
|
||||||
if isinstance(self.analysis, DtwAnalysis) and self.inputs.reference is None:
|
|
||||||
raise ValueError("analysis.kind='dtw' requires inputs.reference to be set")
|
|
||||||
# Stats is non-comparative — a reference without a use is
|
|
||||||
# almost certainly an operator error.
|
|
||||||
if isinstance(self.analysis, StatsAnalysis) and self.inputs.reference is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"analysis.kind='stats' operates on inputs.primary only; "
|
|
||||||
"remove inputs.reference or switch analysis.kind to 'dtw'"
|
|
||||||
)
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Report pieces
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class InputSummary(BaseModel):
|
|
||||||
"""Capsule of an input predictions file's headline metadata.
|
|
||||||
|
|
||||||
Stored in the :class:`AnalysisReport` so a reader of the report
|
|
||||||
can tell at a glance what was analysed without having to load the
|
|
||||||
underlying predictions JSONs.
|
|
||||||
"""
|
|
||||||
|
|
||||||
model_config = ConfigDict(extra="forbid", frozen=True)
|
|
||||||
|
|
||||||
path: Path
|
|
||||||
frame_count: int = Field(ge=0)
|
|
||||||
fps: float = Field(ge=0.0)
|
|
||||||
provenance: Provenance | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class FeatureSummary(BaseModel):
|
|
||||||
"""Pydantic twin of :class:`~neuropose.analyzer.features.FeatureStatistics`.
|
|
||||||
|
|
||||||
The dataclass is used throughout the analyzer for ad-hoc Python
|
|
||||||
consumption; the report path needs a pydantic model for
|
|
||||||
round-tripping through JSON.
|
|
||||||
"""
|
|
||||||
|
|
||||||
model_config = ConfigDict(extra="forbid", frozen=True)
|
|
||||||
|
|
||||||
mean: float
|
|
||||||
std: float
|
|
||||||
min: float
|
|
||||||
max: float
|
|
||||||
range: float
|
|
||||||
|
|
||||||
|
|
||||||
class DtwResults(BaseModel):
|
|
||||||
"""DTW results attached to an :class:`AnalysisReport`.
|
|
||||||
|
|
||||||
``distances`` is parallel to ``segment_labels``. For an
|
|
||||||
unsegmented run the lists have length 1 and the label is
|
|
||||||
``"full_trial"``. For a segmented run each label takes the form
|
|
||||||
``"<segmentation_key>[<index>]"`` (e.g. ``"left_heel_strikes[3]"``).
|
|
||||||
``per_joint_distances`` carries a per-unit breakdown for
|
|
||||||
``method="dtw_per_joint"`` only; its outer length matches
|
|
||||||
``distances``, inner length matches either ``num_joints`` (coords)
|
|
||||||
or ``len(angle_triplets)`` (angles).
|
|
||||||
"""
|
|
||||||
|
|
||||||
model_config = ConfigDict(extra="forbid", frozen=True)
|
|
||||||
|
|
||||||
kind: Literal["dtw"]
|
|
||||||
method: Literal["dtw_all", "dtw_per_joint", "dtw_relation"]
|
|
||||||
distances: list[float]
|
|
||||||
paths: list[list[tuple[int, int]]]
|
|
||||||
per_joint_distances: list[list[float]] | None = None
|
|
||||||
segment_labels: list[str]
|
|
||||||
summary: dict[str, float]
|
|
||||||
|
|
||||||
|
|
||||||
class StatsResults(BaseModel):
|
|
||||||
"""Feature-statistics results attached to an :class:`AnalysisReport`.
|
|
||||||
|
|
||||||
``statistics`` is parallel to ``segment_labels``; see
|
|
||||||
:class:`DtwResults` for the labelling convention.
|
|
||||||
"""
|
|
||||||
|
|
||||||
model_config = ConfigDict(extra="forbid", frozen=True)
|
|
||||||
|
|
||||||
kind: Literal["stats"]
|
|
||||||
statistics: list[FeatureSummary]
|
|
||||||
segment_labels: list[str]
|
|
||||||
|
|
||||||
|
|
||||||
class NoResults(BaseModel):
|
|
||||||
"""Empty results payload for ``analysis.kind='none'`` runs."""
|
|
||||||
|
|
||||||
model_config = ConfigDict(extra="forbid", frozen=True)
|
|
||||||
|
|
||||||
kind: Literal["none"]
|
|
||||||
|
|
||||||
|
|
||||||
AnalysisResults = Annotated[
|
|
||||||
DtwResults | StatsResults | NoResults,
|
|
||||||
Field(discriminator="kind"),
|
|
||||||
]
|
|
||||||
"""Discriminated-union alias for the three analysis-result shapes.
|
|
||||||
|
|
||||||
Mirrors :data:`AnalysisStage` one-for-one: ``DtwAnalysis`` produces
|
|
||||||
:class:`DtwResults`, etc.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Top-level report
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class AnalysisReport(BaseModel):
|
|
||||||
"""Self-describing output artifact of :func:`run_analysis`.
|
|
||||||
|
|
||||||
Serialised to JSON on disk. Carries the originating config, the
|
|
||||||
:class:`~neuropose.io.Provenance` envelope (with the config
|
|
||||||
serialised into :attr:`~neuropose.io.Provenance.analysis_config`
|
|
||||||
so the report is self-describing even if the YAML is lost), each
|
|
||||||
input's headline metadata plus its own provenance if available,
|
|
||||||
any segmentations produced, and the analysis results themselves.
|
|
||||||
|
|
||||||
Lives in the schema-migration registry under ``"AnalysisReport"``
|
|
||||||
at ``CURRENT_VERSION``; see :mod:`neuropose.migrations`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
model_config = ConfigDict(extra="forbid", frozen=True)
|
|
||||||
|
|
||||||
schema_version: int = Field(default=CURRENT_VERSION, ge=1)
|
|
||||||
config: AnalysisConfig
|
|
||||||
provenance: Provenance | None = None
|
|
||||||
primary: InputSummary
|
|
||||||
reference: InputSummary | None = None
|
|
||||||
segmentations: dict[str, Segmentation] = Field(default_factory=dict)
|
|
||||||
results: AnalysisResults
|
|
||||||
|
|
||||||
|
|
||||||
def analysis_config_to_dict(config: AnalysisConfig) -> dict[str, Any]:
|
|
||||||
"""Serialise an :class:`AnalysisConfig` to a JSON-safe dict.
|
|
||||||
|
|
||||||
Returned shape is identical to what pydantic would produce via
|
|
||||||
:meth:`~pydantic.BaseModel.model_dump` in ``mode="json"`` — paths
|
|
||||||
become strings, tuples become lists, enums become their values.
|
|
||||||
Useful for stamping
|
|
||||||
:attr:`~neuropose.io.Provenance.analysis_config` on the
|
|
||||||
:class:`AnalysisReport`'s provenance envelope.
|
|
||||||
"""
|
|
||||||
return config.model_dump(mode="json")
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Load / save
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def load_config(path: Path) -> AnalysisConfig:
|
|
||||||
"""Load and validate an :class:`AnalysisConfig` from a YAML file.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
path
|
|
||||||
Filesystem path to a YAML file conforming to the
|
|
||||||
:class:`AnalysisConfig` schema.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
AnalysisConfig
|
|
||||||
The fully validated config. Cross-field invariants have
|
|
||||||
already been checked.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
pydantic.ValidationError
|
|
||||||
On any schema violation — unknown keys, wrong types, or
|
|
||||||
failed cross-field invariants.
|
|
||||||
yaml.YAMLError
|
|
||||||
On malformed YAML.
|
|
||||||
"""
|
|
||||||
with path.open("r", encoding="utf-8") as f:
|
|
||||||
raw = yaml.safe_load(f)
|
|
||||||
if raw is None:
|
|
||||||
raw = {}
|
|
||||||
return AnalysisConfig.model_validate(raw)
|
|
||||||
|
|
||||||
|
|
||||||
def save_report(path: Path, report: AnalysisReport) -> None:
|
|
||||||
"""Serialise an :class:`AnalysisReport` to ``path`` atomically.
|
|
||||||
|
|
||||||
Writes to a sibling ``<path>.tmp`` first, then renames over
|
|
||||||
``path`` so a crash mid-write cannot leave behind a truncated
|
|
||||||
file. The parent directory is created if it does not exist.
|
|
||||||
"""
|
|
||||||
path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
tmp = path.with_suffix(path.suffix + ".tmp")
|
|
||||||
payload = report.model_dump(mode="json")
|
|
||||||
with tmp.open("w", encoding="utf-8") as f:
|
|
||||||
json.dump(payload, f, indent=2)
|
|
||||||
tmp.replace(path)
|
|
||||||
|
|
||||||
|
|
||||||
def load_report(path: Path) -> AnalysisReport:
|
|
||||||
"""Load and validate an :class:`AnalysisReport` JSON file.
|
|
||||||
|
|
||||||
Runs the payload through :func:`~neuropose.migrations.migrate_analysis_report`
|
|
||||||
before pydantic validation so future schema bumps can upgrade
|
|
||||||
legacy reports transparently.
|
|
||||||
"""
|
|
||||||
with path.open("r", encoding="utf-8") as f:
|
|
||||||
data: Any = json.load(f)
|
|
||||||
if isinstance(data, dict):
|
|
||||||
data = migrate_analysis_report(data)
|
|
||||||
return AnalysisReport.model_validate(data)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Executor
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def run_analysis(config: AnalysisConfig) -> AnalysisReport:
|
|
||||||
"""Execute the pipeline described by ``config`` end-to-end.
|
|
||||||
|
|
||||||
Loads the predictions files named in ``config.inputs``, applies
|
|
||||||
the configured preprocessing + segmentation + analysis stages,
|
|
||||||
and returns an :class:`AnalysisReport` whose
|
|
||||||
:attr:`~AnalysisReport.provenance` inherits the inference-time
|
|
||||||
provenance of the primary input with
|
|
||||||
:attr:`~neuropose.io.Provenance.analysis_config` populated so the
|
|
||||||
report is self-describing even if the YAML config is later lost.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
config
|
|
||||||
The pre-validated pipeline configuration.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
AnalysisReport
|
|
||||||
Fully populated report. Not yet written to disk — the caller
|
|
||||||
passes it to :func:`save_report` (or inspects it directly).
|
|
||||||
|
|
||||||
Notes
|
|
||||||
-----
|
|
||||||
For DTW runs with a segmentation stage, segments are paired
|
|
||||||
one-to-one by index across primary and reference, truncating to
|
|
||||||
``min(len_primary, len_reference)``. Bilateral segmentations
|
|
||||||
produce distances for each side independently, labelled under
|
|
||||||
their segmentation key (e.g. ``"left_heel_strikes[3]"``).
|
|
||||||
"""
|
|
||||||
primary_preds = load_video_predictions(config.inputs.primary)
|
|
||||||
reference_preds: VideoPredictions | None = None
|
|
||||||
if config.inputs.reference is not None:
|
|
||||||
reference_preds = load_video_predictions(config.inputs.reference)
|
|
||||||
|
|
||||||
person_index = config.preprocessing.person_index
|
|
||||||
|
|
||||||
primary_seq = predictions_to_numpy(primary_preds, person_index=person_index)
|
|
||||||
reference_seq: np.ndarray | None = None
|
|
||||||
if reference_preds is not None:
|
|
||||||
reference_seq = predictions_to_numpy(reference_preds, person_index=person_index)
|
|
||||||
|
|
||||||
primary_segmentations: dict[str, Segmentation] = {}
|
|
||||||
reference_segmentations: dict[str, Segmentation] = {}
|
|
||||||
if config.segmentation is not None:
|
|
||||||
primary_segmentations = _run_segmentation(primary_preds, config.segmentation, person_index)
|
|
||||||
if reference_preds is not None:
|
|
||||||
reference_segmentations = _run_segmentation(
|
|
||||||
reference_preds, config.segmentation, person_index
|
|
||||||
)
|
|
||||||
|
|
||||||
results = _run_analysis_stage(
|
|
||||||
config.analysis,
|
|
||||||
primary_seq=primary_seq,
|
|
||||||
reference_seq=reference_seq,
|
|
||||||
primary_segmentations=primary_segmentations,
|
|
||||||
reference_segmentations=reference_segmentations,
|
|
||||||
)
|
|
||||||
|
|
||||||
analysis_config_dump = analysis_config_to_dict(config)
|
|
||||||
report_provenance: Provenance | None = None
|
|
||||||
if primary_preds.provenance is not None:
|
|
||||||
report_provenance = primary_preds.provenance.model_copy(
|
|
||||||
update={"analysis_config": analysis_config_dump}
|
|
||||||
)
|
|
||||||
|
|
||||||
primary_summary = InputSummary(
|
|
||||||
path=config.inputs.primary,
|
|
||||||
frame_count=primary_preds.metadata.frame_count,
|
|
||||||
fps=primary_preds.metadata.fps,
|
|
||||||
provenance=primary_preds.provenance,
|
|
||||||
)
|
|
||||||
reference_summary: InputSummary | None = None
|
|
||||||
if reference_preds is not None and config.inputs.reference is not None:
|
|
||||||
reference_summary = InputSummary(
|
|
||||||
path=config.inputs.reference,
|
|
||||||
frame_count=reference_preds.metadata.frame_count,
|
|
||||||
fps=reference_preds.metadata.fps,
|
|
||||||
provenance=reference_preds.provenance,
|
|
||||||
)
|
|
||||||
|
|
||||||
return AnalysisReport(
|
|
||||||
config=config,
|
|
||||||
provenance=report_provenance,
|
|
||||||
primary=primary_summary,
|
|
||||||
reference=reference_summary,
|
|
||||||
segmentations=primary_segmentations,
|
|
||||||
results=results,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Internal dispatch helpers
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def _run_segmentation(
|
|
||||||
predictions: VideoPredictions,
|
|
||||||
stage: SegmentationStage, # type: ignore[valid-type]
|
|
||||||
person_index: int,
|
|
||||||
) -> dict[str, Segmentation]:
|
|
||||||
"""Apply a segmentation stage to a :class:`VideoPredictions`.
|
|
||||||
|
|
||||||
Returns a dict keyed by a stage-appropriate label: single-side
|
|
||||||
gait cycles use ``"<joint>_cycles"``, bilateral gait cycles use
|
|
||||||
``"left_heel_strikes"`` / ``"right_heel_strikes"``, and extractor
|
|
||||||
segmentation uses the caller-supplied
|
|
||||||
:attr:`~ExtractorSegmentation.label`.
|
|
||||||
"""
|
|
||||||
if isinstance(stage, GaitCyclesSegmentation):
|
|
||||||
seg = segment_gait_cycles(
|
|
||||||
predictions,
|
|
||||||
joint=stage.joint,
|
|
||||||
axis=stage.axis,
|
|
||||||
invert=stage.invert,
|
|
||||||
min_cycle_seconds=stage.min_cycle_seconds,
|
|
||||||
min_prominence=stage.min_prominence,
|
|
||||||
)
|
|
||||||
return {f"{stage.joint}_cycles": seg}
|
|
||||||
if isinstance(stage, GaitCyclesBilateralSegmentation):
|
|
||||||
return segment_gait_cycles_bilateral(
|
|
||||||
predictions,
|
|
||||||
axis=stage.axis,
|
|
||||||
invert=stage.invert,
|
|
||||||
min_cycle_seconds=stage.min_cycle_seconds,
|
|
||||||
min_prominence=stage.min_prominence,
|
|
||||||
)
|
|
||||||
if isinstance(stage, ExtractorSegmentation):
|
|
||||||
effective_person_index = (
|
|
||||||
stage.person_index if stage.person_index is not None else person_index
|
|
||||||
)
|
|
||||||
seg = segment_predictions(
|
|
||||||
predictions,
|
|
||||||
stage.extractor,
|
|
||||||
person_index=effective_person_index,
|
|
||||||
min_distance_seconds=stage.min_distance_seconds,
|
|
||||||
min_prominence=stage.min_prominence,
|
|
||||||
min_height=stage.min_height,
|
|
||||||
pad_seconds=stage.pad_seconds,
|
|
||||||
)
|
|
||||||
return {stage.label: seg}
|
|
||||||
raise TypeError(f"unknown segmentation stage: {type(stage).__name__}")
|
|
||||||
|
|
||||||
|
|
||||||
def _run_analysis_stage(
|
|
||||||
stage: AnalysisStage, # type: ignore[valid-type]
|
|
||||||
*,
|
|
||||||
primary_seq: np.ndarray,
|
|
||||||
reference_seq: np.ndarray | None,
|
|
||||||
primary_segmentations: dict[str, Segmentation],
|
|
||||||
reference_segmentations: dict[str, Segmentation],
|
|
||||||
) -> AnalysisResults: # type: ignore[valid-type]
|
|
||||||
"""Dispatch to the appropriate analysis executor per ``stage.kind``."""
|
|
||||||
if isinstance(stage, DtwAnalysis):
|
|
||||||
if reference_seq is None:
|
|
||||||
# AnalysisConfig's cross-stage validator should prevent
|
|
||||||
# this; duplicate the check here so a direct programmatic
|
|
||||||
# call can't slip through.
|
|
||||||
raise ValueError("DtwAnalysis requires a reference sequence")
|
|
||||||
return _run_dtw(
|
|
||||||
stage,
|
|
||||||
primary_seq=primary_seq,
|
|
||||||
reference_seq=reference_seq,
|
|
||||||
primary_segmentations=primary_segmentations,
|
|
||||||
reference_segmentations=reference_segmentations,
|
|
||||||
)
|
|
||||||
if isinstance(stage, StatsAnalysis):
|
|
||||||
return _run_stats(
|
|
||||||
stage,
|
|
||||||
primary_seq=primary_seq,
|
|
||||||
primary_segmentations=primary_segmentations,
|
|
||||||
)
|
|
||||||
if isinstance(stage, NoAnalysis):
|
|
||||||
return NoResults(kind="none")
|
|
||||||
raise TypeError(f"unknown analysis stage: {type(stage).__name__}")
|
|
||||||
|
|
||||||
|
|
||||||
def _run_dtw(
|
|
||||||
stage: DtwAnalysis,
|
|
||||||
*,
|
|
||||||
primary_seq: np.ndarray,
|
|
||||||
reference_seq: np.ndarray,
|
|
||||||
primary_segmentations: dict[str, Segmentation],
|
|
||||||
reference_segmentations: dict[str, Segmentation],
|
|
||||||
) -> DtwResults:
|
|
||||||
"""Execute a DTW analysis stage, returning :class:`DtwResults`."""
|
|
||||||
labels: list[str] = []
|
|
||||||
distances: list[float] = []
|
|
||||||
paths: list[list[tuple[int, int]]] = []
|
|
||||||
per_joint_distances: list[list[float]] | None = [] if stage.method == "dtw_per_joint" else None
|
|
||||||
|
|
||||||
pairs: list[tuple[str, np.ndarray, np.ndarray]] = []
|
|
||||||
if primary_segmentations:
|
|
||||||
for key, primary_seg in primary_segmentations.items():
|
|
||||||
reference_seg = reference_segmentations.get(key)
|
|
||||||
if reference_seg is None:
|
|
||||||
# Same config was applied to both, so this should not
|
|
||||||
# happen unless the segmentation depends on the input
|
|
||||||
# length in some unexpected way. Skip with a warning
|
|
||||||
# rather than crash the whole run.
|
|
||||||
continue
|
|
||||||
pair_count = min(len(primary_seg.segments), len(reference_seg.segments))
|
|
||||||
for i in range(pair_count):
|
|
||||||
p_seg = primary_seg.segments[i]
|
|
||||||
r_seg = reference_seg.segments[i]
|
|
||||||
pairs.append(
|
|
||||||
(
|
|
||||||
f"{key}[{i}]",
|
|
||||||
primary_seq[p_seg.start : p_seg.end],
|
|
||||||
reference_seq[r_seg.start : r_seg.end],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
pairs.append(("full_trial", primary_seq, reference_seq))
|
|
||||||
|
|
||||||
for label, primary_slice, reference_slice in pairs:
|
|
||||||
labels.append(label)
|
|
||||||
if stage.method == "dtw_all":
|
|
||||||
result = dtw_all(
|
|
||||||
primary_slice,
|
|
||||||
reference_slice,
|
|
||||||
align=stage.align,
|
|
||||||
representation=stage.representation,
|
|
||||||
angle_triplets=stage.angle_triplets,
|
|
||||||
nan_policy=stage.nan_policy,
|
|
||||||
)
|
|
||||||
distances.append(result.distance)
|
|
||||||
paths.append(result.path)
|
|
||||||
elif stage.method == "dtw_per_joint":
|
|
||||||
assert per_joint_distances is not None
|
|
||||||
per_joint_results = dtw_per_joint(
|
|
||||||
primary_slice,
|
|
||||||
reference_slice,
|
|
||||||
align=stage.align,
|
|
||||||
representation=stage.representation,
|
|
||||||
angle_triplets=stage.angle_triplets,
|
|
||||||
nan_policy=stage.nan_policy,
|
|
||||||
)
|
|
||||||
# "distance" for a per-joint run is the sum across units;
|
|
||||||
# "per_joint_distances" carries the full breakdown.
|
|
||||||
per_unit = [r.distance for r in per_joint_results]
|
|
||||||
distances.append(float(sum(per_unit)))
|
|
||||||
per_joint_distances.append(per_unit)
|
|
||||||
# Store just the first joint's path as a representative —
|
|
||||||
# per-joint paths are a list of equal length, but
|
|
||||||
# reporting all of them on disk is almost always overkill.
|
|
||||||
paths.append(per_joint_results[0].path if per_joint_results else [])
|
|
||||||
else: # "dtw_relation"
|
|
||||||
assert stage.joint_i is not None
|
|
||||||
assert stage.joint_j is not None
|
|
||||||
result = _invoke_dtw_relation(
|
|
||||||
primary_slice,
|
|
||||||
reference_slice,
|
|
||||||
joint_i=stage.joint_i,
|
|
||||||
joint_j=stage.joint_j,
|
|
||||||
align=stage.align,
|
|
||||||
nan_policy=stage.nan_policy,
|
|
||||||
)
|
|
||||||
distances.append(result.distance)
|
|
||||||
paths.append(result.path)
|
|
||||||
|
|
||||||
return DtwResults(
|
|
||||||
kind="dtw",
|
|
||||||
method=stage.method,
|
|
||||||
distances=distances,
|
|
||||||
paths=paths,
|
|
||||||
per_joint_distances=per_joint_distances,
|
|
||||||
segment_labels=labels,
|
|
||||||
summary=_summarize_distances(distances),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _invoke_dtw_relation(
|
|
||||||
primary_slice: np.ndarray,
|
|
||||||
reference_slice: np.ndarray,
|
|
||||||
*,
|
|
||||||
joint_i: int,
|
|
||||||
joint_j: int,
|
|
||||||
align: AlignMode,
|
|
||||||
nan_policy: NanPolicy,
|
|
||||||
) -> DTWResult:
|
|
||||||
"""Isolating thin wrapper so test fakes can replace the call site cleanly."""
|
|
||||||
return dtw_relation(
|
|
||||||
primary_slice,
|
|
||||||
reference_slice,
|
|
||||||
joint_i,
|
|
||||||
joint_j,
|
|
||||||
align=align,
|
|
||||||
nan_policy=nan_policy,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _run_stats(
|
|
||||||
stage: StatsAnalysis,
|
|
||||||
*,
|
|
||||||
primary_seq: np.ndarray,
|
|
||||||
primary_segmentations: dict[str, Segmentation],
|
|
||||||
) -> StatsResults:
|
|
||||||
"""Execute a stats analysis stage, returning :class:`StatsResults`."""
|
|
||||||
labels: list[str] = []
|
|
||||||
stats: list[FeatureSummary] = []
|
|
||||||
|
|
||||||
if primary_segmentations:
|
|
||||||
for key, seg in primary_segmentations.items():
|
|
||||||
for i, segment in enumerate(seg.segments):
|
|
||||||
labels.append(f"{key}[{i}]")
|
|
||||||
signal = extract_signal(
|
|
||||||
primary_seq[segment.start : segment.end],
|
|
||||||
stage.extractor,
|
|
||||||
)
|
|
||||||
stats.append(_feature_summary(signal))
|
|
||||||
else:
|
|
||||||
labels.append("full_trial")
|
|
||||||
signal = extract_signal(primary_seq, stage.extractor)
|
|
||||||
stats.append(_feature_summary(signal))
|
|
||||||
|
|
||||||
return StatsResults(kind="stats", statistics=stats, segment_labels=labels)
|
|
||||||
|
|
||||||
|
|
||||||
def _feature_summary(signal: np.ndarray) -> FeatureSummary:
|
|
||||||
"""Wrap :func:`extract_feature_statistics` output in a pydantic model."""
|
|
||||||
raw = extract_feature_statistics(signal)
|
|
||||||
return FeatureSummary(
|
|
||||||
mean=raw.mean,
|
|
||||||
std=raw.std,
|
|
||||||
min=raw.min,
|
|
||||||
max=raw.max,
|
|
||||||
range=raw.range,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _summarize_distances(distances: list[float]) -> dict[str, float]:
|
|
||||||
"""Compute mean / p50 / p95 / p99 of a distance list.
|
|
||||||
|
|
||||||
Empty inputs return an empty dict so the report's ``summary``
|
|
||||||
field still round-trips through JSON without special cases.
|
|
||||||
"""
|
|
||||||
if not distances:
|
|
||||||
return {}
|
|
||||||
arr = np.asarray(distances, dtype=float)
|
|
||||||
return {
|
|
||||||
"mean": float(arr.mean()),
|
|
||||||
"p50": float(np.percentile(arr, 50)),
|
|
||||||
"p95": float(np.percentile(arr, 95)),
|
|
||||||
"p99": float(np.percentile(arr, 99)),
|
|
||||||
}
|
|
||||||
|
|
@ -30,12 +30,6 @@ Three layers of API are provided, in increasing order of convenience:
|
||||||
:class:`~neuropose.io.ExtractorSpec`, converts time-based parameters
|
:class:`~neuropose.io.ExtractorSpec`, converts time-based parameters
|
||||||
to frame counts using ``metadata.fps``, and returns a full
|
to frame counts using ``metadata.fps``, and returns a full
|
||||||
:class:`~neuropose.io.Segmentation` ready to attach to the predictions.
|
:class:`~neuropose.io.Segmentation` ready to attach to the predictions.
|
||||||
- :func:`segment_gait_cycles` and :func:`segment_gait_cycles_bilateral`
|
|
||||||
— clinical convenience wrappers over :func:`segment_predictions`
|
|
||||||
that pre-fill a :func:`joint_axis` extractor with gait-appropriate
|
|
||||||
defaults (heel joint, Y axis, 0.4 s minimum cycle). The bilateral
|
|
||||||
variant returns both sides under ``"left_heel_strikes"`` and
|
|
||||||
``"right_heel_strikes"`` keys.
|
|
||||||
- :func:`slice_predictions` — split a :class:`~neuropose.io.VideoPredictions`
|
- :func:`slice_predictions` — split a :class:`~neuropose.io.VideoPredictions`
|
||||||
into one per-repetition :class:`~neuropose.io.VideoPredictions`,
|
into one per-repetition :class:`~neuropose.io.VideoPredictions`,
|
||||||
useful when downstream code wants per-rep objects rather than windows
|
useful when downstream code wants per-rep objects rather than windows
|
||||||
|
|
@ -72,7 +66,6 @@ installed; a clear :class:`ImportError` surfaces at the first call to
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
@ -90,11 +83,6 @@ from neuropose.io import (
|
||||||
VideoPredictions,
|
VideoPredictions,
|
||||||
)
|
)
|
||||||
|
|
||||||
AxisLetter = Literal["x", "y", "z"]
|
|
||||||
"""Axis selector used by gait-cycle segmentation helpers."""
|
|
||||||
|
|
||||||
_AXIS_INDICES: dict[AxisLetter, int] = {"x": 0, "y": 1, "z": 2}
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# berkeley_mhad_43 joint names
|
# berkeley_mhad_43 joint names
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -534,156 +522,6 @@ def segment_predictions(
|
||||||
return Segmentation(config=config, segments=segments)
|
return Segmentation(config=config, segments=segments)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Gait-cycle segmentation
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def segment_gait_cycles(
|
|
||||||
predictions: VideoPredictions,
|
|
||||||
*,
|
|
||||||
joint: str = "rhee",
|
|
||||||
axis: AxisLetter = "y",
|
|
||||||
invert: bool = False,
|
|
||||||
min_cycle_seconds: float = 0.4,
|
|
||||||
min_prominence: float | None = None,
|
|
||||||
) -> Segmentation:
|
|
||||||
"""Segment gait cycles from a single heel's vertical trace.
|
|
||||||
|
|
||||||
Runs valley-to-valley peak detection (the same engine used by
|
|
||||||
:func:`segment_predictions`) on the chosen joint's coordinate along
|
|
||||||
the chosen spatial axis. By default, each detected peak corresponds
|
|
||||||
to one heel-strike — the frame where the heel reaches its lowest
|
|
||||||
point on the Y-down MeTRAbs world-coordinate convention — and the
|
|
||||||
returned :class:`~neuropose.io.Segment` windows span one full gait
|
|
||||||
cycle from the preceding toe-off valley to the following toe-off
|
|
||||||
valley.
|
|
||||||
|
|
||||||
The function is a **thin wrapper** over :func:`segment_predictions`
|
|
||||||
with a :func:`joint_axis` extractor; it exists to give clinical
|
|
||||||
callers a gait-specific entry point with meaningful defaults
|
|
||||||
(``joint="rhee"``, ``axis="y"``, ``min_cycle_seconds=0.4``)
|
|
||||||
rather than forcing them to construct the extractor by hand.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
predictions
|
|
||||||
Per-video predictions to segment. ``metadata.fps`` is used to
|
|
||||||
translate ``min_cycle_seconds`` into a sample-count distance
|
|
||||||
threshold.
|
|
||||||
joint
|
|
||||||
Joint name in the berkeley_mhad_43 skeleton — typically
|
|
||||||
``"rhee"`` (right heel) or ``"lhee"`` (left heel). Resolved
|
|
||||||
via :func:`joint_index`.
|
|
||||||
axis
|
|
||||||
Spatial axis to track, as ``"x"``, ``"y"``, or ``"z"``. The
|
|
||||||
default ``"y"`` matches the vertical axis in MeTRAbs's output
|
|
||||||
(Y-down world coordinates).
|
|
||||||
invert
|
|
||||||
If ``True``, negate the extracted signal so that minima
|
|
||||||
become peaks. Needed when the recording convention makes a
|
|
||||||
heel-strike appear as a *decrease* in the chosen coordinate
|
|
||||||
— for example, a camera orientation where the vertical axis
|
|
||||||
runs bottom-to-top instead of MeTRAbs's default top-to-bottom.
|
|
||||||
min_cycle_seconds
|
|
||||||
Minimum gait-cycle duration. Used as scipy's
|
|
||||||
``find_peaks(distance=...)`` parameter after conversion to
|
|
||||||
frame count via ``metadata.fps``. Defaults to ``0.4`` seconds,
|
|
||||||
which rejects noise peaks on even the fastest human gaits
|
|
||||||
(~120 strides/min) while retaining every real cadence.
|
|
||||||
min_prominence
|
|
||||||
Forwarded to :func:`segment_by_peaks` to filter out shallow
|
|
||||||
local maxima that aren't real heel-strikes. In MeTRAbs units
|
|
||||||
(millimetres) a threshold of 20 to 50 mm is typical for
|
|
||||||
able-bodied gait; leave ``None`` to accept every peak scipy
|
|
||||||
identifies.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
Segmentation
|
|
||||||
A :class:`~neuropose.io.Segmentation` paired with the full
|
|
||||||
:class:`~neuropose.io.SegmentationConfig` that produced it, so
|
|
||||||
the output is self-describing when persisted. The segments
|
|
||||||
list is **empty** rather than an exception when no peaks are
|
|
||||||
detected — a common outcome for shuffling gaits or
|
|
||||||
walker-assisted trials.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
KeyError
|
|
||||||
If ``joint`` is not a known berkeley_mhad_43 joint name.
|
|
||||||
ValueError
|
|
||||||
If ``axis`` is not one of ``"x"``, ``"y"``, ``"z"``, or if
|
|
||||||
``predictions`` has zero frames, or if ``metadata.fps`` is
|
|
||||||
non-positive.
|
|
||||||
ImportError
|
|
||||||
If :mod:`scipy` is not installed.
|
|
||||||
"""
|
|
||||||
if axis not in _AXIS_INDICES:
|
|
||||||
raise ValueError(f"axis must be one of 'x', 'y', 'z'; got {axis!r}")
|
|
||||||
joint_idx = joint_index(joint)
|
|
||||||
axis_idx = _AXIS_INDICES[axis]
|
|
||||||
extractor = joint_axis(joint_idx, axis_idx, invert=invert)
|
|
||||||
return segment_predictions(
|
|
||||||
predictions,
|
|
||||||
extractor,
|
|
||||||
min_distance_seconds=min_cycle_seconds,
|
|
||||||
min_prominence=min_prominence,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def segment_gait_cycles_bilateral(
|
|
||||||
predictions: VideoPredictions,
|
|
||||||
*,
|
|
||||||
axis: AxisLetter = "y",
|
|
||||||
invert: bool = False,
|
|
||||||
min_cycle_seconds: float = 0.4,
|
|
||||||
min_prominence: float | None = None,
|
|
||||||
) -> dict[str, Segmentation]:
|
|
||||||
"""Segment gait cycles for both heels.
|
|
||||||
|
|
||||||
Runs :func:`segment_gait_cycles` twice — once with ``joint="lhee"``
|
|
||||||
and once with ``joint="rhee"`` — and returns the two results under
|
|
||||||
the keys ``"left_heel_strikes"`` and ``"right_heel_strikes"``. The
|
|
||||||
returned dict is shape-compatible with
|
|
||||||
:class:`~neuropose.io.VideoPredictions.segmentations` so it can be
|
|
||||||
merged directly into a predictions object and persisted to
|
|
||||||
``results.json`` via the usual save path.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
predictions, axis, invert, min_cycle_seconds, min_prominence
|
|
||||||
Forwarded to :func:`segment_gait_cycles`; see that function's
|
|
||||||
docstring for details.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
dict[str, Segmentation]
|
|
||||||
Two-keyed mapping with the left and right heel segmentations
|
|
||||||
under ``"left_heel_strikes"`` and ``"right_heel_strikes"``.
|
|
||||||
Either side may carry an empty segments list if its heel's
|
|
||||||
trace contained no detectable strikes.
|
|
||||||
"""
|
|
||||||
return {
|
|
||||||
"left_heel_strikes": segment_gait_cycles(
|
|
||||||
predictions,
|
|
||||||
joint="lhee",
|
|
||||||
axis=axis,
|
|
||||||
invert=invert,
|
|
||||||
min_cycle_seconds=min_cycle_seconds,
|
|
||||||
min_prominence=min_prominence,
|
|
||||||
),
|
|
||||||
"right_heel_strikes": segment_gait_cycles(
|
|
||||||
predictions,
|
|
||||||
joint="rhee",
|
|
||||||
axis=axis,
|
|
||||||
invert=invert,
|
|
||||||
min_cycle_seconds=min_cycle_seconds,
|
|
||||||
min_prominence=min_prominence,
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Slicing: one VideoPredictions per segment
|
# Slicing: one VideoPredictions per segment
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
|
||||||
|
|
@ -105,17 +105,9 @@ def run_benchmark(
|
||||||
|
|
||||||
passes: list[PerformanceMetrics] = []
|
passes: list[PerformanceMetrics] = []
|
||||||
reference_predictions: VideoPredictions | None = None
|
reference_predictions: VideoPredictions | None = None
|
||||||
# Provenance is identical across every pass of a single run (same
|
|
||||||
# estimator, same model, same environment), so we keep just the
|
|
||||||
# latest one we see. Doing this on every iteration is cheap — it's
|
|
||||||
# one attribute read — and means the benchmark result carries
|
|
||||||
# provenance even when ``capture_reference`` is off.
|
|
||||||
latest_provenance = None
|
|
||||||
for i in range(repeats):
|
for i in range(repeats):
|
||||||
result = estimator.process_video(video_path)
|
result = estimator.process_video(video_path)
|
||||||
passes.append(result.metrics)
|
passes.append(result.metrics)
|
||||||
if result.predictions.provenance is not None:
|
|
||||||
latest_provenance = result.predictions.provenance
|
|
||||||
# Only the *last* measured pass needs to be captured for
|
# Only the *last* measured pass needs to be captured for
|
||||||
# divergence comparison. Earlier passes would just be
|
# divergence comparison. Earlier passes would just be
|
||||||
# overwritten, so we avoid holding their frame dicts in memory.
|
# overwritten, so we avoid holding their frame dicts in memory.
|
||||||
|
|
@ -130,7 +122,6 @@ def run_benchmark(
|
||||||
warmup_pass=passes[0],
|
warmup_pass=passes[0],
|
||||||
measured_passes=passes[1:],
|
measured_passes=passes[1:],
|
||||||
aggregate=aggregate,
|
aggregate=aggregate,
|
||||||
provenance=latest_provenance,
|
|
||||||
)
|
)
|
||||||
return BenchmarkRunOutcome(
|
return BenchmarkRunOutcome(
|
||||||
result=benchmark_result,
|
result=benchmark_result,
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
"""NeuroPose command-line interface.
|
"""NeuroPose command-line interface.
|
||||||
|
|
||||||
Eight subcommands:
|
Seven subcommands:
|
||||||
|
|
||||||
- ``neuropose watch`` — run the :class:`~neuropose.interfacer.Interfacer`
|
- ``neuropose watch`` — run the :class:`~neuropose.interfacer.Interfacer`
|
||||||
daemon against the configured input directory.
|
daemon against the configured input directory.
|
||||||
|
|
@ -12,10 +12,6 @@ Eight subcommands:
|
||||||
- ``neuropose serve`` — start the :mod:`~neuropose.monitor` localhost
|
- ``neuropose serve`` — start the :mod:`~neuropose.monitor` localhost
|
||||||
HTTP dashboard so collaborators can watch a run's progress in a
|
HTTP dashboard so collaborators can watch a run's progress in a
|
||||||
browser or via ``curl``.
|
browser or via ``curl``.
|
||||||
- ``neuropose reset`` — stop the daemon and monitor, then wipe pipeline
|
|
||||||
state (input queue, results, status file, lock file, ingest staging
|
|
||||||
dirs) for a clean restart. See :mod:`neuropose.reset` for the layered
|
|
||||||
implementation.
|
|
||||||
- ``neuropose segment <results>`` — post-hoc repetition segmentation of
|
- ``neuropose segment <results>`` — post-hoc repetition segmentation of
|
||||||
an existing predictions file. Attaches a named
|
an existing predictions file. Attaches a named
|
||||||
:class:`~neuropose.io.Segmentation` to every video it contains and
|
:class:`~neuropose.io.Segmentation` to every video it contains and
|
||||||
|
|
@ -25,11 +21,8 @@ Eight subcommands:
|
||||||
vs CPU numerical-divergence checks. Prints a human report to stdout
|
vs CPU numerical-divergence checks. Prints a human report to stdout
|
||||||
and (optionally) writes a structured :class:`~neuropose.io.BenchmarkResult`
|
and (optionally) writes a structured :class:`~neuropose.io.BenchmarkResult`
|
||||||
JSON to ``--output``.
|
JSON to ``--output``.
|
||||||
- ``neuropose analyze --config <yaml>`` — run the declarative analysis
|
- ``neuropose analyze <results>`` — stubbed placeholder pending the
|
||||||
pipeline described in a YAML config. Loads the named predictions
|
analyzer rewrite in commit 10.
|
||||||
files, applies segmentation + analysis, writes an
|
|
||||||
:class:`~neuropose.analyzer.pipeline.AnalysisReport` JSON. See
|
|
||||||
``examples/analysis/*.yaml`` for runnable references.
|
|
||||||
|
|
||||||
User-facing error handling
|
User-facing error handling
|
||||||
--------------------------
|
--------------------------
|
||||||
|
|
@ -60,7 +53,6 @@ from pathlib import Path
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
import typer
|
import typer
|
||||||
import yaml
|
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from neuropose import __version__
|
from neuropose import __version__
|
||||||
|
|
@ -434,151 +426,6 @@ def serve(
|
||||||
raise typer.Exit(code=EXIT_USAGE) from exc
|
raise typer.Exit(code=EXIT_USAGE) from exc
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# reset
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
|
||||||
def reset(
|
|
||||||
ctx: typer.Context,
|
|
||||||
yes: Annotated[
|
|
||||||
bool,
|
|
||||||
typer.Option(
|
|
||||||
"--yes",
|
|
||||||
"-y",
|
|
||||||
help="Skip the interactive confirmation prompt.",
|
|
||||||
),
|
|
||||||
] = False,
|
|
||||||
keep_failed: Annotated[
|
|
||||||
bool,
|
|
||||||
typer.Option(
|
|
||||||
"--keep-failed",
|
|
||||||
help=(
|
|
||||||
"Preserve $data_dir/failed/ for forensic review. By "
|
|
||||||
"default the failed-job quarantine is wiped along with "
|
|
||||||
"in/ and out/."
|
|
||||||
),
|
|
||||||
),
|
|
||||||
] = False,
|
|
||||||
force_kill: Annotated[
|
|
||||||
bool,
|
|
||||||
typer.Option(
|
|
||||||
"--force-kill",
|
|
||||||
help=(
|
|
||||||
"Escalate to SIGKILL on any daemon or monitor still "
|
|
||||||
"alive after the SIGINT grace period. Necessary if the "
|
|
||||||
"daemon is mid-inference on a long video and you do "
|
|
||||||
"not want to wait for the current video to finish."
|
|
||||||
),
|
|
||||||
),
|
|
||||||
] = False,
|
|
||||||
grace_seconds: Annotated[
|
|
||||||
float,
|
|
||||||
typer.Option(
|
|
||||||
"--grace-seconds",
|
|
||||||
min=0.0,
|
|
||||||
help=(
|
|
||||||
"Seconds to wait after SIGINT before declaring a "
|
|
||||||
"process a survivor (or escalating to SIGKILL when "
|
|
||||||
"--force-kill is set)."
|
|
||||||
),
|
|
||||||
),
|
|
||||||
] = 10.0,
|
|
||||||
dry_run: Annotated[
|
|
||||||
bool,
|
|
||||||
typer.Option(
|
|
||||||
"--dry-run",
|
|
||||||
"-n",
|
|
||||||
help="Show what would be killed and removed without doing it.",
|
|
||||||
),
|
|
||||||
] = False,
|
|
||||||
) -> None:
|
|
||||||
"""Stop the daemon and monitor, then wipe pipeline state.
|
|
||||||
|
|
||||||
Discovers running ``neuropose watch`` and ``neuropose serve``
|
|
||||||
processes, sends SIGINT, waits ``--grace-seconds`` for graceful
|
|
||||||
shutdown (optionally escalating to SIGKILL with ``--force-kill``),
|
|
||||||
then removes the contents of ``$data_dir/in/``, ``$data_dir/out/``
|
|
||||||
(including ``status.json``), ``$data_dir/failed/`` (unless
|
|
||||||
``--keep-failed``), the daemon lock file, and any leftover
|
|
||||||
``.ingest_<uuid>/`` staging directories from interrupted ingests.
|
|
||||||
|
|
||||||
Refuses to wipe state if any process survives the termination
|
|
||||||
phase — wiping the data directory out from under an active daemon
|
|
||||||
would leave it writing into deleted directory entries. Re-run
|
|
||||||
with ``--force-kill`` or stop the survivor manually.
|
|
||||||
"""
|
|
||||||
# Deferred import so reset's psutil scan stays off the watch/process
|
|
||||||
# hot path. psutil is already a runtime dependency for benchmark
|
|
||||||
# metrics, so this import is free at install time.
|
|
||||||
from neuropose.reset import find_neuropose_processes, reset_pipeline, wipe_state
|
|
||||||
|
|
||||||
settings: Settings = ctx.obj
|
|
||||||
|
|
||||||
discovered = find_neuropose_processes()
|
|
||||||
preview = wipe_state(settings, keep_failed=keep_failed, dry_run=True)
|
|
||||||
|
|
||||||
typer.echo(f"data dir: {settings.data_dir}")
|
|
||||||
if discovered:
|
|
||||||
typer.echo(f"would stop: {len(discovered)} process(es)")
|
|
||||||
for rp in discovered:
|
|
||||||
typer.echo(f" pid {rp.pid:>7} {rp.role:<7} {rp.cmdline}")
|
|
||||||
else:
|
|
||||||
typer.echo("would stop: no daemon or monitor running")
|
|
||||||
if preview.removed_paths:
|
|
||||||
size_mb = preview.bytes_freed / (1024 * 1024)
|
|
||||||
typer.echo(f"would remove: {len(preview.removed_paths)} path(s) ({size_mb:.1f} MB)")
|
|
||||||
for path in preview.removed_paths:
|
|
||||||
typer.echo(f" {path}")
|
|
||||||
else:
|
|
||||||
typer.echo("would remove: nothing — data dir is already clean")
|
|
||||||
|
|
||||||
if dry_run:
|
|
||||||
typer.echo("(dry-run; no changes made)")
|
|
||||||
return
|
|
||||||
|
|
||||||
if not discovered and not preview.removed_paths:
|
|
||||||
typer.echo("nothing to do.")
|
|
||||||
return
|
|
||||||
|
|
||||||
if not yes and not typer.confirm("\nproceed?"):
|
|
||||||
typer.echo("aborted.")
|
|
||||||
raise typer.Exit(code=EXIT_USAGE)
|
|
||||||
|
|
||||||
report = reset_pipeline(
|
|
||||||
settings,
|
|
||||||
grace_seconds=grace_seconds,
|
|
||||||
force_kill=force_kill,
|
|
||||||
keep_failed=keep_failed,
|
|
||||||
)
|
|
||||||
|
|
||||||
if report.termination.stopped:
|
|
||||||
typer.echo(f"stopped {len(report.termination.stopped)} process(es) via SIGINT")
|
|
||||||
if report.termination.force_killed:
|
|
||||||
typer.echo(
|
|
||||||
f"force-killed {len(report.termination.force_killed)} process(es) "
|
|
||||||
f"after {grace_seconds:.0f}s grace period"
|
|
||||||
)
|
|
||||||
if report.termination.survivors:
|
|
||||||
typer.echo(
|
|
||||||
f"error: {len(report.termination.survivors)} process(es) did not exit:",
|
|
||||||
err=True,
|
|
||||||
)
|
|
||||||
for rp in report.termination.survivors:
|
|
||||||
typer.echo(f" pid {rp.pid} ({rp.role})", err=True)
|
|
||||||
if report.wipe_skipped_due_to_survivors:
|
|
||||||
typer.echo(
|
|
||||||
" state on disk was NOT wiped — re-run with --force-kill, "
|
|
||||||
"or stop these processes manually first.",
|
|
||||||
err=True,
|
|
||||||
)
|
|
||||||
raise typer.Exit(code=EXIT_USAGE)
|
|
||||||
|
|
||||||
size_mb = report.wipe.bytes_freed / (1024 * 1024)
|
|
||||||
typer.echo(f"removed {len(report.wipe.removed_paths)} path(s) ({size_mb:.1f} MB freed)")
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# segment
|
# segment
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -1200,94 +1047,26 @@ def benchmark(
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# analyze
|
# analyze (stub)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def analyze(
|
def analyze(
|
||||||
ctx: typer.Context,
|
ctx: typer.Context,
|
||||||
config: Annotated[
|
results: Annotated[
|
||||||
Path,
|
Path,
|
||||||
typer.Option(
|
typer.Argument(help="Path to a results.json produced by watch or process."),
|
||||||
"--config",
|
|
||||||
"-c",
|
|
||||||
help=(
|
|
||||||
"Path to a YAML AnalysisConfig file. See examples/analysis/ "
|
|
||||||
"for runnable references."
|
|
||||||
),
|
|
||||||
),
|
|
||||||
],
|
],
|
||||||
output: Annotated[
|
|
||||||
Path | None,
|
|
||||||
typer.Option(
|
|
||||||
"--output",
|
|
||||||
"-o",
|
|
||||||
help=(
|
|
||||||
"Override the report path declared in the config's "
|
|
||||||
"output.report field. Useful when running the same config "
|
|
||||||
"against multiple input pairs from a shell loop."
|
|
||||||
),
|
|
||||||
),
|
|
||||||
] = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run the declarative analysis pipeline described by a YAML config.
|
"""Run the analyzer subpackage against a results.json (pending commit 10)."""
|
||||||
|
del ctx, results
|
||||||
Loads the config, parses it through
|
typer.echo(
|
||||||
:class:`~neuropose.analyzer.pipeline.AnalysisConfig` (so typos fail
|
"error: the analyzer subpackage is pending commit 10. "
|
||||||
immediately with a clear error), executes the pipeline via
|
"Until it lands, use neuropose.io to load results.json from Python.",
|
||||||
:func:`~neuropose.analyzer.pipeline.run_analysis`, and writes the
|
err=True,
|
||||||
resulting :class:`~neuropose.analyzer.pipeline.AnalysisReport` to
|
)
|
||||||
``--output`` (or to ``output.report`` declared in the config).
|
raise typer.Exit(code=EXIT_PENDING)
|
||||||
|
|
||||||
Cross-field invariants (for example,
|
|
||||||
``method='dtw_relation'`` requires ``joint_i`` / ``joint_j``) are
|
|
||||||
enforced at parse time, so a typo fails before any predictions
|
|
||||||
are loaded.
|
|
||||||
"""
|
|
||||||
del ctx
|
|
||||||
# Deferred import keeps the CLI module's top-level imports free of
|
|
||||||
# pipeline dependencies so ``watch`` / ``process`` startup stays
|
|
||||||
# cheap.
|
|
||||||
from neuropose.analyzer.pipeline import load_config, run_analysis, save_report
|
|
||||||
|
|
||||||
if not config.exists():
|
|
||||||
typer.echo(f"error: config file not found: {config}", err=True)
|
|
||||||
raise typer.Exit(code=EXIT_USAGE)
|
|
||||||
|
|
||||||
try:
|
|
||||||
analysis_config = load_config(config)
|
|
||||||
except ValidationError as exc:
|
|
||||||
typer.echo(f"error: invalid config {config}:\n{exc}", err=True)
|
|
||||||
raise typer.Exit(code=EXIT_USAGE) from exc
|
|
||||||
except yaml.YAMLError as exc:
|
|
||||||
typer.echo(f"error: could not parse YAML {config}: {exc}", err=True)
|
|
||||||
raise typer.Exit(code=EXIT_USAGE) from exc
|
|
||||||
|
|
||||||
report_path = output if output is not None else analysis_config.output.report
|
|
||||||
|
|
||||||
try:
|
|
||||||
report = run_analysis(analysis_config)
|
|
||||||
except (FileNotFoundError, ValueError) as exc:
|
|
||||||
typer.echo(f"error: analysis failed: {exc}", err=True)
|
|
||||||
raise typer.Exit(code=EXIT_USAGE) from exc
|
|
||||||
|
|
||||||
save_report(report_path, report)
|
|
||||||
|
|
||||||
typer.echo(f"wrote analysis report to {report_path}")
|
|
||||||
if report.segmentations:
|
|
||||||
seg_summary = ", ".join(
|
|
||||||
f"{name}={len(seg.segments)}" for name, seg in report.segmentations.items()
|
|
||||||
)
|
|
||||||
typer.echo(f"segmentations: {seg_summary}")
|
|
||||||
# Emit a one-line summary of the results regardless of kind.
|
|
||||||
typer.echo(f"analysis kind: {report.results.kind}")
|
|
||||||
if report.results.kind == "dtw":
|
|
||||||
n = len(report.results.distances)
|
|
||||||
mean = report.results.summary.get("mean", float("nan"))
|
|
||||||
typer.echo(f"distances computed: {n} (mean={mean:.4f})")
|
|
||||||
elif report.results.kind == "stats":
|
|
||||||
typer.echo(f"statistic blocks computed: {len(report.results.statistics)}")
|
|
||||||
|
|
||||||
|
|
||||||
def run() -> None:
|
def run() -> None:
|
||||||
|
|
|
||||||
|
|
@ -34,25 +34,19 @@ model is present raises :class:`ModelNotLoadedError`.
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import sys
|
|
||||||
import time
|
import time
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from importlib.metadata import PackageNotFoundError
|
|
||||||
from importlib.metadata import version as _pkg_version
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
|
||||||
import psutil
|
import psutil
|
||||||
|
|
||||||
from neuropose import __version__ as _neuropose_version
|
|
||||||
from neuropose._model import load_metrabs_model
|
from neuropose._model import load_metrabs_model
|
||||||
from neuropose.io import (
|
from neuropose.io import (
|
||||||
FramePrediction,
|
FramePrediction,
|
||||||
PerformanceMetrics,
|
PerformanceMetrics,
|
||||||
Provenance,
|
|
||||||
VideoMetadata,
|
VideoMetadata,
|
||||||
VideoPredictions,
|
VideoPredictions,
|
||||||
)
|
)
|
||||||
|
|
@ -164,12 +158,6 @@ class Estimator:
|
||||||
# successful ``load_model`` below so the next ``process_video`` can
|
# successful ``load_model`` below so the next ``process_video`` can
|
||||||
# pass the real number through into ``PerformanceMetrics``.
|
# pass the real number through into ``PerformanceMetrics``.
|
||||||
self._model_load_seconds: float | None = None
|
self._model_load_seconds: float | None = None
|
||||||
# MeTRAbs artifact identity, set only by ``load_model``. When the
|
|
||||||
# model was injected via the constructor we have no way to
|
|
||||||
# fingerprint it, so these remain ``None`` and ``process_video``
|
|
||||||
# leaves the output's ``provenance`` as ``None`` too.
|
|
||||||
self._model_sha256: str | None = None
|
|
||||||
self._model_filename: str | None = None
|
|
||||||
|
|
||||||
# -- model lifecycle ----------------------------------------------------
|
# -- model lifecycle ----------------------------------------------------
|
||||||
|
|
||||||
|
|
@ -188,21 +176,6 @@ class Estimator:
|
||||||
"""Return ``True`` if a model has been supplied or loaded."""
|
"""Return ``True`` if a model has been supplied or loaded."""
|
||||||
return self._model is not None
|
return self._model is not None
|
||||||
|
|
||||||
@property
|
|
||||||
def model_sha256(self) -> str | None:
|
|
||||||
"""Return the SHA-256 of the loaded MeTRAbs artifact, or ``None``.
|
|
||||||
|
|
||||||
``None`` when the model was injected via ``Estimator(model=...)``
|
|
||||||
rather than loaded via :meth:`load_model`. The value, when
|
|
||||||
present, is the module-pinned SHA from :mod:`neuropose._model`.
|
|
||||||
"""
|
|
||||||
return self._model_sha256
|
|
||||||
|
|
||||||
@property
|
|
||||||
def model_filename(self) -> str | None:
|
|
||||||
"""Return the basename of the MeTRAbs artifact, or ``None`` if injected."""
|
|
||||||
return self._model_filename
|
|
||||||
|
|
||||||
def load_model(self, cache_dir: Path | None = None) -> None:
|
def load_model(self, cache_dir: Path | None = None) -> None:
|
||||||
"""Load the MeTRAbs model via :func:`neuropose._model.load_metrabs_model`.
|
"""Load the MeTRAbs model via :func:`neuropose._model.load_metrabs_model`.
|
||||||
|
|
||||||
|
|
@ -223,16 +196,9 @@ class Estimator:
|
||||||
return
|
return
|
||||||
logger.info("Loading MeTRAbs model (cache_dir=%s)", cache_dir)
|
logger.info("Loading MeTRAbs model (cache_dir=%s)", cache_dir)
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
loaded = load_metrabs_model(cache_dir=cache_dir)
|
self._model = load_metrabs_model(cache_dir=cache_dir)
|
||||||
self._model_load_seconds = time.perf_counter() - start
|
self._model_load_seconds = time.perf_counter() - start
|
||||||
self._model = loaded.model
|
logger.info("MeTRAbs model loaded in %.2f s", self._model_load_seconds)
|
||||||
self._model_sha256 = loaded.sha256
|
|
||||||
self._model_filename = loaded.filename
|
|
||||||
logger.info(
|
|
||||||
"MeTRAbs model loaded in %.2f s (sha256=%s)",
|
|
||||||
self._model_load_seconds,
|
|
||||||
loaded.sha256[:12],
|
|
||||||
)
|
|
||||||
|
|
||||||
# -- inference ----------------------------------------------------------
|
# -- inference ----------------------------------------------------------
|
||||||
|
|
||||||
|
|
@ -364,53 +330,11 @@ class Estimator:
|
||||||
metrics.active_device,
|
metrics.active_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
provenance = self._build_provenance(device_info=device_info)
|
predictions = VideoPredictions(metadata=metadata, frames=frames)
|
||||||
predictions = VideoPredictions(
|
|
||||||
metadata=metadata,
|
|
||||||
frames=frames,
|
|
||||||
provenance=provenance,
|
|
||||||
)
|
|
||||||
return ProcessVideoResult(predictions=predictions, metrics=metrics)
|
return ProcessVideoResult(predictions=predictions, metrics=metrics)
|
||||||
|
|
||||||
# -- internals ----------------------------------------------------------
|
# -- internals ----------------------------------------------------------
|
||||||
|
|
||||||
def _build_provenance(self, *, device_info: _ActiveDeviceInfo) -> Provenance | None:
|
|
||||||
"""Construct a :class:`~neuropose.io.Provenance` for the current run.
|
|
||||||
|
|
||||||
Returns ``None`` when the model was injected via the constructor
|
|
||||||
rather than loaded via :meth:`load_model` — in that case we
|
|
||||||
cannot fingerprint the artifact, and a partial provenance would
|
|
||||||
mislead readers into thinking we could.
|
|
||||||
|
|
||||||
The device-info bundle is shared with the :class:`PerformanceMetrics`
|
|
||||||
construction (one call to :func:`_detect_active_device` per
|
|
||||||
``process_video`` invocation) so that both artifacts see
|
|
||||||
identical TF and Metal state.
|
|
||||||
"""
|
|
||||||
if self._model_sha256 is None or self._model_filename is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
metal_version: str | None = None
|
|
||||||
if device_info.metal_active:
|
|
||||||
try:
|
|
||||||
metal_version = _pkg_version("tensorflow-metal")
|
|
||||||
except PackageNotFoundError:
|
|
||||||
metal_version = None
|
|
||||||
|
|
||||||
python_version = (
|
|
||||||
f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return Provenance(
|
|
||||||
model_sha256=self._model_sha256,
|
|
||||||
model_filename=self._model_filename,
|
|
||||||
tensorflow_version=device_info.tf_version,
|
|
||||||
tensorflow_metal_version=metal_version,
|
|
||||||
numpy_version=np.__version__,
|
|
||||||
neuropose_version=_neuropose_version,
|
|
||||||
python_version=python_version,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _infer_frame(
|
def _infer_frame(
|
||||||
self,
|
self,
|
||||||
model: Any,
|
model: Any,
|
||||||
|
|
|
||||||
|
|
@ -10,14 +10,6 @@ Atomicity: :func:`save_status`, :func:`save_job_results`, and
|
||||||
atomically rename, so a crash mid-write will not leave a partially-written
|
atomically rename, so a crash mid-write will not leave a partially-written
|
||||||
file behind. This matches the crash-resilience guarantee the interfacer
|
file behind. This matches the crash-resilience guarantee the interfacer
|
||||||
daemon makes to callers.
|
daemon makes to callers.
|
||||||
|
|
||||||
Schema versioning: :class:`VideoPredictions` and :class:`BenchmarkResult`
|
|
||||||
each carry a ``schema_version`` integer. On load, the raw JSON dict is
|
|
||||||
passed through :mod:`neuropose.migrations` before pydantic validation so
|
|
||||||
that files written by earlier versions upgrade transparently. :class:`JobResults`
|
|
||||||
is a ``RootModel`` with no envelope of its own, so its loader runs the
|
|
||||||
per-video migration on each entry of its mapping. See
|
|
||||||
:mod:`neuropose.migrations` for the migration-registration pattern.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
@ -31,13 +23,6 @@ from typing import Annotated, Any, Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, model_validator
|
from pydantic import BaseModel, ConfigDict, Field, RootModel, model_validator
|
||||||
|
|
||||||
from neuropose.migrations import (
|
|
||||||
CURRENT_VERSION,
|
|
||||||
migrate_benchmark_result,
|
|
||||||
migrate_job_results,
|
|
||||||
migrate_video_predictions,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class JobStatus(StrEnum):
|
class JobStatus(StrEnum):
|
||||||
"""Lifecycle state of a single processing job."""
|
"""Lifecycle state of a single processing job."""
|
||||||
|
|
@ -172,104 +157,6 @@ class PerformanceMetrics(BaseModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Provenance(BaseModel):
|
|
||||||
"""Reproducibility-grade record of the environment that produced a payload.
|
|
||||||
|
|
||||||
Populated by the estimator on every inference run when the MeTRAbs
|
|
||||||
model was loaded through
|
|
||||||
:meth:`neuropose.estimator.Estimator.load_model` (the production
|
|
||||||
path). ``None`` when the model was injected directly via the
|
|
||||||
``Estimator(model=...)`` constructor (the test-fixture path), since
|
|
||||||
NeuroPose has no way to fingerprint a model it did not load itself.
|
|
||||||
|
|
||||||
Paper C's reproducibility story rests on this envelope: two runs
|
|
||||||
that produced equal ``Provenance`` objects against the same input
|
|
||||||
are expected to produce equal output (modulo non-determinism
|
|
||||||
controlled by ``deterministic``). Reviewers who want to re-derive a
|
|
||||||
figure from raw video need exactly these fields.
|
|
||||||
|
|
||||||
Frozen so a captured ``Provenance`` cannot be mutated after it has
|
|
||||||
been attached to a result; this matches the invariant that
|
|
||||||
provenance is a property of the run, not of the reader.
|
|
||||||
|
|
||||||
``protected_namespaces=()`` silences pydantic's ``model_*`` field
|
|
||||||
warning — the ``model_sha256`` / ``model_filename`` names refer to
|
|
||||||
the MeTRAbs model artifact, not to pydantic's internal
|
|
||||||
``model_validate`` / ``model_dump`` namespace, so the collision is
|
|
||||||
cosmetic.
|
|
||||||
"""
|
|
||||||
|
|
||||||
model_config = ConfigDict(extra="forbid", frozen=True, protected_namespaces=())
|
|
||||||
|
|
||||||
model_sha256: str = Field(
|
|
||||||
description=(
|
|
||||||
"SHA-256 of the MeTRAbs model tarball (hex-encoded, lowercase). "
|
|
||||||
"Pinned at build time in :mod:`neuropose._model` and verified on "
|
|
||||||
"first download. Identifies the exact model weights used."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
model_filename: str = Field(
|
|
||||||
description=(
|
|
||||||
"Canonical basename of the MeTRAbs tarball, e.g. "
|
|
||||||
"``metrabs_eff2l_y4_384px_800k_28ds.tar.gz``. Human-readable "
|
|
||||||
"companion to ``model_sha256``."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
tensorflow_version: str = Field(
|
|
||||||
description="Value of ``tensorflow.__version__`` at the time of the run.",
|
|
||||||
)
|
|
||||||
tensorflow_metal_version: str | None = Field(
|
|
||||||
default=None,
|
|
||||||
description=(
|
|
||||||
"Version of the ``tensorflow-metal`` PyPI package when installed; "
|
|
||||||
"``None`` on platforms without Metal GPU acceleration."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
numpy_version: str = Field(
|
|
||||||
description="Value of ``numpy.__version__`` at the time of the run.",
|
|
||||||
)
|
|
||||||
neuropose_version: str = Field(
|
|
||||||
description="Value of ``neuropose.__version__`` at the time of the run.",
|
|
||||||
)
|
|
||||||
python_version: str = Field(
|
|
||||||
description=(
|
|
||||||
"Python version as ``MAJOR.MINOR.MICRO``, e.g. ``3.11.14``. The "
|
|
||||||
"full ``sys.version`` string is intentionally not captured; the "
|
|
||||||
"three-component form is stable across patch builds and avoids "
|
|
||||||
"embedding compiler and build-date metadata."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
seed: int | None = Field(
|
|
||||||
default=None,
|
|
||||||
description=(
|
|
||||||
"Random seed used for the run if one was set, else ``None``. "
|
|
||||||
"MeTRAbs inference is deterministic on a given device up to "
|
|
||||||
"floating-point associativity, so seeding mostly matters for "
|
|
||||||
"downstream analysis that introduces randomness (bootstraps, "
|
|
||||||
"learned metrics)."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
deterministic: bool = Field(
|
|
||||||
default=False,
|
|
||||||
description=(
|
|
||||||
"``True`` if ``tf.config.experimental.enable_op_determinism()`` "
|
|
||||||
"was active during the run. Track 2 deterministic-inference "
|
|
||||||
"mode; the field exists in Phase 0 so payloads can record "
|
|
||||||
"whether the run *was* deterministic without requiring a "
|
|
||||||
"schema change when the toggle lands."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
analysis_config: dict[str, Any] | None = Field(
|
|
||||||
default=None,
|
|
||||||
description=(
|
|
||||||
"Parsed YAML dict if this payload was produced by ``neuropose "
|
|
||||||
"analyze --config <file>``. ``None`` for direct-library or "
|
|
||||||
"``neuropose watch`` invocations. Reserved for the Phase 0 "
|
|
||||||
"YAML-configurable analysis pipeline."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BenchmarkAggregate(BaseModel):
|
class BenchmarkAggregate(BaseModel):
|
||||||
"""Distributional statistics aggregated across benchmark passes.
|
"""Distributional statistics aggregated across benchmark passes.
|
||||||
|
|
||||||
|
|
@ -368,16 +255,6 @@ class BenchmarkResult(BaseModel):
|
||||||
|
|
||||||
model_config = ConfigDict(extra="forbid", frozen=True)
|
model_config = ConfigDict(extra="forbid", frozen=True)
|
||||||
|
|
||||||
schema_version: int = Field(
|
|
||||||
default=CURRENT_VERSION,
|
|
||||||
ge=1,
|
|
||||||
description=(
|
|
||||||
"Schema version of this BenchmarkResult payload. Fresh writes "
|
|
||||||
"stamp :data:`neuropose.migrations.CURRENT_VERSION`; older files "
|
|
||||||
"are migrated on load via :mod:`neuropose.migrations` before "
|
|
||||||
"pydantic validation."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
video_name: str = Field(
|
video_name: str = Field(
|
||||||
description="Basename of the benchmarked video (no directory components).",
|
description="Basename of the benchmarked video (no directory components).",
|
||||||
)
|
)
|
||||||
|
|
@ -403,14 +280,6 @@ class BenchmarkResult(BaseModel):
|
||||||
)
|
)
|
||||||
aggregate: BenchmarkAggregate
|
aggregate: BenchmarkAggregate
|
||||||
cpu_comparison: CpuComparisonResult | None = None
|
cpu_comparison: CpuComparisonResult | None = None
|
||||||
provenance: Provenance | None = Field(
|
|
||||||
default=None,
|
|
||||||
description=(
|
|
||||||
"Reproducibility envelope from the benchmark run. ``None`` on "
|
|
||||||
"tests where the model was injected directly via "
|
|
||||||
"``Estimator(model=...)``."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class JointAxisExtractor(BaseModel):
|
class JointAxisExtractor(BaseModel):
|
||||||
|
|
@ -600,30 +469,9 @@ class VideoPredictions(BaseModel):
|
||||||
|
|
||||||
model_config = ConfigDict(extra="forbid", frozen=True)
|
model_config = ConfigDict(extra="forbid", frozen=True)
|
||||||
|
|
||||||
schema_version: int = Field(
|
|
||||||
default=CURRENT_VERSION,
|
|
||||||
ge=1,
|
|
||||||
description=(
|
|
||||||
"Schema version of this VideoPredictions payload. Fresh writes "
|
|
||||||
"stamp :data:`neuropose.migrations.CURRENT_VERSION`; files written "
|
|
||||||
"by older NeuroPose versions are migrated to the current version "
|
|
||||||
"by :mod:`neuropose.migrations` before pydantic validation."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
metadata: VideoMetadata
|
metadata: VideoMetadata
|
||||||
frames: dict[str, FramePrediction]
|
frames: dict[str, FramePrediction]
|
||||||
segmentations: dict[str, Segmentation] = Field(default_factory=dict)
|
segmentations: dict[str, Segmentation] = Field(default_factory=dict)
|
||||||
provenance: Provenance | None = Field(
|
|
||||||
default=None,
|
|
||||||
description=(
|
|
||||||
"Reproducibility envelope populated by the estimator on runs "
|
|
||||||
"where the MeTRAbs model was loaded via "
|
|
||||||
":meth:`neuropose.estimator.Estimator.load_model`. ``None`` on "
|
|
||||||
"test paths where the model was injected via "
|
|
||||||
"``Estimator(model=...)``, because no model SHA is known in "
|
|
||||||
"that case."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
def frame_names(self) -> list[str]:
|
def frame_names(self) -> list[str]:
|
||||||
"""Return frame identifiers in insertion order."""
|
"""Return frame identifiers in insertion order."""
|
||||||
|
|
@ -775,16 +623,9 @@ class StatusFile(RootModel[dict[str, JobStatusEntry]]):
|
||||||
|
|
||||||
|
|
||||||
def load_video_predictions(path: Path) -> VideoPredictions:
|
def load_video_predictions(path: Path) -> VideoPredictions:
|
||||||
"""Load and validate a per-video predictions JSON file.
|
"""Load and validate a per-video predictions JSON file."""
|
||||||
|
|
||||||
Runs the payload through :func:`neuropose.migrations.migrate_video_predictions`
|
|
||||||
before pydantic validation so files written by older NeuroPose versions
|
|
||||||
upgrade to the current schema transparently.
|
|
||||||
"""
|
|
||||||
with path.open("r", encoding="utf-8") as f:
|
with path.open("r", encoding="utf-8") as f:
|
||||||
data: Any = json.load(f)
|
data: Any = json.load(f)
|
||||||
if isinstance(data, dict):
|
|
||||||
data = migrate_video_predictions(data)
|
|
||||||
return VideoPredictions.model_validate(data)
|
return VideoPredictions.model_validate(data)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -795,17 +636,9 @@ def save_video_predictions(path: Path, predictions: VideoPredictions) -> None:
|
||||||
|
|
||||||
|
|
||||||
def load_job_results(path: Path) -> JobResults:
|
def load_job_results(path: Path) -> JobResults:
|
||||||
"""Load and validate an aggregated per-job results JSON file.
|
"""Load and validate an aggregated per-job results JSON file."""
|
||||||
|
|
||||||
Runs each video's payload through
|
|
||||||
:func:`neuropose.migrations.migrate_video_predictions` before pydantic
|
|
||||||
validation. :class:`JobResults` is a ``RootModel`` with no envelope of
|
|
||||||
its own, so migration happens per-entry rather than at the top level.
|
|
||||||
"""
|
|
||||||
with path.open("r", encoding="utf-8") as f:
|
with path.open("r", encoding="utf-8") as f:
|
||||||
data: Any = json.load(f)
|
data: Any = json.load(f)
|
||||||
if isinstance(data, dict):
|
|
||||||
data = migrate_job_results(data)
|
|
||||||
return JobResults.model_validate(data)
|
return JobResults.model_validate(data)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -816,16 +649,9 @@ def save_job_results(path: Path, results: JobResults) -> None:
|
||||||
|
|
||||||
|
|
||||||
def load_benchmark_result(path: Path) -> BenchmarkResult:
|
def load_benchmark_result(path: Path) -> BenchmarkResult:
|
||||||
"""Load and validate a benchmark-result JSON file.
|
"""Load and validate a benchmark-result JSON file."""
|
||||||
|
|
||||||
Runs the payload through :func:`neuropose.migrations.migrate_benchmark_result`
|
|
||||||
before pydantic validation so files written by older NeuroPose versions
|
|
||||||
upgrade transparently.
|
|
||||||
"""
|
|
||||||
with path.open("r", encoding="utf-8") as f:
|
with path.open("r", encoding="utf-8") as f:
|
||||||
data: Any = json.load(f)
|
data: Any = json.load(f)
|
||||||
if isinstance(data, dict):
|
|
||||||
data = migrate_benchmark_result(data)
|
|
||||||
return BenchmarkResult.model_validate(data)
|
return BenchmarkResult.model_validate(data)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,318 +0,0 @@
|
||||||
"""Schema migration infrastructure for serialised NeuroPose payloads.
|
|
||||||
|
|
||||||
Every top-level JSON schema that NeuroPose persists to disk
|
|
||||||
(:class:`~neuropose.io.VideoPredictions`,
|
|
||||||
:class:`~neuropose.io.JobResults`, and
|
|
||||||
:class:`~neuropose.io.BenchmarkResult`) carries a ``schema_version``
|
|
||||||
integer. When those files are read back, the raw dict is passed
|
|
||||||
through :func:`migrate_video_predictions` /
|
|
||||||
:func:`migrate_job_results` / :func:`migrate_benchmark_result`
|
|
||||||
*before* pydantic validation runs, so each schema version can be
|
|
||||||
brought up to the current one transparently.
|
|
||||||
|
|
||||||
The pattern is deliberately small: one integer version counter shared
|
|
||||||
across all top-level schemas, plus a per-schema registry of
|
|
||||||
``{from_version: migration_fn}``. Each migration is a pure function
|
|
||||||
``dict -> dict`` responsible for stamping the new ``schema_version``
|
|
||||||
on its output. The framework chains them.
|
|
||||||
|
|
||||||
This module is intentionally separate from :mod:`neuropose.io` so
|
|
||||||
that migration registrations cannot accidentally import the pydantic
|
|
||||||
models they migrate — migrations must operate on raw dicts to be
|
|
||||||
robust to schema drift (a field a migration references may not exist
|
|
||||||
on the pydantic model by the time CURRENT_VERSION has moved past
|
|
||||||
it).
|
|
||||||
|
|
||||||
Adding a new migration
|
|
||||||
----------------------
|
|
||||||
When a schema change lands:
|
|
||||||
|
|
||||||
1. Bump :data:`CURRENT_VERSION`.
|
|
||||||
2. Register a migration from the *previous* version to the new one
|
|
||||||
via :func:`register_video_predictions_migration` (or the sibling
|
|
||||||
for benchmark results). The function receives the raw dict at the
|
|
||||||
old version and must return a dict at the new version *including*
|
|
||||||
the updated ``schema_version`` stamp.
|
|
||||||
3. Update the pydantic model in :mod:`neuropose.io` to reflect the
|
|
||||||
new field set.
|
|
||||||
4. Add a unit test verifying that a fixture at the old version
|
|
||||||
round-trips through ``load_*`` to the expected new-version shape.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from collections.abc import Callable
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
CURRENT_VERSION = 2
|
|
||||||
"""The current schema version for all NeuroPose-persisted JSON payloads.
|
|
||||||
|
|
||||||
Shared across :class:`~neuropose.io.VideoPredictions`,
|
|
||||||
:class:`~neuropose.io.JobResults`, and
|
|
||||||
:class:`~neuropose.io.BenchmarkResult` so that coordinated schema
|
|
||||||
changes (for example, adding a ``provenance`` field to all three at
|
|
||||||
once) bump a single counter rather than three parallel ones.
|
|
||||||
|
|
||||||
Version history
|
|
||||||
---------------
|
|
||||||
- **v1:** initial schema, pre-Phase-0.
|
|
||||||
- **v2:** added optional ``provenance`` field to :class:`~neuropose.io.VideoPredictions`
|
|
||||||
and :class:`~neuropose.io.BenchmarkResult` (Phase 0, Paper C reproducibility envelope).
|
|
||||||
:class:`~neuropose.analyzer.pipeline.AnalysisReport` also enters the registry at v2
|
|
||||||
(no legacy v1 payloads ever existed for it, so no migration is registered)."""
|
|
||||||
|
|
||||||
|
|
||||||
class MigrationError(Exception):
|
|
||||||
"""Base class for schema-migration failures."""
|
|
||||||
|
|
||||||
|
|
||||||
class FutureSchemaError(MigrationError):
|
|
||||||
"""Raised when a payload's ``schema_version`` exceeds :data:`CURRENT_VERSION`.
|
|
||||||
|
|
||||||
Produced when a newer NeuroPose version writes a file and an older
|
|
||||||
version tries to read it. The fix is upgrading NeuroPose; silently
|
|
||||||
stripping fields would corrupt the payload.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class MigrationNotFoundError(MigrationError):
|
|
||||||
"""Raised when no migration is registered for an intermediate version.
|
|
||||||
|
|
||||||
Indicates a bug — :data:`CURRENT_VERSION` was bumped past a version
|
|
||||||
for which no migration function was registered. Should only surface
|
|
||||||
in tests or on a corrupted install.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
# Per-schema registries. Keys are the *source* version of the migration;
|
|
||||||
# the value is a callable that takes a dict at that version and returns a
|
|
||||||
# dict at ``source + 1``.
|
|
||||||
_VIDEO_PREDICTIONS_MIGRATIONS: dict[int, Callable[[dict], dict]] = {}
|
|
||||||
_BENCHMARK_RESULT_MIGRATIONS: dict[int, Callable[[dict], dict]] = {}
|
|
||||||
_ANALYSIS_REPORT_MIGRATIONS: dict[int, Callable[[dict], dict]] = {}
|
|
||||||
|
|
||||||
|
|
||||||
def register_video_predictions_migration(
|
|
||||||
from_version: int,
|
|
||||||
) -> Callable[[Callable[[dict], dict]], Callable[[dict], dict]]:
|
|
||||||
"""Register a :class:`~neuropose.io.VideoPredictions` migration.
|
|
||||||
|
|
||||||
Usage::
|
|
||||||
|
|
||||||
@register_video_predictions_migration(from_version=1)
|
|
||||||
def _v1_to_v2(payload: dict) -> dict:
|
|
||||||
payload = dict(payload)
|
|
||||||
payload["provenance"] = None
|
|
||||||
payload["schema_version"] = 2
|
|
||||||
return payload
|
|
||||||
|
|
||||||
The decorator registers the function into the per-schema migration
|
|
||||||
registry and returns it unchanged, so it can still be called
|
|
||||||
directly from tests.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def wrap(fn: Callable[[dict], dict]) -> Callable[[dict], dict]:
|
|
||||||
if from_version in _VIDEO_PREDICTIONS_MIGRATIONS:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"video-predictions migration already registered from version {from_version}"
|
|
||||||
)
|
|
||||||
_VIDEO_PREDICTIONS_MIGRATIONS[from_version] = fn
|
|
||||||
return fn
|
|
||||||
|
|
||||||
return wrap
|
|
||||||
|
|
||||||
|
|
||||||
def register_benchmark_result_migration(
|
|
||||||
from_version: int,
|
|
||||||
) -> Callable[[Callable[[dict], dict]], Callable[[dict], dict]]:
|
|
||||||
"""Register a :class:`~neuropose.io.BenchmarkResult` migration.
|
|
||||||
|
|
||||||
See :func:`register_video_predictions_migration` for usage — this
|
|
||||||
is the same pattern for the benchmark-result registry.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def wrap(fn: Callable[[dict], dict]) -> Callable[[dict], dict]:
|
|
||||||
if from_version in _BENCHMARK_RESULT_MIGRATIONS:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"benchmark-result migration already registered from version {from_version}"
|
|
||||||
)
|
|
||||||
_BENCHMARK_RESULT_MIGRATIONS[from_version] = fn
|
|
||||||
return fn
|
|
||||||
|
|
||||||
return wrap
|
|
||||||
|
|
||||||
|
|
||||||
def register_analysis_report_migration(
|
|
||||||
from_version: int,
|
|
||||||
) -> Callable[[Callable[[dict], dict]], Callable[[dict], dict]]:
|
|
||||||
"""Register a :class:`~neuropose.analyzer.pipeline.AnalysisReport` migration.
|
|
||||||
|
|
||||||
See :func:`register_video_predictions_migration` for usage — this
|
|
||||||
is the same pattern for the analysis-report registry. Unlike the
|
|
||||||
other two schemas, :class:`AnalysisReport` first appeared at
|
|
||||||
:data:`CURRENT_VERSION = 2`, so no ``from_version=1`` migration
|
|
||||||
exists (and none is expected).
|
|
||||||
"""
|
|
||||||
|
|
||||||
def wrap(fn: Callable[[dict], dict]) -> Callable[[dict], dict]:
|
|
||||||
if from_version in _ANALYSIS_REPORT_MIGRATIONS:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"analysis-report migration already registered from version {from_version}"
|
|
||||||
)
|
|
||||||
_ANALYSIS_REPORT_MIGRATIONS[from_version] = fn
|
|
||||||
return fn
|
|
||||||
|
|
||||||
return wrap
|
|
||||||
|
|
||||||
|
|
||||||
def migrate_video_predictions(payload: dict) -> dict:
|
|
||||||
"""Migrate a raw :class:`~neuropose.io.VideoPredictions` dict to current.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
payload
|
|
||||||
Raw JSON-loaded dict. Must not yet have been through pydantic
|
|
||||||
validation. A missing ``schema_version`` key is interpreted as
|
|
||||||
version ``1`` (the earliest tracked version, shipped before
|
|
||||||
the migration infrastructure existed).
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
dict
|
|
||||||
The payload at :data:`CURRENT_VERSION`. Ready to be passed to
|
|
||||||
``VideoPredictions.model_validate``.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
FutureSchemaError
|
|
||||||
If the payload declares a ``schema_version`` higher than
|
|
||||||
:data:`CURRENT_VERSION`.
|
|
||||||
MigrationNotFoundError
|
|
||||||
If an intermediate migration is missing from the registry.
|
|
||||||
"""
|
|
||||||
return _migrate(payload, _VIDEO_PREDICTIONS_MIGRATIONS, schema_name="VideoPredictions")
|
|
||||||
|
|
||||||
|
|
||||||
def migrate_benchmark_result(payload: dict) -> dict:
|
|
||||||
"""Migrate a raw :class:`~neuropose.io.BenchmarkResult` dict to current.
|
|
||||||
|
|
||||||
See :func:`migrate_video_predictions` for semantics. This is the
|
|
||||||
sibling function for benchmark-result payloads.
|
|
||||||
"""
|
|
||||||
return _migrate(payload, _BENCHMARK_RESULT_MIGRATIONS, schema_name="BenchmarkResult")
|
|
||||||
|
|
||||||
|
|
||||||
def migrate_analysis_report(payload: dict) -> dict:
|
|
||||||
"""Migrate a raw :class:`~neuropose.analyzer.pipeline.AnalysisReport` dict.
|
|
||||||
|
|
||||||
See :func:`migrate_video_predictions` for semantics. Because
|
|
||||||
:class:`AnalysisReport` first shipped at schema_version 2, a
|
|
||||||
payload missing the key still defaults to 1 (and would require a
|
|
||||||
not-yet-registered v1 → v2 migration); this is only reachable for
|
|
||||||
deliberately malformed inputs.
|
|
||||||
"""
|
|
||||||
return _migrate(payload, _ANALYSIS_REPORT_MIGRATIONS, schema_name="AnalysisReport")
|
|
||||||
|
|
||||||
|
|
||||||
def migrate_job_results(payload: dict) -> dict:
|
|
||||||
"""Migrate a :class:`~neuropose.io.JobResults` root dict to current.
|
|
||||||
|
|
||||||
``JobResults`` is a ``RootModel`` whose root is a mapping of video
|
|
||||||
name to :class:`~neuropose.io.VideoPredictions` payload. It has no
|
|
||||||
envelope of its own, so the migration is "run
|
|
||||||
:func:`migrate_video_predictions` on every value in the mapping."
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
payload
|
|
||||||
Raw JSON-loaded dict of ``{video_name: VideoPredictions-shaped dict}``.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
dict
|
|
||||||
The same mapping with each video payload migrated to the
|
|
||||||
current schema version.
|
|
||||||
"""
|
|
||||||
return {name: migrate_video_predictions(video) for name, video in payload.items()}
|
|
||||||
|
|
||||||
|
|
||||||
def _migrate(
|
|
||||||
payload: dict,
|
|
||||||
migrations: dict[int, Callable[[dict], dict]],
|
|
||||||
*,
|
|
||||||
schema_name: str,
|
|
||||||
) -> dict:
|
|
||||||
"""Walk the migration chain to :data:`CURRENT_VERSION` and return the migrated payload.
|
|
||||||
|
|
||||||
Shared driver for :func:`migrate_video_predictions` and
|
|
||||||
:func:`migrate_benchmark_result`. Looks up the incoming
|
|
||||||
``schema_version`` (defaulting to 1 when absent), walks the migration
|
|
||||||
chain until reaching :data:`CURRENT_VERSION`, and returns the
|
|
||||||
migrated payload. Logs at INFO each time it actually advances a
|
|
||||||
version so operators see the upgrade happen.
|
|
||||||
"""
|
|
||||||
version = payload.get("schema_version", 1)
|
|
||||||
if not isinstance(version, int) or version < 1:
|
|
||||||
raise MigrationError(
|
|
||||||
f"{schema_name} payload has invalid schema_version {version!r}; must be an integer >= 1"
|
|
||||||
)
|
|
||||||
if version > CURRENT_VERSION:
|
|
||||||
raise FutureSchemaError(
|
|
||||||
f"{schema_name} payload declares schema_version {version}, which is newer "
|
|
||||||
f"than this build's CURRENT_VERSION ({CURRENT_VERSION}). Upgrade NeuroPose."
|
|
||||||
)
|
|
||||||
while version < CURRENT_VERSION:
|
|
||||||
if version not in migrations:
|
|
||||||
raise MigrationNotFoundError(
|
|
||||||
f"no {schema_name} migration registered from schema_version {version}"
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
"Migrating %s payload from schema_version %d to %d",
|
|
||||||
schema_name,
|
|
||||||
version,
|
|
||||||
version + 1,
|
|
||||||
)
|
|
||||||
payload = migrations[version](payload)
|
|
||||||
version += 1
|
|
||||||
return payload
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Registered migrations
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
#
|
|
||||||
# Keep registrations *below* the driver so the module's public API surfaces
|
|
||||||
# at the top and the version-specific diffs live together at the bottom where
|
|
||||||
# they are easiest to audit chronologically.
|
|
||||||
|
|
||||||
|
|
||||||
@register_video_predictions_migration(from_version=1)
|
|
||||||
def _video_predictions_v1_to_v2(payload: dict) -> dict:
|
|
||||||
"""v1 → v2: add the optional ``provenance`` field (Phase 0).
|
|
||||||
|
|
||||||
Phase 0 introduces the :class:`~neuropose.io.Provenance` envelope
|
|
||||||
for Paper C reproducibility. v1 files predate it, so we stamp
|
|
||||||
``provenance = None`` on load — the field is optional on the
|
|
||||||
pydantic model and ``None`` correctly indicates "we don't have
|
|
||||||
provenance metadata for this payload."
|
|
||||||
"""
|
|
||||||
payload = dict(payload)
|
|
||||||
payload.setdefault("provenance", None)
|
|
||||||
payload["schema_version"] = 2
|
|
||||||
return payload
|
|
||||||
|
|
||||||
|
|
||||||
@register_benchmark_result_migration(from_version=1)
|
|
||||||
def _benchmark_result_v1_to_v2(payload: dict) -> dict:
|
|
||||||
"""v1 → v2: add the optional ``provenance`` field (Phase 0).
|
|
||||||
|
|
||||||
Sibling of :func:`_video_predictions_v1_to_v2` for benchmark
|
|
||||||
payloads; same rationale.
|
|
||||||
"""
|
|
||||||
payload = dict(payload)
|
|
||||||
payload.setdefault("provenance", None)
|
|
||||||
payload["schema_version"] = 2
|
|
||||||
return payload
|
|
||||||
|
|
@ -1,388 +0,0 @@
|
||||||
"""Pipeline-wide reset utility.
|
|
||||||
|
|
||||||
Tear down a running NeuroPose deployment back to a clean state: stop
|
|
||||||
the watch daemon and the monitor server, then wipe the job queue,
|
|
||||||
results, status file, and ingest staging directories. Intended for
|
|
||||||
the rapid iteration loop that comes up during benchmark and validation
|
|
||||||
work, where you want an empty ``$data_dir/in`` without manually
|
|
||||||
``rm -rf``-ing five separate paths and pkill'ing two processes.
|
|
||||||
|
|
||||||
The module is split into three independently-callable layers so each
|
|
||||||
piece is testable in isolation and reusable from non-CLI contexts:
|
|
||||||
|
|
||||||
- :func:`find_neuropose_processes` enumerates running ``neuropose
|
|
||||||
watch`` and ``neuropose serve`` processes by scanning the OS process
|
|
||||||
table. Pure read; no side effects.
|
|
||||||
- :func:`terminate_processes` signals the discovered processes (SIGINT
|
|
||||||
first, optionally SIGKILL after a grace period) and reports
|
|
||||||
survivors.
|
|
||||||
- :func:`wipe_state` removes the data-directory paths that the daemon
|
|
||||||
and monitor produce. Idempotent; safe against a fresh install.
|
|
||||||
|
|
||||||
The top-level :func:`reset_pipeline` orchestrates all three and
|
|
||||||
returns a :class:`ResetReport` summarizing what happened.
|
|
||||||
|
|
||||||
Safety
|
|
||||||
------
|
|
||||||
:func:`reset_pipeline` refuses to wipe state while *any* discovered
|
|
||||||
process is still alive after the termination phase. Wiping
|
|
||||||
``$data_dir`` out from under an active daemon would leave the daemon
|
|
||||||
writing into deleted directory entries — a guaranteed mess. Callers
|
|
||||||
that hit a survivor must either raise the grace period, opt into
|
|
||||||
``force_kill=True`` (SIGKILL), or kill the survivor manually before
|
|
||||||
re-running.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import contextlib
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import shutil
|
|
||||||
import signal
|
|
||||||
import time
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
import psutil
|
|
||||||
|
|
||||||
from neuropose.config import Settings
|
|
||||||
from neuropose.interfacer import LOCK_FILENAME
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
# Cmdline substrings that identify the daemon and monitor.
|
|
||||||
# Matched against the joined argv so ``uv run neuropose watch -v`` and
|
|
||||||
# ``python -m neuropose watch`` both hit. The substring choice
|
|
||||||
# deliberately includes the subcommand name so a generic
|
|
||||||
# ``neuropose --help`` shell doesn't get caught.
|
|
||||||
_DAEMON_MARKER = "neuropose watch"
|
|
||||||
_MONITOR_MARKER = "neuropose serve"
|
|
||||||
|
|
||||||
DEFAULT_GRACE_SECONDS = 10.0
|
|
||||||
"""How long :func:`terminate_processes` waits after SIGINT before
|
|
||||||
declaring a process a survivor (or escalating to SIGKILL when
|
|
||||||
``force_kill`` is set). Long enough for an idle daemon to finish its
|
|
||||||
current poll iteration; short enough that an interactive ``reset``
|
|
||||||
invocation doesn't feel hung. Override per-call when waiting on a
|
|
||||||
multi-minute inference."""
|
|
||||||
|
|
||||||
_POLL_INTERVAL_SECONDS = 0.2
|
|
||||||
|
|
||||||
ProcessRole = Literal["daemon", "monitor"]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class RunningProcess:
|
|
||||||
"""A neuropose process discovered in the OS process table."""
|
|
||||||
|
|
||||||
pid: int
|
|
||||||
role: ProcessRole
|
|
||||||
cmdline: str
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TerminationReport:
|
|
||||||
"""Outcome of trying to stop a set of running processes."""
|
|
||||||
|
|
||||||
stopped: list[RunningProcess] = field(default_factory=list)
|
|
||||||
survivors: list[RunningProcess] = field(default_factory=list)
|
|
||||||
force_killed: list[RunningProcess] = field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class WipeReport:
|
|
||||||
"""Outcome of wiping data-directory state."""
|
|
||||||
|
|
||||||
removed_paths: list[Path] = field(default_factory=list)
|
|
||||||
bytes_freed: int = 0
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ResetReport:
|
|
||||||
"""Aggregate report from a full pipeline reset."""
|
|
||||||
|
|
||||||
discovered: list[RunningProcess]
|
|
||||||
termination: TerminationReport
|
|
||||||
wipe: WipeReport
|
|
||||||
dry_run: bool
|
|
||||||
wipe_skipped_due_to_survivors: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
def find_neuropose_processes(*, exclude_self: bool = True) -> list[RunningProcess]:
|
|
||||||
"""Scan the process table for ``neuropose watch`` / ``neuropose serve``.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
exclude_self
|
|
||||||
Skip the current process. The ``neuropose reset`` command
|
|
||||||
itself has ``"neuropose"`` in its argv and would otherwise see
|
|
||||||
itself in the result. Set to ``False`` only in tests where the
|
|
||||||
caller has constructed a process table that should match
|
|
||||||
verbatim.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
list[RunningProcess]
|
|
||||||
Processes whose joined argv contains either marker substring.
|
|
||||||
Empty list when nothing matches.
|
|
||||||
"""
|
|
||||||
self_pid = os.getpid()
|
|
||||||
found: list[RunningProcess] = []
|
|
||||||
for proc in psutil.process_iter(["pid", "cmdline"]):
|
|
||||||
try:
|
|
||||||
pid = proc.info["pid"]
|
|
||||||
cmdline = proc.info["cmdline"] or []
|
|
||||||
except (psutil.NoSuchProcess, psutil.AccessDenied):
|
|
||||||
continue
|
|
||||||
if exclude_self and pid == self_pid:
|
|
||||||
continue
|
|
||||||
joined = " ".join(cmdline)
|
|
||||||
# Daemon check first because "neuropose serve" and
|
|
||||||
# "neuropose watch" cannot both appear in a single process.
|
|
||||||
if _DAEMON_MARKER in joined:
|
|
||||||
found.append(RunningProcess(pid=pid, role="daemon", cmdline=joined))
|
|
||||||
elif _MONITOR_MARKER in joined:
|
|
||||||
found.append(RunningProcess(pid=pid, role="monitor", cmdline=joined))
|
|
||||||
return found
|
|
||||||
|
|
||||||
|
|
||||||
def terminate_processes(
|
|
||||||
processes: list[RunningProcess],
|
|
||||||
*,
|
|
||||||
grace_seconds: float = DEFAULT_GRACE_SECONDS,
|
|
||||||
force_kill: bool = False,
|
|
||||||
) -> TerminationReport:
|
|
||||||
"""Send SIGINT to each process; optionally escalate to SIGKILL.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
processes
|
|
||||||
The processes to stop. Pass an empty list for a no-op.
|
|
||||||
grace_seconds
|
|
||||||
Maximum time to wait for processes to exit after SIGINT.
|
|
||||||
force_kill
|
|
||||||
When ``True``, any process still alive after ``grace_seconds``
|
|
||||||
is sent SIGKILL. When ``False``, survivors are reported back
|
|
||||||
to the caller untouched.
|
|
||||||
|
|
||||||
Notes
|
|
||||||
-----
|
|
||||||
SIGTERM is *not* used as an intermediate escalation step. The
|
|
||||||
interfacer's signal handler treats SIGINT and SIGTERM identically
|
|
||||||
(both call :meth:`Interfacer.stop`), so SIGTERM accomplishes
|
|
||||||
nothing that SIGINT did not already attempt. The only escalation
|
|
||||||
that actually forces a stuck daemon to exit is SIGKILL, which
|
|
||||||
bypasses the handler entirely.
|
|
||||||
"""
|
|
||||||
report = TerminationReport()
|
|
||||||
if not processes:
|
|
||||||
return report
|
|
||||||
|
|
||||||
for rp in processes:
|
|
||||||
with contextlib.suppress(ProcessLookupError, PermissionError):
|
|
||||||
os.kill(rp.pid, signal.SIGINT)
|
|
||||||
logger.info("sent SIGINT to pid %d (%s)", rp.pid, rp.role)
|
|
||||||
|
|
||||||
survivors = _wait_for_exit(processes, grace_seconds)
|
|
||||||
stopped = [p for p in processes if p not in survivors]
|
|
||||||
report.stopped.extend(stopped)
|
|
||||||
|
|
||||||
if not survivors:
|
|
||||||
return report
|
|
||||||
|
|
||||||
if not force_kill:
|
|
||||||
report.survivors.extend(survivors)
|
|
||||||
return report
|
|
||||||
|
|
||||||
for rp in survivors:
|
|
||||||
with contextlib.suppress(ProcessLookupError, PermissionError):
|
|
||||||
os.kill(rp.pid, signal.SIGKILL)
|
|
||||||
logger.warning("escalated to SIGKILL for pid %d (%s)", rp.pid, rp.role)
|
|
||||||
|
|
||||||
# SIGKILL is delivered synchronously enough that a short final
|
|
||||||
# poll is sufficient — any remaining "survivor" at this point is
|
|
||||||
# a permission error or a kernel-side hang, not graceful shutdown.
|
|
||||||
final_survivors = _wait_for_exit(survivors, grace_seconds=2.0)
|
|
||||||
killed = [p for p in survivors if p not in final_survivors]
|
|
||||||
report.force_killed.extend(killed)
|
|
||||||
report.survivors.extend(final_survivors)
|
|
||||||
return report
|
|
||||||
|
|
||||||
|
|
||||||
def _wait_for_exit(
|
|
||||||
processes: list[RunningProcess],
|
|
||||||
grace_seconds: float,
|
|
||||||
) -> list[RunningProcess]:
|
|
||||||
"""Poll until every process exits or the deadline passes."""
|
|
||||||
deadline = time.monotonic() + grace_seconds
|
|
||||||
while time.monotonic() < deadline:
|
|
||||||
survivors = [p for p in processes if _is_alive(p.pid)]
|
|
||||||
if not survivors:
|
|
||||||
return []
|
|
||||||
time.sleep(_POLL_INTERVAL_SECONDS)
|
|
||||||
return [p for p in processes if _is_alive(p.pid)]
|
|
||||||
|
|
||||||
|
|
||||||
def _is_alive(pid: int) -> bool:
|
|
||||||
"""Return ``True`` if ``pid`` is still running and not a zombie."""
|
|
||||||
try:
|
|
||||||
proc = psutil.Process(pid)
|
|
||||||
except psutil.NoSuchProcess:
|
|
||||||
return False
|
|
||||||
try:
|
|
||||||
return proc.status() != psutil.STATUS_ZOMBIE
|
|
||||||
except psutil.NoSuchProcess:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def wipe_state(
|
|
||||||
settings: Settings,
|
|
||||||
*,
|
|
||||||
keep_failed: bool = False,
|
|
||||||
dry_run: bool = False,
|
|
||||||
) -> WipeReport:
|
|
||||||
"""Remove data-directory paths produced by the daemon and monitor.
|
|
||||||
|
|
||||||
Removes the *contents* of ``in/``, ``out/``, and (unless
|
|
||||||
``keep_failed`` is set) ``failed/``, plus the daemon lock file and
|
|
||||||
any leftover ``.ingest_<uuid>/`` staging directories. The
|
|
||||||
container directories themselves are preserved so the daemon does
|
|
||||||
not need to recreate them on next startup.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
settings
|
|
||||||
Resolved :class:`~neuropose.config.Settings`. Determines all
|
|
||||||
target paths via the ``input_dir`` / ``output_dir`` /
|
|
||||||
``failed_dir`` properties.
|
|
||||||
keep_failed
|
|
||||||
Preserve ``$data_dir/failed/`` for forensic review of past
|
|
||||||
crashes. Default removes it along with the rest of the
|
|
||||||
pipeline state.
|
|
||||||
dry_run
|
|
||||||
Compute the report without actually deleting anything. Useful
|
|
||||||
for previewing the blast radius before confirming a reset.
|
|
||||||
"""
|
|
||||||
report = WipeReport()
|
|
||||||
|
|
||||||
targets: list[Path] = []
|
|
||||||
if settings.input_dir.exists():
|
|
||||||
targets.extend(settings.input_dir.iterdir())
|
|
||||||
if settings.output_dir.exists():
|
|
||||||
targets.extend(settings.output_dir.iterdir())
|
|
||||||
if not keep_failed and settings.failed_dir.exists():
|
|
||||||
targets.extend(settings.failed_dir.iterdir())
|
|
||||||
|
|
||||||
lock_path = settings.data_dir / LOCK_FILENAME
|
|
||||||
if lock_path.exists():
|
|
||||||
targets.append(lock_path)
|
|
||||||
|
|
||||||
if settings.data_dir.exists():
|
|
||||||
targets.extend(settings.data_dir.glob(".ingest_*"))
|
|
||||||
|
|
||||||
for target in targets:
|
|
||||||
size = _path_size(target)
|
|
||||||
if not dry_run:
|
|
||||||
_remove(target)
|
|
||||||
report.removed_paths.append(target)
|
|
||||||
report.bytes_freed += size
|
|
||||||
|
|
||||||
return report
|
|
||||||
|
|
||||||
|
|
||||||
def _path_size(path: Path) -> int:
|
|
||||||
"""Return the cumulative size of ``path``, recursing into directories."""
|
|
||||||
if path.is_symlink() or path.is_file():
|
|
||||||
try:
|
|
||||||
return path.stat().st_size
|
|
||||||
except OSError:
|
|
||||||
return 0
|
|
||||||
total = 0
|
|
||||||
for sub in path.rglob("*"):
|
|
||||||
try:
|
|
||||||
if sub.is_file() and not sub.is_symlink():
|
|
||||||
total += sub.stat().st_size
|
|
||||||
except OSError:
|
|
||||||
continue
|
|
||||||
return total
|
|
||||||
|
|
||||||
|
|
||||||
def _remove(path: Path) -> None:
|
|
||||||
"""Remove ``path`` whether file, symlink, or directory."""
|
|
||||||
if path.is_dir() and not path.is_symlink():
|
|
||||||
shutil.rmtree(path)
|
|
||||||
else:
|
|
||||||
path.unlink()
|
|
||||||
|
|
||||||
|
|
||||||
def reset_pipeline(
|
|
||||||
settings: Settings,
|
|
||||||
*,
|
|
||||||
grace_seconds: float = DEFAULT_GRACE_SECONDS,
|
|
||||||
force_kill: bool = False,
|
|
||||||
keep_failed: bool = False,
|
|
||||||
dry_run: bool = False,
|
|
||||||
) -> ResetReport:
|
|
||||||
"""Stop daemon + monitor, then wipe pipeline state.
|
|
||||||
|
|
||||||
Composes :func:`find_neuropose_processes`,
|
|
||||||
:func:`terminate_processes`, and :func:`wipe_state` into a single
|
|
||||||
operation. The wipe phase is *skipped* if any process survives
|
|
||||||
termination — wiping ``$data_dir`` out from under an active
|
|
||||||
daemon would corrupt its in-flight writes. The returned
|
|
||||||
:class:`ResetReport` flags this case via
|
|
||||||
``wipe_skipped_due_to_survivors``.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
settings
|
|
||||||
Resolved :class:`~neuropose.config.Settings`.
|
|
||||||
grace_seconds
|
|
||||||
Maximum time to wait for SIGINT to take effect.
|
|
||||||
force_kill
|
|
||||||
Escalate to SIGKILL on any process still alive after
|
|
||||||
``grace_seconds``. Necessary when the daemon is mid-inference
|
|
||||||
on a long video and you do not want to wait for the current
|
|
||||||
video to finish.
|
|
||||||
keep_failed
|
|
||||||
Preserve ``$data_dir/failed/`` during the wipe.
|
|
||||||
dry_run
|
|
||||||
Discover and report without killing anything or removing any
|
|
||||||
paths. Termination phase is skipped entirely.
|
|
||||||
"""
|
|
||||||
discovered = find_neuropose_processes()
|
|
||||||
|
|
||||||
if dry_run:
|
|
||||||
wipe = wipe_state(settings, keep_failed=keep_failed, dry_run=True)
|
|
||||||
return ResetReport(
|
|
||||||
discovered=discovered,
|
|
||||||
termination=TerminationReport(),
|
|
||||||
wipe=wipe,
|
|
||||||
dry_run=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
termination = terminate_processes(
|
|
||||||
discovered,
|
|
||||||
grace_seconds=grace_seconds,
|
|
||||||
force_kill=force_kill,
|
|
||||||
)
|
|
||||||
|
|
||||||
if termination.survivors:
|
|
||||||
return ResetReport(
|
|
||||||
discovered=discovered,
|
|
||||||
termination=termination,
|
|
||||||
wipe=WipeReport(),
|
|
||||||
dry_run=False,
|
|
||||||
wipe_skipped_due_to_survivors=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
wipe = wipe_state(settings, keep_failed=keep_failed, dry_run=False)
|
|
||||||
return ResetReport(
|
|
||||||
discovered=discovered,
|
|
||||||
termination=termination,
|
|
||||||
wipe=wipe,
|
|
||||||
dry_run=False,
|
|
||||||
)
|
|
||||||
|
|
@ -1,205 +0,0 @@
|
||||||
"""Example-config sanity integration tests.
|
|
||||||
|
|
||||||
Every YAML config under ``examples/analysis/`` is exercised here in
|
|
||||||
two ways:
|
|
||||||
|
|
||||||
1. ``test_<name>_parses`` — :func:`load_config` accepts the YAML,
|
|
||||||
i.e. the file matches the current :class:`AnalysisConfig` schema
|
|
||||||
including cross-field invariants. Catches silent drift between the
|
|
||||||
example configs and the schema they claim to exercise.
|
|
||||||
2. ``test_<name>_runs`` — the example's pipeline runs end-to-end
|
|
||||||
against synthetic predictions. Paths in the YAML are overwritten
|
|
||||||
with test fixtures before :func:`run_analysis` is invoked;
|
|
||||||
everything else (stages, thresholds, triplets) is used verbatim.
|
|
||||||
|
|
||||||
These tests are deliberately not marked ``slow`` — they use synthetic
|
|
||||||
fixtures and do not touch the MeTRAbs SavedModel, so they run in the
|
|
||||||
default unit-test suite.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import math
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from neuropose.analyzer.pipeline import (
|
|
||||||
AnalysisConfig,
|
|
||||||
AnalysisReport,
|
|
||||||
DtwResults,
|
|
||||||
load_config,
|
|
||||||
run_analysis,
|
|
||||||
)
|
|
||||||
from neuropose.analyzer.segment import JOINT_INDEX
|
|
||||||
from neuropose.io import VideoPredictions, save_video_predictions
|
|
||||||
|
|
||||||
EXAMPLES_DIR = Path(__file__).resolve().parents[2] / "examples" / "analysis"
|
|
||||||
|
|
||||||
NUM_JOINTS = 43
|
|
||||||
|
|
||||||
|
|
||||||
def _sinusoid(num_cycles: int, frames_per_cycle: int, amplitude: float = 100.0) -> np.ndarray:
|
|
||||||
total = num_cycles * frames_per_cycle
|
|
||||||
t = np.linspace(0.0, num_cycles * 2.0 * math.pi, total, endpoint=False)
|
|
||||||
return (np.sin(t) * amplitude + amplitude).astype(float)
|
|
||||||
|
|
||||||
|
|
||||||
def _write_trial(
|
|
||||||
path: Path,
|
|
||||||
*,
|
|
||||||
num_cycles: int = 4,
|
|
||||||
frames_per_cycle: int = 30,
|
|
||||||
seed: int = 0,
|
|
||||||
) -> Path:
|
|
||||||
"""Write a synthetic VideoPredictions with every joint oscillating.
|
|
||||||
|
|
||||||
Joint ``0`` gets a reproducible RNG-driven trace so Procrustes has
|
|
||||||
something non-degenerate to align. All other joints get their own
|
|
||||||
phase-shifted sinusoid so joint-angle triplets and per-joint DTW
|
|
||||||
have signal to act on.
|
|
||||||
"""
|
|
||||||
rng = np.random.default_rng(seed)
|
|
||||||
base = _sinusoid(num_cycles, frames_per_cycle)
|
|
||||||
total = base.shape[0]
|
|
||||||
frames: dict[str, dict] = {}
|
|
||||||
for frame_idx in range(total):
|
|
||||||
poses = [[[0.0, 0.0, 0.0] for _ in range(NUM_JOINTS)]]
|
|
||||||
for j in range(NUM_JOINTS):
|
|
||||||
# Unique per-joint position so no triplet is degenerate.
|
|
||||||
phase = rng.uniform(0.0, 2.0 * math.pi)
|
|
||||||
amplitude = 30.0 + 10.0 * (j % 5)
|
|
||||||
offset = float(j) * 15.0
|
|
||||||
poses[0][j][0] = offset + amplitude * math.cos(
|
|
||||||
2.0 * math.pi * frame_idx / frames_per_cycle + phase
|
|
||||||
)
|
|
||||||
poses[0][j][1] = offset * 0.5 + base[frame_idx] + 5.0 * j
|
|
||||||
poses[0][j][2] = 3.0 * j
|
|
||||||
frames[f"frame_{frame_idx:06d}"] = {
|
|
||||||
"boxes": [[0.0, 0.0, 1.0, 1.0, 0.9]],
|
|
||||||
"poses3d": poses,
|
|
||||||
"poses2d": [[[0.0, 0.0]] * NUM_JOINTS],
|
|
||||||
}
|
|
||||||
preds = VideoPredictions.model_validate(
|
|
||||||
{
|
|
||||||
"metadata": {
|
|
||||||
"frame_count": total,
|
|
||||||
"fps": float(frames_per_cycle),
|
|
||||||
"width": 640,
|
|
||||||
"height": 480,
|
|
||||||
},
|
|
||||||
"frames": frames,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
save_video_predictions(path, preds)
|
|
||||||
return path
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def example_fixtures(tmp_path: Path) -> tuple[Path, Path, Path]:
|
|
||||||
"""Return (primary_path, reference_path, report_path) under ``tmp_path``."""
|
|
||||||
primary = _write_trial(tmp_path / "primary.json", seed=1)
|
|
||||||
reference = _write_trial(tmp_path / "reference.json", seed=2)
|
|
||||||
report = tmp_path / "report.json"
|
|
||||||
return primary, reference, report
|
|
||||||
|
|
||||||
|
|
||||||
def _rewrite_paths(
|
|
||||||
example_path: Path, primary: Path, reference: Path, report: Path
|
|
||||||
) -> AnalysisConfig:
|
|
||||||
"""Load an example YAML and rewrite inputs/output paths to fixtures.
|
|
||||||
|
|
||||||
Tests run against synthetic predictions in ``tmp_path``; the
|
|
||||||
example YAML's hardcoded ``data/*.json`` paths would never resolve
|
|
||||||
otherwise.
|
|
||||||
"""
|
|
||||||
config = load_config(example_path)
|
|
||||||
update: dict = {
|
|
||||||
"inputs": config.inputs.model_copy(update={"primary": primary, "reference": reference}),
|
|
||||||
"output": config.output.model_copy(update={"report": report}),
|
|
||||||
}
|
|
||||||
return config.model_copy(update=update)
|
|
||||||
|
|
||||||
|
|
||||||
class TestMinimalExample:
|
|
||||||
def test_minimal_parses(self) -> None:
|
|
||||||
config = load_config(EXAMPLES_DIR / "minimal.yaml")
|
|
||||||
assert isinstance(config, AnalysisConfig)
|
|
||||||
assert config.analysis.kind == "dtw"
|
|
||||||
assert config.segmentation is None
|
|
||||||
|
|
||||||
def test_minimal_runs(self, example_fixtures: tuple[Path, Path, Path]) -> None:
|
|
||||||
primary, reference, report = example_fixtures
|
|
||||||
config = _rewrite_paths(EXAMPLES_DIR / "minimal.yaml", primary, reference, report)
|
|
||||||
result = run_analysis(config)
|
|
||||||
assert isinstance(result, AnalysisReport)
|
|
||||||
assert isinstance(result.results, DtwResults)
|
|
||||||
# Unsegmented → one distance.
|
|
||||||
assert result.results.segment_labels == ["full_trial"]
|
|
||||||
assert len(result.results.distances) == 1
|
|
||||||
|
|
||||||
|
|
||||||
class TestPaperCExample:
|
|
||||||
def test_paper_c_parses(self) -> None:
|
|
||||||
config = load_config(EXAMPLES_DIR / "paper_c_headline.yaml")
|
|
||||||
assert config.analysis.kind == "dtw"
|
|
||||||
assert config.segmentation is not None
|
|
||||||
assert config.segmentation.kind == "gait_cycles_bilateral"
|
|
||||||
# Joint triplets must be in range for berkeley_mhad_43.
|
|
||||||
assert config.analysis.kind == "dtw"
|
|
||||||
angle_triplets = config.analysis.angle_triplets # type: ignore[union-attr]
|
|
||||||
assert angle_triplets is not None
|
|
||||||
for a, b, c in angle_triplets:
|
|
||||||
for idx in (a, b, c):
|
|
||||||
assert 0 <= idx < NUM_JOINTS
|
|
||||||
|
|
||||||
def test_paper_c_runs(self, example_fixtures: tuple[Path, Path, Path]) -> None:
|
|
||||||
primary, reference, report = example_fixtures
|
|
||||||
config = _rewrite_paths(EXAMPLES_DIR / "paper_c_headline.yaml", primary, reference, report)
|
|
||||||
result = run_analysis(config)
|
|
||||||
assert isinstance(result.results, DtwResults)
|
|
||||||
# Bilateral segmentation → distances labelled per side.
|
|
||||||
assert any(lbl.startswith("left_heel_strikes") for lbl in result.results.segment_labels)
|
|
||||||
assert any(lbl.startswith("right_heel_strikes") for lbl in result.results.segment_labels)
|
|
||||||
|
|
||||||
def test_paper_c_uses_documented_knee_triplets(self) -> None:
|
|
||||||
"""The Paper C config must target knee-flexion joint triplets.
|
|
||||||
|
|
||||||
Safety net: if someone edits the YAML and breaks the joint
|
|
||||||
references, this test catches it before the example silently
|
|
||||||
starts computing the wrong angles.
|
|
||||||
"""
|
|
||||||
config = load_config(EXAMPLES_DIR / "paper_c_headline.yaml")
|
|
||||||
assert config.analysis.kind == "dtw"
|
|
||||||
triplets = config.analysis.angle_triplets # type: ignore[union-attr]
|
|
||||||
assert triplets is not None
|
|
||||||
# Left knee flex = hip → knee → ankle.
|
|
||||||
assert (
|
|
||||||
JOINT_INDEX["lhipb"],
|
|
||||||
JOINT_INDEX["lkne"],
|
|
||||||
JOINT_INDEX["lank"],
|
|
||||||
) in triplets or (
|
|
||||||
JOINT_INDEX["lhipf"],
|
|
||||||
JOINT_INDEX["lkne"],
|
|
||||||
JOINT_INDEX["lank"],
|
|
||||||
) in triplets
|
|
||||||
|
|
||||||
|
|
||||||
class TestPerJointDebugExample:
|
|
||||||
def test_per_joint_debug_parses(self) -> None:
|
|
||||||
config = load_config(EXAMPLES_DIR / "per_joint_debug.yaml")
|
|
||||||
assert config.analysis.kind == "dtw"
|
|
||||||
assert config.analysis.method == "dtw_per_joint" # type: ignore[union-attr]
|
|
||||||
|
|
||||||
def test_per_joint_debug_runs(self, example_fixtures: tuple[Path, Path, Path]) -> None:
|
|
||||||
primary, reference, report = example_fixtures
|
|
||||||
config = _rewrite_paths(EXAMPLES_DIR / "per_joint_debug.yaml", primary, reference, report)
|
|
||||||
result = run_analysis(config)
|
|
||||||
assert isinstance(result.results, DtwResults)
|
|
||||||
# dtw_per_joint → per_joint_distances populated.
|
|
||||||
assert result.results.per_joint_distances is not None
|
|
||||||
# Inner length must match num_joints for the coords
|
|
||||||
# representation.
|
|
||||||
for per_seg in result.results.per_joint_distances:
|
|
||||||
assert len(per_seg) == NUM_JOINTS
|
|
||||||
|
|
@ -81,29 +81,26 @@ class TestMetrabsLoader:
|
||||||
"""Exercises the loader's download → verify → extract → load path."""
|
"""Exercises the loader's download → verify → extract → load path."""
|
||||||
|
|
||||||
def test_download_and_load(self, shared_model_cache_dir: Path) -> None:
|
def test_download_and_load(self, shared_model_cache_dir: Path) -> None:
|
||||||
loaded = load_metrabs_model(cache_dir=shared_model_cache_dir)
|
model = load_metrabs_model(cache_dir=shared_model_cache_dir)
|
||||||
assert loaded.model is not None
|
assert model is not None
|
||||||
assert loaded.sha256
|
|
||||||
assert loaded.filename
|
|
||||||
for attr in ("detect_poses", "per_skeleton_joint_names", "per_skeleton_joint_edges"):
|
for attr in ("detect_poses", "per_skeleton_joint_names", "per_skeleton_joint_edges"):
|
||||||
assert hasattr(loaded.model, attr), f"loaded model is missing {attr}"
|
assert hasattr(model, attr), f"loaded model is missing {attr}"
|
||||||
|
|
||||||
def test_second_call_uses_cache(self, shared_model_cache_dir: Path) -> None:
|
def test_second_call_uses_cache(self, shared_model_cache_dir: Path) -> None:
|
||||||
"""Idempotent: second call should return the cached model cheaply."""
|
"""Idempotent: second call should return the cached model cheaply."""
|
||||||
loaded_a = load_metrabs_model(cache_dir=shared_model_cache_dir)
|
model_a = load_metrabs_model(cache_dir=shared_model_cache_dir)
|
||||||
loaded_b = load_metrabs_model(cache_dir=shared_model_cache_dir)
|
model_b = load_metrabs_model(cache_dir=shared_model_cache_dir)
|
||||||
# tf.saved_model.load returns a new Python object each call, so
|
# tf.saved_model.load returns a new Python object each call, so
|
||||||
# identity comparison doesn't work — but both should still
|
# identity comparison doesn't work — but both should still
|
||||||
# expose the MeTRAbs interface, and the SHA should match.
|
# expose the MeTRAbs interface.
|
||||||
assert hasattr(loaded_a.model, "detect_poses")
|
assert hasattr(model_a, "detect_poses")
|
||||||
assert hasattr(loaded_b.model, "detect_poses")
|
assert hasattr(model_b, "detect_poses")
|
||||||
assert loaded_a.sha256 == loaded_b.sha256
|
|
||||||
|
|
||||||
def test_berkeley_mhad_skeleton_is_present(self, shared_model_cache_dir: Path) -> None:
|
def test_berkeley_mhad_skeleton_is_present(self, shared_model_cache_dir: Path) -> None:
|
||||||
"""The estimator pins skeleton='berkeley_mhad_43'; verify it exists."""
|
"""The estimator pins skeleton='berkeley_mhad_43'; verify it exists."""
|
||||||
loaded = load_metrabs_model(cache_dir=shared_model_cache_dir)
|
model = load_metrabs_model(cache_dir=shared_model_cache_dir)
|
||||||
joint_names = loaded.model.per_skeleton_joint_names["berkeley_mhad_43"]
|
joint_names = model.per_skeleton_joint_names["berkeley_mhad_43"]
|
||||||
joint_edges = loaded.model.per_skeleton_joint_edges["berkeley_mhad_43"]
|
joint_edges = model.per_skeleton_joint_edges["berkeley_mhad_43"]
|
||||||
# MeTRAbs exposes these as tf.Tensor objects; just verify we
|
# MeTRAbs exposes these as tf.Tensor objects; just verify we
|
||||||
# can pull a shape out.
|
# can pull a shape out.
|
||||||
assert joint_names.shape[0] == 43
|
assert joint_names.shape[0] == 43
|
||||||
|
|
|
||||||
|
|
@ -50,8 +50,8 @@ def test_joint_names_match_pinned_model(metrabs_model_cache_dir: Path) -> None:
|
||||||
commit that bumps the model pin in :mod:`neuropose._model`.
|
commit that bumps the model pin in :mod:`neuropose._model`.
|
||||||
2. Cross-check any CLI or docs that embed hardcoded joint names.
|
2. Cross-check any CLI or docs that embed hardcoded joint names.
|
||||||
"""
|
"""
|
||||||
loaded = load_metrabs_model(cache_dir=metrabs_model_cache_dir)
|
model = load_metrabs_model(cache_dir=metrabs_model_cache_dir)
|
||||||
tensor = loaded.model.per_skeleton_joint_names["berkeley_mhad_43"]
|
tensor = model.per_skeleton_joint_names["berkeley_mhad_43"]
|
||||||
model_names = tuple(tensor.numpy().astype(str).tolist())
|
model_names = tuple(tensor.numpy().astype(str).tolist())
|
||||||
assert model_names == JOINT_NAMES, (
|
assert model_names == JOINT_NAMES, (
|
||||||
"JOINT_NAMES drift detected — the hardcoded tuple in "
|
"JOINT_NAMES drift detected — the hardcoded tuple in "
|
||||||
|
|
|
||||||
|
|
@ -131,250 +131,3 @@ class TestDtwRelation:
|
||||||
b = np.zeros((3, 2, 3))
|
b = np.zeros((3, 2, 3))
|
||||||
with pytest.raises(ValueError, match="joint count"):
|
with pytest.raises(ValueError, match="joint count"):
|
||||||
dtw_relation(a, b, joint_i=0, joint_j=1)
|
dtw_relation(a, b, joint_i=0, joint_j=1)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# representation="angles"
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def _rotation_matrix_z(angle_rad: float) -> np.ndarray:
|
|
||||||
c, s = np.cos(angle_rad), np.sin(angle_rad)
|
|
||||||
return np.array(
|
|
||||||
[
|
|
||||||
[c, -s, 0.0],
|
|
||||||
[s, c, 0.0],
|
|
||||||
[0.0, 0.0, 1.0],
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _three_joint_arm(num_frames: int = 6) -> np.ndarray:
|
|
||||||
"""A three-joint arm opening from a right angle to straight.
|
|
||||||
|
|
||||||
Joints laid out as [shoulder, elbow, wrist], forming an angle at
|
|
||||||
the elbow that linearly opens from pi/2 to pi across ``num_frames``.
|
|
||||||
"""
|
|
||||||
sequence = np.zeros((num_frames, 3, 3))
|
|
||||||
angles = np.linspace(np.pi / 2, np.pi, num_frames)
|
|
||||||
for i, theta in enumerate(angles):
|
|
||||||
sequence[i, 0] = [-1.0, 0.0, 0.0] # shoulder
|
|
||||||
sequence[i, 1] = [0.0, 0.0, 0.0] # elbow
|
|
||||||
sequence[i, 2] = [np.cos(theta - np.pi), np.sin(theta - np.pi), 0.0] # wrist
|
|
||||||
return sequence
|
|
||||||
|
|
||||||
|
|
||||||
class TestDtwAllAngles:
|
|
||||||
def test_angles_identical_sequences_distance_zero(self) -> None:
|
|
||||||
seq = _three_joint_arm()
|
|
||||||
result = dtw_all(
|
|
||||||
seq,
|
|
||||||
seq,
|
|
||||||
representation="angles",
|
|
||||||
angle_triplets=[(0, 1, 2)],
|
|
||||||
)
|
|
||||||
assert result.distance == pytest.approx(0.0, abs=1e-9)
|
|
||||||
|
|
||||||
def test_angles_invariant_to_global_rotation(self) -> None:
|
|
||||||
"""Angle-space DTW must not change under a global rotation."""
|
|
||||||
seq = _three_joint_arm()
|
|
||||||
rotated = seq @ _rotation_matrix_z(np.deg2rad(40.0)).T
|
|
||||||
baseline = dtw_all(seq, seq, representation="angles", angle_triplets=[(0, 1, 2)])
|
|
||||||
under_rotation = dtw_all(
|
|
||||||
seq,
|
|
||||||
rotated,
|
|
||||||
representation="angles",
|
|
||||||
angle_triplets=[(0, 1, 2)],
|
|
||||||
)
|
|
||||||
assert baseline.distance == pytest.approx(under_rotation.distance, abs=1e-6)
|
|
||||||
|
|
||||||
def test_angles_translation_invariant(self) -> None:
|
|
||||||
seq = _three_joint_arm()
|
|
||||||
translated = seq + np.array([10.0, -5.0, 2.0])
|
|
||||||
result = dtw_all(
|
|
||||||
seq,
|
|
||||||
translated,
|
|
||||||
representation="angles",
|
|
||||||
angle_triplets=[(0, 1, 2)],
|
|
||||||
)
|
|
||||||
assert result.distance == pytest.approx(0.0, abs=1e-9)
|
|
||||||
|
|
||||||
def test_angles_detects_different_motion(self) -> None:
|
|
||||||
# A sequence whose angle is constant vs. one that opens.
|
|
||||||
constant = np.zeros((6, 3, 3))
|
|
||||||
constant[:, 0] = [-1.0, 0.0, 0.0]
|
|
||||||
constant[:, 1] = [0.0, 0.0, 0.0]
|
|
||||||
constant[:, 2] = [0.0, 1.0, 0.0] # right angle throughout
|
|
||||||
opening = _three_joint_arm()
|
|
||||||
result = dtw_all(
|
|
||||||
constant,
|
|
||||||
opening,
|
|
||||||
representation="angles",
|
|
||||||
angle_triplets=[(0, 1, 2)],
|
|
||||||
)
|
|
||||||
assert result.distance > 0.0
|
|
||||||
|
|
||||||
def test_angles_without_triplets_rejected(self) -> None:
|
|
||||||
seq = _three_joint_arm()
|
|
||||||
with pytest.raises(ValueError, match="angle_triplets"):
|
|
||||||
dtw_all(seq, seq, representation="angles")
|
|
||||||
|
|
||||||
|
|
||||||
class TestDtwPerJointAngles:
|
|
||||||
def test_returns_one_result_per_triplet(self) -> None:
|
|
||||||
seq = _three_joint_arm()
|
|
||||||
triplets = [(0, 1, 2), (0, 1, 2)] # duplicate triplet on purpose
|
|
||||||
results = dtw_per_joint(
|
|
||||||
seq,
|
|
||||||
seq,
|
|
||||||
representation="angles",
|
|
||||||
angle_triplets=triplets,
|
|
||||||
)
|
|
||||||
assert len(results) == 2
|
|
||||||
for result in results:
|
|
||||||
assert result.distance == pytest.approx(0.0, abs=1e-9)
|
|
||||||
|
|
||||||
def test_per_triplet_distinct_paths(self) -> None:
|
|
||||||
# Two triplets covering different angles; with different motion
|
|
||||||
# per triplet, the per-unit results should differ.
|
|
||||||
seq_a = np.zeros((5, 4, 3))
|
|
||||||
seq_b = np.zeros((5, 4, 3))
|
|
||||||
# joint 0: pivot, joint 1/2/3: arm endpoints
|
|
||||||
for i in range(5):
|
|
||||||
seq_a[i, 0] = [0.0, 0.0, 0.0]
|
|
||||||
seq_a[i, 1] = [1.0, 0.0, 0.0]
|
|
||||||
seq_a[i, 2] = [0.0, 1.0, 0.0]
|
|
||||||
seq_a[i, 3] = [0.0, 0.0, 1.0]
|
|
||||||
seq_b[i, 0] = [0.0, 0.0, 0.0]
|
|
||||||
seq_b[i, 1] = [1.0, 0.0, 0.0]
|
|
||||||
seq_b[i, 2] = [np.cos(i * 0.3), np.sin(i * 0.3), 0.0] # rotating
|
|
||||||
seq_b[i, 3] = [0.0, 0.0, 1.0]
|
|
||||||
results = dtw_per_joint(
|
|
||||||
seq_a,
|
|
||||||
seq_b,
|
|
||||||
representation="angles",
|
|
||||||
angle_triplets=[(1, 0, 2), (1, 0, 3)],
|
|
||||||
)
|
|
||||||
assert len(results) == 2
|
|
||||||
# First triplet tracks the rotation, second is stationary.
|
|
||||||
assert results[0].distance > 0.0
|
|
||||||
assert results[1].distance == pytest.approx(0.0, abs=1e-9)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# nan_policy
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def _collinear_sequence(num_frames: int = 4) -> np.ndarray:
|
|
||||||
"""Three collinear joints — the angle at the middle joint is degenerate."""
|
|
||||||
seq = np.zeros((num_frames, 3, 3))
|
|
||||||
seq[:, 0] = [-1.0, 0.0, 0.0]
|
|
||||||
# Middle joint at (0,0,0); but because the outer joints are collinear
|
|
||||||
# through the origin, we need one joint overlapping with the middle
|
|
||||||
# to force a zero-length vector. Place joint 2 AT joint 1 to trigger
|
|
||||||
# the degenerate case in extract_joint_angles.
|
|
||||||
seq[:, 1] = [0.0, 0.0, 0.0]
|
|
||||||
seq[:, 2] = [0.0, 0.0, 0.0]
|
|
||||||
return seq
|
|
||||||
|
|
||||||
|
|
||||||
class TestNanPolicy:
|
|
||||||
def test_propagate_surfaces_error(self) -> None:
|
|
||||||
# Degenerate triplet produces NaN angles for every frame.
|
|
||||||
# With nan_policy="propagate" the NaN reaches fastdtw, which
|
|
||||||
# validates via numpy.asarray_chkfinite and raises ValueError —
|
|
||||||
# the intended behaviour ("make the problem visible").
|
|
||||||
seq = _collinear_sequence(num_frames=4)
|
|
||||||
other = _three_joint_arm(num_frames=4)
|
|
||||||
with pytest.raises(ValueError, match="infs or NaNs"):
|
|
||||||
dtw_all(
|
|
||||||
seq,
|
|
||||||
other,
|
|
||||||
representation="angles",
|
|
||||||
angle_triplets=[(0, 1, 2)],
|
|
||||||
nan_policy="propagate",
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_interpolate_fills_isolated_nan(self) -> None:
|
|
||||||
# One bad frame in a 5-frame sequence — the other four are
|
|
||||||
# finite anchors to interpolate between.
|
|
||||||
good = _three_joint_arm(num_frames=5)
|
|
||||||
# Inject a degenerate middle frame.
|
|
||||||
good[2, 2] = good[2, 1] # force zero-length vector → NaN angle
|
|
||||||
# Reference is the same arm without injection.
|
|
||||||
reference = _three_joint_arm(num_frames=5)
|
|
||||||
result = dtw_all(
|
|
||||||
good,
|
|
||||||
reference,
|
|
||||||
representation="angles",
|
|
||||||
angle_triplets=[(0, 1, 2)],
|
|
||||||
nan_policy="interpolate",
|
|
||||||
)
|
|
||||||
assert not np.isnan(result.distance)
|
|
||||||
|
|
||||||
def test_interpolate_all_nan_column_rejected(self) -> None:
|
|
||||||
seq = _collinear_sequence(num_frames=5)
|
|
||||||
other = _three_joint_arm(num_frames=5)
|
|
||||||
with pytest.raises(ValueError, match="all values are NaN"):
|
|
||||||
dtw_all(
|
|
||||||
seq,
|
|
||||||
other,
|
|
||||||
representation="angles",
|
|
||||||
angle_triplets=[(0, 1, 2)],
|
|
||||||
nan_policy="interpolate",
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_drop_removes_nan_frames(self) -> None:
|
|
||||||
good = _three_joint_arm(num_frames=6)
|
|
||||||
good[2, 2] = good[2, 1] # inject NaN at frame 2
|
|
||||||
good[4, 2] = good[4, 1] # inject NaN at frame 4
|
|
||||||
reference = _three_joint_arm(num_frames=6)
|
|
||||||
result = dtw_all(
|
|
||||||
good,
|
|
||||||
reference,
|
|
||||||
representation="angles",
|
|
||||||
angle_triplets=[(0, 1, 2)],
|
|
||||||
nan_policy="drop",
|
|
||||||
)
|
|
||||||
# The 4 remaining finite frames should align cleanly with
|
|
||||||
# their counterparts in the reference.
|
|
||||||
assert not np.isnan(result.distance)
|
|
||||||
|
|
||||||
def test_drop_empties_sequence_rejected(self) -> None:
|
|
||||||
seq = _collinear_sequence(num_frames=5)
|
|
||||||
other = _three_joint_arm(num_frames=5)
|
|
||||||
with pytest.raises(ValueError, match="every frame"):
|
|
||||||
dtw_all(
|
|
||||||
seq,
|
|
||||||
other,
|
|
||||||
representation="angles",
|
|
||||||
angle_triplets=[(0, 1, 2)],
|
|
||||||
nan_policy="drop",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# align + representation composition
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestAlignWithAngles:
|
|
||||||
def test_procrustes_before_angles_is_no_op_on_invariant_representation(self) -> None:
|
|
||||||
"""Procrustes on angle-space DTW should be redundant but safe."""
|
|
||||||
seq = _three_joint_arm()
|
|
||||||
rotated = seq @ _rotation_matrix_z(np.deg2rad(20.0)).T
|
|
||||||
with_align = dtw_all(
|
|
||||||
seq,
|
|
||||||
rotated,
|
|
||||||
align="procrustes_per_sequence",
|
|
||||||
representation="angles",
|
|
||||||
angle_triplets=[(0, 1, 2)],
|
|
||||||
)
|
|
||||||
without_align = dtw_all(
|
|
||||||
seq,
|
|
||||||
rotated,
|
|
||||||
representation="angles",
|
|
||||||
angle_triplets=[(0, 1, 2)],
|
|
||||||
)
|
|
||||||
assert with_align.distance == pytest.approx(without_align.distance, abs=1e-6)
|
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,6 @@ import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from neuropose.analyzer.features import (
|
from neuropose.analyzer.features import (
|
||||||
AlignmentDiagnostics,
|
|
||||||
FeatureStatistics,
|
FeatureStatistics,
|
||||||
extract_feature_statistics,
|
extract_feature_statistics,
|
||||||
extract_joint_angles,
|
extract_joint_angles,
|
||||||
|
|
@ -16,7 +15,6 @@ from neuropose.analyzer.features import (
|
||||||
normalize_pose_sequence,
|
normalize_pose_sequence,
|
||||||
pad_sequences,
|
pad_sequences,
|
||||||
predictions_to_numpy,
|
predictions_to_numpy,
|
||||||
procrustes_align,
|
|
||||||
)
|
)
|
||||||
from neuropose.io import VideoPredictions
|
from neuropose.io import VideoPredictions
|
||||||
|
|
||||||
|
|
@ -299,177 +297,3 @@ class TestFindPeaks:
|
||||||
def test_rejects_2d_input(self) -> None:
|
def test_rejects_2d_input(self) -> None:
|
||||||
with pytest.raises(ValueError, match="1D"):
|
with pytest.raises(ValueError, match="1D"):
|
||||||
find_peaks(np.zeros((5, 5)))
|
find_peaks(np.zeros((5, 5)))
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# procrustes_align
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def _rotation_matrix_z(angle_rad: float) -> np.ndarray:
|
|
||||||
"""Rotation matrix about the Z axis."""
|
|
||||||
c, s = np.cos(angle_rad), np.sin(angle_rad)
|
|
||||||
return np.array(
|
|
||||||
[
|
|
||||||
[c, -s, 0.0],
|
|
||||||
[s, c, 0.0],
|
|
||||||
[0.0, 0.0, 1.0],
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _skeleton(num_joints: int = 8, seed: int = 0) -> np.ndarray:
|
|
||||||
"""A deterministic, non-degenerate single-frame skeleton."""
|
|
||||||
rng = np.random.default_rng(seed)
|
|
||||||
return rng.standard_normal((num_joints, 3))
|
|
||||||
|
|
||||||
|
|
||||||
class TestProcrustesAlignPerSequence:
|
|
||||||
def test_identical_sequences_yield_identity_transform(self) -> None:
|
|
||||||
sequence = _skeleton()[np.newaxis, :, :].repeat(3, axis=0) # (3, 8, 3)
|
|
||||||
aligned, target, diag = procrustes_align(sequence, sequence, mode="per_sequence")
|
|
||||||
np.testing.assert_allclose(aligned, sequence, atol=1e-10)
|
|
||||||
np.testing.assert_array_equal(target, sequence)
|
|
||||||
assert diag.mode == "per_sequence"
|
|
||||||
assert diag.rotation_deg == pytest.approx(0.0, abs=1e-6)
|
|
||||||
assert diag.translation == pytest.approx(0.0, abs=1e-9)
|
|
||||||
assert diag.scale == pytest.approx(1.0)
|
|
||||||
|
|
||||||
def test_recovers_known_rotation(self) -> None:
|
|
||||||
# Build a reference sequence; construct the source by rotating it
|
|
||||||
# about Z, then verify alignment returns the reference up to
|
|
||||||
# floating-point error.
|
|
||||||
rotation = _rotation_matrix_z(np.deg2rad(37.0))
|
|
||||||
reference = _skeleton(num_joints=10)[np.newaxis, :, :].repeat(4, axis=0)
|
|
||||||
source = reference @ rotation.T
|
|
||||||
aligned, _, diag = procrustes_align(source, reference, mode="per_sequence")
|
|
||||||
np.testing.assert_allclose(aligned, reference, atol=1e-8)
|
|
||||||
# The recovered rotation's magnitude should be the original 37°.
|
|
||||||
assert diag.rotation_deg == pytest.approx(37.0, abs=1e-4)
|
|
||||||
|
|
||||||
def test_recovers_known_translation(self) -> None:
|
|
||||||
reference = _skeleton()[np.newaxis, :, :].repeat(5, axis=0)
|
|
||||||
translation = np.array([10.0, -4.5, 2.25])
|
|
||||||
source = reference + translation
|
|
||||||
aligned, _, diag = procrustes_align(source, reference, mode="per_sequence")
|
|
||||||
np.testing.assert_allclose(aligned, reference, atol=1e-9)
|
|
||||||
# rotation_deg may be numerically tiny but not exactly 0.
|
|
||||||
assert diag.rotation_deg == pytest.approx(0.0, abs=1e-4)
|
|
||||||
assert diag.translation == pytest.approx(np.linalg.norm(translation), rel=1e-6)
|
|
||||||
|
|
||||||
def test_recovers_combined_rotation_and_translation(self) -> None:
|
|
||||||
rotation = _rotation_matrix_z(np.deg2rad(-12.0))
|
|
||||||
translation = np.array([1.0, 2.0, 3.0])
|
|
||||||
reference = _skeleton(num_joints=6)[np.newaxis, :, :].repeat(3, axis=0)
|
|
||||||
source = reference @ rotation.T + translation
|
|
||||||
aligned, _, diag = procrustes_align(source, reference, mode="per_sequence")
|
|
||||||
np.testing.assert_allclose(aligned, reference, atol=1e-8)
|
|
||||||
assert diag.rotation_deg == pytest.approx(12.0, abs=1e-4)
|
|
||||||
assert diag.translation == pytest.approx(np.linalg.norm(translation), rel=1e-4)
|
|
||||||
|
|
||||||
def test_scale_flag_recovers_known_scale(self) -> None:
|
|
||||||
reference = _skeleton()[np.newaxis, :, :].repeat(2, axis=0)
|
|
||||||
source = reference * 0.5
|
|
||||||
aligned, _, diag = procrustes_align(source, reference, mode="per_sequence", scale=True)
|
|
||||||
np.testing.assert_allclose(aligned, reference, atol=1e-8)
|
|
||||||
assert diag.scale == pytest.approx(2.0, rel=1e-6)
|
|
||||||
|
|
||||||
def test_scale_flag_off_leaves_scale_at_one(self) -> None:
|
|
||||||
reference = _skeleton()[np.newaxis, :, :].repeat(2, axis=0)
|
|
||||||
source = reference * 0.5
|
|
||||||
_, _, diag = procrustes_align(source, reference, mode="per_sequence", scale=False)
|
|
||||||
assert diag.scale == pytest.approx(1.0)
|
|
||||||
|
|
||||||
def test_rejects_mismatched_shapes(self) -> None:
|
|
||||||
a = np.zeros((4, 8, 3))
|
|
||||||
b = np.zeros((4, 7, 3))
|
|
||||||
with pytest.raises(ValueError, match="same shape"):
|
|
||||||
procrustes_align(a, b)
|
|
||||||
|
|
||||||
def test_rejects_wrong_trailing_axis(self) -> None:
|
|
||||||
a = np.zeros((4, 8, 2))
|
|
||||||
b = np.zeros((4, 8, 2))
|
|
||||||
with pytest.raises(ValueError, match="joints, 3"):
|
|
||||||
procrustes_align(a, b)
|
|
||||||
|
|
||||||
def test_rejects_unknown_mode(self) -> None:
|
|
||||||
a = np.zeros((2, 4, 3))
|
|
||||||
with pytest.raises(ValueError, match="unknown mode"):
|
|
||||||
procrustes_align(a, a, mode="nope") # type: ignore[arg-type]
|
|
||||||
|
|
||||||
def test_does_not_mutate_inputs(self) -> None:
|
|
||||||
source = _skeleton()[np.newaxis, :, :].repeat(3, axis=0).copy()
|
|
||||||
target = (source @ _rotation_matrix_z(np.deg2rad(10.0)).T).copy()
|
|
||||||
source_before = source.copy()
|
|
||||||
target_before = target.copy()
|
|
||||||
procrustes_align(source, target, mode="per_sequence")
|
|
||||||
np.testing.assert_array_equal(source, source_before)
|
|
||||||
np.testing.assert_array_equal(target, target_before)
|
|
||||||
|
|
||||||
def test_returns_alignment_diagnostics_dataclass(self) -> None:
|
|
||||||
a = _skeleton()[np.newaxis, :, :].repeat(2, axis=0)
|
|
||||||
_, _, diag = procrustes_align(a, a)
|
|
||||||
assert isinstance(diag, AlignmentDiagnostics)
|
|
||||||
|
|
||||||
|
|
||||||
class TestProcrustesAlignPerFrame:
|
|
||||||
def test_per_frame_recovers_varying_rotations(self) -> None:
|
|
||||||
# Each frame is rotated by a different angle; per_frame alignment
|
|
||||||
# should recover each frame independently.
|
|
||||||
num_frames = 4
|
|
||||||
reference_frame = _skeleton(num_joints=6)
|
|
||||||
angles = np.deg2rad([5.0, -10.0, 20.0, 45.0])
|
|
||||||
reference = np.stack([reference_frame for _ in range(num_frames)], axis=0)
|
|
||||||
source = np.stack([reference_frame @ _rotation_matrix_z(a).T for a in angles], axis=0)
|
|
||||||
aligned, _, diag = procrustes_align(source, reference, mode="per_frame")
|
|
||||||
np.testing.assert_allclose(aligned, reference, atol=1e-8)
|
|
||||||
assert diag.mode == "per_frame"
|
|
||||||
# The max rotation across frames should be 45°.
|
|
||||||
assert diag.rotation_deg_max == pytest.approx(45.0, abs=1e-4)
|
|
||||||
# The mean rotation across frames should be 20°.
|
|
||||||
assert diag.rotation_deg == pytest.approx(20.0, abs=1e-4)
|
|
||||||
|
|
||||||
def test_per_frame_with_identical_sequences_yields_zero(self) -> None:
|
|
||||||
sequence = _skeleton(num_joints=5)[np.newaxis, :, :].repeat(3, axis=0)
|
|
||||||
aligned, _, diag = procrustes_align(sequence, sequence, mode="per_frame")
|
|
||||||
np.testing.assert_allclose(aligned, sequence, atol=1e-10)
|
|
||||||
# Per-frame SVD on a symmetric covariance is numerically ambiguous
|
|
||||||
# in axis selection, so the fitted rotation can be a few micro-
|
|
||||||
# degrees off zero; the residual positions are still exact.
|
|
||||||
assert diag.rotation_deg == pytest.approx(0.0, abs=1e-3)
|
|
||||||
assert diag.rotation_deg_max == pytest.approx(0.0, abs=1e-3)
|
|
||||||
assert diag.translation == pytest.approx(0.0, abs=1e-9)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# DTW with align= (integration)
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestDtwAlignIntegration:
|
|
||||||
"""Smoke tests: align= routes through procrustes_align correctly.
|
|
||||||
|
|
||||||
Depth tests of the DTW path itself live in test_analyzer_dtw.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def test_dtw_all_with_alignment_cancels_rigid_offset(self) -> None:
|
|
||||||
pytest.importorskip("fastdtw")
|
|
||||||
from neuropose.analyzer.dtw import dtw_all
|
|
||||||
|
|
||||||
rotation = _rotation_matrix_z(np.deg2rad(30.0))
|
|
||||||
translation = np.array([5.0, -2.0, 1.0])
|
|
||||||
reference = _skeleton(num_joints=6)[np.newaxis, :, :].repeat(4, axis=0)
|
|
||||||
source = reference @ rotation.T + translation
|
|
||||||
baseline = dtw_all(source, reference, align="none")
|
|
||||||
aligned_result = dtw_all(source, reference, align="procrustes_per_sequence")
|
|
||||||
assert baseline.distance > 0.0
|
|
||||||
assert aligned_result.distance == pytest.approx(0.0, abs=1e-6)
|
|
||||||
|
|
||||||
def test_dtw_align_rejects_mismatched_frame_counts(self) -> None:
|
|
||||||
pytest.importorskip("fastdtw")
|
|
||||||
from neuropose.analyzer.dtw import dtw_all
|
|
||||||
|
|
||||||
a = np.zeros((5, 3, 3))
|
|
||||||
b = np.zeros((6, 3, 3))
|
|
||||||
with pytest.raises(ValueError, match="matching frame counts"):
|
|
||||||
dtw_all(a, b, align="procrustes_per_sequence")
|
|
||||||
|
|
|
||||||
|
|
@ -1,881 +0,0 @@
|
||||||
"""Tests for :mod:`neuropose.analyzer.pipeline`.
|
|
||||||
|
|
||||||
Covers both halves of the pipeline:
|
|
||||||
|
|
||||||
- **Schemas** — :class:`AnalysisConfig` parsing (discriminated unions
|
|
||||||
for segmentation and analysis stages, cross-field invariants),
|
|
||||||
:class:`AnalysisReport` construction + JSON round-trip (including
|
|
||||||
the migration hook on ``schema_version``), and
|
|
||||||
:func:`analysis_config_to_dict` JSON-safety.
|
|
||||||
- **Executor** — :func:`run_analysis` dispatches to each analysis kind
|
|
||||||
(dtw / stats / none) with and without segmentation; provenance is
|
|
||||||
inherited from the primary input with ``analysis_config``
|
|
||||||
populated; :func:`load_config`, :func:`save_report`, and
|
|
||||||
:func:`load_report` round-trip.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
|
||||||
import yaml
|
|
||||||
from pydantic import ValidationError
|
|
||||||
|
|
||||||
from neuropose.analyzer.pipeline import (
|
|
||||||
AnalysisConfig,
|
|
||||||
AnalysisReport,
|
|
||||||
DtwAnalysis,
|
|
||||||
DtwResults,
|
|
||||||
ExtractorSegmentation,
|
|
||||||
FeatureSummary,
|
|
||||||
GaitCyclesBilateralSegmentation,
|
|
||||||
GaitCyclesSegmentation,
|
|
||||||
InputsConfig,
|
|
||||||
InputSummary,
|
|
||||||
NoAnalysis,
|
|
||||||
NoResults,
|
|
||||||
OutputConfig,
|
|
||||||
PreprocessingConfig,
|
|
||||||
StatsAnalysis,
|
|
||||||
StatsResults,
|
|
||||||
analysis_config_to_dict,
|
|
||||||
load_config,
|
|
||||||
load_report,
|
|
||||||
run_analysis,
|
|
||||||
save_report,
|
|
||||||
)
|
|
||||||
from neuropose.io import Provenance, VideoPredictions, save_video_predictions
|
|
||||||
from neuropose.migrations import CURRENT_VERSION
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Helpers
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def _minimal_dtw_config(tmp_path: Path) -> dict[str, Any]:
|
|
||||||
"""A minimal AnalysisConfig dict with dtw_all + reference."""
|
|
||||||
return {
|
|
||||||
"inputs": {
|
|
||||||
"primary": str(tmp_path / "primary.json"),
|
|
||||||
"reference": str(tmp_path / "reference.json"),
|
|
||||||
},
|
|
||||||
"analysis": {"kind": "dtw", "method": "dtw_all"},
|
|
||||||
"output": {"report": str(tmp_path / "report.json")},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _minimal_stats_config(tmp_path: Path) -> dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"inputs": {"primary": str(tmp_path / "primary.json")},
|
|
||||||
"analysis": {
|
|
||||||
"kind": "stats",
|
|
||||||
"extractor": {"kind": "joint_axis", "joint": 32, "axis": 1, "invert": False},
|
|
||||||
},
|
|
||||||
"output": {"report": str(tmp_path / "report.json")},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# InputsConfig / PreprocessingConfig / OutputConfig
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestInputsConfig:
|
|
||||||
def test_primary_only(self, tmp_path: Path) -> None:
|
|
||||||
cfg = InputsConfig(primary=tmp_path / "a.json")
|
|
||||||
assert cfg.reference is None
|
|
||||||
|
|
||||||
def test_primary_and_reference(self, tmp_path: Path) -> None:
|
|
||||||
cfg = InputsConfig(
|
|
||||||
primary=tmp_path / "a.json",
|
|
||||||
reference=tmp_path / "b.json",
|
|
||||||
)
|
|
||||||
assert cfg.reference == tmp_path / "b.json"
|
|
||||||
|
|
||||||
def test_extra_field_rejected(self, tmp_path: Path) -> None:
|
|
||||||
with pytest.raises(ValidationError, match="Extra inputs"):
|
|
||||||
InputsConfig.model_validate(
|
|
||||||
{
|
|
||||||
"primary": str(tmp_path / "a.json"),
|
|
||||||
"extra": "field",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_frozen(self, tmp_path: Path) -> None:
|
|
||||||
cfg = InputsConfig(primary=tmp_path / "a.json")
|
|
||||||
with pytest.raises(ValidationError, match="frozen"):
|
|
||||||
cfg.primary = tmp_path / "b.json" # type: ignore[misc]
|
|
||||||
|
|
||||||
|
|
||||||
class TestPreprocessingConfig:
|
|
||||||
def test_default_person_index_zero(self) -> None:
|
|
||||||
cfg = PreprocessingConfig()
|
|
||||||
assert cfg.person_index == 0
|
|
||||||
|
|
||||||
def test_negative_person_index_rejected(self) -> None:
|
|
||||||
with pytest.raises(ValidationError):
|
|
||||||
PreprocessingConfig(person_index=-1)
|
|
||||||
|
|
||||||
|
|
||||||
class TestOutputConfig:
|
|
||||||
def test_report_path(self, tmp_path: Path) -> None:
|
|
||||||
cfg = OutputConfig(report=tmp_path / "out.json")
|
|
||||||
assert cfg.report == tmp_path / "out.json"
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Segmentation stage discriminated union
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestSegmentationStage:
|
|
||||||
def test_gait_cycles_parses_from_dict(self, tmp_path: Path) -> None:
|
|
||||||
config_dict = _minimal_dtw_config(tmp_path)
|
|
||||||
config_dict["segmentation"] = {
|
|
||||||
"kind": "gait_cycles",
|
|
||||||
"joint": "lhee",
|
|
||||||
"axis": "y",
|
|
||||||
"min_cycle_seconds": 0.5,
|
|
||||||
}
|
|
||||||
cfg = AnalysisConfig.model_validate(config_dict)
|
|
||||||
assert isinstance(cfg.segmentation, GaitCyclesSegmentation)
|
|
||||||
assert cfg.segmentation.joint == "lhee"
|
|
||||||
assert cfg.segmentation.min_cycle_seconds == 0.5
|
|
||||||
|
|
||||||
def test_bilateral_parses_from_dict(self, tmp_path: Path) -> None:
|
|
||||||
config_dict = _minimal_dtw_config(tmp_path)
|
|
||||||
config_dict["segmentation"] = {"kind": "gait_cycles_bilateral"}
|
|
||||||
cfg = AnalysisConfig.model_validate(config_dict)
|
|
||||||
assert isinstance(cfg.segmentation, GaitCyclesBilateralSegmentation)
|
|
||||||
|
|
||||||
def test_extractor_parses_from_dict(self, tmp_path: Path) -> None:
|
|
||||||
config_dict = _minimal_dtw_config(tmp_path)
|
|
||||||
config_dict["segmentation"] = {
|
|
||||||
"kind": "extractor",
|
|
||||||
"extractor": {
|
|
||||||
"kind": "joint_axis",
|
|
||||||
"joint": 15,
|
|
||||||
"axis": 1,
|
|
||||||
"invert": False,
|
|
||||||
},
|
|
||||||
"label": "wrist_cycles",
|
|
||||||
"min_distance_seconds": 0.5,
|
|
||||||
}
|
|
||||||
cfg = AnalysisConfig.model_validate(config_dict)
|
|
||||||
assert isinstance(cfg.segmentation, ExtractorSegmentation)
|
|
||||||
assert cfg.segmentation.label == "wrist_cycles"
|
|
||||||
|
|
||||||
def test_unknown_kind_rejected(self, tmp_path: Path) -> None:
|
|
||||||
config_dict = _minimal_dtw_config(tmp_path)
|
|
||||||
config_dict["segmentation"] = {"kind": "unknown_method"}
|
|
||||||
with pytest.raises(ValidationError):
|
|
||||||
AnalysisConfig.model_validate(config_dict)
|
|
||||||
|
|
||||||
def test_segmentation_omitted_is_none(self, tmp_path: Path) -> None:
|
|
||||||
cfg = AnalysisConfig.model_validate(_minimal_dtw_config(tmp_path))
|
|
||||||
assert cfg.segmentation is None
|
|
||||||
|
|
||||||
def test_invalid_min_cycle_seconds_rejected(self, tmp_path: Path) -> None:
|
|
||||||
config_dict = _minimal_dtw_config(tmp_path)
|
|
||||||
config_dict["segmentation"] = {
|
|
||||||
"kind": "gait_cycles",
|
|
||||||
"min_cycle_seconds": 0.0, # must be > 0
|
|
||||||
}
|
|
||||||
with pytest.raises(ValidationError):
|
|
||||||
AnalysisConfig.model_validate(config_dict)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Analysis stage discriminated union + cross-field invariants
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestDtwAnalysisValidation:
|
|
||||||
def test_dtw_relation_requires_joints(self) -> None:
|
|
||||||
with pytest.raises(ValidationError, match="joint_i and joint_j"):
|
|
||||||
DtwAnalysis(kind="dtw", method="dtw_relation")
|
|
||||||
|
|
||||||
def test_dtw_relation_rejects_angles_representation(self) -> None:
|
|
||||||
with pytest.raises(ValidationError, match="only supports representation='coords'"):
|
|
||||||
DtwAnalysis(
|
|
||||||
kind="dtw",
|
|
||||||
method="dtw_relation",
|
|
||||||
joint_i=0,
|
|
||||||
joint_j=1,
|
|
||||||
representation="angles",
|
|
||||||
angle_triplets=[(0, 1, 2)],
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_angles_requires_triplets(self) -> None:
|
|
||||||
with pytest.raises(ValidationError, match="angle_triplets"):
|
|
||||||
DtwAnalysis(
|
|
||||||
kind="dtw",
|
|
||||||
method="dtw_all",
|
|
||||||
representation="angles",
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_angles_with_empty_triplets_rejected(self) -> None:
|
|
||||||
with pytest.raises(ValidationError, match="angle_triplets"):
|
|
||||||
DtwAnalysis(
|
|
||||||
kind="dtw",
|
|
||||||
method="dtw_all",
|
|
||||||
representation="angles",
|
|
||||||
angle_triplets=[],
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_happy_path_dtw_all_coords(self) -> None:
|
|
||||||
analysis = DtwAnalysis(
|
|
||||||
kind="dtw",
|
|
||||||
method="dtw_all",
|
|
||||||
align="procrustes_per_sequence",
|
|
||||||
nan_policy="interpolate",
|
|
||||||
)
|
|
||||||
assert analysis.align == "procrustes_per_sequence"
|
|
||||||
|
|
||||||
def test_happy_path_dtw_all_angles(self) -> None:
|
|
||||||
analysis = DtwAnalysis(
|
|
||||||
kind="dtw",
|
|
||||||
method="dtw_all",
|
|
||||||
representation="angles",
|
|
||||||
angle_triplets=[(0, 1, 2), (3, 4, 5)],
|
|
||||||
)
|
|
||||||
assert len(analysis.angle_triplets or []) == 2
|
|
||||||
|
|
||||||
def test_happy_path_dtw_relation(self) -> None:
|
|
||||||
analysis = DtwAnalysis(
|
|
||||||
kind="dtw",
|
|
||||||
method="dtw_relation",
|
|
||||||
joint_i=15,
|
|
||||||
joint_j=23,
|
|
||||||
)
|
|
||||||
assert analysis.joint_i == 15
|
|
||||||
|
|
||||||
|
|
||||||
class TestAnalysisCrossStage:
|
|
||||||
def test_dtw_without_reference_rejected(self, tmp_path: Path) -> None:
|
|
||||||
config_dict = {
|
|
||||||
"inputs": {"primary": str(tmp_path / "a.json")},
|
|
||||||
"analysis": {"kind": "dtw", "method": "dtw_all"},
|
|
||||||
"output": {"report": str(tmp_path / "out.json")},
|
|
||||||
}
|
|
||||||
with pytest.raises(ValidationError, match=r"inputs\.reference"):
|
|
||||||
AnalysisConfig.model_validate(config_dict)
|
|
||||||
|
|
||||||
def test_stats_with_reference_rejected(self, tmp_path: Path) -> None:
|
|
||||||
config_dict = _minimal_stats_config(tmp_path)
|
|
||||||
config_dict["inputs"]["reference"] = str(tmp_path / "b.json")
|
|
||||||
with pytest.raises(ValidationError, match="primary only"):
|
|
||||||
AnalysisConfig.model_validate(config_dict)
|
|
||||||
|
|
||||||
def test_none_analysis_requires_no_reference(self, tmp_path: Path) -> None:
|
|
||||||
# NoAnalysis is fine with either reference present or absent.
|
|
||||||
config_dict = {
|
|
||||||
"inputs": {"primary": str(tmp_path / "a.json")},
|
|
||||||
"analysis": {"kind": "none"},
|
|
||||||
"output": {"report": str(tmp_path / "out.json")},
|
|
||||||
}
|
|
||||||
cfg = AnalysisConfig.model_validate(config_dict)
|
|
||||||
assert isinstance(cfg.analysis, NoAnalysis)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Top-level AnalysisConfig
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestAnalysisConfig:
|
|
||||||
def test_minimal_dtw_config_parses(self, tmp_path: Path) -> None:
|
|
||||||
cfg = AnalysisConfig.model_validate(_minimal_dtw_config(tmp_path))
|
|
||||||
assert cfg.config_version == 1
|
|
||||||
assert isinstance(cfg.analysis, DtwAnalysis)
|
|
||||||
assert cfg.preprocessing.person_index == 0 # default
|
|
||||||
|
|
||||||
def test_minimal_stats_config_parses(self, tmp_path: Path) -> None:
|
|
||||||
cfg = AnalysisConfig.model_validate(_minimal_stats_config(tmp_path))
|
|
||||||
assert isinstance(cfg.analysis, StatsAnalysis)
|
|
||||||
|
|
||||||
def test_config_version_must_be_1(self, tmp_path: Path) -> None:
|
|
||||||
config_dict = _minimal_dtw_config(tmp_path)
|
|
||||||
config_dict["config_version"] = 99
|
|
||||||
with pytest.raises(ValidationError):
|
|
||||||
AnalysisConfig.model_validate(config_dict)
|
|
||||||
|
|
||||||
def test_round_trip_json(self, tmp_path: Path) -> None:
|
|
||||||
original = AnalysisConfig.model_validate(_minimal_dtw_config(tmp_path))
|
|
||||||
serialised = original.model_dump_json()
|
|
||||||
restored = AnalysisConfig.model_validate_json(serialised)
|
|
||||||
assert restored == original
|
|
||||||
|
|
||||||
def test_extra_top_level_field_rejected(self, tmp_path: Path) -> None:
|
|
||||||
config_dict = _minimal_dtw_config(tmp_path)
|
|
||||||
config_dict["unknown_key"] = "typo"
|
|
||||||
with pytest.raises(ValidationError):
|
|
||||||
AnalysisConfig.model_validate(config_dict)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# analysis_config_to_dict
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestAnalysisConfigToDict:
|
|
||||||
def test_returns_json_safe_dict(self, tmp_path: Path) -> None:
|
|
||||||
cfg = AnalysisConfig.model_validate(_minimal_dtw_config(tmp_path))
|
|
||||||
dumped = analysis_config_to_dict(cfg)
|
|
||||||
# Paths must have become strings.
|
|
||||||
assert isinstance(dumped["inputs"]["primary"], str)
|
|
||||||
assert isinstance(dumped["output"]["report"], str)
|
|
||||||
|
|
||||||
def test_round_trips_through_dict(self, tmp_path: Path) -> None:
|
|
||||||
original = AnalysisConfig.model_validate(_minimal_dtw_config(tmp_path))
|
|
||||||
dumped = analysis_config_to_dict(original)
|
|
||||||
restored = AnalysisConfig.model_validate(dumped)
|
|
||||||
assert restored == original
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Result sub-schemas
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestDtwResults:
|
|
||||||
def test_minimal_construction(self) -> None:
|
|
||||||
res = DtwResults(
|
|
||||||
kind="dtw",
|
|
||||||
method="dtw_all",
|
|
||||||
distances=[0.5],
|
|
||||||
paths=[[(0, 0), (1, 1)]],
|
|
||||||
segment_labels=["full_trial"],
|
|
||||||
summary={"mean": 0.5},
|
|
||||||
)
|
|
||||||
assert res.kind == "dtw"
|
|
||||||
|
|
||||||
def test_per_joint_distances_shape_is_free(self) -> None:
|
|
||||||
# No validator enforces that per_joint_distances outer length
|
|
||||||
# matches distances — that's run-time semantics of the
|
|
||||||
# executor. Still, verify the field round-trips.
|
|
||||||
res = DtwResults(
|
|
||||||
kind="dtw",
|
|
||||||
method="dtw_per_joint",
|
|
||||||
distances=[0.1, 0.2],
|
|
||||||
paths=[[(0, 0)], [(0, 0)]],
|
|
||||||
per_joint_distances=[[0.05, 0.05], [0.1, 0.1]],
|
|
||||||
segment_labels=["rhee_cycles[0]", "rhee_cycles[1]"],
|
|
||||||
summary={"mean": 0.15},
|
|
||||||
)
|
|
||||||
assert res.per_joint_distances is not None
|
|
||||||
|
|
||||||
|
|
||||||
class TestStatsResults:
|
|
||||||
def test_round_trip(self) -> None:
|
|
||||||
res = StatsResults(
|
|
||||||
kind="stats",
|
|
||||||
statistics=[
|
|
||||||
FeatureSummary(mean=1.0, std=0.1, min=0.8, max=1.2, range=0.4),
|
|
||||||
FeatureSummary(mean=1.1, std=0.2, min=0.7, max=1.5, range=0.8),
|
|
||||||
],
|
|
||||||
segment_labels=["rhee_cycles[0]", "rhee_cycles[1]"],
|
|
||||||
)
|
|
||||||
dumped = res.model_dump_json()
|
|
||||||
restored = StatsResults.model_validate_json(dumped)
|
|
||||||
assert restored == res
|
|
||||||
|
|
||||||
|
|
||||||
class TestNoResults:
|
|
||||||
def test_construction(self) -> None:
|
|
||||||
res = NoResults(kind="none")
|
|
||||||
assert res.kind == "none"
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# AnalysisReport
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def _make_report(tmp_path: Path) -> AnalysisReport:
|
|
||||||
config = AnalysisConfig.model_validate(_minimal_dtw_config(tmp_path))
|
|
||||||
return AnalysisReport(
|
|
||||||
config=config,
|
|
||||||
primary=InputSummary(
|
|
||||||
path=tmp_path / "primary.json",
|
|
||||||
frame_count=300,
|
|
||||||
fps=30.0,
|
|
||||||
),
|
|
||||||
reference=InputSummary(
|
|
||||||
path=tmp_path / "reference.json",
|
|
||||||
frame_count=300,
|
|
||||||
fps=30.0,
|
|
||||||
),
|
|
||||||
results=DtwResults(
|
|
||||||
kind="dtw",
|
|
||||||
method="dtw_all",
|
|
||||||
distances=[0.42],
|
|
||||||
paths=[[(0, 0), (1, 1)]],
|
|
||||||
segment_labels=["full_trial"],
|
|
||||||
summary={"mean": 0.42, "p50": 0.42},
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestAnalysisReport:
|
|
||||||
def test_schema_version_defaults_to_current(self, tmp_path: Path) -> None:
|
|
||||||
report = _make_report(tmp_path)
|
|
||||||
assert report.schema_version == CURRENT_VERSION
|
|
||||||
|
|
||||||
def test_round_trip_json(self, tmp_path: Path) -> None:
|
|
||||||
report = _make_report(tmp_path)
|
|
||||||
serialised = report.model_dump_json()
|
|
||||||
restored = AnalysisReport.model_validate_json(serialised)
|
|
||||||
assert restored == report
|
|
||||||
|
|
||||||
def test_empty_segmentations_default(self, tmp_path: Path) -> None:
|
|
||||||
report = _make_report(tmp_path)
|
|
||||||
assert report.segmentations == {}
|
|
||||||
|
|
||||||
def test_extra_field_rejected(self, tmp_path: Path) -> None:
|
|
||||||
report = _make_report(tmp_path)
|
|
||||||
dumped = report.model_dump(mode="json")
|
|
||||||
dumped["mystery_field"] = 1
|
|
||||||
with pytest.raises(ValidationError):
|
|
||||||
AnalysisReport.model_validate(dumped)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Executor helpers
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
NUM_JOINTS = 43
|
|
||||||
|
|
||||||
|
|
||||||
def _heel_signal(num_cycles: int, frames_per_cycle: int, amplitude: float = 100.0) -> np.ndarray:
|
|
||||||
"""Clean sinusoid stand-in for a heel's vertical trace."""
|
|
||||||
import math
|
|
||||||
|
|
||||||
total = num_cycles * frames_per_cycle
|
|
||||||
t = np.linspace(0.0, num_cycles * 2.0 * math.pi, total, endpoint=False)
|
|
||||||
return (np.sin(t) * amplitude + amplitude).astype(float)
|
|
||||||
|
|
||||||
|
|
||||||
def _build_predictions(
|
|
||||||
signal: np.ndarray,
|
|
||||||
joint: int,
|
|
||||||
*,
|
|
||||||
axis: int = 1,
|
|
||||||
fps: float = 30.0,
|
|
||||||
provenance: Provenance | None = None,
|
|
||||||
) -> VideoPredictions:
|
|
||||||
"""Build a VideoPredictions whose ``joint``'s ``axis`` follows ``signal``."""
|
|
||||||
frames = {}
|
|
||||||
for i, value in enumerate(signal):
|
|
||||||
poses = [[[0.0, 0.0, 0.0] for _ in range(NUM_JOINTS)]]
|
|
||||||
poses[0][joint][axis] = float(value)
|
|
||||||
frames[f"frame_{i:06d}"] = {
|
|
||||||
"boxes": [[0.0, 0.0, 1.0, 1.0, 0.9]],
|
|
||||||
"poses3d": poses,
|
|
||||||
"poses2d": [[[0.0, 0.0]] * NUM_JOINTS],
|
|
||||||
}
|
|
||||||
return VideoPredictions.model_validate(
|
|
||||||
{
|
|
||||||
"metadata": {
|
|
||||||
"frame_count": len(signal),
|
|
||||||
"fps": fps,
|
|
||||||
"width": 640,
|
|
||||||
"height": 480,
|
|
||||||
},
|
|
||||||
"frames": frames,
|
|
||||||
"provenance": provenance.model_dump() if provenance is not None else None,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _write_heel_trial(
|
|
||||||
tmp_path: Path,
|
|
||||||
filename: str,
|
|
||||||
*,
|
|
||||||
joint: int,
|
|
||||||
num_cycles: int = 4,
|
|
||||||
frames_per_cycle: int = 30,
|
|
||||||
amplitude: float = 100.0,
|
|
||||||
provenance: Provenance | None = None,
|
|
||||||
) -> Path:
|
|
||||||
"""Write a heel-trace VideoPredictions JSON and return its path."""
|
|
||||||
signal = _heel_signal(num_cycles, frames_per_cycle, amplitude=amplitude)
|
|
||||||
preds = _build_predictions(signal, joint=joint, provenance=provenance)
|
|
||||||
path = tmp_path / filename
|
|
||||||
save_video_predictions(path, preds)
|
|
||||||
return path
|
|
||||||
|
|
||||||
|
|
||||||
def _fake_provenance(sha: str = "a" * 64) -> Provenance:
|
|
||||||
return Provenance(
|
|
||||||
model_sha256=sha,
|
|
||||||
model_filename="fake_model.tar.gz",
|
|
||||||
tensorflow_version="2.18.0",
|
|
||||||
numpy_version="1.26.0",
|
|
||||||
neuropose_version="0.0.0",
|
|
||||||
python_version="3.11.0",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Executor: run_analysis
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestRunAnalysisDtwFullTrial:
|
|
||||||
def test_dtw_all_unsegmented_yields_one_distance(self, tmp_path: Path) -> None:
|
|
||||||
from neuropose.analyzer.segment import JOINT_INDEX
|
|
||||||
|
|
||||||
primary = _write_heel_trial(tmp_path, "a.json", joint=JOINT_INDEX["rhee"])
|
|
||||||
reference = _write_heel_trial(tmp_path, "b.json", joint=JOINT_INDEX["rhee"])
|
|
||||||
report_path = tmp_path / "report.json"
|
|
||||||
config = AnalysisConfig.model_validate(
|
|
||||||
{
|
|
||||||
"inputs": {"primary": str(primary), "reference": str(reference)},
|
|
||||||
"analysis": {"kind": "dtw", "method": "dtw_all"},
|
|
||||||
"output": {"report": str(report_path)},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
report = run_analysis(config)
|
|
||||||
assert isinstance(report.results, DtwResults)
|
|
||||||
assert report.results.segment_labels == ["full_trial"]
|
|
||||||
assert len(report.results.distances) == 1
|
|
||||||
# Identical inputs → distance 0.
|
|
||||||
assert report.results.distances[0] == pytest.approx(0.0, abs=1e-9)
|
|
||||||
|
|
||||||
def test_dtw_all_different_trials_positive_distance(self, tmp_path: Path) -> None:
|
|
||||||
from neuropose.analyzer.segment import JOINT_INDEX
|
|
||||||
|
|
||||||
primary = _write_heel_trial(tmp_path, "a.json", joint=JOINT_INDEX["rhee"], amplitude=100.0)
|
|
||||||
reference = _write_heel_trial(
|
|
||||||
tmp_path, "b.json", joint=JOINT_INDEX["rhee"], amplitude=200.0
|
|
||||||
)
|
|
||||||
config = AnalysisConfig.model_validate(
|
|
||||||
{
|
|
||||||
"inputs": {"primary": str(primary), "reference": str(reference)},
|
|
||||||
"analysis": {"kind": "dtw", "method": "dtw_all"},
|
|
||||||
"output": {"report": str(tmp_path / "r.json")},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
report = run_analysis(config)
|
|
||||||
assert isinstance(report.results, DtwResults)
|
|
||||||
assert report.results.distances[0] > 0.0
|
|
||||||
assert "mean" in report.results.summary
|
|
||||||
|
|
||||||
|
|
||||||
class TestRunAnalysisDtwSegmented:
|
|
||||||
def test_dtw_with_gait_cycles_produces_per_segment_distances(self, tmp_path: Path) -> None:
|
|
||||||
from neuropose.analyzer.segment import JOINT_INDEX
|
|
||||||
|
|
||||||
primary = _write_heel_trial(tmp_path, "a.json", joint=JOINT_INDEX["rhee"], num_cycles=4)
|
|
||||||
reference = _write_heel_trial(tmp_path, "b.json", joint=JOINT_INDEX["rhee"], num_cycles=4)
|
|
||||||
config = AnalysisConfig.model_validate(
|
|
||||||
{
|
|
||||||
"inputs": {"primary": str(primary), "reference": str(reference)},
|
|
||||||
"segmentation": {"kind": "gait_cycles", "joint": "rhee"},
|
|
||||||
"analysis": {"kind": "dtw", "method": "dtw_all"},
|
|
||||||
"output": {"report": str(tmp_path / "r.json")},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
report = run_analysis(config)
|
|
||||||
assert isinstance(report.results, DtwResults)
|
|
||||||
# 4 cycles detected on both → 4 paired distances.
|
|
||||||
assert len(report.results.distances) == 4
|
|
||||||
assert all(label.startswith("rhee_cycles[") for label in report.results.segment_labels)
|
|
||||||
|
|
||||||
def test_dtw_bilateral_produces_distances_per_side(self, tmp_path: Path) -> None:
|
|
||||||
from neuropose.analyzer.segment import JOINT_INDEX
|
|
||||||
|
|
||||||
# Put the heel trace on both lhee and rhee.
|
|
||||||
rng_signal = _heel_signal(num_cycles=3, frames_per_cycle=30)
|
|
||||||
frames = {}
|
|
||||||
for i, value in enumerate(rng_signal):
|
|
||||||
poses = [[[0.0, 0.0, 0.0] for _ in range(NUM_JOINTS)]]
|
|
||||||
poses[0][JOINT_INDEX["lhee"]][1] = float(value)
|
|
||||||
poses[0][JOINT_INDEX["rhee"]][1] = float(value)
|
|
||||||
frames[f"frame_{i:06d}"] = {
|
|
||||||
"boxes": [[0.0, 0.0, 1.0, 1.0, 0.9]],
|
|
||||||
"poses3d": poses,
|
|
||||||
"poses2d": [[[0.0, 0.0]] * NUM_JOINTS],
|
|
||||||
}
|
|
||||||
preds = VideoPredictions.model_validate(
|
|
||||||
{
|
|
||||||
"metadata": {
|
|
||||||
"frame_count": len(rng_signal),
|
|
||||||
"fps": 30.0,
|
|
||||||
"width": 640,
|
|
||||||
"height": 480,
|
|
||||||
},
|
|
||||||
"frames": frames,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
primary = tmp_path / "a.json"
|
|
||||||
reference = tmp_path / "b.json"
|
|
||||||
save_video_predictions(primary, preds)
|
|
||||||
save_video_predictions(reference, preds)
|
|
||||||
config = AnalysisConfig.model_validate(
|
|
||||||
{
|
|
||||||
"inputs": {"primary": str(primary), "reference": str(reference)},
|
|
||||||
"segmentation": {"kind": "gait_cycles_bilateral"},
|
|
||||||
"analysis": {"kind": "dtw", "method": "dtw_all"},
|
|
||||||
"output": {"report": str(tmp_path / "r.json")},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
report = run_analysis(config)
|
|
||||||
assert isinstance(report.results, DtwResults)
|
|
||||||
# 3 cycles * 2 sides.
|
|
||||||
assert len(report.results.distances) == 6
|
|
||||||
left = [lbl for lbl in report.results.segment_labels if lbl.startswith("left_heel_strikes")]
|
|
||||||
right = [
|
|
||||||
lbl for lbl in report.results.segment_labels if lbl.startswith("right_heel_strikes")
|
|
||||||
]
|
|
||||||
assert len(left) == 3
|
|
||||||
assert len(right) == 3
|
|
||||||
# Identical primary and reference → all distances zero.
|
|
||||||
for d in report.results.distances:
|
|
||||||
assert d == pytest.approx(0.0, abs=1e-9)
|
|
||||||
|
|
||||||
def test_dtw_per_joint_populates_per_joint_distances(self, tmp_path: Path) -> None:
|
|
||||||
from neuropose.analyzer.segment import JOINT_INDEX
|
|
||||||
|
|
||||||
primary = _write_heel_trial(tmp_path, "a.json", joint=JOINT_INDEX["rhee"])
|
|
||||||
reference = _write_heel_trial(tmp_path, "b.json", joint=JOINT_INDEX["rhee"])
|
|
||||||
config = AnalysisConfig.model_validate(
|
|
||||||
{
|
|
||||||
"inputs": {"primary": str(primary), "reference": str(reference)},
|
|
||||||
"analysis": {"kind": "dtw", "method": "dtw_per_joint"},
|
|
||||||
"output": {"report": str(tmp_path / "r.json")},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
report = run_analysis(config)
|
|
||||||
assert isinstance(report.results, DtwResults)
|
|
||||||
assert report.results.per_joint_distances is not None
|
|
||||||
assert len(report.results.per_joint_distances) == 1 # unsegmented → one pair
|
|
||||||
assert len(report.results.per_joint_distances[0]) == NUM_JOINTS
|
|
||||||
|
|
||||||
|
|
||||||
class TestRunAnalysisStats:
|
|
||||||
def test_stats_unsegmented_single_block(self, tmp_path: Path) -> None:
|
|
||||||
from neuropose.analyzer.segment import JOINT_INDEX
|
|
||||||
|
|
||||||
primary = _write_heel_trial(tmp_path, "a.json", joint=JOINT_INDEX["rhee"])
|
|
||||||
config = AnalysisConfig.model_validate(
|
|
||||||
{
|
|
||||||
"inputs": {"primary": str(primary)},
|
|
||||||
"analysis": {
|
|
||||||
"kind": "stats",
|
|
||||||
"extractor": {
|
|
||||||
"kind": "joint_axis",
|
|
||||||
"joint": JOINT_INDEX["rhee"],
|
|
||||||
"axis": 1,
|
|
||||||
"invert": False,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"output": {"report": str(tmp_path / "r.json")},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
report = run_analysis(config)
|
|
||||||
assert isinstance(report.results, StatsResults)
|
|
||||||
assert report.results.segment_labels == ["full_trial"]
|
|
||||||
assert len(report.results.statistics) == 1
|
|
||||||
stat = report.results.statistics[0]
|
|
||||||
assert isinstance(stat, FeatureSummary)
|
|
||||||
assert stat.max > stat.min # Signal oscillates.
|
|
||||||
|
|
||||||
def test_stats_with_segmentation_emits_per_segment(self, tmp_path: Path) -> None:
|
|
||||||
from neuropose.analyzer.segment import JOINT_INDEX
|
|
||||||
|
|
||||||
primary = _write_heel_trial(tmp_path, "a.json", joint=JOINT_INDEX["rhee"], num_cycles=3)
|
|
||||||
config = AnalysisConfig.model_validate(
|
|
||||||
{
|
|
||||||
"inputs": {"primary": str(primary)},
|
|
||||||
"segmentation": {"kind": "gait_cycles", "joint": "rhee"},
|
|
||||||
"analysis": {
|
|
||||||
"kind": "stats",
|
|
||||||
"extractor": {
|
|
||||||
"kind": "joint_axis",
|
|
||||||
"joint": JOINT_INDEX["rhee"],
|
|
||||||
"axis": 1,
|
|
||||||
"invert": False,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"output": {"report": str(tmp_path / "r.json")},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
report = run_analysis(config)
|
|
||||||
assert isinstance(report.results, StatsResults)
|
|
||||||
assert len(report.results.statistics) == 3
|
|
||||||
|
|
||||||
|
|
||||||
class TestRunAnalysisNone:
|
|
||||||
def test_none_analysis_returns_no_results(self, tmp_path: Path) -> None:
|
|
||||||
from neuropose.analyzer.segment import JOINT_INDEX
|
|
||||||
|
|
||||||
primary = _write_heel_trial(tmp_path, "a.json", joint=JOINT_INDEX["rhee"])
|
|
||||||
config = AnalysisConfig.model_validate(
|
|
||||||
{
|
|
||||||
"inputs": {"primary": str(primary)},
|
|
||||||
"analysis": {"kind": "none"},
|
|
||||||
"output": {"report": str(tmp_path / "r.json")},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
report = run_analysis(config)
|
|
||||||
assert isinstance(report.results, NoResults)
|
|
||||||
|
|
||||||
def test_none_with_segmentation_still_emits_segmentations(self, tmp_path: Path) -> None:
|
|
||||||
from neuropose.analyzer.segment import JOINT_INDEX
|
|
||||||
|
|
||||||
primary = _write_heel_trial(tmp_path, "a.json", joint=JOINT_INDEX["rhee"])
|
|
||||||
config = AnalysisConfig.model_validate(
|
|
||||||
{
|
|
||||||
"inputs": {"primary": str(primary)},
|
|
||||||
"segmentation": {"kind": "gait_cycles", "joint": "rhee"},
|
|
||||||
"analysis": {"kind": "none"},
|
|
||||||
"output": {"report": str(tmp_path / "r.json")},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
report = run_analysis(config)
|
|
||||||
assert isinstance(report.results, NoResults)
|
|
||||||
assert "rhee_cycles" in report.segmentations
|
|
||||||
assert len(report.segmentations["rhee_cycles"].segments) > 0
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Provenance inheritance
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestRunAnalysisProvenance:
|
|
||||||
def test_inherits_primary_provenance_and_stamps_config(self, tmp_path: Path) -> None:
|
|
||||||
from neuropose.analyzer.segment import JOINT_INDEX
|
|
||||||
|
|
||||||
provenance = _fake_provenance()
|
|
||||||
primary = _write_heel_trial(
|
|
||||||
tmp_path, "a.json", joint=JOINT_INDEX["rhee"], provenance=provenance
|
|
||||||
)
|
|
||||||
reference = _write_heel_trial(tmp_path, "b.json", joint=JOINT_INDEX["rhee"])
|
|
||||||
config = AnalysisConfig.model_validate(
|
|
||||||
{
|
|
||||||
"inputs": {"primary": str(primary), "reference": str(reference)},
|
|
||||||
"analysis": {"kind": "dtw", "method": "dtw_all"},
|
|
||||||
"output": {"report": str(tmp_path / "r.json")},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
report = run_analysis(config)
|
|
||||||
assert report.provenance is not None
|
|
||||||
# Model SHA inherited from primary.
|
|
||||||
assert report.provenance.model_sha256 == provenance.model_sha256
|
|
||||||
# analysis_config populated with the serialised config.
|
|
||||||
assert report.provenance.analysis_config is not None
|
|
||||||
assert report.provenance.analysis_config["config_version"] == 1
|
|
||||||
|
|
||||||
def test_no_primary_provenance_yields_none_report_provenance(self, tmp_path: Path) -> None:
|
|
||||||
from neuropose.analyzer.segment import JOINT_INDEX
|
|
||||||
|
|
||||||
primary = _write_heel_trial(tmp_path, "a.json", joint=JOINT_INDEX["rhee"])
|
|
||||||
reference = _write_heel_trial(tmp_path, "b.json", joint=JOINT_INDEX["rhee"])
|
|
||||||
config = AnalysisConfig.model_validate(
|
|
||||||
{
|
|
||||||
"inputs": {"primary": str(primary), "reference": str(reference)},
|
|
||||||
"analysis": {"kind": "dtw", "method": "dtw_all"},
|
|
||||||
"output": {"report": str(tmp_path / "r.json")},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
report = run_analysis(config)
|
|
||||||
assert report.provenance is None
|
|
||||||
|
|
||||||
def test_input_summaries_track_paths_and_metadata(self, tmp_path: Path) -> None:
|
|
||||||
from neuropose.analyzer.segment import JOINT_INDEX
|
|
||||||
|
|
||||||
primary = _write_heel_trial(tmp_path, "a.json", joint=JOINT_INDEX["rhee"], num_cycles=5)
|
|
||||||
reference = _write_heel_trial(tmp_path, "b.json", joint=JOINT_INDEX["rhee"], num_cycles=3)
|
|
||||||
config = AnalysisConfig.model_validate(
|
|
||||||
{
|
|
||||||
"inputs": {"primary": str(primary), "reference": str(reference)},
|
|
||||||
"analysis": {"kind": "dtw", "method": "dtw_all"},
|
|
||||||
"output": {"report": str(tmp_path / "r.json")},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
report = run_analysis(config)
|
|
||||||
assert report.primary.path == primary
|
|
||||||
assert report.primary.frame_count == 5 * 30
|
|
||||||
assert report.reference is not None
|
|
||||||
assert report.reference.frame_count == 3 * 30
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# load_config / save_report / load_report
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestLoadSave:
|
|
||||||
def test_load_config_parses_yaml(self, tmp_path: Path) -> None:
|
|
||||||
config_dict = {
|
|
||||||
"inputs": {
|
|
||||||
"primary": str(tmp_path / "a.json"),
|
|
||||||
"reference": str(tmp_path / "b.json"),
|
|
||||||
},
|
|
||||||
"analysis": {"kind": "dtw", "method": "dtw_all"},
|
|
||||||
"output": {"report": str(tmp_path / "r.json")},
|
|
||||||
}
|
|
||||||
yaml_path = tmp_path / "exp.yaml"
|
|
||||||
yaml_path.write_text(yaml.safe_dump(config_dict))
|
|
||||||
loaded = load_config(yaml_path)
|
|
||||||
assert isinstance(loaded.analysis, DtwAnalysis)
|
|
||||||
|
|
||||||
def test_load_config_empty_file_fails_cleanly(self, tmp_path: Path) -> None:
|
|
||||||
empty = tmp_path / "empty.yaml"
|
|
||||||
empty.write_text("")
|
|
||||||
with pytest.raises(ValidationError):
|
|
||||||
load_config(empty)
|
|
||||||
|
|
||||||
def test_load_config_rejects_malformed_yaml(self, tmp_path: Path) -> None:
|
|
||||||
bad = tmp_path / "bad.yaml"
|
|
||||||
# Unclosed flow-style mapping — yaml.safe_load raises here.
|
|
||||||
bad.write_text("inputs: {primary: foo\n")
|
|
||||||
with pytest.raises(yaml.YAMLError):
|
|
||||||
load_config(bad)
|
|
||||||
|
|
||||||
def test_save_report_round_trip(self, tmp_path: Path) -> None:
|
|
||||||
from neuropose.analyzer.segment import JOINT_INDEX
|
|
||||||
|
|
||||||
primary = _write_heel_trial(tmp_path, "a.json", joint=JOINT_INDEX["rhee"])
|
|
||||||
reference = _write_heel_trial(tmp_path, "b.json", joint=JOINT_INDEX["rhee"])
|
|
||||||
config = AnalysisConfig.model_validate(
|
|
||||||
{
|
|
||||||
"inputs": {"primary": str(primary), "reference": str(reference)},
|
|
||||||
"analysis": {"kind": "dtw", "method": "dtw_all"},
|
|
||||||
"output": {"report": str(tmp_path / "report.json")},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
report = run_analysis(config)
|
|
||||||
report_path = tmp_path / "report.json"
|
|
||||||
save_report(report_path, report)
|
|
||||||
assert report_path.exists()
|
|
||||||
|
|
||||||
restored = load_report(report_path)
|
|
||||||
assert restored == report
|
|
||||||
|
|
||||||
def test_save_report_is_atomic(self, tmp_path: Path) -> None:
|
|
||||||
"""The saver writes via a sibling .tmp path and renames."""
|
|
||||||
report = _make_report(tmp_path)
|
|
||||||
report_path = tmp_path / "subdir" / "report.json"
|
|
||||||
save_report(report_path, report)
|
|
||||||
# Parent directory was created.
|
|
||||||
assert report_path.exists()
|
|
||||||
# No .tmp sibling left behind.
|
|
||||||
assert not (report_path.with_suffix(report_path.suffix + ".tmp")).exists()
|
|
||||||
|
|
||||||
def test_load_report_rejects_future_schema(self, tmp_path: Path) -> None:
|
|
||||||
"""Future schema_version surfaces as a migration error."""
|
|
||||||
from neuropose.migrations import FutureSchemaError
|
|
||||||
|
|
||||||
future = {"schema_version": CURRENT_VERSION + 1}
|
|
||||||
path = tmp_path / "future.json"
|
|
||||||
path.write_text(json.dumps(future))
|
|
||||||
with pytest.raises(FutureSchemaError):
|
|
||||||
load_report(path)
|
|
||||||
|
|
@ -33,8 +33,6 @@ from neuropose.analyzer.segment import (
|
||||||
joint_pair_distance,
|
joint_pair_distance,
|
||||||
joint_speed,
|
joint_speed,
|
||||||
segment_by_peaks,
|
segment_by_peaks,
|
||||||
segment_gait_cycles,
|
|
||||||
segment_gait_cycles_bilateral,
|
|
||||||
segment_predictions,
|
segment_predictions,
|
||||||
slice_predictions,
|
slice_predictions,
|
||||||
)
|
)
|
||||||
|
|
@ -431,139 +429,3 @@ class TestSlicePredictions:
|
||||||
sliced = slice_predictions(preds, segments)[0]
|
sliced = slice_predictions(preds, segments)[0]
|
||||||
# frame_000000 of the slice must equal frame_000050 of the source
|
# frame_000000 of the slice must equal frame_000050 of the source
|
||||||
assert sliced["frame_000000"].poses3d == preds["frame_000050"].poses3d
|
assert sliced["frame_000000"].poses3d == preds["frame_000050"].poses3d
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# segment_gait_cycles / segment_gait_cycles_bilateral
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def _heel_signal(num_cycles: int, frames_per_cycle: int) -> np.ndarray:
|
|
||||||
"""A clean sinusoid standing in for a heel's vertical trace."""
|
|
||||||
total = num_cycles * frames_per_cycle
|
|
||||||
t = np.linspace(0.0, num_cycles * 2.0 * math.pi, total, endpoint=False)
|
|
||||||
# Amplitude chosen so min_prominence tests have a non-trivial range.
|
|
||||||
return (np.sin(t) * 100.0 + 100.0).astype(float)
|
|
||||||
|
|
||||||
|
|
||||||
class TestSegmentGaitCycles:
|
|
||||||
def test_detects_expected_number_of_cycles(self) -> None:
|
|
||||||
# 5 cycles at 30 fps, 30 frames per cycle = 1.0 s per stride.
|
|
||||||
# Well inside the default min_cycle_seconds=0.4 gate.
|
|
||||||
signal = _heel_signal(num_cycles=5, frames_per_cycle=30)
|
|
||||||
preds = _make_predictions(signal, joint=JOINT_INDEX["rhee"])
|
|
||||||
seg = segment_gait_cycles(preds, joint="rhee", axis="y")
|
|
||||||
assert len(seg.segments) == 5
|
|
||||||
|
|
||||||
def test_config_records_inputs(self) -> None:
|
|
||||||
signal = _heel_signal(num_cycles=3, frames_per_cycle=30)
|
|
||||||
preds = _make_predictions(signal, joint=JOINT_INDEX["rhee"])
|
|
||||||
seg = segment_gait_cycles(
|
|
||||||
preds,
|
|
||||||
joint="rhee",
|
|
||||||
axis="y",
|
|
||||||
min_cycle_seconds=0.5,
|
|
||||||
min_prominence=10.0,
|
|
||||||
)
|
|
||||||
assert isinstance(seg, Segmentation)
|
|
||||||
assert isinstance(seg.config.extractor, JointAxisExtractor)
|
|
||||||
assert seg.config.extractor.joint == JOINT_INDEX["rhee"]
|
|
||||||
assert seg.config.extractor.axis == 1 # "y" → 1
|
|
||||||
assert seg.config.extractor.invert is False
|
|
||||||
assert seg.config.min_distance_seconds == 0.5
|
|
||||||
assert seg.config.min_prominence == 10.0
|
|
||||||
|
|
||||||
def test_axis_selection(self) -> None:
|
|
||||||
# Put the signal on the X axis instead of Y.
|
|
||||||
signal = _heel_signal(num_cycles=4, frames_per_cycle=30)
|
|
||||||
preds = _make_predictions(signal, joint=JOINT_INDEX["rhee"], axis=0)
|
|
||||||
seg_y = segment_gait_cycles(preds, joint="rhee", axis="y")
|
|
||||||
seg_x = segment_gait_cycles(preds, joint="rhee", axis="x")
|
|
||||||
# Y is all-zeros (flat → no peaks), X carries the signal.
|
|
||||||
assert len(seg_y.segments) == 0
|
|
||||||
assert len(seg_x.segments) == 4
|
|
||||||
|
|
||||||
def test_invert_flips_peaks_and_valleys(self) -> None:
|
|
||||||
# Invert the heel trace; with invert=True, the original valleys
|
|
||||||
# become the peaks detected as heel-strikes.
|
|
||||||
signal = _heel_signal(num_cycles=4, frames_per_cycle=30)
|
|
||||||
preds = _make_predictions(signal, joint=JOINT_INDEX["rhee"])
|
|
||||||
seg_plain = segment_gait_cycles(preds, joint="rhee", axis="y", invert=False)
|
|
||||||
seg_inverted = segment_gait_cycles(preds, joint="rhee", axis="y", invert=True)
|
|
||||||
# Both detect four distinct events (peaks in either the signal
|
|
||||||
# or its negation). Peaks differ by roughly half a cycle.
|
|
||||||
assert len(seg_plain.segments) == 4
|
|
||||||
assert len(seg_inverted.segments) == 4
|
|
||||||
plain_peaks = [s.peak for s in seg_plain.segments]
|
|
||||||
inverted_peaks = [s.peak for s in seg_inverted.segments]
|
|
||||||
assert plain_peaks != inverted_peaks
|
|
||||||
|
|
||||||
def test_pathological_flat_signal_returns_empty(self) -> None:
|
|
||||||
# A subject whose heel never leaves the ground — no peaks.
|
|
||||||
signal = np.zeros(120)
|
|
||||||
preds = _make_predictions(signal, joint=JOINT_INDEX["rhee"])
|
|
||||||
seg = segment_gait_cycles(preds, joint="rhee", axis="y")
|
|
||||||
assert seg.segments == []
|
|
||||||
|
|
||||||
def test_min_cycle_seconds_rejects_close_peaks(self) -> None:
|
|
||||||
# 10 cycles in 60 frames @ 30 fps = 0.2 s per cycle.
|
|
||||||
# min_cycle_seconds=0.4 should reject all but every-other peak.
|
|
||||||
signal = _heel_signal(num_cycles=10, frames_per_cycle=6)
|
|
||||||
preds = _make_predictions(signal, joint=JOINT_INDEX["rhee"])
|
|
||||||
seg_permissive = segment_gait_cycles(preds, joint="rhee", min_cycle_seconds=0.0)
|
|
||||||
seg_strict = segment_gait_cycles(preds, joint="rhee", min_cycle_seconds=0.4)
|
|
||||||
# Strict mode drops peaks that are too close together.
|
|
||||||
assert len(seg_strict.segments) < len(seg_permissive.segments)
|
|
||||||
|
|
||||||
def test_unknown_joint_raises_key_error(self) -> None:
|
|
||||||
signal = _heel_signal(num_cycles=3, frames_per_cycle=30)
|
|
||||||
preds = _make_predictions(signal, joint=JOINT_INDEX["rhee"])
|
|
||||||
with pytest.raises(KeyError, match="unknown joint"):
|
|
||||||
segment_gait_cycles(preds, joint="left_heel") # wrong name
|
|
||||||
|
|
||||||
def test_invalid_axis_raises_value_error(self) -> None:
|
|
||||||
signal = _heel_signal(num_cycles=3, frames_per_cycle=30)
|
|
||||||
preds = _make_predictions(signal, joint=JOINT_INDEX["rhee"])
|
|
||||||
with pytest.raises(ValueError, match="axis must be one of"):
|
|
||||||
segment_gait_cycles(preds, joint="rhee", axis="w") # type: ignore[arg-type]
|
|
||||||
|
|
||||||
|
|
||||||
class TestSegmentGaitCyclesBilateral:
|
|
||||||
def test_returns_both_keys(self) -> None:
|
|
||||||
signal = _heel_signal(num_cycles=3, frames_per_cycle=30)
|
|
||||||
# Put the same signal on both heels so both sides find cycles.
|
|
||||||
preds = _make_predictions(signal, joint=JOINT_INDEX["rhee"])
|
|
||||||
# Rebuild predictions with lhee populated too.
|
|
||||||
frames = {}
|
|
||||||
for i, value in enumerate(signal):
|
|
||||||
poses = [[[0.0, 0.0, 0.0] for _ in range(NUM_JOINTS)]]
|
|
||||||
poses[0][JOINT_INDEX["lhee"]][1] = float(value)
|
|
||||||
poses[0][JOINT_INDEX["rhee"]][1] = float(value)
|
|
||||||
frames[f"frame_{i:06d}"] = {
|
|
||||||
"boxes": [[0.0, 0.0, 1.0, 1.0, 0.9]],
|
|
||||||
"poses3d": poses,
|
|
||||||
"poses2d": [[[0.0, 0.0]] * NUM_JOINTS],
|
|
||||||
}
|
|
||||||
preds = VideoPredictions.model_validate(
|
|
||||||
{
|
|
||||||
"metadata": {
|
|
||||||
"frame_count": len(signal),
|
|
||||||
"fps": 30.0,
|
|
||||||
"width": 640,
|
|
||||||
"height": 480,
|
|
||||||
},
|
|
||||||
"frames": frames,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
result = segment_gait_cycles_bilateral(preds)
|
|
||||||
assert set(result.keys()) == {"left_heel_strikes", "right_heel_strikes"}
|
|
||||||
assert len(result["left_heel_strikes"].segments) == 3
|
|
||||||
assert len(result["right_heel_strikes"].segments) == 3
|
|
||||||
|
|
||||||
def test_pathological_one_side_returns_empty_for_that_side(self) -> None:
|
|
||||||
# Only the right heel carries a signal; left heel is flat.
|
|
||||||
signal = _heel_signal(num_cycles=3, frames_per_cycle=30)
|
|
||||||
preds = _make_predictions(signal, joint=JOINT_INDEX["rhee"])
|
|
||||||
result = segment_gait_cycles_bilateral(preds)
|
|
||||||
assert len(result["right_heel_strikes"].segments) == 3
|
|
||||||
assert result["left_heel_strikes"].segments == []
|
|
||||||
|
|
|
||||||
|
|
@ -683,15 +683,9 @@ def stub_estimator_with_metrics(monkeypatch: pytest.MonkeyPatch):
|
||||||
"poses2d": np.array([[[0.0, 0.0], [1.0, 1.0]]]),
|
"poses2d": np.array([[[0.0, 0.0], [1.0, 1.0]]]),
|
||||||
}
|
}
|
||||||
|
|
||||||
from neuropose._model import LoadedModel
|
def fake_loader(cache_dir: Path | None = None) -> object:
|
||||||
|
|
||||||
def fake_loader(cache_dir: Path | None = None) -> LoadedModel:
|
|
||||||
del cache_dir
|
del cache_dir
|
||||||
return LoadedModel(
|
return RecordingFake()
|
||||||
model=RecordingFake(),
|
|
||||||
sha256="smoke_sha",
|
|
||||||
filename="metrabs_smoke.tar.gz",
|
|
||||||
)
|
|
||||||
|
|
||||||
monkeypatch.setattr("neuropose.estimator.load_metrabs_model", fake_loader)
|
monkeypatch.setattr("neuropose.estimator.load_metrabs_model", fake_loader)
|
||||||
|
|
||||||
|
|
@ -776,142 +770,17 @@ class TestBenchmarkSubcommand:
|
||||||
|
|
||||||
|
|
||||||
class TestAnalyze:
|
class TestAnalyze:
|
||||||
"""Covers the ``neuropose analyze --config <yaml>`` subcommand.
|
def test_analyze_stub_exits_with_pending_message(
|
||||||
|
|
||||||
Execution happy path is exercised in detail in
|
|
||||||
:mod:`tests.unit.test_analyzer_pipeline` — this file focuses on
|
|
||||||
the CLI wiring: argument parsing, config-loading error modes, and
|
|
||||||
end-to-end smoke.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _make_predictions_file(self, tmp_path: Path, name: str, num_frames: int = 30) -> Path:
|
|
||||||
"""Write a trivial VideoPredictions file to disk for the CLI to load."""
|
|
||||||
import math
|
|
||||||
|
|
||||||
from neuropose.io import VideoPredictions, save_video_predictions
|
|
||||||
|
|
||||||
num_joints = 43
|
|
||||||
frames = {}
|
|
||||||
for i in range(num_frames):
|
|
||||||
poses = [[[0.0, 0.0, 0.0] for _ in range(num_joints)]]
|
|
||||||
poses[0][41][1] = float(math.sin(i * 0.3)) * 100.0 # rhee Y
|
|
||||||
frames[f"frame_{i:06d}"] = {
|
|
||||||
"boxes": [[0.0, 0.0, 1.0, 1.0, 0.9]],
|
|
||||||
"poses3d": poses,
|
|
||||||
"poses2d": [[[0.0, 0.0]] * num_joints],
|
|
||||||
}
|
|
||||||
preds = VideoPredictions.model_validate(
|
|
||||||
{
|
|
||||||
"metadata": {
|
|
||||||
"frame_count": num_frames,
|
|
||||||
"fps": 30.0,
|
|
||||||
"width": 640,
|
|
||||||
"height": 480,
|
|
||||||
},
|
|
||||||
"frames": frames,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
path = tmp_path / name
|
|
||||||
save_video_predictions(path, preds)
|
|
||||||
return path
|
|
||||||
|
|
||||||
def _write_dtw_config(
|
|
||||||
self,
|
|
||||||
tmp_path: Path,
|
|
||||||
*,
|
|
||||||
primary: Path,
|
|
||||||
reference: Path,
|
|
||||||
report: Path,
|
|
||||||
) -> Path:
|
|
||||||
import yaml as _yaml
|
|
||||||
|
|
||||||
config_path = tmp_path / "config.yaml"
|
|
||||||
config_path.write_text(
|
|
||||||
_yaml.safe_dump(
|
|
||||||
{
|
|
||||||
"inputs": {"primary": str(primary), "reference": str(reference)},
|
|
||||||
"analysis": {"kind": "dtw", "method": "dtw_all"},
|
|
||||||
"output": {"report": str(report)},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return config_path
|
|
||||||
|
|
||||||
def test_missing_config_is_usage_error(self, runner: CliRunner, tmp_path: Path) -> None:
|
|
||||||
result = runner.invoke(app, ["analyze", "--config", str(tmp_path / "nope.yaml")])
|
|
||||||
assert result.exit_code == EXIT_USAGE
|
|
||||||
assert "config file not found" in result.output
|
|
||||||
|
|
||||||
def test_missing_config_flag_is_usage_error(self, runner: CliRunner) -> None:
|
|
||||||
result = runner.invoke(app, ["analyze"])
|
|
||||||
assert result.exit_code == EXIT_USAGE
|
|
||||||
|
|
||||||
def test_invalid_yaml_is_usage_error(self, runner: CliRunner, tmp_path: Path) -> None:
|
|
||||||
bad = tmp_path / "bad.yaml"
|
|
||||||
bad.write_text("inputs: {primary: foo\n") # unclosed flow mapping
|
|
||||||
result = runner.invoke(app, ["analyze", "--config", str(bad)])
|
|
||||||
assert result.exit_code == EXIT_USAGE
|
|
||||||
assert "could not parse YAML" in result.output
|
|
||||||
|
|
||||||
def test_schema_violation_is_usage_error(self, runner: CliRunner, tmp_path: Path) -> None:
|
|
||||||
import yaml as _yaml
|
|
||||||
|
|
||||||
bad = tmp_path / "schema.yaml"
|
|
||||||
bad.write_text(
|
|
||||||
_yaml.safe_dump(
|
|
||||||
{
|
|
||||||
"inputs": {"primary": str(tmp_path / "a.json")},
|
|
||||||
# dtw without reference — violates cross-field invariant.
|
|
||||||
"analysis": {"kind": "dtw", "method": "dtw_all"},
|
|
||||||
"output": {"report": str(tmp_path / "r.json")},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
result = runner.invoke(app, ["analyze", "--config", str(bad)])
|
|
||||||
assert result.exit_code == EXIT_USAGE
|
|
||||||
assert "invalid config" in result.output
|
|
||||||
|
|
||||||
def test_happy_path_writes_report(self, runner: CliRunner, tmp_path: Path) -> None:
|
|
||||||
primary = self._make_predictions_file(tmp_path, "a.json")
|
|
||||||
reference = self._make_predictions_file(tmp_path, "b.json")
|
|
||||||
report_path = tmp_path / "report.json"
|
|
||||||
config = self._write_dtw_config(
|
|
||||||
tmp_path, primary=primary, reference=reference, report=report_path
|
|
||||||
)
|
|
||||||
result = runner.invoke(app, ["analyze", "--config", str(config)])
|
|
||||||
assert result.exit_code == EXIT_OK, result.output
|
|
||||||
assert report_path.exists()
|
|
||||||
assert "wrote analysis report" in result.output
|
|
||||||
assert "analysis kind: dtw" in result.output
|
|
||||||
|
|
||||||
def test_output_option_overrides_config_path(self, runner: CliRunner, tmp_path: Path) -> None:
|
|
||||||
primary = self._make_predictions_file(tmp_path, "a.json")
|
|
||||||
reference = self._make_predictions_file(tmp_path, "b.json")
|
|
||||||
# Config points at one report path ...
|
|
||||||
config = self._write_dtw_config(
|
|
||||||
tmp_path,
|
|
||||||
primary=primary,
|
|
||||||
reference=reference,
|
|
||||||
report=tmp_path / "declared.json",
|
|
||||||
)
|
|
||||||
# ... but --output overrides.
|
|
||||||
override = tmp_path / "override.json"
|
|
||||||
result = runner.invoke(app, ["analyze", "--config", str(config), "--output", str(override)])
|
|
||||||
assert result.exit_code == EXIT_OK, result.output
|
|
||||||
assert override.exists()
|
|
||||||
assert not (tmp_path / "declared.json").exists()
|
|
||||||
|
|
||||||
def test_missing_predictions_file_is_usage_error(
|
|
||||||
self, runner: CliRunner, tmp_path: Path
|
self, runner: CliRunner, tmp_path: Path
|
||||||
) -> None:
|
) -> None:
|
||||||
# Config points at a primary that does not exist.
|
results_path = tmp_path / "results.json"
|
||||||
config = self._write_dtw_config(
|
results_path.write_text("{}")
|
||||||
tmp_path,
|
result = runner.invoke(app, ["analyze", str(results_path)])
|
||||||
primary=tmp_path / "missing_primary.json",
|
assert result.exit_code == EXIT_PENDING
|
||||||
reference=tmp_path / "missing_reference.json",
|
assert "commit 10" in result.output
|
||||||
report=tmp_path / "report.json",
|
|
||||||
)
|
def test_analyze_requires_an_argument(self, runner: CliRunner) -> None:
|
||||||
result = runner.invoke(app, ["analyze", "--config", str(config)])
|
result = runner.invoke(app, ["analyze"])
|
||||||
assert result.exit_code == EXIT_USAGE
|
assert result.exit_code == EXIT_USAGE
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -70,21 +70,17 @@ class TestModelGuard:
|
||||||
network: the loader is monkeypatched to return a sentinel, and we
|
network: the loader is monkeypatched to return a sentinel, and we
|
||||||
assert it ends up as the estimator's model.
|
assert it ends up as the estimator's model.
|
||||||
"""
|
"""
|
||||||
from neuropose._model import LoadedModel
|
|
||||||
|
|
||||||
sentinel = object()
|
sentinel = object()
|
||||||
called_with: list[Path | None] = []
|
called_with: list[Path | None] = []
|
||||||
|
|
||||||
def fake_loader(cache_dir: Path | None = None) -> LoadedModel:
|
def fake_loader(cache_dir: Path | None = None) -> object:
|
||||||
called_with.append(cache_dir)
|
called_with.append(cache_dir)
|
||||||
return LoadedModel(model=sentinel, sha256="deadbeef", filename="fake.tar.gz")
|
return sentinel
|
||||||
|
|
||||||
monkeypatch.setattr("neuropose.estimator.load_metrabs_model", fake_loader)
|
monkeypatch.setattr("neuropose.estimator.load_metrabs_model", fake_loader)
|
||||||
estimator = Estimator()
|
estimator = Estimator()
|
||||||
estimator.load_model(cache_dir=Path("/tmp/fake-cache"))
|
estimator.load_model(cache_dir=Path("/tmp/fake-cache"))
|
||||||
assert estimator.model is sentinel
|
assert estimator.model is sentinel
|
||||||
assert estimator.model_sha256 == "deadbeef"
|
|
||||||
assert estimator.model_filename == "fake.tar.gz"
|
|
||||||
assert called_with == [Path("/tmp/fake-cache")]
|
assert called_with == [Path("/tmp/fake-cache")]
|
||||||
|
|
||||||
def test_load_model_is_idempotent_when_already_loaded(
|
def test_load_model_is_idempotent_when_already_loaded(
|
||||||
|
|
@ -282,15 +278,9 @@ class TestPerformanceMetrics:
|
||||||
"poses2d": np.array([[[0.0, 0.0]]]),
|
"poses2d": np.array([[[0.0, 0.0]]]),
|
||||||
}
|
}
|
||||||
|
|
||||||
from neuropose._model import LoadedModel
|
def fake_loader(cache_dir: Path | None = None) -> object:
|
||||||
|
|
||||||
def fake_loader(cache_dir: Path | None = None) -> LoadedModel:
|
|
||||||
del cache_dir
|
del cache_dir
|
||||||
return LoadedModel(
|
return Recorder()
|
||||||
model=Recorder(),
|
|
||||||
sha256="fake_sha",
|
|
||||||
filename="metrabs_fake.tar.gz",
|
|
||||||
)
|
|
||||||
|
|
||||||
monkeypatch.setattr("neuropose.estimator.load_metrabs_model", fake_loader)
|
monkeypatch.setattr("neuropose.estimator.load_metrabs_model", fake_loader)
|
||||||
estimator = Estimator()
|
estimator = Estimator()
|
||||||
|
|
@ -322,88 +312,6 @@ class TestPerformanceMetrics:
|
||||||
assert result.metrics.tensorflow_version not in {"", "unknown"}
|
assert result.metrics.tensorflow_version not in {"", "unknown"}
|
||||||
|
|
||||||
|
|
||||||
class TestProvenance:
|
|
||||||
"""Provenance attachment to VideoPredictions.
|
|
||||||
|
|
||||||
Covers the two relevant paths: the injected-model path (no SHA
|
|
||||||
known → ``provenance=None`` on output) and the ``load_model`` path
|
|
||||||
(SHA is known → full ``Provenance`` populated and attached).
|
|
||||||
"""
|
|
||||||
|
|
||||||
def test_injected_model_produces_no_provenance(
|
|
||||||
self,
|
|
||||||
synthetic_video: Path,
|
|
||||||
fake_metrabs_model,
|
|
||||||
) -> None:
|
|
||||||
estimator = Estimator(model=fake_metrabs_model)
|
|
||||||
result = estimator.process_video(synthetic_video)
|
|
||||||
assert result.predictions.provenance is None
|
|
||||||
assert estimator.model_sha256 is None
|
|
||||||
assert estimator.model_filename is None
|
|
||||||
|
|
||||||
def test_loaded_model_populates_provenance(
|
|
||||||
self,
|
|
||||||
synthetic_video: Path,
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
) -> None:
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from neuropose._model import LoadedModel
|
|
||||||
|
|
||||||
class Recorder:
|
|
||||||
def detect_poses(self, image, **kwargs):
|
|
||||||
del image, kwargs
|
|
||||||
return {
|
|
||||||
"boxes": np.array([[0.0, 0.0, 1.0, 1.0, 0.9]]),
|
|
||||||
"poses3d": np.array([[[0.0, 0.0, 0.0]]]),
|
|
||||||
"poses2d": np.array([[[0.0, 0.0]]]),
|
|
||||||
}
|
|
||||||
|
|
||||||
def fake_loader(cache_dir: Path | None = None) -> LoadedModel:
|
|
||||||
del cache_dir
|
|
||||||
return LoadedModel(
|
|
||||||
model=Recorder(),
|
|
||||||
sha256="e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
|
|
||||||
filename="metrabs_stub.tar.gz",
|
|
||||||
)
|
|
||||||
|
|
||||||
monkeypatch.setattr("neuropose.estimator.load_metrabs_model", fake_loader)
|
|
||||||
estimator = Estimator()
|
|
||||||
estimator.load_model()
|
|
||||||
result = estimator.process_video(synthetic_video)
|
|
||||||
|
|
||||||
prov = result.predictions.provenance
|
|
||||||
assert prov is not None
|
|
||||||
assert prov.model_sha256.startswith("e3b0c44")
|
|
||||||
assert prov.model_filename == "metrabs_stub.tar.gz"
|
|
||||||
assert prov.numpy_version == np.__version__
|
|
||||||
assert prov.python_version.count(".") == 2 # MAJOR.MINOR.MICRO
|
|
||||||
# neuropose_version should match the package's __version__
|
|
||||||
from neuropose import __version__ as pkg_version
|
|
||||||
|
|
||||||
assert prov.neuropose_version == pkg_version
|
|
||||||
# tensorflow_version should also be real (TF is in dev deps).
|
|
||||||
assert prov.tensorflow_version not in {"", "unknown"}
|
|
||||||
|
|
||||||
def test_model_sha256_and_filename_properties_after_load(
|
|
||||||
self,
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
) -> None:
|
|
||||||
from neuropose._model import LoadedModel
|
|
||||||
|
|
||||||
def fake_loader(cache_dir: Path | None = None) -> LoadedModel:
|
|
||||||
del cache_dir
|
|
||||||
return LoadedModel(model=object(), sha256="abcd", filename="x.tar.gz")
|
|
||||||
|
|
||||||
monkeypatch.setattr("neuropose.estimator.load_metrabs_model", fake_loader)
|
|
||||||
estimator = Estimator()
|
|
||||||
assert estimator.model_sha256 is None
|
|
||||||
assert estimator.model_filename is None
|
|
||||||
estimator.load_model()
|
|
||||||
assert estimator.model_sha256 == "abcd"
|
|
||||||
assert estimator.model_filename == "x.tar.gz"
|
|
||||||
|
|
||||||
|
|
||||||
class TestErrors:
|
class TestErrors:
|
||||||
def test_missing_video(
|
def test_missing_video(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,6 @@ from neuropose.io import (
|
||||||
JointPairDistanceExtractor,
|
JointPairDistanceExtractor,
|
||||||
JointSpeedExtractor,
|
JointSpeedExtractor,
|
||||||
PerformanceMetrics,
|
PerformanceMetrics,
|
||||||
Provenance,
|
|
||||||
Segment,
|
Segment,
|
||||||
Segmentation,
|
Segmentation,
|
||||||
SegmentationConfig,
|
SegmentationConfig,
|
||||||
|
|
@ -279,102 +278,6 @@ class TestPerformanceMetricsModel:
|
||||||
m.total_seconds = 2.0
|
m.total_seconds = 2.0
|
||||||
|
|
||||||
|
|
||||||
def _minimal_provenance() -> Provenance:
|
|
||||||
return Provenance(
|
|
||||||
model_sha256="a" * 64,
|
|
||||||
model_filename="metrabs_fake.tar.gz",
|
|
||||||
tensorflow_version="2.18.1",
|
|
||||||
numpy_version="2.0.2",
|
|
||||||
neuropose_version="0.1.0.dev0",
|
|
||||||
python_version="3.11.14",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestProvenanceModel:
|
|
||||||
"""Schema-level behaviour of :class:`neuropose.io.Provenance`."""
|
|
||||||
|
|
||||||
def test_roundtrip_through_json(self) -> None:
|
|
||||||
p = Provenance(
|
|
||||||
model_sha256="a" * 64,
|
|
||||||
model_filename="metrabs_fake.tar.gz",
|
|
||||||
tensorflow_version="2.18.1",
|
|
||||||
tensorflow_metal_version="1.2.0",
|
|
||||||
numpy_version="2.0.2",
|
|
||||||
neuropose_version="0.1.0.dev0",
|
|
||||||
python_version="3.11.14",
|
|
||||||
seed=42,
|
|
||||||
deterministic=True,
|
|
||||||
analysis_config={"step": "dtw", "nan_policy": "propagate"},
|
|
||||||
)
|
|
||||||
rehydrated = Provenance.model_validate(p.model_dump(mode="json"))
|
|
||||||
assert rehydrated == p
|
|
||||||
|
|
||||||
def test_optional_fields_default_to_none_and_false(self) -> None:
|
|
||||||
p = _minimal_provenance()
|
|
||||||
assert p.tensorflow_metal_version is None
|
|
||||||
assert p.seed is None
|
|
||||||
assert p.deterministic is False
|
|
||||||
assert p.analysis_config is None
|
|
||||||
|
|
||||||
def test_is_frozen(self) -> None:
|
|
||||||
p = _minimal_provenance()
|
|
||||||
with pytest.raises(ValidationError):
|
|
||||||
p.model_sha256 = "different"
|
|
||||||
|
|
||||||
def test_extra_fields_forbidden(self) -> None:
|
|
||||||
# Construct via model_validate so pyright doesn't have to prove the
|
|
||||||
# keyword doesn't exist on the class at static-type time.
|
|
||||||
with pytest.raises(ValidationError):
|
|
||||||
Provenance.model_validate(
|
|
||||||
{
|
|
||||||
"model_sha256": "x" * 64,
|
|
||||||
"model_filename": "x.tar.gz",
|
|
||||||
"tensorflow_version": "2.18",
|
|
||||||
"numpy_version": "2.0",
|
|
||||||
"neuropose_version": "0.1",
|
|
||||||
"python_version": "3.11.14",
|
|
||||||
"unknown_field": "bogus",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestVideoPredictionsProvenance:
|
|
||||||
"""``provenance`` field on :class:`VideoPredictions` round-trips."""
|
|
||||||
|
|
||||||
def test_default_is_none(self) -> None:
|
|
||||||
vp = VideoPredictions(
|
|
||||||
metadata=VideoMetadata(frame_count=0, fps=30.0, width=32, height=32),
|
|
||||||
frames={},
|
|
||||||
)
|
|
||||||
assert vp.provenance is None
|
|
||||||
|
|
||||||
def test_roundtrip_with_provenance(self, tmp_path: Path) -> None:
|
|
||||||
prov = Provenance(
|
|
||||||
model_sha256="f" * 64,
|
|
||||||
model_filename="metrabs.tar.gz",
|
|
||||||
tensorflow_version="2.18.1",
|
|
||||||
numpy_version="2.0.2",
|
|
||||||
neuropose_version="0.1.0.dev0",
|
|
||||||
python_version="3.11.14",
|
|
||||||
)
|
|
||||||
vp = VideoPredictions(
|
|
||||||
metadata=VideoMetadata(frame_count=1, fps=30.0, width=32, height=32),
|
|
||||||
frames={
|
|
||||||
"frame_000000": FramePrediction(
|
|
||||||
boxes=[[0.0, 0.0, 32.0, 32.0, 0.9]],
|
|
||||||
poses3d=[[[1.0, 2.0, 3.0]]],
|
|
||||||
poses2d=[[[10.0, 20.0]]],
|
|
||||||
)
|
|
||||||
},
|
|
||||||
provenance=prov,
|
|
||||||
)
|
|
||||||
path = tmp_path / "vp.json"
|
|
||||||
save_video_predictions(path, vp)
|
|
||||||
loaded = load_video_predictions(path)
|
|
||||||
assert loaded == vp
|
|
||||||
assert loaded.provenance == prov
|
|
||||||
|
|
||||||
|
|
||||||
class TestBenchmarkResultPersistence:
|
class TestBenchmarkResultPersistence:
|
||||||
def test_roundtrip_to_disk(self, tmp_path: Path) -> None:
|
def test_roundtrip_to_disk(self, tmp_path: Path) -> None:
|
||||||
result = BenchmarkResult(
|
result = BenchmarkResult(
|
||||||
|
|
|
||||||
|
|
@ -1,488 +0,0 @@
|
||||||
"""Tests for :mod:`neuropose.migrations`.
|
|
||||||
|
|
||||||
Covers both the low-level migration driver (version walking, future/missing
|
|
||||||
errors, INFO logging) and its integration through the
|
|
||||||
:mod:`neuropose.io` load helpers (legacy payloads round-trip; future
|
|
||||||
payloads fail with a clear message).
|
|
||||||
|
|
||||||
The migration driver is tested by monkey-patching ``CURRENT_VERSION`` and
|
|
||||||
the per-schema migration registries, so the tests exercise the full
|
|
||||||
chain-walking machinery without needing the codebase to actually be on a
|
|
||||||
non-initial schema version.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from neuropose import migrations
|
|
||||||
from neuropose.io import (
|
|
||||||
BenchmarkResult,
|
|
||||||
FramePrediction,
|
|
||||||
JobResults,
|
|
||||||
VideoMetadata,
|
|
||||||
VideoPredictions,
|
|
||||||
load_benchmark_result,
|
|
||||||
load_job_results,
|
|
||||||
load_video_predictions,
|
|
||||||
save_benchmark_result,
|
|
||||||
save_job_results,
|
|
||||||
save_video_predictions,
|
|
||||||
)
|
|
||||||
from neuropose.migrations import (
|
|
||||||
CURRENT_VERSION,
|
|
||||||
FutureSchemaError,
|
|
||||||
MigrationError,
|
|
||||||
MigrationNotFoundError,
|
|
||||||
migrate_analysis_report,
|
|
||||||
migrate_benchmark_result,
|
|
||||||
migrate_job_results,
|
|
||||||
migrate_video_predictions,
|
|
||||||
)
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Fixtures
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def _minimal_video_predictions_payload() -> dict:
|
|
||||||
"""A valid VideoPredictions payload at the current schema version."""
|
|
||||||
return {
|
|
||||||
"schema_version": CURRENT_VERSION,
|
|
||||||
"metadata": {
|
|
||||||
"frame_count": 1,
|
|
||||||
"fps": 30.0,
|
|
||||||
"width": 32,
|
|
||||||
"height": 32,
|
|
||||||
},
|
|
||||||
"frames": {
|
|
||||||
"frame_000000": {
|
|
||||||
"boxes": [[0.0, 0.0, 32.0, 32.0, 0.95]],
|
|
||||||
"poses3d": [[[1.0, 2.0, 3.0]]],
|
|
||||||
"poses2d": [[[10.0, 20.0]]],
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"segmentations": {},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _minimal_video_predictions_object() -> VideoPredictions:
|
|
||||||
"""Same payload, as a validated pydantic object."""
|
|
||||||
return VideoPredictions(
|
|
||||||
metadata=VideoMetadata(frame_count=1, fps=30.0, width=32, height=32),
|
|
||||||
frames={
|
|
||||||
"frame_000000": FramePrediction(
|
|
||||||
boxes=[[0.0, 0.0, 32.0, 32.0, 0.95]],
|
|
||||||
poses3d=[[[1.0, 2.0, 3.0]]],
|
|
||||||
poses2d=[[[10.0, 20.0]]],
|
|
||||||
)
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def fake_two_version_chain(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
"""Patch the module to look like CURRENT_VERSION=2 with a v1->v2 migration.
|
|
||||||
|
|
||||||
Lets the tests exercise the full migration loop even though the real
|
|
||||||
codebase is still at CURRENT_VERSION=1.
|
|
||||||
"""
|
|
||||||
monkeypatch.setattr(migrations, "CURRENT_VERSION", 2)
|
|
||||||
|
|
||||||
def _v1_to_v2(payload: dict) -> dict:
|
|
||||||
payload = dict(payload)
|
|
||||||
payload["schema_version"] = 2
|
|
||||||
payload["added_in_v2"] = "hello"
|
|
||||||
return payload
|
|
||||||
|
|
||||||
monkeypatch.setattr(migrations, "_VIDEO_PREDICTIONS_MIGRATIONS", {1: _v1_to_v2})
|
|
||||||
monkeypatch.setattr(migrations, "_BENCHMARK_RESULT_MIGRATIONS", {1: _v1_to_v2})
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def fake_three_version_chain(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
"""Patch the module to look like CURRENT_VERSION=3 with v1->v2 and v2->v3.
|
|
||||||
|
|
||||||
Exercises multi-step migration chaining.
|
|
||||||
"""
|
|
||||||
monkeypatch.setattr(migrations, "CURRENT_VERSION", 3)
|
|
||||||
|
|
||||||
def _v1_to_v2(payload: dict) -> dict:
|
|
||||||
payload = dict(payload)
|
|
||||||
payload["schema_version"] = 2
|
|
||||||
payload["added_in_v2"] = "alpha"
|
|
||||||
return payload
|
|
||||||
|
|
||||||
def _v2_to_v3(payload: dict) -> dict:
|
|
||||||
payload = dict(payload)
|
|
||||||
payload["schema_version"] = 3
|
|
||||||
payload["added_in_v3"] = "beta"
|
|
||||||
return payload
|
|
||||||
|
|
||||||
monkeypatch.setattr(
|
|
||||||
migrations,
|
|
||||||
"_VIDEO_PREDICTIONS_MIGRATIONS",
|
|
||||||
{1: _v1_to_v2, 2: _v2_to_v3},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# migrate_video_predictions — driver behavior
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestMigrateVideoPredictions:
|
|
||||||
def test_current_version_payload_is_noop(self) -> None:
|
|
||||||
payload = {"schema_version": CURRENT_VERSION, "hello": "world"}
|
|
||||||
result = migrate_video_predictions(payload)
|
|
||||||
assert result == payload
|
|
||||||
|
|
||||||
def test_missing_version_key_treated_as_v1(self) -> None:
|
|
||||||
"""A payload with no schema_version is treated as version 1.
|
|
||||||
|
|
||||||
With CURRENT_VERSION == 2, the legacy payload is run through
|
|
||||||
the registered v1 → v2 migration (which stamps ``provenance =
|
|
||||||
None``) on the way to the current version.
|
|
||||||
"""
|
|
||||||
payload = {"hello": "world"}
|
|
||||||
result = migrate_video_predictions(payload)
|
|
||||||
assert result["hello"] == "world"
|
|
||||||
assert result["schema_version"] == CURRENT_VERSION
|
|
||||||
assert result["provenance"] is None
|
|
||||||
|
|
||||||
def test_future_version_raises(self) -> None:
|
|
||||||
payload = {"schema_version": CURRENT_VERSION + 99}
|
|
||||||
with pytest.raises(FutureSchemaError, match="newer than"):
|
|
||||||
migrate_video_predictions(payload)
|
|
||||||
|
|
||||||
def test_non_integer_version_raises(self) -> None:
|
|
||||||
payload = {"schema_version": "1.0"}
|
|
||||||
with pytest.raises(MigrationError, match="invalid schema_version"):
|
|
||||||
migrate_video_predictions(payload)
|
|
||||||
|
|
||||||
def test_zero_version_raises(self) -> None:
|
|
||||||
payload = {"schema_version": 0}
|
|
||||||
with pytest.raises(MigrationError, match="invalid schema_version"):
|
|
||||||
migrate_video_predictions(payload)
|
|
||||||
|
|
||||||
def test_single_step_migration(self, fake_two_version_chain: None) -> None:
|
|
||||||
del fake_two_version_chain
|
|
||||||
payload = {"schema_version": 1, "original_field": "keep_me"}
|
|
||||||
result = migrate_video_predictions(payload)
|
|
||||||
assert result == {
|
|
||||||
"schema_version": 2,
|
|
||||||
"original_field": "keep_me",
|
|
||||||
"added_in_v2": "hello",
|
|
||||||
}
|
|
||||||
|
|
||||||
def test_missing_version_under_patched_chain_migrates_from_v1(
|
|
||||||
self, fake_two_version_chain: None
|
|
||||||
) -> None:
|
|
||||||
del fake_two_version_chain
|
|
||||||
# Legacy file with no version stamp: should be treated as v1 and
|
|
||||||
# upgraded to v2.
|
|
||||||
payload = {"legacy": True}
|
|
||||||
result = migrate_video_predictions(payload)
|
|
||||||
assert result["schema_version"] == 2
|
|
||||||
assert result["added_in_v2"] == "hello"
|
|
||||||
assert result["legacy"] is True
|
|
||||||
|
|
||||||
def test_multi_step_migration_chains(self, fake_three_version_chain: None) -> None:
|
|
||||||
del fake_three_version_chain
|
|
||||||
payload = {"schema_version": 1, "original": "yes"}
|
|
||||||
result = migrate_video_predictions(payload)
|
|
||||||
assert result == {
|
|
||||||
"schema_version": 3,
|
|
||||||
"original": "yes",
|
|
||||||
"added_in_v2": "alpha",
|
|
||||||
"added_in_v3": "beta",
|
|
||||||
}
|
|
||||||
|
|
||||||
def test_missing_intermediate_migration_raises(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
"""If CURRENT advances past a version with no registered migration, fail loud."""
|
|
||||||
monkeypatch.setattr(migrations, "CURRENT_VERSION", 3)
|
|
||||||
# Only v1 -> v2 registered; v2 -> v3 is the missing link.
|
|
||||||
monkeypatch.setattr(
|
|
||||||
migrations,
|
|
||||||
"_VIDEO_PREDICTIONS_MIGRATIONS",
|
|
||||||
{1: lambda p: {**p, "schema_version": 2}},
|
|
||||||
)
|
|
||||||
with pytest.raises(MigrationNotFoundError, match="from schema_version 2"):
|
|
||||||
migrate_video_predictions({"schema_version": 1})
|
|
||||||
|
|
||||||
def test_logs_at_info_on_migration(
|
|
||||||
self,
|
|
||||||
fake_two_version_chain: None,
|
|
||||||
caplog: pytest.LogCaptureFixture,
|
|
||||||
) -> None:
|
|
||||||
del fake_two_version_chain
|
|
||||||
caplog.set_level(logging.INFO, logger="neuropose.migrations")
|
|
||||||
migrate_video_predictions({"schema_version": 1})
|
|
||||||
assert any("Migrating VideoPredictions" in record.message for record in caplog.records)
|
|
||||||
|
|
||||||
def test_starting_from_current_logs_nothing(
|
|
||||||
self,
|
|
||||||
fake_two_version_chain: None,
|
|
||||||
caplog: pytest.LogCaptureFixture,
|
|
||||||
) -> None:
|
|
||||||
del fake_two_version_chain
|
|
||||||
caplog.set_level(logging.INFO, logger="neuropose.migrations")
|
|
||||||
migrate_video_predictions({"schema_version": 2})
|
|
||||||
assert not any("Migrating" in record.message for record in caplog.records)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# migrate_benchmark_result — same driver, sibling registry
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestMigrateBenchmarkResult:
|
|
||||||
def test_uses_benchmark_registry_not_video_registry(
|
|
||||||
self, monkeypatch: pytest.MonkeyPatch
|
|
||||||
) -> None:
|
|
||||||
"""Each schema has its own migration registry; they must not cross-pollinate."""
|
|
||||||
monkeypatch.setattr(migrations, "CURRENT_VERSION", 2)
|
|
||||||
# Register video migration but NOT benchmark migration.
|
|
||||||
monkeypatch.setattr(
|
|
||||||
migrations,
|
|
||||||
"_VIDEO_PREDICTIONS_MIGRATIONS",
|
|
||||||
{1: lambda p: {**p, "schema_version": 2, "from_video_registry": True}},
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(migrations, "_BENCHMARK_RESULT_MIGRATIONS", {})
|
|
||||||
# Video migration works:
|
|
||||||
assert migrate_video_predictions({"schema_version": 1})["from_video_registry"] is True
|
|
||||||
# Benchmark migration should fail — no entry in its registry.
|
|
||||||
with pytest.raises(MigrationNotFoundError):
|
|
||||||
migrate_benchmark_result({"schema_version": 1})
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# migrate_job_results — per-entry dispatch
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestMigrateJobResults:
|
|
||||||
def test_empty_dict_is_noop(self) -> None:
|
|
||||||
assert migrate_job_results({}) == {}
|
|
||||||
|
|
||||||
def test_each_video_is_migrated(self, fake_two_version_chain: None) -> None:
|
|
||||||
del fake_two_version_chain
|
|
||||||
payload = {
|
|
||||||
"video_a.mp4": {"schema_version": 1, "content_a": True},
|
|
||||||
"video_b.mp4": {"schema_version": 1, "content_b": True},
|
|
||||||
}
|
|
||||||
result = migrate_job_results(payload)
|
|
||||||
assert result["video_a.mp4"]["schema_version"] == 2
|
|
||||||
assert result["video_a.mp4"]["content_a"] is True
|
|
||||||
assert result["video_a.mp4"]["added_in_v2"] == "hello"
|
|
||||||
assert result["video_b.mp4"]["schema_version"] == 2
|
|
||||||
assert result["video_b.mp4"]["content_b"] is True
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# migrate_analysis_report
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestMigrateAnalysisReport:
|
|
||||||
def test_current_version_is_noop(self) -> None:
|
|
||||||
"""A payload already at CURRENT_VERSION passes through unchanged."""
|
|
||||||
payload = {"schema_version": CURRENT_VERSION, "foo": "bar"}
|
|
||||||
assert migrate_analysis_report(payload) == payload
|
|
||||||
|
|
||||||
def test_missing_version_defaults_to_v1_and_fails(self) -> None:
|
|
||||||
"""AnalysisReport first shipped at CURRENT_VERSION, so a payload
|
|
||||||
without schema_version (defaulting to 1) would require a
|
|
||||||
non-existent v1→v2 migration and fail with a clear error."""
|
|
||||||
with pytest.raises(MigrationNotFoundError, match="AnalysisReport"):
|
|
||||||
migrate_analysis_report({})
|
|
||||||
|
|
||||||
def test_future_version_rejected(self) -> None:
|
|
||||||
with pytest.raises(FutureSchemaError, match="AnalysisReport"):
|
|
||||||
migrate_analysis_report({"schema_version": CURRENT_VERSION + 5})
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# register_video_predictions_migration / register_benchmark_result_migration
|
|
||||||
# / register_analysis_report_migration
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestRegistration:
|
|
||||||
def test_duplicate_registration_raises(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
monkeypatch.setattr(migrations, "_VIDEO_PREDICTIONS_MIGRATIONS", {})
|
|
||||||
|
|
||||||
@migrations.register_video_predictions_migration(from_version=1)
|
|
||||||
def _first(p: dict) -> dict:
|
|
||||||
return p
|
|
||||||
|
|
||||||
with pytest.raises(RuntimeError, match="already registered"):
|
|
||||||
|
|
||||||
@migrations.register_video_predictions_migration(from_version=1)
|
|
||||||
def _second(p: dict) -> dict:
|
|
||||||
return p
|
|
||||||
|
|
||||||
def test_decorator_returns_callable_unchanged(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
"""The decorator must not wrap or rename the function."""
|
|
||||||
monkeypatch.setattr(migrations, "_VIDEO_PREDICTIONS_MIGRATIONS", {})
|
|
||||||
|
|
||||||
@migrations.register_video_predictions_migration(from_version=1)
|
|
||||||
def _fn(p: dict) -> dict:
|
|
||||||
return p
|
|
||||||
|
|
||||||
assert _fn.__name__ == "_fn"
|
|
||||||
assert _fn({"x": 1}) == {"x": 1}
|
|
||||||
|
|
||||||
def test_analysis_report_duplicate_registration_raises(
|
|
||||||
self, monkeypatch: pytest.MonkeyPatch
|
|
||||||
) -> None:
|
|
||||||
monkeypatch.setattr(migrations, "_ANALYSIS_REPORT_MIGRATIONS", {})
|
|
||||||
|
|
||||||
@migrations.register_analysis_report_migration(from_version=2)
|
|
||||||
def _first(p: dict) -> dict:
|
|
||||||
return p
|
|
||||||
|
|
||||||
with pytest.raises(RuntimeError, match="already registered"):
|
|
||||||
|
|
||||||
@migrations.register_analysis_report_migration(from_version=2)
|
|
||||||
def _second(p: dict) -> dict:
|
|
||||||
return p
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Integration: load_* functions run migrations before validation
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestLoadIntegration:
|
|
||||||
def test_load_video_predictions_accepts_legacy_payload(self, tmp_path: Path) -> None:
|
|
||||||
"""A VideoPredictions JSON written before schema_version existed loads cleanly."""
|
|
||||||
legacy = _minimal_video_predictions_payload()
|
|
||||||
del legacy["schema_version"] # Pretend this file predates versioning.
|
|
||||||
path = tmp_path / "legacy.json"
|
|
||||||
path.write_text(json.dumps(legacy))
|
|
||||||
|
|
||||||
loaded = load_video_predictions(path)
|
|
||||||
assert loaded.schema_version == CURRENT_VERSION
|
|
||||||
|
|
||||||
def test_load_video_predictions_rejects_future_version(self, tmp_path: Path) -> None:
|
|
||||||
payload = _minimal_video_predictions_payload()
|
|
||||||
payload["schema_version"] = CURRENT_VERSION + 42
|
|
||||||
path = tmp_path / "future.json"
|
|
||||||
path.write_text(json.dumps(payload))
|
|
||||||
|
|
||||||
with pytest.raises(FutureSchemaError):
|
|
||||||
load_video_predictions(path)
|
|
||||||
|
|
||||||
def test_save_then_load_roundtrips(self, tmp_path: Path) -> None:
|
|
||||||
obj = _minimal_video_predictions_object()
|
|
||||||
path = tmp_path / "out.json"
|
|
||||||
save_video_predictions(path, obj)
|
|
||||||
loaded = load_video_predictions(path)
|
|
||||||
assert loaded == obj
|
|
||||||
assert loaded.schema_version == CURRENT_VERSION
|
|
||||||
|
|
||||||
def test_load_job_results_migrates_each_video(self, tmp_path: Path) -> None:
|
|
||||||
video_a = _minimal_video_predictions_payload()
|
|
||||||
video_b = _minimal_video_predictions_payload()
|
|
||||||
# Strip schema_version from both to simulate legacy file.
|
|
||||||
del video_a["schema_version"]
|
|
||||||
del video_b["schema_version"]
|
|
||||||
payload = {"a.mp4": video_a, "b.mp4": video_b}
|
|
||||||
path = tmp_path / "job.json"
|
|
||||||
path.write_text(json.dumps(payload))
|
|
||||||
|
|
||||||
loaded = load_job_results(path)
|
|
||||||
assert len(loaded) == 2
|
|
||||||
for video in ("a.mp4", "b.mp4"):
|
|
||||||
assert loaded[video].schema_version == CURRENT_VERSION
|
|
||||||
|
|
||||||
def test_save_then_load_job_results_roundtrips(self, tmp_path: Path) -> None:
|
|
||||||
obj = JobResults(root={"video_a.mp4": _minimal_video_predictions_object()})
|
|
||||||
path = tmp_path / "job.json"
|
|
||||||
save_job_results(path, obj)
|
|
||||||
loaded = load_job_results(path)
|
|
||||||
assert loaded == obj
|
|
||||||
|
|
||||||
def test_load_benchmark_result_roundtrips(self, tmp_path: Path) -> None:
|
|
||||||
"""Save → load round-trip for a realistic benchmark result."""
|
|
||||||
from neuropose.io import BenchmarkAggregate, PerformanceMetrics
|
|
||||||
|
|
||||||
metrics = PerformanceMetrics(
|
|
||||||
total_seconds=1.0,
|
|
||||||
per_frame_latencies_ms=[10.0, 11.0],
|
|
||||||
peak_rss_mb=100.0,
|
|
||||||
active_device="/CPU:0",
|
|
||||||
tensorflow_version="2.18.0",
|
|
||||||
)
|
|
||||||
aggregate = BenchmarkAggregate(
|
|
||||||
repeats_measured=1,
|
|
||||||
warmup_frames_per_pass=0,
|
|
||||||
mean_frame_latency_ms=10.5,
|
|
||||||
p50_frame_latency_ms=10.5,
|
|
||||||
p95_frame_latency_ms=11.0,
|
|
||||||
p99_frame_latency_ms=11.0,
|
|
||||||
stddev_frame_latency_ms=0.5,
|
|
||||||
mean_throughput_fps=95.0,
|
|
||||||
peak_rss_mb_max=100.0,
|
|
||||||
active_device="/CPU:0",
|
|
||||||
tensorflow_version="2.18.0",
|
|
||||||
)
|
|
||||||
result = BenchmarkResult(
|
|
||||||
video_name="test.mp4",
|
|
||||||
repeats=2,
|
|
||||||
warmup_frames=0,
|
|
||||||
warmup_pass=metrics,
|
|
||||||
measured_passes=[metrics],
|
|
||||||
aggregate=aggregate,
|
|
||||||
)
|
|
||||||
path = tmp_path / "bench.json"
|
|
||||||
save_benchmark_result(path, result)
|
|
||||||
loaded = load_benchmark_result(path)
|
|
||||||
assert loaded == result
|
|
||||||
assert loaded.schema_version == CURRENT_VERSION
|
|
||||||
|
|
||||||
def test_load_benchmark_result_rejects_future_version(self, tmp_path: Path) -> None:
|
|
||||||
"""Future-versioned benchmark file should raise with a clear message."""
|
|
||||||
from neuropose.io import BenchmarkAggregate, PerformanceMetrics
|
|
||||||
|
|
||||||
metrics = PerformanceMetrics(
|
|
||||||
total_seconds=1.0,
|
|
||||||
per_frame_latencies_ms=[10.0],
|
|
||||||
peak_rss_mb=100.0,
|
|
||||||
active_device="/CPU:0",
|
|
||||||
tensorflow_version="2.18.0",
|
|
||||||
)
|
|
||||||
aggregate = BenchmarkAggregate(
|
|
||||||
repeats_measured=1,
|
|
||||||
warmup_frames_per_pass=0,
|
|
||||||
mean_frame_latency_ms=10.0,
|
|
||||||
p50_frame_latency_ms=10.0,
|
|
||||||
p95_frame_latency_ms=10.0,
|
|
||||||
p99_frame_latency_ms=10.0,
|
|
||||||
stddev_frame_latency_ms=0.0,
|
|
||||||
mean_throughput_fps=100.0,
|
|
||||||
peak_rss_mb_max=100.0,
|
|
||||||
active_device="/CPU:0",
|
|
||||||
tensorflow_version="2.18.0",
|
|
||||||
)
|
|
||||||
result = BenchmarkResult(
|
|
||||||
video_name="x.mp4",
|
|
||||||
repeats=1,
|
|
||||||
warmup_frames=0,
|
|
||||||
warmup_pass=metrics,
|
|
||||||
measured_passes=[metrics],
|
|
||||||
aggregate=aggregate,
|
|
||||||
)
|
|
||||||
# Serialize then hand-edit to inject a future version.
|
|
||||||
payload = result.model_dump(mode="json")
|
|
||||||
payload["schema_version"] = CURRENT_VERSION + 1
|
|
||||||
path = tmp_path / "bench_future.json"
|
|
||||||
path.write_text(json.dumps(payload))
|
|
||||||
|
|
||||||
with pytest.raises(FutureSchemaError):
|
|
||||||
load_benchmark_result(path)
|
|
||||||
|
|
@ -1,451 +0,0 @@
|
||||||
"""Tests for :mod:`neuropose.reset` and the ``neuropose reset`` CLI command.
|
|
||||||
|
|
||||||
Coverage:
|
|
||||||
|
|
||||||
- :func:`find_neuropose_processes` filters by cmdline marker, classifies
|
|
||||||
daemon vs monitor, and excludes the calling process.
|
|
||||||
- :func:`terminate_processes` sends SIGINT, escalates to SIGKILL when
|
|
||||||
asked, and reports survivors.
|
|
||||||
- :func:`wipe_state` removes contents of in/, out/, failed/, the lock
|
|
||||||
file, and ``.ingest_*`` staging dirs; honors ``keep_failed`` and
|
|
||||||
``dry_run``.
|
|
||||||
- :func:`reset_pipeline` skips the wipe phase when termination leaves
|
|
||||||
survivors.
|
|
||||||
- The ``neuropose reset`` CLI command renders previews, honors
|
|
||||||
``--dry-run`` and ``--yes``, and exits non-zero on survivors.
|
|
||||||
|
|
||||||
The process-killing tests use monkeypatched ``psutil.process_iter``
|
|
||||||
and ``os.kill`` so the suite never touches the real process table or
|
|
||||||
sends real signals.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import signal
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from typer.testing import CliRunner
|
|
||||||
|
|
||||||
from neuropose.cli import EXIT_OK, EXIT_USAGE, app
|
|
||||||
from neuropose.config import Settings
|
|
||||||
from neuropose.interfacer import LOCK_FILENAME
|
|
||||||
from neuropose.reset import (
|
|
||||||
DEFAULT_GRACE_SECONDS,
|
|
||||||
RunningProcess,
|
|
||||||
TerminationReport,
|
|
||||||
WipeReport,
|
|
||||||
find_neuropose_processes,
|
|
||||||
reset_pipeline,
|
|
||||||
terminate_processes,
|
|
||||||
wipe_state,
|
|
||||||
)
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Helpers
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class _FakeProc:
|
|
||||||
"""Minimal stand-in for ``psutil.Process`` for the discovery tests."""
|
|
||||||
|
|
||||||
def __init__(self, pid: int, cmdline: list[str]) -> None:
|
|
||||||
self.info = {"pid": pid, "cmdline": cmdline}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def settings(tmp_path: Path) -> Settings:
|
|
||||||
"""A Settings pointing at an isolated temp data dir, with subdirs created."""
|
|
||||||
s = Settings(data_dir=tmp_path / "jobs", model_cache_dir=tmp_path / "models")
|
|
||||||
s.ensure_dirs()
|
|
||||||
return s
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# find_neuropose_processes
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestFindNeuroposeProcesses:
|
|
||||||
def test_classifies_watch_as_daemon(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
procs = [_FakeProc(1234, ["python", "-m", "neuropose", "watch"])]
|
|
||||||
monkeypatch.setattr("psutil.process_iter", lambda attrs: iter(procs))
|
|
||||||
found = find_neuropose_processes()
|
|
||||||
assert len(found) == 1
|
|
||||||
assert found[0].pid == 1234
|
|
||||||
assert found[0].role == "daemon"
|
|
||||||
|
|
||||||
def test_classifies_serve_as_monitor(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
procs = [_FakeProc(5678, ["uv", "run", "neuropose", "serve", "--port", "8765"])]
|
|
||||||
monkeypatch.setattr("psutil.process_iter", lambda attrs: iter(procs))
|
|
||||||
found = find_neuropose_processes()
|
|
||||||
assert len(found) == 1
|
|
||||||
assert found[0].role == "monitor"
|
|
||||||
|
|
||||||
def test_ignores_unrelated_processes(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
procs = [
|
|
||||||
_FakeProc(1, ["bash"]),
|
|
||||||
_FakeProc(2, ["python", "-m", "pip", "install", "neuropose"]),
|
|
||||||
_FakeProc(3, ["neuropose", "--help"]),
|
|
||||||
]
|
|
||||||
monkeypatch.setattr("psutil.process_iter", lambda attrs: iter(procs))
|
|
||||||
assert find_neuropose_processes() == []
|
|
||||||
|
|
||||||
def test_excludes_self(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
import os
|
|
||||||
|
|
||||||
self_pid = os.getpid()
|
|
||||||
procs = [
|
|
||||||
_FakeProc(self_pid, ["python", "-m", "neuropose", "watch"]),
|
|
||||||
_FakeProc(9999, ["python", "-m", "neuropose", "watch"]),
|
|
||||||
]
|
|
||||||
monkeypatch.setattr("psutil.process_iter", lambda attrs: iter(procs))
|
|
||||||
found = find_neuropose_processes()
|
|
||||||
assert [rp.pid for rp in found] == [9999]
|
|
||||||
|
|
||||||
def test_includes_self_when_disabled(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
import os
|
|
||||||
|
|
||||||
self_pid = os.getpid()
|
|
||||||
procs = [_FakeProc(self_pid, ["python", "-m", "neuropose", "watch"])]
|
|
||||||
monkeypatch.setattr("psutil.process_iter", lambda attrs: iter(procs))
|
|
||||||
found = find_neuropose_processes(exclude_self=False)
|
|
||||||
assert [rp.pid for rp in found] == [self_pid]
|
|
||||||
|
|
||||||
def test_handles_dead_processes(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
"""A NoSuchProcess raised mid-iteration must not crash the scan."""
|
|
||||||
import psutil
|
|
||||||
|
|
||||||
class _RaisingProc:
|
|
||||||
@property
|
|
||||||
def info(self) -> dict[str, Any]:
|
|
||||||
raise psutil.NoSuchProcess(pid=1)
|
|
||||||
|
|
||||||
procs = [
|
|
||||||
_RaisingProc(),
|
|
||||||
_FakeProc(2, ["python", "-m", "neuropose", "watch"]),
|
|
||||||
]
|
|
||||||
monkeypatch.setattr("psutil.process_iter", lambda attrs: iter(procs))
|
|
||||||
found = find_neuropose_processes()
|
|
||||||
assert [rp.pid for rp in found] == [2]
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# terminate_processes
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestTerminateProcesses:
|
|
||||||
def test_empty_list_is_noop(self) -> None:
|
|
||||||
report = terminate_processes([])
|
|
||||||
assert report.stopped == []
|
|
||||||
assert report.survivors == []
|
|
||||||
assert report.force_killed == []
|
|
||||||
|
|
||||||
def test_sigint_to_each_process(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
sent: list[tuple[int, int]] = []
|
|
||||||
monkeypatch.setattr("os.kill", lambda pid, sig: sent.append((pid, sig)))
|
|
||||||
# Mark all processes immediately dead so the wait loop exits fast.
|
|
||||||
monkeypatch.setattr("neuropose.reset._is_alive", lambda pid: False)
|
|
||||||
|
|
||||||
rps = [
|
|
||||||
RunningProcess(pid=10, role="daemon", cmdline="x"),
|
|
||||||
RunningProcess(pid=20, role="monitor", cmdline="y"),
|
|
||||||
]
|
|
||||||
report = terminate_processes(rps, grace_seconds=0.0)
|
|
||||||
assert sent == [(10, signal.SIGINT), (20, signal.SIGINT)]
|
|
||||||
assert {p.pid for p in report.stopped} == {10, 20}
|
|
||||||
assert report.survivors == []
|
|
||||||
|
|
||||||
def test_survivors_reported_when_force_kill_off(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
monkeypatch.setattr("os.kill", lambda pid, sig: None)
|
|
||||||
# Process 10 always alive; process 20 dies after SIGINT.
|
|
||||||
alive = {10}
|
|
||||||
monkeypatch.setattr("neuropose.reset._is_alive", lambda pid: pid in alive)
|
|
||||||
|
|
||||||
rps = [
|
|
||||||
RunningProcess(pid=10, role="daemon", cmdline="x"),
|
|
||||||
RunningProcess(pid=20, role="monitor", cmdline="y"),
|
|
||||||
]
|
|
||||||
# Drop pid 20 so it appears dead immediately.
|
|
||||||
alive.discard(20)
|
|
||||||
report = terminate_processes(rps, grace_seconds=0.0, force_kill=False)
|
|
||||||
assert {p.pid for p in report.stopped} == {20}
|
|
||||||
assert {p.pid for p in report.survivors} == {10}
|
|
||||||
assert report.force_killed == []
|
|
||||||
|
|
||||||
def test_force_kill_escalates_to_sigkill(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
sent: list[tuple[int, int]] = []
|
|
||||||
monkeypatch.setattr("os.kill", lambda pid, sig: sent.append((pid, sig)))
|
|
||||||
|
|
||||||
# Start alive; SIGKILL "kills" by toggling the flag from inside _is_alive.
|
|
||||||
alive = {10}
|
|
||||||
|
|
||||||
def fake_is_alive(pid: int) -> bool:
|
|
||||||
if (pid, signal.SIGKILL) in sent:
|
|
||||||
return False
|
|
||||||
return pid in alive
|
|
||||||
|
|
||||||
monkeypatch.setattr("neuropose.reset._is_alive", fake_is_alive)
|
|
||||||
|
|
||||||
rp = RunningProcess(pid=10, role="daemon", cmdline="x")
|
|
||||||
report = terminate_processes([rp], grace_seconds=0.0, force_kill=True)
|
|
||||||
assert (10, signal.SIGINT) in sent
|
|
||||||
assert (10, signal.SIGKILL) in sent
|
|
||||||
assert {p.pid for p in report.force_killed} == {10}
|
|
||||||
assert report.survivors == []
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# wipe_state
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestWipeState:
|
|
||||||
def test_no_op_on_empty_dirs(self, settings: Settings) -> None:
|
|
||||||
report = wipe_state(settings)
|
|
||||||
assert report.removed_paths == []
|
|
||||||
assert report.bytes_freed == 0
|
|
||||||
|
|
||||||
def test_removes_in_out_failed_contents(self, settings: Settings) -> None:
|
|
||||||
(settings.input_dir / "job_a").mkdir()
|
|
||||||
(settings.input_dir / "job_a" / "video.mp4").write_bytes(b"x" * 100)
|
|
||||||
(settings.output_dir / "status.json").write_text("{}")
|
|
||||||
(settings.failed_dir / "job_b").mkdir()
|
|
||||||
|
|
||||||
report = wipe_state(settings)
|
|
||||||
names = {p.name for p in report.removed_paths}
|
|
||||||
assert names == {"job_a", "status.json", "job_b"}
|
|
||||||
# Containers themselves preserved.
|
|
||||||
assert settings.input_dir.exists()
|
|
||||||
assert settings.output_dir.exists()
|
|
||||||
assert settings.failed_dir.exists()
|
|
||||||
|
|
||||||
def test_keep_failed_preserves_failed_contents(self, settings: Settings) -> None:
|
|
||||||
(settings.input_dir / "job_a").mkdir()
|
|
||||||
(settings.failed_dir / "job_b").mkdir()
|
|
||||||
(settings.failed_dir / "job_b" / "evidence.log").write_text("crash")
|
|
||||||
|
|
||||||
report = wipe_state(settings, keep_failed=True)
|
|
||||||
names = {p.name for p in report.removed_paths}
|
|
||||||
assert "job_a" in names
|
|
||||||
assert "job_b" not in names
|
|
||||||
assert (settings.failed_dir / "job_b" / "evidence.log").exists()
|
|
||||||
|
|
||||||
def test_removes_lock_file(self, settings: Settings) -> None:
|
|
||||||
(settings.data_dir / LOCK_FILENAME).write_text("12345\n")
|
|
||||||
report = wipe_state(settings)
|
|
||||||
assert (settings.data_dir / LOCK_FILENAME) in report.removed_paths
|
|
||||||
assert not (settings.data_dir / LOCK_FILENAME).exists()
|
|
||||||
|
|
||||||
def test_removes_ingest_staging_dirs(self, settings: Settings) -> None:
|
|
||||||
staging_a = settings.data_dir / ".ingest_abc123"
|
|
||||||
staging_b = settings.data_dir / ".ingest_def456"
|
|
||||||
staging_a.mkdir()
|
|
||||||
staging_b.mkdir()
|
|
||||||
(staging_a / "leftover.mp4").write_bytes(b"y" * 50)
|
|
||||||
|
|
||||||
report = wipe_state(settings)
|
|
||||||
assert staging_a in report.removed_paths
|
|
||||||
assert staging_b in report.removed_paths
|
|
||||||
assert not staging_a.exists()
|
|
||||||
assert not staging_b.exists()
|
|
||||||
|
|
||||||
def test_dry_run_reports_without_removing(self, settings: Settings) -> None:
|
|
||||||
(settings.input_dir / "job_a").mkdir()
|
|
||||||
(settings.input_dir / "job_a" / "video.mp4").write_bytes(b"z" * 200)
|
|
||||||
|
|
||||||
report = wipe_state(settings, dry_run=True)
|
|
||||||
assert len(report.removed_paths) == 1
|
|
||||||
assert report.bytes_freed == 200
|
|
||||||
# Nothing actually deleted.
|
|
||||||
assert (settings.input_dir / "job_a" / "video.mp4").exists()
|
|
||||||
|
|
||||||
def test_bytes_freed_recurses_into_subdirs(self, settings: Settings) -> None:
|
|
||||||
job = settings.input_dir / "job_a"
|
|
||||||
job.mkdir()
|
|
||||||
(job / "a.mp4").write_bytes(b"a" * 100)
|
|
||||||
(job / "nested").mkdir()
|
|
||||||
(job / "nested" / "b.mp4").write_bytes(b"b" * 250)
|
|
||||||
|
|
||||||
report = wipe_state(settings, dry_run=True)
|
|
||||||
assert report.bytes_freed == 350
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# reset_pipeline
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestResetPipeline:
|
|
||||||
def test_dry_run_skips_termination(
|
|
||||||
self,
|
|
||||||
settings: Settings,
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
) -> None:
|
|
||||||
rp = RunningProcess(pid=10, role="daemon", cmdline="x")
|
|
||||||
monkeypatch.setattr("neuropose.reset.find_neuropose_processes", lambda: [rp])
|
|
||||||
|
|
||||||
# Sentinel to detect termination calls.
|
|
||||||
def _should_not_be_called(*args: Any, **kwargs: Any) -> TerminationReport:
|
|
||||||
raise AssertionError("dry_run must not invoke terminate_processes")
|
|
||||||
|
|
||||||
monkeypatch.setattr("neuropose.reset.terminate_processes", _should_not_be_called)
|
|
||||||
|
|
||||||
report = reset_pipeline(settings, dry_run=True)
|
|
||||||
assert report.dry_run is True
|
|
||||||
assert report.discovered == [rp]
|
|
||||||
assert report.termination.stopped == []
|
|
||||||
|
|
||||||
def test_skips_wipe_when_survivors_remain(
|
|
||||||
self,
|
|
||||||
settings: Settings,
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
) -> None:
|
|
||||||
rp = RunningProcess(pid=10, role="daemon", cmdline="x")
|
|
||||||
monkeypatch.setattr("neuropose.reset.find_neuropose_processes", lambda: [rp])
|
|
||||||
monkeypatch.setattr(
|
|
||||||
"neuropose.reset.terminate_processes",
|
|
||||||
lambda procs, **_: TerminationReport(survivors=list(procs)),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Seed something to wipe so we can confirm it's untouched.
|
|
||||||
(settings.input_dir / "job_a").mkdir()
|
|
||||||
|
|
||||||
report = reset_pipeline(settings, dry_run=False)
|
|
||||||
assert report.wipe_skipped_due_to_survivors is True
|
|
||||||
assert report.wipe.removed_paths == []
|
|
||||||
assert (settings.input_dir / "job_a").exists()
|
|
||||||
|
|
||||||
def test_wipes_when_all_processes_stopped(
|
|
||||||
self,
|
|
||||||
settings: Settings,
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
) -> None:
|
|
||||||
rp = RunningProcess(pid=10, role="daemon", cmdline="x")
|
|
||||||
monkeypatch.setattr("neuropose.reset.find_neuropose_processes", lambda: [rp])
|
|
||||||
monkeypatch.setattr(
|
|
||||||
"neuropose.reset.terminate_processes",
|
|
||||||
lambda procs, **_: TerminationReport(stopped=list(procs)),
|
|
||||||
)
|
|
||||||
|
|
||||||
(settings.input_dir / "job_a").mkdir()
|
|
||||||
|
|
||||||
report = reset_pipeline(settings, dry_run=False)
|
|
||||||
assert report.wipe_skipped_due_to_survivors is False
|
|
||||||
assert any(p.name == "job_a" for p in report.wipe.removed_paths)
|
|
||||||
assert not (settings.input_dir / "job_a").exists()
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# CLI: neuropose reset
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def runner() -> CliRunner:
|
|
||||||
return CliRunner()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def env_data_dir(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> Path:
|
|
||||||
"""Point NEUROPOSE_DATA_DIR at an isolated temp dir for CLI tests."""
|
|
||||||
data_dir = tmp_path / "jobs"
|
|
||||||
data_dir.mkdir()
|
|
||||||
(data_dir / "in").mkdir()
|
|
||||||
(data_dir / "out").mkdir()
|
|
||||||
(data_dir / "failed").mkdir()
|
|
||||||
monkeypatch.setenv("NEUROPOSE_DATA_DIR", str(data_dir))
|
|
||||||
return data_dir
|
|
||||||
|
|
||||||
|
|
||||||
class TestResetCli:
|
|
||||||
def test_reset_dry_run_does_not_modify(
|
|
||||||
self,
|
|
||||||
runner: CliRunner,
|
|
||||||
env_data_dir: Path,
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
) -> None:
|
|
||||||
(env_data_dir / "in" / "job_a").mkdir()
|
|
||||||
monkeypatch.setattr("neuropose.reset.find_neuropose_processes", list)
|
|
||||||
|
|
||||||
result = runner.invoke(app, ["reset", "--dry-run"])
|
|
||||||
assert result.exit_code == EXIT_OK, result.output
|
|
||||||
assert "would remove" in result.output
|
|
||||||
assert "(dry-run; no changes made)" in result.output
|
|
||||||
assert (env_data_dir / "in" / "job_a").exists()
|
|
||||||
|
|
||||||
def test_reset_yes_skips_confirmation_and_wipes(
|
|
||||||
self,
|
|
||||||
runner: CliRunner,
|
|
||||||
env_data_dir: Path,
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
) -> None:
|
|
||||||
(env_data_dir / "in" / "job_a").mkdir()
|
|
||||||
(env_data_dir / "in" / "job_a" / "video.mp4").write_bytes(b"x" * 100)
|
|
||||||
monkeypatch.setattr("neuropose.reset.find_neuropose_processes", list)
|
|
||||||
|
|
||||||
result = runner.invoke(app, ["reset", "--yes"])
|
|
||||||
assert result.exit_code == EXIT_OK, result.output
|
|
||||||
assert "removed" in result.output
|
|
||||||
assert not (env_data_dir / "in" / "job_a").exists()
|
|
||||||
|
|
||||||
def test_reset_aborts_on_no_confirmation(
|
|
||||||
self,
|
|
||||||
runner: CliRunner,
|
|
||||||
env_data_dir: Path,
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
) -> None:
|
|
||||||
(env_data_dir / "in" / "job_a").mkdir()
|
|
||||||
monkeypatch.setattr("neuropose.reset.find_neuropose_processes", list)
|
|
||||||
|
|
||||||
# typer.confirm reads from stdin; "n\n" declines.
|
|
||||||
result = runner.invoke(app, ["reset"], input="n\n")
|
|
||||||
assert result.exit_code == EXIT_USAGE, result.output
|
|
||||||
assert "aborted" in result.output
|
|
||||||
assert (env_data_dir / "in" / "job_a").exists()
|
|
||||||
|
|
||||||
def test_reset_clean_state_is_noop(
|
|
||||||
self,
|
|
||||||
runner: CliRunner,
|
|
||||||
env_data_dir: Path,
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
) -> None:
|
|
||||||
del env_data_dir
|
|
||||||
monkeypatch.setattr("neuropose.reset.find_neuropose_processes", list)
|
|
||||||
|
|
||||||
result = runner.invoke(app, ["reset"])
|
|
||||||
assert result.exit_code == EXIT_OK, result.output
|
|
||||||
assert "nothing to do" in result.output
|
|
||||||
|
|
||||||
def test_reset_reports_survivors_with_nonzero_exit(
|
|
||||||
self,
|
|
||||||
runner: CliRunner,
|
|
||||||
env_data_dir: Path,
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
) -> None:
|
|
||||||
del env_data_dir
|
|
||||||
rp = RunningProcess(pid=42, role="daemon", cmdline="neuropose watch")
|
|
||||||
monkeypatch.setattr("neuropose.reset.find_neuropose_processes", lambda: [rp])
|
|
||||||
monkeypatch.setattr(
|
|
||||||
"neuropose.reset.terminate_processes",
|
|
||||||
lambda procs, **_: TerminationReport(survivors=list(procs)),
|
|
||||||
)
|
|
||||||
|
|
||||||
result = runner.invoke(app, ["reset", "--yes"])
|
|
||||||
assert result.exit_code == EXIT_USAGE, result.output
|
|
||||||
assert "did not exit" in result.output
|
|
||||||
assert "pid 42" in result.output
|
|
||||||
assert "--force-kill" in result.output
|
|
||||||
|
|
||||||
|
|
||||||
def test_default_grace_seconds_constant_is_reasonable() -> None:
|
|
||||||
"""Lock the default grace period so a refactor cannot silently lower it."""
|
|
||||||
assert 5.0 <= DEFAULT_GRACE_SECONDS <= 60.0
|
|
||||||
|
|
||||||
|
|
||||||
def test_wipe_report_default_construction() -> None:
|
|
||||||
r = WipeReport()
|
|
||||||
assert r.removed_paths == []
|
|
||||||
assert r.bytes_freed == 0
|
|
||||||
Loading…
Reference in New Issue