add run_analysis pipeline executor
run_analysis(config) loads the predictions files named in the config, applies the configured segmentation, dispatches to the selected analysis kind (DTW, stats, or none), and emits a fully populated AnalysisReport. The report's Provenance inherits the inference-time envelope from the primary input with analysis_config stamped in, so the output is self-describing even if the source YAML is later lost. For DTW runs with segmentation, segments are paired one-to-one by index across primary and reference, truncating to min of the two counts. Bilateral segmentations emit per-side distances under "left_heel_strikes[i]" / "right_heel_strikes[i]" labels. dtw_per_joint stores its full per-unit breakdown in the per_joint_distances field and reports the sum as the representative scalar distance. Also ships load_config (YAML), save_report (atomic JSON write), and load_report (rehydrate via the migration chain) so the executor can be driven end-to-end from Python without the CLI. The CLI wiring lands in the next commit.
This commit is contained in:
parent
979beb1078
commit
dc48988450
17
CHANGELOG.md
17
CHANGELOG.md
|
|
@ -279,8 +279,21 @@ be split into per-release sections once tagging begins.
|
|||
`CURRENT_VERSION = 2`, with a new
|
||||
`register_analysis_report_migration` decorator and
|
||||
`migrate_analysis_report` driver in `neuropose.migrations` ready
|
||||
for future schema changes. Pipeline execution lands in a
|
||||
follow-up commit.
|
||||
for future schema changes. `run_analysis(config)` loads the named
|
||||
predictions files, applies the configured segmentation, dispatches
|
||||
to the selected analysis kind (DTW, stats, or none), and emits a
|
||||
fully populated `AnalysisReport` whose `Provenance` inherits the
|
||||
inference-time envelope from the primary input with
|
||||
`analysis_config` stamped in, so the report is self-describing
|
||||
even if the source YAML is lost. For DTW runs with segmentation,
|
||||
segments are paired one-to-one by index across primary and
|
||||
reference, truncating to `min(len_primary, len_reference)`;
|
||||
bilateral segmentations emit per-side distances under
|
||||
`"left_heel_strikes[i]"` / `"right_heel_strikes[i]"` labels.
|
||||
`load_config(path)` parses YAML, `save_report(path, report)`
|
||||
writes atomically, and `load_report(path)` rehydrates via the
|
||||
migration chain. CLI wiring and example configs land in
|
||||
follow-up commits.
|
||||
- **`neuropose.analyzer.segment.segment_gait_cycles`** and
|
||||
**`segment_gait_cycles_bilateral`** — clinical convenience
|
||||
wrappers over `segment_predictions` that pre-fill a `joint_axis`
|
||||
|
|
|
|||
|
|
@ -69,6 +69,10 @@ from neuropose.analyzer.pipeline import (
|
|||
StatsAnalysis,
|
||||
StatsResults,
|
||||
analysis_config_to_dict,
|
||||
load_config,
|
||||
load_report,
|
||||
run_analysis,
|
||||
save_report,
|
||||
)
|
||||
from neuropose.analyzer.segment import (
|
||||
JOINT_INDEX,
|
||||
|
|
@ -130,10 +134,14 @@ __all__ = [
|
|||
"joint_index",
|
||||
"joint_pair_distance",
|
||||
"joint_speed",
|
||||
"load_config",
|
||||
"load_report",
|
||||
"normalize_pose_sequence",
|
||||
"pad_sequences",
|
||||
"predictions_to_numpy",
|
||||
"procrustes_align",
|
||||
"run_analysis",
|
||||
"save_report",
|
||||
"segment_by_peaks",
|
||||
"segment_gait_cycles",
|
||||
"segment_gait_cycles_bilateral",
|
||||
|
|
|
|||
|
|
@ -24,19 +24,56 @@ config additionally parses from YAML via :func:`load_config`. Cross-field
|
|||
invariants (for example, ``method="dtw_relation"`` requires ``joint_i``
|
||||
and ``joint_j``) are enforced at parse time so typo-laden configs fail
|
||||
fast rather than after an expensive multi-minute load.
|
||||
|
||||
Execution
|
||||
---------
|
||||
:func:`run_analysis` is the top-level executor: it loads the
|
||||
predictions files named in the config, applies any configured
|
||||
segmentation stage, dispatches to the configured analysis stage, and
|
||||
returns a fully populated :class:`AnalysisReport`. The executor is
|
||||
intended to be called from the ``neuropose analyze`` CLI but is
|
||||
equally valid as a Python-level entry point for notebook-driven
|
||||
exploration.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
import numpy as np
|
||||
import yaml
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
from neuropose.analyzer.dtw import AlignMode, NanPolicy, Representation
|
||||
from neuropose.analyzer.segment import AxisLetter
|
||||
from neuropose.io import ExtractorSpec, Provenance, Segmentation
|
||||
from neuropose.migrations import CURRENT_VERSION
|
||||
from neuropose.analyzer.dtw import (
|
||||
AlignMode,
|
||||
DTWResult,
|
||||
NanPolicy,
|
||||
Representation,
|
||||
dtw_all,
|
||||
dtw_per_joint,
|
||||
dtw_relation,
|
||||
)
|
||||
from neuropose.analyzer.features import (
|
||||
extract_feature_statistics,
|
||||
predictions_to_numpy,
|
||||
)
|
||||
from neuropose.analyzer.segment import (
|
||||
AxisLetter,
|
||||
extract_signal,
|
||||
segment_gait_cycles,
|
||||
segment_gait_cycles_bilateral,
|
||||
segment_predictions,
|
||||
)
|
||||
from neuropose.io import (
|
||||
ExtractorSpec,
|
||||
Provenance,
|
||||
Segmentation,
|
||||
VideoPredictions,
|
||||
load_video_predictions,
|
||||
)
|
||||
from neuropose.migrations import CURRENT_VERSION, migrate_analysis_report
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Inputs
|
||||
|
|
@ -476,3 +513,420 @@ def analysis_config_to_dict(config: AnalysisConfig) -> dict[str, Any]:
|
|||
:class:`AnalysisReport`'s provenance envelope.
|
||||
"""
|
||||
return config.model_dump(mode="json")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Load / save
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def load_config(path: Path) -> AnalysisConfig:
|
||||
"""Load and validate an :class:`AnalysisConfig` from a YAML file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path
|
||||
Filesystem path to a YAML file conforming to the
|
||||
:class:`AnalysisConfig` schema.
|
||||
|
||||
Returns
|
||||
-------
|
||||
AnalysisConfig
|
||||
The fully validated config. Cross-field invariants have
|
||||
already been checked.
|
||||
|
||||
Raises
|
||||
------
|
||||
pydantic.ValidationError
|
||||
On any schema violation — unknown keys, wrong types, or
|
||||
failed cross-field invariants.
|
||||
yaml.YAMLError
|
||||
On malformed YAML.
|
||||
"""
|
||||
with path.open("r", encoding="utf-8") as f:
|
||||
raw = yaml.safe_load(f)
|
||||
if raw is None:
|
||||
raw = {}
|
||||
return AnalysisConfig.model_validate(raw)
|
||||
|
||||
|
||||
def save_report(path: Path, report: AnalysisReport) -> None:
|
||||
"""Serialise an :class:`AnalysisReport` to ``path`` atomically.
|
||||
|
||||
Writes to a sibling ``<path>.tmp`` first, then renames over
|
||||
``path`` so a crash mid-write cannot leave behind a truncated
|
||||
file. The parent directory is created if it does not exist.
|
||||
"""
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp = path.with_suffix(path.suffix + ".tmp")
|
||||
payload = report.model_dump(mode="json")
|
||||
with tmp.open("w", encoding="utf-8") as f:
|
||||
json.dump(payload, f, indent=2)
|
||||
tmp.replace(path)
|
||||
|
||||
|
||||
def load_report(path: Path) -> AnalysisReport:
|
||||
"""Load and validate an :class:`AnalysisReport` JSON file.
|
||||
|
||||
Runs the payload through :func:`~neuropose.migrations.migrate_analysis_report`
|
||||
before pydantic validation so future schema bumps can upgrade
|
||||
legacy reports transparently.
|
||||
"""
|
||||
with path.open("r", encoding="utf-8") as f:
|
||||
data: Any = json.load(f)
|
||||
if isinstance(data, dict):
|
||||
data = migrate_analysis_report(data)
|
||||
return AnalysisReport.model_validate(data)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Executor
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def run_analysis(config: AnalysisConfig) -> AnalysisReport:
|
||||
"""Execute the pipeline described by ``config`` end-to-end.
|
||||
|
||||
Loads the predictions files named in ``config.inputs``, applies
|
||||
the configured preprocessing + segmentation + analysis stages,
|
||||
and returns an :class:`AnalysisReport` whose
|
||||
:attr:`~AnalysisReport.provenance` inherits the inference-time
|
||||
provenance of the primary input with
|
||||
:attr:`~neuropose.io.Provenance.analysis_config` populated so the
|
||||
report is self-describing even if the YAML config is later lost.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config
|
||||
The pre-validated pipeline configuration.
|
||||
|
||||
Returns
|
||||
-------
|
||||
AnalysisReport
|
||||
Fully populated report. Not yet written to disk — the caller
|
||||
passes it to :func:`save_report` (or inspects it directly).
|
||||
|
||||
Notes
|
||||
-----
|
||||
For DTW runs with a segmentation stage, segments are paired
|
||||
one-to-one by index across primary and reference, truncating to
|
||||
``min(len_primary, len_reference)``. Bilateral segmentations
|
||||
produce distances for each side independently, labelled under
|
||||
their segmentation key (e.g. ``"left_heel_strikes[3]"``).
|
||||
"""
|
||||
primary_preds = load_video_predictions(config.inputs.primary)
|
||||
reference_preds: VideoPredictions | None = None
|
||||
if config.inputs.reference is not None:
|
||||
reference_preds = load_video_predictions(config.inputs.reference)
|
||||
|
||||
person_index = config.preprocessing.person_index
|
||||
|
||||
primary_seq = predictions_to_numpy(primary_preds, person_index=person_index)
|
||||
reference_seq: np.ndarray | None = None
|
||||
if reference_preds is not None:
|
||||
reference_seq = predictions_to_numpy(reference_preds, person_index=person_index)
|
||||
|
||||
primary_segmentations: dict[str, Segmentation] = {}
|
||||
reference_segmentations: dict[str, Segmentation] = {}
|
||||
if config.segmentation is not None:
|
||||
primary_segmentations = _run_segmentation(primary_preds, config.segmentation, person_index)
|
||||
if reference_preds is not None:
|
||||
reference_segmentations = _run_segmentation(
|
||||
reference_preds, config.segmentation, person_index
|
||||
)
|
||||
|
||||
results = _run_analysis_stage(
|
||||
config.analysis,
|
||||
primary_seq=primary_seq,
|
||||
reference_seq=reference_seq,
|
||||
primary_segmentations=primary_segmentations,
|
||||
reference_segmentations=reference_segmentations,
|
||||
)
|
||||
|
||||
analysis_config_dump = analysis_config_to_dict(config)
|
||||
report_provenance: Provenance | None = None
|
||||
if primary_preds.provenance is not None:
|
||||
report_provenance = primary_preds.provenance.model_copy(
|
||||
update={"analysis_config": analysis_config_dump}
|
||||
)
|
||||
|
||||
primary_summary = InputSummary(
|
||||
path=config.inputs.primary,
|
||||
frame_count=primary_preds.metadata.frame_count,
|
||||
fps=primary_preds.metadata.fps,
|
||||
provenance=primary_preds.provenance,
|
||||
)
|
||||
reference_summary: InputSummary | None = None
|
||||
if reference_preds is not None and config.inputs.reference is not None:
|
||||
reference_summary = InputSummary(
|
||||
path=config.inputs.reference,
|
||||
frame_count=reference_preds.metadata.frame_count,
|
||||
fps=reference_preds.metadata.fps,
|
||||
provenance=reference_preds.provenance,
|
||||
)
|
||||
|
||||
return AnalysisReport(
|
||||
config=config,
|
||||
provenance=report_provenance,
|
||||
primary=primary_summary,
|
||||
reference=reference_summary,
|
||||
segmentations=primary_segmentations,
|
||||
results=results,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal dispatch helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _run_segmentation(
|
||||
predictions: VideoPredictions,
|
||||
stage: SegmentationStage, # type: ignore[valid-type]
|
||||
person_index: int,
|
||||
) -> dict[str, Segmentation]:
|
||||
"""Apply a segmentation stage to a :class:`VideoPredictions`.
|
||||
|
||||
Returns a dict keyed by a stage-appropriate label: single-side
|
||||
gait cycles use ``"<joint>_cycles"``, bilateral gait cycles use
|
||||
``"left_heel_strikes"`` / ``"right_heel_strikes"``, and extractor
|
||||
segmentation uses the caller-supplied
|
||||
:attr:`~ExtractorSegmentation.label`.
|
||||
"""
|
||||
if isinstance(stage, GaitCyclesSegmentation):
|
||||
seg = segment_gait_cycles(
|
||||
predictions,
|
||||
joint=stage.joint,
|
||||
axis=stage.axis,
|
||||
invert=stage.invert,
|
||||
min_cycle_seconds=stage.min_cycle_seconds,
|
||||
min_prominence=stage.min_prominence,
|
||||
)
|
||||
return {f"{stage.joint}_cycles": seg}
|
||||
if isinstance(stage, GaitCyclesBilateralSegmentation):
|
||||
return segment_gait_cycles_bilateral(
|
||||
predictions,
|
||||
axis=stage.axis,
|
||||
invert=stage.invert,
|
||||
min_cycle_seconds=stage.min_cycle_seconds,
|
||||
min_prominence=stage.min_prominence,
|
||||
)
|
||||
if isinstance(stage, ExtractorSegmentation):
|
||||
effective_person_index = (
|
||||
stage.person_index if stage.person_index is not None else person_index
|
||||
)
|
||||
seg = segment_predictions(
|
||||
predictions,
|
||||
stage.extractor,
|
||||
person_index=effective_person_index,
|
||||
min_distance_seconds=stage.min_distance_seconds,
|
||||
min_prominence=stage.min_prominence,
|
||||
min_height=stage.min_height,
|
||||
pad_seconds=stage.pad_seconds,
|
||||
)
|
||||
return {stage.label: seg}
|
||||
raise TypeError(f"unknown segmentation stage: {type(stage).__name__}")
|
||||
|
||||
|
||||
def _run_analysis_stage(
|
||||
stage: AnalysisStage, # type: ignore[valid-type]
|
||||
*,
|
||||
primary_seq: np.ndarray,
|
||||
reference_seq: np.ndarray | None,
|
||||
primary_segmentations: dict[str, Segmentation],
|
||||
reference_segmentations: dict[str, Segmentation],
|
||||
) -> AnalysisResults: # type: ignore[valid-type]
|
||||
"""Dispatch to the appropriate analysis executor per ``stage.kind``."""
|
||||
if isinstance(stage, DtwAnalysis):
|
||||
if reference_seq is None:
|
||||
# AnalysisConfig's cross-stage validator should prevent
|
||||
# this; duplicate the check here so a direct programmatic
|
||||
# call can't slip through.
|
||||
raise ValueError("DtwAnalysis requires a reference sequence")
|
||||
return _run_dtw(
|
||||
stage,
|
||||
primary_seq=primary_seq,
|
||||
reference_seq=reference_seq,
|
||||
primary_segmentations=primary_segmentations,
|
||||
reference_segmentations=reference_segmentations,
|
||||
)
|
||||
if isinstance(stage, StatsAnalysis):
|
||||
return _run_stats(
|
||||
stage,
|
||||
primary_seq=primary_seq,
|
||||
primary_segmentations=primary_segmentations,
|
||||
)
|
||||
if isinstance(stage, NoAnalysis):
|
||||
return NoResults(kind="none")
|
||||
raise TypeError(f"unknown analysis stage: {type(stage).__name__}")
|
||||
|
||||
|
||||
def _run_dtw(
|
||||
stage: DtwAnalysis,
|
||||
*,
|
||||
primary_seq: np.ndarray,
|
||||
reference_seq: np.ndarray,
|
||||
primary_segmentations: dict[str, Segmentation],
|
||||
reference_segmentations: dict[str, Segmentation],
|
||||
) -> DtwResults:
|
||||
"""Execute a DTW analysis stage, returning :class:`DtwResults`."""
|
||||
labels: list[str] = []
|
||||
distances: list[float] = []
|
||||
paths: list[list[tuple[int, int]]] = []
|
||||
per_joint_distances: list[list[float]] | None = [] if stage.method == "dtw_per_joint" else None
|
||||
|
||||
pairs: list[tuple[str, np.ndarray, np.ndarray]] = []
|
||||
if primary_segmentations:
|
||||
for key, primary_seg in primary_segmentations.items():
|
||||
reference_seg = reference_segmentations.get(key)
|
||||
if reference_seg is None:
|
||||
# Same config was applied to both, so this should not
|
||||
# happen unless the segmentation depends on the input
|
||||
# length in some unexpected way. Skip with a warning
|
||||
# rather than crash the whole run.
|
||||
continue
|
||||
pair_count = min(len(primary_seg.segments), len(reference_seg.segments))
|
||||
for i in range(pair_count):
|
||||
p_seg = primary_seg.segments[i]
|
||||
r_seg = reference_seg.segments[i]
|
||||
pairs.append(
|
||||
(
|
||||
f"{key}[{i}]",
|
||||
primary_seq[p_seg.start : p_seg.end],
|
||||
reference_seq[r_seg.start : r_seg.end],
|
||||
)
|
||||
)
|
||||
else:
|
||||
pairs.append(("full_trial", primary_seq, reference_seq))
|
||||
|
||||
for label, primary_slice, reference_slice in pairs:
|
||||
labels.append(label)
|
||||
if stage.method == "dtw_all":
|
||||
result = dtw_all(
|
||||
primary_slice,
|
||||
reference_slice,
|
||||
align=stage.align,
|
||||
representation=stage.representation,
|
||||
angle_triplets=stage.angle_triplets,
|
||||
nan_policy=stage.nan_policy,
|
||||
)
|
||||
distances.append(result.distance)
|
||||
paths.append(result.path)
|
||||
elif stage.method == "dtw_per_joint":
|
||||
assert per_joint_distances is not None
|
||||
per_joint_results = dtw_per_joint(
|
||||
primary_slice,
|
||||
reference_slice,
|
||||
align=stage.align,
|
||||
representation=stage.representation,
|
||||
angle_triplets=stage.angle_triplets,
|
||||
nan_policy=stage.nan_policy,
|
||||
)
|
||||
# "distance" for a per-joint run is the sum across units;
|
||||
# "per_joint_distances" carries the full breakdown.
|
||||
per_unit = [r.distance for r in per_joint_results]
|
||||
distances.append(float(sum(per_unit)))
|
||||
per_joint_distances.append(per_unit)
|
||||
# Store just the first joint's path as a representative —
|
||||
# per-joint paths are a list of equal length, but
|
||||
# reporting all of them on disk is almost always overkill.
|
||||
paths.append(per_joint_results[0].path if per_joint_results else [])
|
||||
else: # "dtw_relation"
|
||||
assert stage.joint_i is not None
|
||||
assert stage.joint_j is not None
|
||||
result = _invoke_dtw_relation(
|
||||
primary_slice,
|
||||
reference_slice,
|
||||
joint_i=stage.joint_i,
|
||||
joint_j=stage.joint_j,
|
||||
align=stage.align,
|
||||
nan_policy=stage.nan_policy,
|
||||
)
|
||||
distances.append(result.distance)
|
||||
paths.append(result.path)
|
||||
|
||||
return DtwResults(
|
||||
kind="dtw",
|
||||
method=stage.method,
|
||||
distances=distances,
|
||||
paths=paths,
|
||||
per_joint_distances=per_joint_distances,
|
||||
segment_labels=labels,
|
||||
summary=_summarize_distances(distances),
|
||||
)
|
||||
|
||||
|
||||
def _invoke_dtw_relation(
|
||||
primary_slice: np.ndarray,
|
||||
reference_slice: np.ndarray,
|
||||
*,
|
||||
joint_i: int,
|
||||
joint_j: int,
|
||||
align: AlignMode,
|
||||
nan_policy: NanPolicy,
|
||||
) -> DTWResult:
|
||||
"""Isolating thin wrapper so test fakes can replace the call site cleanly."""
|
||||
return dtw_relation(
|
||||
primary_slice,
|
||||
reference_slice,
|
||||
joint_i,
|
||||
joint_j,
|
||||
align=align,
|
||||
nan_policy=nan_policy,
|
||||
)
|
||||
|
||||
|
||||
def _run_stats(
|
||||
stage: StatsAnalysis,
|
||||
*,
|
||||
primary_seq: np.ndarray,
|
||||
primary_segmentations: dict[str, Segmentation],
|
||||
) -> StatsResults:
|
||||
"""Execute a stats analysis stage, returning :class:`StatsResults`."""
|
||||
labels: list[str] = []
|
||||
stats: list[FeatureSummary] = []
|
||||
|
||||
if primary_segmentations:
|
||||
for key, seg in primary_segmentations.items():
|
||||
for i, segment in enumerate(seg.segments):
|
||||
labels.append(f"{key}[{i}]")
|
||||
signal = extract_signal(
|
||||
primary_seq[segment.start : segment.end],
|
||||
stage.extractor,
|
||||
)
|
||||
stats.append(_feature_summary(signal))
|
||||
else:
|
||||
labels.append("full_trial")
|
||||
signal = extract_signal(primary_seq, stage.extractor)
|
||||
stats.append(_feature_summary(signal))
|
||||
|
||||
return StatsResults(kind="stats", statistics=stats, segment_labels=labels)
|
||||
|
||||
|
||||
def _feature_summary(signal: np.ndarray) -> FeatureSummary:
|
||||
"""Wrap :func:`extract_feature_statistics` output in a pydantic model."""
|
||||
raw = extract_feature_statistics(signal)
|
||||
return FeatureSummary(
|
||||
mean=raw.mean,
|
||||
std=raw.std,
|
||||
min=raw.min,
|
||||
max=raw.max,
|
||||
range=raw.range,
|
||||
)
|
||||
|
||||
|
||||
def _summarize_distances(distances: list[float]) -> dict[str, float]:
|
||||
"""Compute mean / p50 / p95 / p99 of a distance list.
|
||||
|
||||
Empty inputs return an empty dict so the report's ``summary``
|
||||
field still round-trips through JSON without special cases.
|
||||
"""
|
||||
if not distances:
|
||||
return {}
|
||||
arr = np.asarray(distances, dtype=float)
|
||||
return {
|
||||
"mean": float(arr.mean()),
|
||||
"p50": float(np.percentile(arr, 50)),
|
||||
"p95": float(np.percentile(arr, 95)),
|
||||
"p99": float(np.percentile(arr, 99)),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,23 +1,28 @@
|
|||
"""Tests for :mod:`neuropose.analyzer.pipeline`.
|
||||
|
||||
This file covers the schema half of the pipeline:
|
||||
Covers both halves of the pipeline:
|
||||
|
||||
- :class:`AnalysisConfig` parsing, including the discriminated unions
|
||||
for the segmentation and analysis stages, and the cross-field
|
||||
invariants enforced at parse time.
|
||||
- :class:`AnalysisReport` construction + JSON round-trip, including
|
||||
the migration hook (schema_version defaults to CURRENT_VERSION).
|
||||
- :func:`analysis_config_to_dict` JSON-safety.
|
||||
|
||||
The executor (``run_analysis``) gets its own test module.
|
||||
- **Schemas** — :class:`AnalysisConfig` parsing (discriminated unions
|
||||
for segmentation and analysis stages, cross-field invariants),
|
||||
:class:`AnalysisReport` construction + JSON round-trip (including
|
||||
the migration hook on ``schema_version``), and
|
||||
:func:`analysis_config_to_dict` JSON-safety.
|
||||
- **Executor** — :func:`run_analysis` dispatches to each analysis kind
|
||||
(dtw / stats / none) with and without segmentation; provenance is
|
||||
inherited from the primary input with ``analysis_config``
|
||||
populated; :func:`load_config`, :func:`save_report`, and
|
||||
:func:`load_report` round-trip.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import yaml
|
||||
from pydantic import ValidationError
|
||||
|
||||
from neuropose.analyzer.pipeline import (
|
||||
|
|
@ -38,7 +43,12 @@ from neuropose.analyzer.pipeline import (
|
|||
StatsAnalysis,
|
||||
StatsResults,
|
||||
analysis_config_to_dict,
|
||||
load_config,
|
||||
load_report,
|
||||
run_analysis,
|
||||
save_report,
|
||||
)
|
||||
from neuropose.io import Provenance, VideoPredictions, save_video_predictions
|
||||
from neuropose.migrations import CURRENT_VERSION
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -432,3 +442,440 @@ class TestAnalysisReport:
|
|||
dumped["mystery_field"] = 1
|
||||
with pytest.raises(ValidationError):
|
||||
AnalysisReport.model_validate(dumped)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Executor helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
NUM_JOINTS = 43
|
||||
|
||||
|
||||
def _heel_signal(num_cycles: int, frames_per_cycle: int, amplitude: float = 100.0) -> np.ndarray:
|
||||
"""Clean sinusoid stand-in for a heel's vertical trace."""
|
||||
import math
|
||||
|
||||
total = num_cycles * frames_per_cycle
|
||||
t = np.linspace(0.0, num_cycles * 2.0 * math.pi, total, endpoint=False)
|
||||
return (np.sin(t) * amplitude + amplitude).astype(float)
|
||||
|
||||
|
||||
def _build_predictions(
|
||||
signal: np.ndarray,
|
||||
joint: int,
|
||||
*,
|
||||
axis: int = 1,
|
||||
fps: float = 30.0,
|
||||
provenance: Provenance | None = None,
|
||||
) -> VideoPredictions:
|
||||
"""Build a VideoPredictions whose ``joint``'s ``axis`` follows ``signal``."""
|
||||
frames = {}
|
||||
for i, value in enumerate(signal):
|
||||
poses = [[[0.0, 0.0, 0.0] for _ in range(NUM_JOINTS)]]
|
||||
poses[0][joint][axis] = float(value)
|
||||
frames[f"frame_{i:06d}"] = {
|
||||
"boxes": [[0.0, 0.0, 1.0, 1.0, 0.9]],
|
||||
"poses3d": poses,
|
||||
"poses2d": [[[0.0, 0.0]] * NUM_JOINTS],
|
||||
}
|
||||
return VideoPredictions.model_validate(
|
||||
{
|
||||
"metadata": {
|
||||
"frame_count": len(signal),
|
||||
"fps": fps,
|
||||
"width": 640,
|
||||
"height": 480,
|
||||
},
|
||||
"frames": frames,
|
||||
"provenance": provenance.model_dump() if provenance is not None else None,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _write_heel_trial(
|
||||
tmp_path: Path,
|
||||
filename: str,
|
||||
*,
|
||||
joint: int,
|
||||
num_cycles: int = 4,
|
||||
frames_per_cycle: int = 30,
|
||||
amplitude: float = 100.0,
|
||||
provenance: Provenance | None = None,
|
||||
) -> Path:
|
||||
"""Write a heel-trace VideoPredictions JSON and return its path."""
|
||||
signal = _heel_signal(num_cycles, frames_per_cycle, amplitude=amplitude)
|
||||
preds = _build_predictions(signal, joint=joint, provenance=provenance)
|
||||
path = tmp_path / filename
|
||||
save_video_predictions(path, preds)
|
||||
return path
|
||||
|
||||
|
||||
def _fake_provenance(sha: str = "a" * 64) -> Provenance:
|
||||
return Provenance(
|
||||
model_sha256=sha,
|
||||
model_filename="fake_model.tar.gz",
|
||||
tensorflow_version="2.18.0",
|
||||
numpy_version="1.26.0",
|
||||
neuropose_version="0.0.0",
|
||||
python_version="3.11.0",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Executor: run_analysis
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunAnalysisDtwFullTrial:
|
||||
def test_dtw_all_unsegmented_yields_one_distance(self, tmp_path: Path) -> None:
|
||||
from neuropose.analyzer.segment import JOINT_INDEX
|
||||
|
||||
primary = _write_heel_trial(tmp_path, "a.json", joint=JOINT_INDEX["rhee"])
|
||||
reference = _write_heel_trial(tmp_path, "b.json", joint=JOINT_INDEX["rhee"])
|
||||
report_path = tmp_path / "report.json"
|
||||
config = AnalysisConfig.model_validate(
|
||||
{
|
||||
"inputs": {"primary": str(primary), "reference": str(reference)},
|
||||
"analysis": {"kind": "dtw", "method": "dtw_all"},
|
||||
"output": {"report": str(report_path)},
|
||||
}
|
||||
)
|
||||
report = run_analysis(config)
|
||||
assert isinstance(report.results, DtwResults)
|
||||
assert report.results.segment_labels == ["full_trial"]
|
||||
assert len(report.results.distances) == 1
|
||||
# Identical inputs → distance 0.
|
||||
assert report.results.distances[0] == pytest.approx(0.0, abs=1e-9)
|
||||
|
||||
def test_dtw_all_different_trials_positive_distance(self, tmp_path: Path) -> None:
|
||||
from neuropose.analyzer.segment import JOINT_INDEX
|
||||
|
||||
primary = _write_heel_trial(tmp_path, "a.json", joint=JOINT_INDEX["rhee"], amplitude=100.0)
|
||||
reference = _write_heel_trial(
|
||||
tmp_path, "b.json", joint=JOINT_INDEX["rhee"], amplitude=200.0
|
||||
)
|
||||
config = AnalysisConfig.model_validate(
|
||||
{
|
||||
"inputs": {"primary": str(primary), "reference": str(reference)},
|
||||
"analysis": {"kind": "dtw", "method": "dtw_all"},
|
||||
"output": {"report": str(tmp_path / "r.json")},
|
||||
}
|
||||
)
|
||||
report = run_analysis(config)
|
||||
assert isinstance(report.results, DtwResults)
|
||||
assert report.results.distances[0] > 0.0
|
||||
assert "mean" in report.results.summary
|
||||
|
||||
|
||||
class TestRunAnalysisDtwSegmented:
|
||||
def test_dtw_with_gait_cycles_produces_per_segment_distances(self, tmp_path: Path) -> None:
|
||||
from neuropose.analyzer.segment import JOINT_INDEX
|
||||
|
||||
primary = _write_heel_trial(tmp_path, "a.json", joint=JOINT_INDEX["rhee"], num_cycles=4)
|
||||
reference = _write_heel_trial(tmp_path, "b.json", joint=JOINT_INDEX["rhee"], num_cycles=4)
|
||||
config = AnalysisConfig.model_validate(
|
||||
{
|
||||
"inputs": {"primary": str(primary), "reference": str(reference)},
|
||||
"segmentation": {"kind": "gait_cycles", "joint": "rhee"},
|
||||
"analysis": {"kind": "dtw", "method": "dtw_all"},
|
||||
"output": {"report": str(tmp_path / "r.json")},
|
||||
}
|
||||
)
|
||||
report = run_analysis(config)
|
||||
assert isinstance(report.results, DtwResults)
|
||||
# 4 cycles detected on both → 4 paired distances.
|
||||
assert len(report.results.distances) == 4
|
||||
assert all(label.startswith("rhee_cycles[") for label in report.results.segment_labels)
|
||||
|
||||
def test_dtw_bilateral_produces_distances_per_side(self, tmp_path: Path) -> None:
|
||||
from neuropose.analyzer.segment import JOINT_INDEX
|
||||
|
||||
# Put the heel trace on both lhee and rhee.
|
||||
rng_signal = _heel_signal(num_cycles=3, frames_per_cycle=30)
|
||||
frames = {}
|
||||
for i, value in enumerate(rng_signal):
|
||||
poses = [[[0.0, 0.0, 0.0] for _ in range(NUM_JOINTS)]]
|
||||
poses[0][JOINT_INDEX["lhee"]][1] = float(value)
|
||||
poses[0][JOINT_INDEX["rhee"]][1] = float(value)
|
||||
frames[f"frame_{i:06d}"] = {
|
||||
"boxes": [[0.0, 0.0, 1.0, 1.0, 0.9]],
|
||||
"poses3d": poses,
|
||||
"poses2d": [[[0.0, 0.0]] * NUM_JOINTS],
|
||||
}
|
||||
preds = VideoPredictions.model_validate(
|
||||
{
|
||||
"metadata": {
|
||||
"frame_count": len(rng_signal),
|
||||
"fps": 30.0,
|
||||
"width": 640,
|
||||
"height": 480,
|
||||
},
|
||||
"frames": frames,
|
||||
}
|
||||
)
|
||||
primary = tmp_path / "a.json"
|
||||
reference = tmp_path / "b.json"
|
||||
save_video_predictions(primary, preds)
|
||||
save_video_predictions(reference, preds)
|
||||
config = AnalysisConfig.model_validate(
|
||||
{
|
||||
"inputs": {"primary": str(primary), "reference": str(reference)},
|
||||
"segmentation": {"kind": "gait_cycles_bilateral"},
|
||||
"analysis": {"kind": "dtw", "method": "dtw_all"},
|
||||
"output": {"report": str(tmp_path / "r.json")},
|
||||
}
|
||||
)
|
||||
report = run_analysis(config)
|
||||
assert isinstance(report.results, DtwResults)
|
||||
# 3 cycles * 2 sides.
|
||||
assert len(report.results.distances) == 6
|
||||
left = [lbl for lbl in report.results.segment_labels if lbl.startswith("left_heel_strikes")]
|
||||
right = [
|
||||
lbl for lbl in report.results.segment_labels if lbl.startswith("right_heel_strikes")
|
||||
]
|
||||
assert len(left) == 3
|
||||
assert len(right) == 3
|
||||
# Identical primary and reference → all distances zero.
|
||||
for d in report.results.distances:
|
||||
assert d == pytest.approx(0.0, abs=1e-9)
|
||||
|
||||
def test_dtw_per_joint_populates_per_joint_distances(self, tmp_path: Path) -> None:
|
||||
from neuropose.analyzer.segment import JOINT_INDEX
|
||||
|
||||
primary = _write_heel_trial(tmp_path, "a.json", joint=JOINT_INDEX["rhee"])
|
||||
reference = _write_heel_trial(tmp_path, "b.json", joint=JOINT_INDEX["rhee"])
|
||||
config = AnalysisConfig.model_validate(
|
||||
{
|
||||
"inputs": {"primary": str(primary), "reference": str(reference)},
|
||||
"analysis": {"kind": "dtw", "method": "dtw_per_joint"},
|
||||
"output": {"report": str(tmp_path / "r.json")},
|
||||
}
|
||||
)
|
||||
report = run_analysis(config)
|
||||
assert isinstance(report.results, DtwResults)
|
||||
assert report.results.per_joint_distances is not None
|
||||
assert len(report.results.per_joint_distances) == 1 # unsegmented → one pair
|
||||
assert len(report.results.per_joint_distances[0]) == NUM_JOINTS
|
||||
|
||||
|
||||
class TestRunAnalysisStats:
|
||||
def test_stats_unsegmented_single_block(self, tmp_path: Path) -> None:
|
||||
from neuropose.analyzer.segment import JOINT_INDEX
|
||||
|
||||
primary = _write_heel_trial(tmp_path, "a.json", joint=JOINT_INDEX["rhee"])
|
||||
config = AnalysisConfig.model_validate(
|
||||
{
|
||||
"inputs": {"primary": str(primary)},
|
||||
"analysis": {
|
||||
"kind": "stats",
|
||||
"extractor": {
|
||||
"kind": "joint_axis",
|
||||
"joint": JOINT_INDEX["rhee"],
|
||||
"axis": 1,
|
||||
"invert": False,
|
||||
},
|
||||
},
|
||||
"output": {"report": str(tmp_path / "r.json")},
|
||||
}
|
||||
)
|
||||
report = run_analysis(config)
|
||||
assert isinstance(report.results, StatsResults)
|
||||
assert report.results.segment_labels == ["full_trial"]
|
||||
assert len(report.results.statistics) == 1
|
||||
stat = report.results.statistics[0]
|
||||
assert isinstance(stat, FeatureSummary)
|
||||
assert stat.max > stat.min # Signal oscillates.
|
||||
|
||||
def test_stats_with_segmentation_emits_per_segment(self, tmp_path: Path) -> None:
|
||||
from neuropose.analyzer.segment import JOINT_INDEX
|
||||
|
||||
primary = _write_heel_trial(tmp_path, "a.json", joint=JOINT_INDEX["rhee"], num_cycles=3)
|
||||
config = AnalysisConfig.model_validate(
|
||||
{
|
||||
"inputs": {"primary": str(primary)},
|
||||
"segmentation": {"kind": "gait_cycles", "joint": "rhee"},
|
||||
"analysis": {
|
||||
"kind": "stats",
|
||||
"extractor": {
|
||||
"kind": "joint_axis",
|
||||
"joint": JOINT_INDEX["rhee"],
|
||||
"axis": 1,
|
||||
"invert": False,
|
||||
},
|
||||
},
|
||||
"output": {"report": str(tmp_path / "r.json")},
|
||||
}
|
||||
)
|
||||
report = run_analysis(config)
|
||||
assert isinstance(report.results, StatsResults)
|
||||
assert len(report.results.statistics) == 3
|
||||
|
||||
|
||||
class TestRunAnalysisNone:
|
||||
def test_none_analysis_returns_no_results(self, tmp_path: Path) -> None:
|
||||
from neuropose.analyzer.segment import JOINT_INDEX
|
||||
|
||||
primary = _write_heel_trial(tmp_path, "a.json", joint=JOINT_INDEX["rhee"])
|
||||
config = AnalysisConfig.model_validate(
|
||||
{
|
||||
"inputs": {"primary": str(primary)},
|
||||
"analysis": {"kind": "none"},
|
||||
"output": {"report": str(tmp_path / "r.json")},
|
||||
}
|
||||
)
|
||||
report = run_analysis(config)
|
||||
assert isinstance(report.results, NoResults)
|
||||
|
||||
def test_none_with_segmentation_still_emits_segmentations(self, tmp_path: Path) -> None:
|
||||
from neuropose.analyzer.segment import JOINT_INDEX
|
||||
|
||||
primary = _write_heel_trial(tmp_path, "a.json", joint=JOINT_INDEX["rhee"])
|
||||
config = AnalysisConfig.model_validate(
|
||||
{
|
||||
"inputs": {"primary": str(primary)},
|
||||
"segmentation": {"kind": "gait_cycles", "joint": "rhee"},
|
||||
"analysis": {"kind": "none"},
|
||||
"output": {"report": str(tmp_path / "r.json")},
|
||||
}
|
||||
)
|
||||
report = run_analysis(config)
|
||||
assert isinstance(report.results, NoResults)
|
||||
assert "rhee_cycles" in report.segmentations
|
||||
assert len(report.segmentations["rhee_cycles"].segments) > 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Provenance inheritance
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunAnalysisProvenance:
|
||||
def test_inherits_primary_provenance_and_stamps_config(self, tmp_path: Path) -> None:
|
||||
from neuropose.analyzer.segment import JOINT_INDEX
|
||||
|
||||
provenance = _fake_provenance()
|
||||
primary = _write_heel_trial(
|
||||
tmp_path, "a.json", joint=JOINT_INDEX["rhee"], provenance=provenance
|
||||
)
|
||||
reference = _write_heel_trial(tmp_path, "b.json", joint=JOINT_INDEX["rhee"])
|
||||
config = AnalysisConfig.model_validate(
|
||||
{
|
||||
"inputs": {"primary": str(primary), "reference": str(reference)},
|
||||
"analysis": {"kind": "dtw", "method": "dtw_all"},
|
||||
"output": {"report": str(tmp_path / "r.json")},
|
||||
}
|
||||
)
|
||||
report = run_analysis(config)
|
||||
assert report.provenance is not None
|
||||
# Model SHA inherited from primary.
|
||||
assert report.provenance.model_sha256 == provenance.model_sha256
|
||||
# analysis_config populated with the serialised config.
|
||||
assert report.provenance.analysis_config is not None
|
||||
assert report.provenance.analysis_config["config_version"] == 1
|
||||
|
||||
def test_no_primary_provenance_yields_none_report_provenance(self, tmp_path: Path) -> None:
|
||||
from neuropose.analyzer.segment import JOINT_INDEX
|
||||
|
||||
primary = _write_heel_trial(tmp_path, "a.json", joint=JOINT_INDEX["rhee"])
|
||||
reference = _write_heel_trial(tmp_path, "b.json", joint=JOINT_INDEX["rhee"])
|
||||
config = AnalysisConfig.model_validate(
|
||||
{
|
||||
"inputs": {"primary": str(primary), "reference": str(reference)},
|
||||
"analysis": {"kind": "dtw", "method": "dtw_all"},
|
||||
"output": {"report": str(tmp_path / "r.json")},
|
||||
}
|
||||
)
|
||||
report = run_analysis(config)
|
||||
assert report.provenance is None
|
||||
|
||||
def test_input_summaries_track_paths_and_metadata(self, tmp_path: Path) -> None:
|
||||
from neuropose.analyzer.segment import JOINT_INDEX
|
||||
|
||||
primary = _write_heel_trial(tmp_path, "a.json", joint=JOINT_INDEX["rhee"], num_cycles=5)
|
||||
reference = _write_heel_trial(tmp_path, "b.json", joint=JOINT_INDEX["rhee"], num_cycles=3)
|
||||
config = AnalysisConfig.model_validate(
|
||||
{
|
||||
"inputs": {"primary": str(primary), "reference": str(reference)},
|
||||
"analysis": {"kind": "dtw", "method": "dtw_all"},
|
||||
"output": {"report": str(tmp_path / "r.json")},
|
||||
}
|
||||
)
|
||||
report = run_analysis(config)
|
||||
assert report.primary.path == primary
|
||||
assert report.primary.frame_count == 5 * 30
|
||||
assert report.reference is not None
|
||||
assert report.reference.frame_count == 3 * 30
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# load_config / save_report / load_report
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLoadSave:
|
||||
def test_load_config_parses_yaml(self, tmp_path: Path) -> None:
|
||||
config_dict = {
|
||||
"inputs": {
|
||||
"primary": str(tmp_path / "a.json"),
|
||||
"reference": str(tmp_path / "b.json"),
|
||||
},
|
||||
"analysis": {"kind": "dtw", "method": "dtw_all"},
|
||||
"output": {"report": str(tmp_path / "r.json")},
|
||||
}
|
||||
yaml_path = tmp_path / "exp.yaml"
|
||||
yaml_path.write_text(yaml.safe_dump(config_dict))
|
||||
loaded = load_config(yaml_path)
|
||||
assert isinstance(loaded.analysis, DtwAnalysis)
|
||||
|
||||
def test_load_config_empty_file_fails_cleanly(self, tmp_path: Path) -> None:
|
||||
empty = tmp_path / "empty.yaml"
|
||||
empty.write_text("")
|
||||
with pytest.raises(ValidationError):
|
||||
load_config(empty)
|
||||
|
||||
def test_load_config_rejects_malformed_yaml(self, tmp_path: Path) -> None:
|
||||
bad = tmp_path / "bad.yaml"
|
||||
# Unclosed flow-style mapping — yaml.safe_load raises here.
|
||||
bad.write_text("inputs: {primary: foo\n")
|
||||
with pytest.raises(yaml.YAMLError):
|
||||
load_config(bad)
|
||||
|
||||
def test_save_report_round_trip(self, tmp_path: Path) -> None:
|
||||
from neuropose.analyzer.segment import JOINT_INDEX
|
||||
|
||||
primary = _write_heel_trial(tmp_path, "a.json", joint=JOINT_INDEX["rhee"])
|
||||
reference = _write_heel_trial(tmp_path, "b.json", joint=JOINT_INDEX["rhee"])
|
||||
config = AnalysisConfig.model_validate(
|
||||
{
|
||||
"inputs": {"primary": str(primary), "reference": str(reference)},
|
||||
"analysis": {"kind": "dtw", "method": "dtw_all"},
|
||||
"output": {"report": str(tmp_path / "report.json")},
|
||||
}
|
||||
)
|
||||
report = run_analysis(config)
|
||||
report_path = tmp_path / "report.json"
|
||||
save_report(report_path, report)
|
||||
assert report_path.exists()
|
||||
|
||||
restored = load_report(report_path)
|
||||
assert restored == report
|
||||
|
||||
def test_save_report_is_atomic(self, tmp_path: Path) -> None:
|
||||
"""The saver writes via a sibling .tmp path and renames."""
|
||||
report = _make_report(tmp_path)
|
||||
report_path = tmp_path / "subdir" / "report.json"
|
||||
save_report(report_path, report)
|
||||
# Parent directory was created.
|
||||
assert report_path.exists()
|
||||
# No .tmp sibling left behind.
|
||||
assert not (report_path.with_suffix(report_path.suffix + ".tmp")).exists()
|
||||
|
||||
def test_load_report_rejects_future_schema(self, tmp_path: Path) -> None:
|
||||
"""Future schema_version surfaces as a migration error."""
|
||||
from neuropose.migrations import FutureSchemaError
|
||||
|
||||
future = {"schema_version": CURRENT_VERSION + 1}
|
||||
path = tmp_path / "future.json"
|
||||
path.write_text(json.dumps(future))
|
||||
with pytest.raises(FutureSchemaError):
|
||||
load_report(path)
|
||||
|
|
|
|||
Loading…
Reference in New Issue