add run_analysis pipeline executor
run_analysis(config) loads the predictions files named in the config, applies the configured segmentation, dispatches to the selected analysis kind (DTW, stats, or none), and emits a fully populated AnalysisReport. The report's Provenance inherits the inference-time envelope from the primary input with analysis_config stamped in, so the output is self-describing even if the source YAML is later lost. For DTW runs with segmentation, segments are paired one-to-one by index across primary and reference, truncating to min of the two counts. Bilateral segmentations emit per-side distances under "left_heel_strikes[i]" / "right_heel_strikes[i]" labels. dtw_per_joint stores its full per-unit breakdown in the per_joint_distances field and reports the sum as the representative scalar distance. Also ships load_config (YAML), save_report (atomic JSON write), and load_report (rehydrate via the migration chain) so the executor can be driven end-to-end from Python without the CLI. The CLI wiring lands in the next commit.
This commit is contained in:
parent
979beb1078
commit
dc48988450
17
CHANGELOG.md
17
CHANGELOG.md
|
|
@ -279,8 +279,21 @@ be split into per-release sections once tagging begins.
|
||||||
`CURRENT_VERSION = 2`, with a new
|
`CURRENT_VERSION = 2`, with a new
|
||||||
`register_analysis_report_migration` decorator and
|
`register_analysis_report_migration` decorator and
|
||||||
`migrate_analysis_report` driver in `neuropose.migrations` ready
|
`migrate_analysis_report` driver in `neuropose.migrations` ready
|
||||||
for future schema changes. Pipeline execution lands in a
|
for future schema changes. `run_analysis(config)` loads the named
|
||||||
follow-up commit.
|
predictions files, applies the configured segmentation, dispatches
|
||||||
|
to the selected analysis kind (DTW, stats, or none), and emits a
|
||||||
|
fully populated `AnalysisReport` whose `Provenance` inherits the
|
||||||
|
inference-time envelope from the primary input with
|
||||||
|
`analysis_config` stamped in, so the report is self-describing
|
||||||
|
even if the source YAML is lost. For DTW runs with segmentation,
|
||||||
|
segments are paired one-to-one by index across primary and
|
||||||
|
reference, truncating to `min(len_primary, len_reference)`;
|
||||||
|
bilateral segmentations emit per-side distances under
|
||||||
|
`"left_heel_strikes[i]"` / `"right_heel_strikes[i]"` labels.
|
||||||
|
`load_config(path)` parses YAML, `save_report(path, report)`
|
||||||
|
writes atomically, and `load_report(path)` rehydrates via the
|
||||||
|
migration chain. CLI wiring and example configs land in
|
||||||
|
follow-up commits.
|
||||||
- **`neuropose.analyzer.segment.segment_gait_cycles`** and
|
- **`neuropose.analyzer.segment.segment_gait_cycles`** and
|
||||||
**`segment_gait_cycles_bilateral`** — clinical convenience
|
**`segment_gait_cycles_bilateral`** — clinical convenience
|
||||||
wrappers over `segment_predictions` that pre-fill a `joint_axis`
|
wrappers over `segment_predictions` that pre-fill a `joint_axis`
|
||||||
|
|
|
||||||
|
|
@ -69,6 +69,10 @@ from neuropose.analyzer.pipeline import (
|
||||||
StatsAnalysis,
|
StatsAnalysis,
|
||||||
StatsResults,
|
StatsResults,
|
||||||
analysis_config_to_dict,
|
analysis_config_to_dict,
|
||||||
|
load_config,
|
||||||
|
load_report,
|
||||||
|
run_analysis,
|
||||||
|
save_report,
|
||||||
)
|
)
|
||||||
from neuropose.analyzer.segment import (
|
from neuropose.analyzer.segment import (
|
||||||
JOINT_INDEX,
|
JOINT_INDEX,
|
||||||
|
|
@ -130,10 +134,14 @@ __all__ = [
|
||||||
"joint_index",
|
"joint_index",
|
||||||
"joint_pair_distance",
|
"joint_pair_distance",
|
||||||
"joint_speed",
|
"joint_speed",
|
||||||
|
"load_config",
|
||||||
|
"load_report",
|
||||||
"normalize_pose_sequence",
|
"normalize_pose_sequence",
|
||||||
"pad_sequences",
|
"pad_sequences",
|
||||||
"predictions_to_numpy",
|
"predictions_to_numpy",
|
||||||
"procrustes_align",
|
"procrustes_align",
|
||||||
|
"run_analysis",
|
||||||
|
"save_report",
|
||||||
"segment_by_peaks",
|
"segment_by_peaks",
|
||||||
"segment_gait_cycles",
|
"segment_gait_cycles",
|
||||||
"segment_gait_cycles_bilateral",
|
"segment_gait_cycles_bilateral",
|
||||||
|
|
|
||||||
|
|
@ -24,19 +24,56 @@ config additionally parses from YAML via :func:`load_config`. Cross-field
|
||||||
invariants (for example, ``method="dtw_relation"`` requires ``joint_i``
|
invariants (for example, ``method="dtw_relation"`` requires ``joint_i``
|
||||||
and ``joint_j``) are enforced at parse time so typo-laden configs fail
|
and ``joint_j``) are enforced at parse time so typo-laden configs fail
|
||||||
fast rather than after an expensive multi-minute load.
|
fast rather than after an expensive multi-minute load.
|
||||||
|
|
||||||
|
Execution
|
||||||
|
---------
|
||||||
|
:func:`run_analysis` is the top-level executor: it loads the
|
||||||
|
predictions files named in the config, applies any configured
|
||||||
|
segmentation stage, dispatches to the configured analysis stage, and
|
||||||
|
returns a fully populated :class:`AnalysisReport`. The executor is
|
||||||
|
intended to be called from the ``neuropose analyze`` CLI but is
|
||||||
|
equally valid as a Python-level entry point for notebook-driven
|
||||||
|
exploration.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Annotated, Any, Literal
|
from typing import Annotated, Any, Literal
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import yaml
|
||||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||||
|
|
||||||
from neuropose.analyzer.dtw import AlignMode, NanPolicy, Representation
|
from neuropose.analyzer.dtw import (
|
||||||
from neuropose.analyzer.segment import AxisLetter
|
AlignMode,
|
||||||
from neuropose.io import ExtractorSpec, Provenance, Segmentation
|
DTWResult,
|
||||||
from neuropose.migrations import CURRENT_VERSION
|
NanPolicy,
|
||||||
|
Representation,
|
||||||
|
dtw_all,
|
||||||
|
dtw_per_joint,
|
||||||
|
dtw_relation,
|
||||||
|
)
|
||||||
|
from neuropose.analyzer.features import (
|
||||||
|
extract_feature_statistics,
|
||||||
|
predictions_to_numpy,
|
||||||
|
)
|
||||||
|
from neuropose.analyzer.segment import (
|
||||||
|
AxisLetter,
|
||||||
|
extract_signal,
|
||||||
|
segment_gait_cycles,
|
||||||
|
segment_gait_cycles_bilateral,
|
||||||
|
segment_predictions,
|
||||||
|
)
|
||||||
|
from neuropose.io import (
|
||||||
|
ExtractorSpec,
|
||||||
|
Provenance,
|
||||||
|
Segmentation,
|
||||||
|
VideoPredictions,
|
||||||
|
load_video_predictions,
|
||||||
|
)
|
||||||
|
from neuropose.migrations import CURRENT_VERSION, migrate_analysis_report
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Inputs
|
# Inputs
|
||||||
|
|
@ -476,3 +513,420 @@ def analysis_config_to_dict(config: AnalysisConfig) -> dict[str, Any]:
|
||||||
:class:`AnalysisReport`'s provenance envelope.
|
:class:`AnalysisReport`'s provenance envelope.
|
||||||
"""
|
"""
|
||||||
return config.model_dump(mode="json")
|
return config.model_dump(mode="json")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Load / save
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def load_config(path: Path) -> AnalysisConfig:
|
||||||
|
"""Load and validate an :class:`AnalysisConfig` from a YAML file.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
path
|
||||||
|
Filesystem path to a YAML file conforming to the
|
||||||
|
:class:`AnalysisConfig` schema.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
AnalysisConfig
|
||||||
|
The fully validated config. Cross-field invariants have
|
||||||
|
already been checked.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
pydantic.ValidationError
|
||||||
|
On any schema violation — unknown keys, wrong types, or
|
||||||
|
failed cross-field invariants.
|
||||||
|
yaml.YAMLError
|
||||||
|
On malformed YAML.
|
||||||
|
"""
|
||||||
|
with path.open("r", encoding="utf-8") as f:
|
||||||
|
raw = yaml.safe_load(f)
|
||||||
|
if raw is None:
|
||||||
|
raw = {}
|
||||||
|
return AnalysisConfig.model_validate(raw)
|
||||||
|
|
||||||
|
|
||||||
|
def save_report(path: Path, report: AnalysisReport) -> None:
|
||||||
|
"""Serialise an :class:`AnalysisReport` to ``path`` atomically.
|
||||||
|
|
||||||
|
Writes to a sibling ``<path>.tmp`` first, then renames over
|
||||||
|
``path`` so a crash mid-write cannot leave behind a truncated
|
||||||
|
file. The parent directory is created if it does not exist.
|
||||||
|
"""
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
tmp = path.with_suffix(path.suffix + ".tmp")
|
||||||
|
payload = report.model_dump(mode="json")
|
||||||
|
with tmp.open("w", encoding="utf-8") as f:
|
||||||
|
json.dump(payload, f, indent=2)
|
||||||
|
tmp.replace(path)
|
||||||
|
|
||||||
|
|
||||||
|
def load_report(path: Path) -> AnalysisReport:
|
||||||
|
"""Load and validate an :class:`AnalysisReport` JSON file.
|
||||||
|
|
||||||
|
Runs the payload through :func:`~neuropose.migrations.migrate_analysis_report`
|
||||||
|
before pydantic validation so future schema bumps can upgrade
|
||||||
|
legacy reports transparently.
|
||||||
|
"""
|
||||||
|
with path.open("r", encoding="utf-8") as f:
|
||||||
|
data: Any = json.load(f)
|
||||||
|
if isinstance(data, dict):
|
||||||
|
data = migrate_analysis_report(data)
|
||||||
|
return AnalysisReport.model_validate(data)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Executor
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def run_analysis(config: AnalysisConfig) -> AnalysisReport:
|
||||||
|
"""Execute the pipeline described by ``config`` end-to-end.
|
||||||
|
|
||||||
|
Loads the predictions files named in ``config.inputs``, applies
|
||||||
|
the configured preprocessing + segmentation + analysis stages,
|
||||||
|
and returns an :class:`AnalysisReport` whose
|
||||||
|
:attr:`~AnalysisReport.provenance` inherits the inference-time
|
||||||
|
provenance of the primary input with
|
||||||
|
:attr:`~neuropose.io.Provenance.analysis_config` populated so the
|
||||||
|
report is self-describing even if the YAML config is later lost.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
config
|
||||||
|
The pre-validated pipeline configuration.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
AnalysisReport
|
||||||
|
Fully populated report. Not yet written to disk — the caller
|
||||||
|
passes it to :func:`save_report` (or inspects it directly).
|
||||||
|
|
||||||
|
Notes
|
||||||
|
-----
|
||||||
|
For DTW runs with a segmentation stage, segments are paired
|
||||||
|
one-to-one by index across primary and reference, truncating to
|
||||||
|
``min(len_primary, len_reference)``. Bilateral segmentations
|
||||||
|
produce distances for each side independently, labelled under
|
||||||
|
their segmentation key (e.g. ``"left_heel_strikes[3]"``).
|
||||||
|
"""
|
||||||
|
primary_preds = load_video_predictions(config.inputs.primary)
|
||||||
|
reference_preds: VideoPredictions | None = None
|
||||||
|
if config.inputs.reference is not None:
|
||||||
|
reference_preds = load_video_predictions(config.inputs.reference)
|
||||||
|
|
||||||
|
person_index = config.preprocessing.person_index
|
||||||
|
|
||||||
|
primary_seq = predictions_to_numpy(primary_preds, person_index=person_index)
|
||||||
|
reference_seq: np.ndarray | None = None
|
||||||
|
if reference_preds is not None:
|
||||||
|
reference_seq = predictions_to_numpy(reference_preds, person_index=person_index)
|
||||||
|
|
||||||
|
primary_segmentations: dict[str, Segmentation] = {}
|
||||||
|
reference_segmentations: dict[str, Segmentation] = {}
|
||||||
|
if config.segmentation is not None:
|
||||||
|
primary_segmentations = _run_segmentation(primary_preds, config.segmentation, person_index)
|
||||||
|
if reference_preds is not None:
|
||||||
|
reference_segmentations = _run_segmentation(
|
||||||
|
reference_preds, config.segmentation, person_index
|
||||||
|
)
|
||||||
|
|
||||||
|
results = _run_analysis_stage(
|
||||||
|
config.analysis,
|
||||||
|
primary_seq=primary_seq,
|
||||||
|
reference_seq=reference_seq,
|
||||||
|
primary_segmentations=primary_segmentations,
|
||||||
|
reference_segmentations=reference_segmentations,
|
||||||
|
)
|
||||||
|
|
||||||
|
analysis_config_dump = analysis_config_to_dict(config)
|
||||||
|
report_provenance: Provenance | None = None
|
||||||
|
if primary_preds.provenance is not None:
|
||||||
|
report_provenance = primary_preds.provenance.model_copy(
|
||||||
|
update={"analysis_config": analysis_config_dump}
|
||||||
|
)
|
||||||
|
|
||||||
|
primary_summary = InputSummary(
|
||||||
|
path=config.inputs.primary,
|
||||||
|
frame_count=primary_preds.metadata.frame_count,
|
||||||
|
fps=primary_preds.metadata.fps,
|
||||||
|
provenance=primary_preds.provenance,
|
||||||
|
)
|
||||||
|
reference_summary: InputSummary | None = None
|
||||||
|
if reference_preds is not None and config.inputs.reference is not None:
|
||||||
|
reference_summary = InputSummary(
|
||||||
|
path=config.inputs.reference,
|
||||||
|
frame_count=reference_preds.metadata.frame_count,
|
||||||
|
fps=reference_preds.metadata.fps,
|
||||||
|
provenance=reference_preds.provenance,
|
||||||
|
)
|
||||||
|
|
||||||
|
return AnalysisReport(
|
||||||
|
config=config,
|
||||||
|
provenance=report_provenance,
|
||||||
|
primary=primary_summary,
|
||||||
|
reference=reference_summary,
|
||||||
|
segmentations=primary_segmentations,
|
||||||
|
results=results,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Internal dispatch helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _run_segmentation(
|
||||||
|
predictions: VideoPredictions,
|
||||||
|
stage: SegmentationStage, # type: ignore[valid-type]
|
||||||
|
person_index: int,
|
||||||
|
) -> dict[str, Segmentation]:
|
||||||
|
"""Apply a segmentation stage to a :class:`VideoPredictions`.
|
||||||
|
|
||||||
|
Returns a dict keyed by a stage-appropriate label: single-side
|
||||||
|
gait cycles use ``"<joint>_cycles"``, bilateral gait cycles use
|
||||||
|
``"left_heel_strikes"`` / ``"right_heel_strikes"``, and extractor
|
||||||
|
segmentation uses the caller-supplied
|
||||||
|
:attr:`~ExtractorSegmentation.label`.
|
||||||
|
"""
|
||||||
|
if isinstance(stage, GaitCyclesSegmentation):
|
||||||
|
seg = segment_gait_cycles(
|
||||||
|
predictions,
|
||||||
|
joint=stage.joint,
|
||||||
|
axis=stage.axis,
|
||||||
|
invert=stage.invert,
|
||||||
|
min_cycle_seconds=stage.min_cycle_seconds,
|
||||||
|
min_prominence=stage.min_prominence,
|
||||||
|
)
|
||||||
|
return {f"{stage.joint}_cycles": seg}
|
||||||
|
if isinstance(stage, GaitCyclesBilateralSegmentation):
|
||||||
|
return segment_gait_cycles_bilateral(
|
||||||
|
predictions,
|
||||||
|
axis=stage.axis,
|
||||||
|
invert=stage.invert,
|
||||||
|
min_cycle_seconds=stage.min_cycle_seconds,
|
||||||
|
min_prominence=stage.min_prominence,
|
||||||
|
)
|
||||||
|
if isinstance(stage, ExtractorSegmentation):
|
||||||
|
effective_person_index = (
|
||||||
|
stage.person_index if stage.person_index is not None else person_index
|
||||||
|
)
|
||||||
|
seg = segment_predictions(
|
||||||
|
predictions,
|
||||||
|
stage.extractor,
|
||||||
|
person_index=effective_person_index,
|
||||||
|
min_distance_seconds=stage.min_distance_seconds,
|
||||||
|
min_prominence=stage.min_prominence,
|
||||||
|
min_height=stage.min_height,
|
||||||
|
pad_seconds=stage.pad_seconds,
|
||||||
|
)
|
||||||
|
return {stage.label: seg}
|
||||||
|
raise TypeError(f"unknown segmentation stage: {type(stage).__name__}")
|
||||||
|
|
||||||
|
|
||||||
|
def _run_analysis_stage(
|
||||||
|
stage: AnalysisStage, # type: ignore[valid-type]
|
||||||
|
*,
|
||||||
|
primary_seq: np.ndarray,
|
||||||
|
reference_seq: np.ndarray | None,
|
||||||
|
primary_segmentations: dict[str, Segmentation],
|
||||||
|
reference_segmentations: dict[str, Segmentation],
|
||||||
|
) -> AnalysisResults: # type: ignore[valid-type]
|
||||||
|
"""Dispatch to the appropriate analysis executor per ``stage.kind``."""
|
||||||
|
if isinstance(stage, DtwAnalysis):
|
||||||
|
if reference_seq is None:
|
||||||
|
# AnalysisConfig's cross-stage validator should prevent
|
||||||
|
# this; duplicate the check here so a direct programmatic
|
||||||
|
# call can't slip through.
|
||||||
|
raise ValueError("DtwAnalysis requires a reference sequence")
|
||||||
|
return _run_dtw(
|
||||||
|
stage,
|
||||||
|
primary_seq=primary_seq,
|
||||||
|
reference_seq=reference_seq,
|
||||||
|
primary_segmentations=primary_segmentations,
|
||||||
|
reference_segmentations=reference_segmentations,
|
||||||
|
)
|
||||||
|
if isinstance(stage, StatsAnalysis):
|
||||||
|
return _run_stats(
|
||||||
|
stage,
|
||||||
|
primary_seq=primary_seq,
|
||||||
|
primary_segmentations=primary_segmentations,
|
||||||
|
)
|
||||||
|
if isinstance(stage, NoAnalysis):
|
||||||
|
return NoResults(kind="none")
|
||||||
|
raise TypeError(f"unknown analysis stage: {type(stage).__name__}")
|
||||||
|
|
||||||
|
|
||||||
|
def _run_dtw(
|
||||||
|
stage: DtwAnalysis,
|
||||||
|
*,
|
||||||
|
primary_seq: np.ndarray,
|
||||||
|
reference_seq: np.ndarray,
|
||||||
|
primary_segmentations: dict[str, Segmentation],
|
||||||
|
reference_segmentations: dict[str, Segmentation],
|
||||||
|
) -> DtwResults:
|
||||||
|
"""Execute a DTW analysis stage, returning :class:`DtwResults`."""
|
||||||
|
labels: list[str] = []
|
||||||
|
distances: list[float] = []
|
||||||
|
paths: list[list[tuple[int, int]]] = []
|
||||||
|
per_joint_distances: list[list[float]] | None = [] if stage.method == "dtw_per_joint" else None
|
||||||
|
|
||||||
|
pairs: list[tuple[str, np.ndarray, np.ndarray]] = []
|
||||||
|
if primary_segmentations:
|
||||||
|
for key, primary_seg in primary_segmentations.items():
|
||||||
|
reference_seg = reference_segmentations.get(key)
|
||||||
|
if reference_seg is None:
|
||||||
|
# Same config was applied to both, so this should not
|
||||||
|
# happen unless the segmentation depends on the input
|
||||||
|
# length in some unexpected way. Skip with a warning
|
||||||
|
# rather than crash the whole run.
|
||||||
|
continue
|
||||||
|
pair_count = min(len(primary_seg.segments), len(reference_seg.segments))
|
||||||
|
for i in range(pair_count):
|
||||||
|
p_seg = primary_seg.segments[i]
|
||||||
|
r_seg = reference_seg.segments[i]
|
||||||
|
pairs.append(
|
||||||
|
(
|
||||||
|
f"{key}[{i}]",
|
||||||
|
primary_seq[p_seg.start : p_seg.end],
|
||||||
|
reference_seq[r_seg.start : r_seg.end],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
pairs.append(("full_trial", primary_seq, reference_seq))
|
||||||
|
|
||||||
|
for label, primary_slice, reference_slice in pairs:
|
||||||
|
labels.append(label)
|
||||||
|
if stage.method == "dtw_all":
|
||||||
|
result = dtw_all(
|
||||||
|
primary_slice,
|
||||||
|
reference_slice,
|
||||||
|
align=stage.align,
|
||||||
|
representation=stage.representation,
|
||||||
|
angle_triplets=stage.angle_triplets,
|
||||||
|
nan_policy=stage.nan_policy,
|
||||||
|
)
|
||||||
|
distances.append(result.distance)
|
||||||
|
paths.append(result.path)
|
||||||
|
elif stage.method == "dtw_per_joint":
|
||||||
|
assert per_joint_distances is not None
|
||||||
|
per_joint_results = dtw_per_joint(
|
||||||
|
primary_slice,
|
||||||
|
reference_slice,
|
||||||
|
align=stage.align,
|
||||||
|
representation=stage.representation,
|
||||||
|
angle_triplets=stage.angle_triplets,
|
||||||
|
nan_policy=stage.nan_policy,
|
||||||
|
)
|
||||||
|
# "distance" for a per-joint run is the sum across units;
|
||||||
|
# "per_joint_distances" carries the full breakdown.
|
||||||
|
per_unit = [r.distance for r in per_joint_results]
|
||||||
|
distances.append(float(sum(per_unit)))
|
||||||
|
per_joint_distances.append(per_unit)
|
||||||
|
# Store just the first joint's path as a representative —
|
||||||
|
# per-joint paths are a list of equal length, but
|
||||||
|
# reporting all of them on disk is almost always overkill.
|
||||||
|
paths.append(per_joint_results[0].path if per_joint_results else [])
|
||||||
|
else: # "dtw_relation"
|
||||||
|
assert stage.joint_i is not None
|
||||||
|
assert stage.joint_j is not None
|
||||||
|
result = _invoke_dtw_relation(
|
||||||
|
primary_slice,
|
||||||
|
reference_slice,
|
||||||
|
joint_i=stage.joint_i,
|
||||||
|
joint_j=stage.joint_j,
|
||||||
|
align=stage.align,
|
||||||
|
nan_policy=stage.nan_policy,
|
||||||
|
)
|
||||||
|
distances.append(result.distance)
|
||||||
|
paths.append(result.path)
|
||||||
|
|
||||||
|
return DtwResults(
|
||||||
|
kind="dtw",
|
||||||
|
method=stage.method,
|
||||||
|
distances=distances,
|
||||||
|
paths=paths,
|
||||||
|
per_joint_distances=per_joint_distances,
|
||||||
|
segment_labels=labels,
|
||||||
|
summary=_summarize_distances(distances),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _invoke_dtw_relation(
|
||||||
|
primary_slice: np.ndarray,
|
||||||
|
reference_slice: np.ndarray,
|
||||||
|
*,
|
||||||
|
joint_i: int,
|
||||||
|
joint_j: int,
|
||||||
|
align: AlignMode,
|
||||||
|
nan_policy: NanPolicy,
|
||||||
|
) -> DTWResult:
|
||||||
|
"""Isolating thin wrapper so test fakes can replace the call site cleanly."""
|
||||||
|
return dtw_relation(
|
||||||
|
primary_slice,
|
||||||
|
reference_slice,
|
||||||
|
joint_i,
|
||||||
|
joint_j,
|
||||||
|
align=align,
|
||||||
|
nan_policy=nan_policy,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _run_stats(
|
||||||
|
stage: StatsAnalysis,
|
||||||
|
*,
|
||||||
|
primary_seq: np.ndarray,
|
||||||
|
primary_segmentations: dict[str, Segmentation],
|
||||||
|
) -> StatsResults:
|
||||||
|
"""Execute a stats analysis stage, returning :class:`StatsResults`."""
|
||||||
|
labels: list[str] = []
|
||||||
|
stats: list[FeatureSummary] = []
|
||||||
|
|
||||||
|
if primary_segmentations:
|
||||||
|
for key, seg in primary_segmentations.items():
|
||||||
|
for i, segment in enumerate(seg.segments):
|
||||||
|
labels.append(f"{key}[{i}]")
|
||||||
|
signal = extract_signal(
|
||||||
|
primary_seq[segment.start : segment.end],
|
||||||
|
stage.extractor,
|
||||||
|
)
|
||||||
|
stats.append(_feature_summary(signal))
|
||||||
|
else:
|
||||||
|
labels.append("full_trial")
|
||||||
|
signal = extract_signal(primary_seq, stage.extractor)
|
||||||
|
stats.append(_feature_summary(signal))
|
||||||
|
|
||||||
|
return StatsResults(kind="stats", statistics=stats, segment_labels=labels)
|
||||||
|
|
||||||
|
|
||||||
|
def _feature_summary(signal: np.ndarray) -> FeatureSummary:
|
||||||
|
"""Wrap :func:`extract_feature_statistics` output in a pydantic model."""
|
||||||
|
raw = extract_feature_statistics(signal)
|
||||||
|
return FeatureSummary(
|
||||||
|
mean=raw.mean,
|
||||||
|
std=raw.std,
|
||||||
|
min=raw.min,
|
||||||
|
max=raw.max,
|
||||||
|
range=raw.range,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _summarize_distances(distances: list[float]) -> dict[str, float]:
|
||||||
|
"""Compute mean / p50 / p95 / p99 of a distance list.
|
||||||
|
|
||||||
|
Empty inputs return an empty dict so the report's ``summary``
|
||||||
|
field still round-trips through JSON without special cases.
|
||||||
|
"""
|
||||||
|
if not distances:
|
||||||
|
return {}
|
||||||
|
arr = np.asarray(distances, dtype=float)
|
||||||
|
return {
|
||||||
|
"mean": float(arr.mean()),
|
||||||
|
"p50": float(np.percentile(arr, 50)),
|
||||||
|
"p95": float(np.percentile(arr, 95)),
|
||||||
|
"p99": float(np.percentile(arr, 99)),
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,23 +1,28 @@
|
||||||
"""Tests for :mod:`neuropose.analyzer.pipeline`.
|
"""Tests for :mod:`neuropose.analyzer.pipeline`.
|
||||||
|
|
||||||
This file covers the schema half of the pipeline:
|
Covers both halves of the pipeline:
|
||||||
|
|
||||||
- :class:`AnalysisConfig` parsing, including the discriminated unions
|
- **Schemas** — :class:`AnalysisConfig` parsing (discriminated unions
|
||||||
for the segmentation and analysis stages, and the cross-field
|
for segmentation and analysis stages, cross-field invariants),
|
||||||
invariants enforced at parse time.
|
:class:`AnalysisReport` construction + JSON round-trip (including
|
||||||
- :class:`AnalysisReport` construction + JSON round-trip, including
|
the migration hook on ``schema_version``), and
|
||||||
the migration hook (schema_version defaults to CURRENT_VERSION).
|
:func:`analysis_config_to_dict` JSON-safety.
|
||||||
- :func:`analysis_config_to_dict` JSON-safety.
|
- **Executor** — :func:`run_analysis` dispatches to each analysis kind
|
||||||
|
(dtw / stats / none) with and without segmentation; provenance is
|
||||||
The executor (``run_analysis``) gets its own test module.
|
inherited from the primary input with ``analysis_config``
|
||||||
|
populated; :func:`load_config`, :func:`save_report`, and
|
||||||
|
:func:`load_report` round-trip.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
import yaml
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from neuropose.analyzer.pipeline import (
|
from neuropose.analyzer.pipeline import (
|
||||||
|
|
@ -38,7 +43,12 @@ from neuropose.analyzer.pipeline import (
|
||||||
StatsAnalysis,
|
StatsAnalysis,
|
||||||
StatsResults,
|
StatsResults,
|
||||||
analysis_config_to_dict,
|
analysis_config_to_dict,
|
||||||
|
load_config,
|
||||||
|
load_report,
|
||||||
|
run_analysis,
|
||||||
|
save_report,
|
||||||
)
|
)
|
||||||
|
from neuropose.io import Provenance, VideoPredictions, save_video_predictions
|
||||||
from neuropose.migrations import CURRENT_VERSION
|
from neuropose.migrations import CURRENT_VERSION
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -432,3 +442,440 @@ class TestAnalysisReport:
|
||||||
dumped["mystery_field"] = 1
|
dumped["mystery_field"] = 1
|
||||||
with pytest.raises(ValidationError):
|
with pytest.raises(ValidationError):
|
||||||
AnalysisReport.model_validate(dumped)
|
AnalysisReport.model_validate(dumped)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Executor helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
NUM_JOINTS = 43
|
||||||
|
|
||||||
|
|
||||||
|
def _heel_signal(num_cycles: int, frames_per_cycle: int, amplitude: float = 100.0) -> np.ndarray:
|
||||||
|
"""Clean sinusoid stand-in for a heel's vertical trace."""
|
||||||
|
import math
|
||||||
|
|
||||||
|
total = num_cycles * frames_per_cycle
|
||||||
|
t = np.linspace(0.0, num_cycles * 2.0 * math.pi, total, endpoint=False)
|
||||||
|
return (np.sin(t) * amplitude + amplitude).astype(float)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_predictions(
|
||||||
|
signal: np.ndarray,
|
||||||
|
joint: int,
|
||||||
|
*,
|
||||||
|
axis: int = 1,
|
||||||
|
fps: float = 30.0,
|
||||||
|
provenance: Provenance | None = None,
|
||||||
|
) -> VideoPredictions:
|
||||||
|
"""Build a VideoPredictions whose ``joint``'s ``axis`` follows ``signal``."""
|
||||||
|
frames = {}
|
||||||
|
for i, value in enumerate(signal):
|
||||||
|
poses = [[[0.0, 0.0, 0.0] for _ in range(NUM_JOINTS)]]
|
||||||
|
poses[0][joint][axis] = float(value)
|
||||||
|
frames[f"frame_{i:06d}"] = {
|
||||||
|
"boxes": [[0.0, 0.0, 1.0, 1.0, 0.9]],
|
||||||
|
"poses3d": poses,
|
||||||
|
"poses2d": [[[0.0, 0.0]] * NUM_JOINTS],
|
||||||
|
}
|
||||||
|
return VideoPredictions.model_validate(
|
||||||
|
{
|
||||||
|
"metadata": {
|
||||||
|
"frame_count": len(signal),
|
||||||
|
"fps": fps,
|
||||||
|
"width": 640,
|
||||||
|
"height": 480,
|
||||||
|
},
|
||||||
|
"frames": frames,
|
||||||
|
"provenance": provenance.model_dump() if provenance is not None else None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _write_heel_trial(
|
||||||
|
tmp_path: Path,
|
||||||
|
filename: str,
|
||||||
|
*,
|
||||||
|
joint: int,
|
||||||
|
num_cycles: int = 4,
|
||||||
|
frames_per_cycle: int = 30,
|
||||||
|
amplitude: float = 100.0,
|
||||||
|
provenance: Provenance | None = None,
|
||||||
|
) -> Path:
|
||||||
|
"""Write a heel-trace VideoPredictions JSON and return its path."""
|
||||||
|
signal = _heel_signal(num_cycles, frames_per_cycle, amplitude=amplitude)
|
||||||
|
preds = _build_predictions(signal, joint=joint, provenance=provenance)
|
||||||
|
path = tmp_path / filename
|
||||||
|
save_video_predictions(path, preds)
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
def _fake_provenance(sha: str = "a" * 64) -> Provenance:
|
||||||
|
return Provenance(
|
||||||
|
model_sha256=sha,
|
||||||
|
model_filename="fake_model.tar.gz",
|
||||||
|
tensorflow_version="2.18.0",
|
||||||
|
numpy_version="1.26.0",
|
||||||
|
neuropose_version="0.0.0",
|
||||||
|
python_version="3.11.0",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Executor: run_analysis
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestRunAnalysisDtwFullTrial:
|
||||||
|
def test_dtw_all_unsegmented_yields_one_distance(self, tmp_path: Path) -> None:
|
||||||
|
from neuropose.analyzer.segment import JOINT_INDEX
|
||||||
|
|
||||||
|
primary = _write_heel_trial(tmp_path, "a.json", joint=JOINT_INDEX["rhee"])
|
||||||
|
reference = _write_heel_trial(tmp_path, "b.json", joint=JOINT_INDEX["rhee"])
|
||||||
|
report_path = tmp_path / "report.json"
|
||||||
|
config = AnalysisConfig.model_validate(
|
||||||
|
{
|
||||||
|
"inputs": {"primary": str(primary), "reference": str(reference)},
|
||||||
|
"analysis": {"kind": "dtw", "method": "dtw_all"},
|
||||||
|
"output": {"report": str(report_path)},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
report = run_analysis(config)
|
||||||
|
assert isinstance(report.results, DtwResults)
|
||||||
|
assert report.results.segment_labels == ["full_trial"]
|
||||||
|
assert len(report.results.distances) == 1
|
||||||
|
# Identical inputs → distance 0.
|
||||||
|
assert report.results.distances[0] == pytest.approx(0.0, abs=1e-9)
|
||||||
|
|
||||||
|
def test_dtw_all_different_trials_positive_distance(self, tmp_path: Path) -> None:
|
||||||
|
from neuropose.analyzer.segment import JOINT_INDEX
|
||||||
|
|
||||||
|
primary = _write_heel_trial(tmp_path, "a.json", joint=JOINT_INDEX["rhee"], amplitude=100.0)
|
||||||
|
reference = _write_heel_trial(
|
||||||
|
tmp_path, "b.json", joint=JOINT_INDEX["rhee"], amplitude=200.0
|
||||||
|
)
|
||||||
|
config = AnalysisConfig.model_validate(
|
||||||
|
{
|
||||||
|
"inputs": {"primary": str(primary), "reference": str(reference)},
|
||||||
|
"analysis": {"kind": "dtw", "method": "dtw_all"},
|
||||||
|
"output": {"report": str(tmp_path / "r.json")},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
report = run_analysis(config)
|
||||||
|
assert isinstance(report.results, DtwResults)
|
||||||
|
assert report.results.distances[0] > 0.0
|
||||||
|
assert "mean" in report.results.summary
|
||||||
|
|
||||||
|
|
||||||
|
class TestRunAnalysisDtwSegmented:
|
||||||
|
def test_dtw_with_gait_cycles_produces_per_segment_distances(self, tmp_path: Path) -> None:
|
||||||
|
from neuropose.analyzer.segment import JOINT_INDEX
|
||||||
|
|
||||||
|
primary = _write_heel_trial(tmp_path, "a.json", joint=JOINT_INDEX["rhee"], num_cycles=4)
|
||||||
|
reference = _write_heel_trial(tmp_path, "b.json", joint=JOINT_INDEX["rhee"], num_cycles=4)
|
||||||
|
config = AnalysisConfig.model_validate(
|
||||||
|
{
|
||||||
|
"inputs": {"primary": str(primary), "reference": str(reference)},
|
||||||
|
"segmentation": {"kind": "gait_cycles", "joint": "rhee"},
|
||||||
|
"analysis": {"kind": "dtw", "method": "dtw_all"},
|
||||||
|
"output": {"report": str(tmp_path / "r.json")},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
report = run_analysis(config)
|
||||||
|
assert isinstance(report.results, DtwResults)
|
||||||
|
# 4 cycles detected on both → 4 paired distances.
|
||||||
|
assert len(report.results.distances) == 4
|
||||||
|
assert all(label.startswith("rhee_cycles[") for label in report.results.segment_labels)
|
||||||
|
|
||||||
|
def test_dtw_bilateral_produces_distances_per_side(self, tmp_path: Path) -> None:
|
||||||
|
from neuropose.analyzer.segment import JOINT_INDEX
|
||||||
|
|
||||||
|
# Put the heel trace on both lhee and rhee.
|
||||||
|
rng_signal = _heel_signal(num_cycles=3, frames_per_cycle=30)
|
||||||
|
frames = {}
|
||||||
|
for i, value in enumerate(rng_signal):
|
||||||
|
poses = [[[0.0, 0.0, 0.0] for _ in range(NUM_JOINTS)]]
|
||||||
|
poses[0][JOINT_INDEX["lhee"]][1] = float(value)
|
||||||
|
poses[0][JOINT_INDEX["rhee"]][1] = float(value)
|
||||||
|
frames[f"frame_{i:06d}"] = {
|
||||||
|
"boxes": [[0.0, 0.0, 1.0, 1.0, 0.9]],
|
||||||
|
"poses3d": poses,
|
||||||
|
"poses2d": [[[0.0, 0.0]] * NUM_JOINTS],
|
||||||
|
}
|
||||||
|
preds = VideoPredictions.model_validate(
|
||||||
|
{
|
||||||
|
"metadata": {
|
||||||
|
"frame_count": len(rng_signal),
|
||||||
|
"fps": 30.0,
|
||||||
|
"width": 640,
|
||||||
|
"height": 480,
|
||||||
|
},
|
||||||
|
"frames": frames,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
primary = tmp_path / "a.json"
|
||||||
|
reference = tmp_path / "b.json"
|
||||||
|
save_video_predictions(primary, preds)
|
||||||
|
save_video_predictions(reference, preds)
|
||||||
|
config = AnalysisConfig.model_validate(
|
||||||
|
{
|
||||||
|
"inputs": {"primary": str(primary), "reference": str(reference)},
|
||||||
|
"segmentation": {"kind": "gait_cycles_bilateral"},
|
||||||
|
"analysis": {"kind": "dtw", "method": "dtw_all"},
|
||||||
|
"output": {"report": str(tmp_path / "r.json")},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
report = run_analysis(config)
|
||||||
|
assert isinstance(report.results, DtwResults)
|
||||||
|
# 3 cycles * 2 sides.
|
||||||
|
assert len(report.results.distances) == 6
|
||||||
|
left = [lbl for lbl in report.results.segment_labels if lbl.startswith("left_heel_strikes")]
|
||||||
|
right = [
|
||||||
|
lbl for lbl in report.results.segment_labels if lbl.startswith("right_heel_strikes")
|
||||||
|
]
|
||||||
|
assert len(left) == 3
|
||||||
|
assert len(right) == 3
|
||||||
|
# Identical primary and reference → all distances zero.
|
||||||
|
for d in report.results.distances:
|
||||||
|
assert d == pytest.approx(0.0, abs=1e-9)
|
||||||
|
|
||||||
|
def test_dtw_per_joint_populates_per_joint_distances(self, tmp_path: Path) -> None:
|
||||||
|
from neuropose.analyzer.segment import JOINT_INDEX
|
||||||
|
|
||||||
|
primary = _write_heel_trial(tmp_path, "a.json", joint=JOINT_INDEX["rhee"])
|
||||||
|
reference = _write_heel_trial(tmp_path, "b.json", joint=JOINT_INDEX["rhee"])
|
||||||
|
config = AnalysisConfig.model_validate(
|
||||||
|
{
|
||||||
|
"inputs": {"primary": str(primary), "reference": str(reference)},
|
||||||
|
"analysis": {"kind": "dtw", "method": "dtw_per_joint"},
|
||||||
|
"output": {"report": str(tmp_path / "r.json")},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
report = run_analysis(config)
|
||||||
|
assert isinstance(report.results, DtwResults)
|
||||||
|
assert report.results.per_joint_distances is not None
|
||||||
|
assert len(report.results.per_joint_distances) == 1 # unsegmented → one pair
|
||||||
|
assert len(report.results.per_joint_distances[0]) == NUM_JOINTS
|
||||||
|
|
||||||
|
|
||||||
|
class TestRunAnalysisStats:
|
||||||
|
def test_stats_unsegmented_single_block(self, tmp_path: Path) -> None:
|
||||||
|
from neuropose.analyzer.segment import JOINT_INDEX
|
||||||
|
|
||||||
|
primary = _write_heel_trial(tmp_path, "a.json", joint=JOINT_INDEX["rhee"])
|
||||||
|
config = AnalysisConfig.model_validate(
|
||||||
|
{
|
||||||
|
"inputs": {"primary": str(primary)},
|
||||||
|
"analysis": {
|
||||||
|
"kind": "stats",
|
||||||
|
"extractor": {
|
||||||
|
"kind": "joint_axis",
|
||||||
|
"joint": JOINT_INDEX["rhee"],
|
||||||
|
"axis": 1,
|
||||||
|
"invert": False,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"output": {"report": str(tmp_path / "r.json")},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
report = run_analysis(config)
|
||||||
|
assert isinstance(report.results, StatsResults)
|
||||||
|
assert report.results.segment_labels == ["full_trial"]
|
||||||
|
assert len(report.results.statistics) == 1
|
||||||
|
stat = report.results.statistics[0]
|
||||||
|
assert isinstance(stat, FeatureSummary)
|
||||||
|
assert stat.max > stat.min # Signal oscillates.
|
||||||
|
|
||||||
|
def test_stats_with_segmentation_emits_per_segment(self, tmp_path: Path) -> None:
|
||||||
|
from neuropose.analyzer.segment import JOINT_INDEX
|
||||||
|
|
||||||
|
primary = _write_heel_trial(tmp_path, "a.json", joint=JOINT_INDEX["rhee"], num_cycles=3)
|
||||||
|
config = AnalysisConfig.model_validate(
|
||||||
|
{
|
||||||
|
"inputs": {"primary": str(primary)},
|
||||||
|
"segmentation": {"kind": "gait_cycles", "joint": "rhee"},
|
||||||
|
"analysis": {
|
||||||
|
"kind": "stats",
|
||||||
|
"extractor": {
|
||||||
|
"kind": "joint_axis",
|
||||||
|
"joint": JOINT_INDEX["rhee"],
|
||||||
|
"axis": 1,
|
||||||
|
"invert": False,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"output": {"report": str(tmp_path / "r.json")},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
report = run_analysis(config)
|
||||||
|
assert isinstance(report.results, StatsResults)
|
||||||
|
assert len(report.results.statistics) == 3
|
||||||
|
|
||||||
|
|
||||||
|
class TestRunAnalysisNone:
|
||||||
|
def test_none_analysis_returns_no_results(self, tmp_path: Path) -> None:
|
||||||
|
from neuropose.analyzer.segment import JOINT_INDEX
|
||||||
|
|
||||||
|
primary = _write_heel_trial(tmp_path, "a.json", joint=JOINT_INDEX["rhee"])
|
||||||
|
config = AnalysisConfig.model_validate(
|
||||||
|
{
|
||||||
|
"inputs": {"primary": str(primary)},
|
||||||
|
"analysis": {"kind": "none"},
|
||||||
|
"output": {"report": str(tmp_path / "r.json")},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
report = run_analysis(config)
|
||||||
|
assert isinstance(report.results, NoResults)
|
||||||
|
|
||||||
|
def test_none_with_segmentation_still_emits_segmentations(self, tmp_path: Path) -> None:
|
||||||
|
from neuropose.analyzer.segment import JOINT_INDEX
|
||||||
|
|
||||||
|
primary = _write_heel_trial(tmp_path, "a.json", joint=JOINT_INDEX["rhee"])
|
||||||
|
config = AnalysisConfig.model_validate(
|
||||||
|
{
|
||||||
|
"inputs": {"primary": str(primary)},
|
||||||
|
"segmentation": {"kind": "gait_cycles", "joint": "rhee"},
|
||||||
|
"analysis": {"kind": "none"},
|
||||||
|
"output": {"report": str(tmp_path / "r.json")},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
report = run_analysis(config)
|
||||||
|
assert isinstance(report.results, NoResults)
|
||||||
|
assert "rhee_cycles" in report.segmentations
|
||||||
|
assert len(report.segmentations["rhee_cycles"].segments) > 0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Provenance inheritance
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestRunAnalysisProvenance:
|
||||||
|
def test_inherits_primary_provenance_and_stamps_config(self, tmp_path: Path) -> None:
|
||||||
|
from neuropose.analyzer.segment import JOINT_INDEX
|
||||||
|
|
||||||
|
provenance = _fake_provenance()
|
||||||
|
primary = _write_heel_trial(
|
||||||
|
tmp_path, "a.json", joint=JOINT_INDEX["rhee"], provenance=provenance
|
||||||
|
)
|
||||||
|
reference = _write_heel_trial(tmp_path, "b.json", joint=JOINT_INDEX["rhee"])
|
||||||
|
config = AnalysisConfig.model_validate(
|
||||||
|
{
|
||||||
|
"inputs": {"primary": str(primary), "reference": str(reference)},
|
||||||
|
"analysis": {"kind": "dtw", "method": "dtw_all"},
|
||||||
|
"output": {"report": str(tmp_path / "r.json")},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
report = run_analysis(config)
|
||||||
|
assert report.provenance is not None
|
||||||
|
# Model SHA inherited from primary.
|
||||||
|
assert report.provenance.model_sha256 == provenance.model_sha256
|
||||||
|
# analysis_config populated with the serialised config.
|
||||||
|
assert report.provenance.analysis_config is not None
|
||||||
|
assert report.provenance.analysis_config["config_version"] == 1
|
||||||
|
|
||||||
|
def test_no_primary_provenance_yields_none_report_provenance(self, tmp_path: Path) -> None:
|
||||||
|
from neuropose.analyzer.segment import JOINT_INDEX
|
||||||
|
|
||||||
|
primary = _write_heel_trial(tmp_path, "a.json", joint=JOINT_INDEX["rhee"])
|
||||||
|
reference = _write_heel_trial(tmp_path, "b.json", joint=JOINT_INDEX["rhee"])
|
||||||
|
config = AnalysisConfig.model_validate(
|
||||||
|
{
|
||||||
|
"inputs": {"primary": str(primary), "reference": str(reference)},
|
||||||
|
"analysis": {"kind": "dtw", "method": "dtw_all"},
|
||||||
|
"output": {"report": str(tmp_path / "r.json")},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
report = run_analysis(config)
|
||||||
|
assert report.provenance is None
|
||||||
|
|
||||||
|
def test_input_summaries_track_paths_and_metadata(self, tmp_path: Path) -> None:
|
||||||
|
from neuropose.analyzer.segment import JOINT_INDEX
|
||||||
|
|
||||||
|
primary = _write_heel_trial(tmp_path, "a.json", joint=JOINT_INDEX["rhee"], num_cycles=5)
|
||||||
|
reference = _write_heel_trial(tmp_path, "b.json", joint=JOINT_INDEX["rhee"], num_cycles=3)
|
||||||
|
config = AnalysisConfig.model_validate(
|
||||||
|
{
|
||||||
|
"inputs": {"primary": str(primary), "reference": str(reference)},
|
||||||
|
"analysis": {"kind": "dtw", "method": "dtw_all"},
|
||||||
|
"output": {"report": str(tmp_path / "r.json")},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
report = run_analysis(config)
|
||||||
|
assert report.primary.path == primary
|
||||||
|
assert report.primary.frame_count == 5 * 30
|
||||||
|
assert report.reference is not None
|
||||||
|
assert report.reference.frame_count == 3 * 30
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# load_config / save_report / load_report
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestLoadSave:
|
||||||
|
def test_load_config_parses_yaml(self, tmp_path: Path) -> None:
|
||||||
|
config_dict = {
|
||||||
|
"inputs": {
|
||||||
|
"primary": str(tmp_path / "a.json"),
|
||||||
|
"reference": str(tmp_path / "b.json"),
|
||||||
|
},
|
||||||
|
"analysis": {"kind": "dtw", "method": "dtw_all"},
|
||||||
|
"output": {"report": str(tmp_path / "r.json")},
|
||||||
|
}
|
||||||
|
yaml_path = tmp_path / "exp.yaml"
|
||||||
|
yaml_path.write_text(yaml.safe_dump(config_dict))
|
||||||
|
loaded = load_config(yaml_path)
|
||||||
|
assert isinstance(loaded.analysis, DtwAnalysis)
|
||||||
|
|
||||||
|
def test_load_config_empty_file_fails_cleanly(self, tmp_path: Path) -> None:
|
||||||
|
empty = tmp_path / "empty.yaml"
|
||||||
|
empty.write_text("")
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
load_config(empty)
|
||||||
|
|
||||||
|
def test_load_config_rejects_malformed_yaml(self, tmp_path: Path) -> None:
|
||||||
|
bad = tmp_path / "bad.yaml"
|
||||||
|
# Unclosed flow-style mapping — yaml.safe_load raises here.
|
||||||
|
bad.write_text("inputs: {primary: foo\n")
|
||||||
|
with pytest.raises(yaml.YAMLError):
|
||||||
|
load_config(bad)
|
||||||
|
|
||||||
|
def test_save_report_round_trip(self, tmp_path: Path) -> None:
|
||||||
|
from neuropose.analyzer.segment import JOINT_INDEX
|
||||||
|
|
||||||
|
primary = _write_heel_trial(tmp_path, "a.json", joint=JOINT_INDEX["rhee"])
|
||||||
|
reference = _write_heel_trial(tmp_path, "b.json", joint=JOINT_INDEX["rhee"])
|
||||||
|
config = AnalysisConfig.model_validate(
|
||||||
|
{
|
||||||
|
"inputs": {"primary": str(primary), "reference": str(reference)},
|
||||||
|
"analysis": {"kind": "dtw", "method": "dtw_all"},
|
||||||
|
"output": {"report": str(tmp_path / "report.json")},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
report = run_analysis(config)
|
||||||
|
report_path = tmp_path / "report.json"
|
||||||
|
save_report(report_path, report)
|
||||||
|
assert report_path.exists()
|
||||||
|
|
||||||
|
restored = load_report(report_path)
|
||||||
|
assert restored == report
|
||||||
|
|
||||||
|
def test_save_report_is_atomic(self, tmp_path: Path) -> None:
|
||||||
|
"""The saver writes via a sibling .tmp path and renames."""
|
||||||
|
report = _make_report(tmp_path)
|
||||||
|
report_path = tmp_path / "subdir" / "report.json"
|
||||||
|
save_report(report_path, report)
|
||||||
|
# Parent directory was created.
|
||||||
|
assert report_path.exists()
|
||||||
|
# No .tmp sibling left behind.
|
||||||
|
assert not (report_path.with_suffix(report_path.suffix + ".tmp")).exists()
|
||||||
|
|
||||||
|
def test_load_report_rejects_future_schema(self, tmp_path: Path) -> None:
|
||||||
|
"""Future schema_version surfaces as a migration error."""
|
||||||
|
from neuropose.migrations import FutureSchemaError
|
||||||
|
|
||||||
|
future = {"schema_version": CURRENT_VERSION + 1}
|
||||||
|
path = tmp_path / "future.json"
|
||||||
|
path.write_text(json.dumps(future))
|
||||||
|
with pytest.raises(FutureSchemaError):
|
||||||
|
load_report(path)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue