add AnalysisConfig and AnalysisReport schemas
neuropose.analyzer.pipeline ships two top-level pydantic schemas: AnalysisConfig — what a user writes in YAML. Inputs (primary plus optional reference), preprocessing (person_index, room to grow), optional segmentation as a discriminated union of gait_cycles, gait_cycles_bilateral, and extractor, and a required analysis stage as a discriminated union of dtw, stats, none. AnalysisReport — runtime output with config, Provenance envelope, per-input summaries, produced segmentations, and a results payload whose shape mirrors the stage (DtwResults, StatsResults, NoResults). schema_version defaults to CURRENT_VERSION. Cross-field invariants enforced at parse time via model_validator: method='dtw_relation' requires joint_i/joint_j and refuses representation='angles'; representation='angles' requires non-empty angle_triplets; analysis.kind='dtw' requires inputs.reference; analysis.kind='stats' refuses a reference. Typos fail in milliseconds instead of after a multi-minute predictions load. neuropose.migrations gains a third registry for AnalysisReport (_ANALYSIS_REPORT_MIGRATIONS + register_analysis_report_migration + migrate_analysis_report), ready for future schema changes. No v1→v2 migration is registered because AnalysisReport first shipped at v2. Execution, CLI wiring, and example configs land in follow-up commits.
This commit is contained in:
parent
87461a17d0
commit
979beb1078
24
CHANGELOG.md
24
CHANGELOG.md
|
|
@ -257,6 +257,30 @@ be split into per-release sections once tagging begins.
|
||||||
two-joint displacement DTW; users who prefer a unified API can
|
two-joint displacement DTW; users who prefer a unified API can
|
||||||
express the same computation via `dtw_all` with an appropriate
|
express the same computation via `dtw_all` with an appropriate
|
||||||
pair of angle triplets or run `dtw_relation` directly.
|
pair of angle triplets or run `dtw_relation` directly.
|
||||||
|
- **`neuropose.analyzer.pipeline`** (schemas) — declarative
|
||||||
|
analysis-pipeline configuration and output artifact, parseable from
|
||||||
|
YAML or JSON via pydantic. `AnalysisConfig` captures a full
|
||||||
|
experiment: inputs (primary + optional reference predictions
|
||||||
|
files), preprocessing (person index, with room to grow),
|
||||||
|
optional segmentation (`gait_cycles` / `gait_cycles_bilateral` /
|
||||||
|
`extractor` discriminated union), and a required analysis stage
|
||||||
|
(`dtw` / `stats` / `none` discriminated union). `AnalysisReport`
|
||||||
|
is the runtime output: carries the originating config, a
|
||||||
|
`Provenance` envelope with `analysis_config` populated, per-input
|
||||||
|
summaries, produced segmentations, and an analysis-result payload
|
||||||
|
that mirrors the stage choice (`DtwResults`, `StatsResults`, or
|
||||||
|
`NoResults`). Cross-field invariants — `method="dtw_relation"`
|
||||||
|
requires `joint_i`/`joint_j`, `representation="angles"` requires
|
||||||
|
non-empty `angle_triplets`, `analysis.kind="dtw"` requires
|
||||||
|
`inputs.reference`, `analysis.kind="stats"` refuses a reference —
|
||||||
|
are enforced at parse time via `model_validator` so typos fail in
|
||||||
|
milliseconds instead of after a multi-minute predictions load.
|
||||||
|
`AnalysisReport` carries a `schema_version` field defaulting to
|
||||||
|
`CURRENT_VERSION = 2`, with a new
|
||||||
|
`register_analysis_report_migration` decorator and
|
||||||
|
`migrate_analysis_report` driver in `neuropose.migrations` ready
|
||||||
|
for future schema changes. Pipeline execution lands in a
|
||||||
|
follow-up commit.
|
||||||
- **`neuropose.analyzer.segment.segment_gait_cycles`** and
|
- **`neuropose.analyzer.segment.segment_gait_cycles`** and
|
||||||
**`segment_gait_cycles_bilateral`** — clinical convenience
|
**`segment_gait_cycles_bilateral`** — clinical convenience
|
||||||
wrappers over `segment_predictions` that pre-fill a `joint_axis`
|
wrappers over `segment_predictions` that pre-fill a `joint_axis`
|
||||||
|
|
|
||||||
|
|
@ -48,6 +48,28 @@ from neuropose.analyzer.features import (
|
||||||
predictions_to_numpy,
|
predictions_to_numpy,
|
||||||
procrustes_align,
|
procrustes_align,
|
||||||
)
|
)
|
||||||
|
from neuropose.analyzer.pipeline import (
|
||||||
|
AnalysisConfig,
|
||||||
|
AnalysisReport,
|
||||||
|
AnalysisResults,
|
||||||
|
AnalysisStage,
|
||||||
|
DtwAnalysis,
|
||||||
|
DtwResults,
|
||||||
|
ExtractorSegmentation,
|
||||||
|
FeatureSummary,
|
||||||
|
GaitCyclesBilateralSegmentation,
|
||||||
|
GaitCyclesSegmentation,
|
||||||
|
InputsConfig,
|
||||||
|
InputSummary,
|
||||||
|
NoAnalysis,
|
||||||
|
NoResults,
|
||||||
|
OutputConfig,
|
||||||
|
PreprocessingConfig,
|
||||||
|
SegmentationStage,
|
||||||
|
StatsAnalysis,
|
||||||
|
StatsResults,
|
||||||
|
analysis_config_to_dict,
|
||||||
|
)
|
||||||
from neuropose.analyzer.segment import (
|
from neuropose.analyzer.segment import (
|
||||||
JOINT_INDEX,
|
JOINT_INDEX,
|
||||||
JOINT_NAMES,
|
JOINT_NAMES,
|
||||||
|
|
@ -70,12 +92,32 @@ __all__ = [
|
||||||
"JOINT_NAMES",
|
"JOINT_NAMES",
|
||||||
"AlignMode",
|
"AlignMode",
|
||||||
"AlignmentDiagnostics",
|
"AlignmentDiagnostics",
|
||||||
|
"AnalysisConfig",
|
||||||
|
"AnalysisReport",
|
||||||
|
"AnalysisResults",
|
||||||
|
"AnalysisStage",
|
||||||
"AxisLetter",
|
"AxisLetter",
|
||||||
"DTWResult",
|
"DTWResult",
|
||||||
|
"DtwAnalysis",
|
||||||
|
"DtwResults",
|
||||||
|
"ExtractorSegmentation",
|
||||||
"FeatureStatistics",
|
"FeatureStatistics",
|
||||||
|
"FeatureSummary",
|
||||||
|
"GaitCyclesBilateralSegmentation",
|
||||||
|
"GaitCyclesSegmentation",
|
||||||
|
"InputSummary",
|
||||||
|
"InputsConfig",
|
||||||
"NanPolicy",
|
"NanPolicy",
|
||||||
|
"NoAnalysis",
|
||||||
|
"NoResults",
|
||||||
|
"OutputConfig",
|
||||||
|
"PreprocessingConfig",
|
||||||
"ProcrustesMode",
|
"ProcrustesMode",
|
||||||
"Representation",
|
"Representation",
|
||||||
|
"SegmentationStage",
|
||||||
|
"StatsAnalysis",
|
||||||
|
"StatsResults",
|
||||||
|
"analysis_config_to_dict",
|
||||||
"dtw_all",
|
"dtw_all",
|
||||||
"dtw_per_joint",
|
"dtw_per_joint",
|
||||||
"dtw_relation",
|
"dtw_relation",
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,478 @@
|
||||||
|
"""YAML-configurable analysis pipeline.
|
||||||
|
|
||||||
|
This module unifies the analyzer's individual primitives (Procrustes
|
||||||
|
alignment, gait-cycle segmentation, DTW on coords or angles, feature
|
||||||
|
statistics) behind a single declarative configuration object so an
|
||||||
|
experiment can be reproduced from one file that lives in a git
|
||||||
|
repository, carries through to the :class:`~neuropose.io.Provenance`
|
||||||
|
envelope on the output artifact, and can be cited unambiguously in
|
||||||
|
accompanying papers.
|
||||||
|
|
||||||
|
Two top-level schemas live here:
|
||||||
|
|
||||||
|
- :class:`AnalysisConfig` — what the user writes in YAML. Describes
|
||||||
|
the full pipeline: inputs, preprocessing, optional segmentation,
|
||||||
|
required analysis stage, and output path.
|
||||||
|
- :class:`AnalysisReport` — what :func:`run_analysis` emits. Carries
|
||||||
|
the config, a :class:`~neuropose.io.Provenance` envelope (with the
|
||||||
|
config serialised into :attr:`~neuropose.io.Provenance.analysis_config`),
|
||||||
|
per-input summaries, segmentation results, and the analysis results
|
||||||
|
themselves.
|
||||||
|
|
||||||
|
Both schemas parse from (and serialise to) JSON via pydantic; the
|
||||||
|
config additionally parses from YAML via :func:`load_config`. Cross-field
|
||||||
|
invariants (for example, ``method="dtw_relation"`` requires ``joint_i``
|
||||||
|
and ``joint_j``) are enforced at parse time so typo-laden configs fail
|
||||||
|
fast rather than after an expensive multi-minute load.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Annotated, Any, Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||||
|
|
||||||
|
from neuropose.analyzer.dtw import AlignMode, NanPolicy, Representation
|
||||||
|
from neuropose.analyzer.segment import AxisLetter
|
||||||
|
from neuropose.io import ExtractorSpec, Provenance, Segmentation
|
||||||
|
from neuropose.migrations import CURRENT_VERSION
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Inputs
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class InputsConfig(BaseModel):
|
||||||
|
"""Predictions files consumed by the pipeline.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
primary
|
||||||
|
Path to a :class:`~neuropose.io.VideoPredictions` JSON file.
|
||||||
|
Always required.
|
||||||
|
reference
|
||||||
|
Optional second predictions file. When provided,
|
||||||
|
:class:`DtwAnalysis` runs comparative DTW between primary and
|
||||||
|
reference; when absent, analysis stages that require a
|
||||||
|
reference (i.e. DTW) raise a validation error at parse time.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="forbid", frozen=True)
|
||||||
|
|
||||||
|
primary: Path
|
||||||
|
reference: Path | None = None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Preprocessing
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class PreprocessingConfig(BaseModel):
|
||||||
|
"""Per-input preprocessing applied before segmentation and analysis.
|
||||||
|
|
||||||
|
Minimal today — just picks which detected person to extract from
|
||||||
|
each frame. Left as a named stage so future extensions (coordinate
|
||||||
|
normalisation, smoothing) can land here without reshuffling the
|
||||||
|
config shape.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
person_index
|
||||||
|
Which detected person to extract per frame. Defaults to ``0``
|
||||||
|
(the first detected person), matching the single-subject
|
||||||
|
clinical case.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="forbid", frozen=True)
|
||||||
|
|
||||||
|
person_index: int = Field(default=0, ge=0)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Segmentation stage
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class GaitCyclesSegmentation(BaseModel):
|
||||||
|
"""Single-heel gait-cycle segmentation via peak detection.
|
||||||
|
|
||||||
|
Produces one :class:`~neuropose.io.Segmentation` keyed under the
|
||||||
|
joint name (e.g. ``"rhee_cycles"``). See
|
||||||
|
:func:`~neuropose.analyzer.segment.segment_gait_cycles` for the
|
||||||
|
underlying implementation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="forbid", frozen=True)
|
||||||
|
|
||||||
|
kind: Literal["gait_cycles"]
|
||||||
|
joint: str = "rhee"
|
||||||
|
axis: AxisLetter = "y"
|
||||||
|
invert: bool = False
|
||||||
|
min_cycle_seconds: float = Field(default=0.4, gt=0.0)
|
||||||
|
min_prominence: float | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class GaitCyclesBilateralSegmentation(BaseModel):
|
||||||
|
"""Bilateral (both heels) gait-cycle segmentation.
|
||||||
|
|
||||||
|
Produces two :class:`~neuropose.io.Segmentation` objects keyed as
|
||||||
|
``"left_heel_strikes"`` and ``"right_heel_strikes"``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="forbid", frozen=True)
|
||||||
|
|
||||||
|
kind: Literal["gait_cycles_bilateral"]
|
||||||
|
axis: AxisLetter = "y"
|
||||||
|
invert: bool = False
|
||||||
|
min_cycle_seconds: float = Field(default=0.4, gt=0.0)
|
||||||
|
min_prominence: float | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ExtractorSegmentation(BaseModel):
|
||||||
|
"""Generic extractor-driven segmentation.
|
||||||
|
|
||||||
|
Wraps :func:`~neuropose.analyzer.segment.segment_predictions` with
|
||||||
|
a caller-supplied :class:`~neuropose.io.ExtractorSpec`. Use this
|
||||||
|
when the signal of interest is not the vertical heel trace — e.g.
|
||||||
|
wrist-hip distance for a reach-and-grasp task, or elbow flexion
|
||||||
|
angle for a range-of-motion trial.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="forbid", frozen=True)
|
||||||
|
|
||||||
|
kind: Literal["extractor"]
|
||||||
|
extractor: ExtractorSpec
|
||||||
|
label: str = Field(
|
||||||
|
default="segmentation",
|
||||||
|
description="Key under which the resulting Segmentation is stored.",
|
||||||
|
)
|
||||||
|
person_index: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"Overrides preprocessing.person_index for this stage. "
|
||||||
|
"None defers to the global preprocessing setting."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
min_distance_seconds: float | None = Field(default=None, ge=0.0)
|
||||||
|
min_prominence: float | None = None
|
||||||
|
min_height: float | None = None
|
||||||
|
pad_seconds: float = Field(default=0.0, ge=0.0)
|
||||||
|
|
||||||
|
|
||||||
|
SegmentationStage = Annotated[
|
||||||
|
GaitCyclesSegmentation | GaitCyclesBilateralSegmentation | ExtractorSegmentation,
|
||||||
|
Field(discriminator="kind"),
|
||||||
|
]
|
||||||
|
"""Discriminated-union alias for the three segmentation variants.
|
||||||
|
|
||||||
|
Pydantic dispatches on the ``kind`` field. A config without a
|
||||||
|
``segmentation`` key at all skips this stage entirely
|
||||||
|
(see :class:`AnalysisConfig.segmentation`).
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Analysis stage
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class DtwAnalysis(BaseModel):
|
||||||
|
"""Dynamic Time Warping between the primary and reference inputs.
|
||||||
|
|
||||||
|
Dispatches to one of :func:`~neuropose.analyzer.dtw.dtw_all`,
|
||||||
|
:func:`~neuropose.analyzer.dtw.dtw_per_joint`, or
|
||||||
|
:func:`~neuropose.analyzer.dtw.dtw_relation` per the ``method``
|
||||||
|
field. Cross-field invariants — ``method="dtw_relation"`` requires
|
||||||
|
``joint_i`` and ``joint_j``, ``representation="angles"`` requires
|
||||||
|
a non-empty ``angle_triplets`` — are enforced at parse time.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="forbid", frozen=True)
|
||||||
|
|
||||||
|
kind: Literal["dtw"]
|
||||||
|
method: Literal["dtw_all", "dtw_per_joint", "dtw_relation"] = "dtw_all"
|
||||||
|
align: AlignMode = "none"
|
||||||
|
representation: Representation = "coords"
|
||||||
|
angle_triplets: list[tuple[int, int, int]] | None = None
|
||||||
|
joint_i: int | None = None
|
||||||
|
joint_j: int | None = None
|
||||||
|
nan_policy: NanPolicy = "propagate"
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def _check_method_fields(self) -> DtwAnalysis:
|
||||||
|
if self.method == "dtw_relation":
|
||||||
|
if self.joint_i is None or self.joint_j is None:
|
||||||
|
raise ValueError("method='dtw_relation' requires joint_i and joint_j")
|
||||||
|
if self.representation != "coords":
|
||||||
|
raise ValueError(
|
||||||
|
"method='dtw_relation' only supports representation='coords' "
|
||||||
|
"(a two-joint displacement is not a joint-angle signal)"
|
||||||
|
)
|
||||||
|
if self.representation == "angles":
|
||||||
|
if not self.angle_triplets:
|
||||||
|
raise ValueError("representation='angles' requires a non-empty angle_triplets list")
|
||||||
|
if self.method == "dtw_relation":
|
||||||
|
# Guarded by the earlier branch, but make the invariant explicit.
|
||||||
|
raise ValueError(
|
||||||
|
"representation='angles' is incompatible with method='dtw_relation'"
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class StatsAnalysis(BaseModel):
|
||||||
|
"""Summary statistics over a scalar signal extracted from the primary input.
|
||||||
|
|
||||||
|
Runs :func:`~neuropose.analyzer.segment.extract_signal` with the
|
||||||
|
caller-supplied :class:`~neuropose.io.ExtractorSpec`, then
|
||||||
|
computes :func:`~neuropose.analyzer.features.extract_feature_statistics`
|
||||||
|
on each segment (or on the full trial if no segmentation stage
|
||||||
|
runs).
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="forbid", frozen=True)
|
||||||
|
|
||||||
|
kind: Literal["stats"]
|
||||||
|
extractor: ExtractorSpec
|
||||||
|
|
||||||
|
|
||||||
|
class NoAnalysis(BaseModel):
|
||||||
|
"""Terminal stage placeholder; produces no per-segment results.
|
||||||
|
|
||||||
|
Useful when the pipeline's goal is just to segment the input and
|
||||||
|
persist the :class:`~neuropose.io.Segmentation` plus an
|
||||||
|
:class:`AnalysisReport` with provenance — the ``none`` analysis
|
||||||
|
kind makes that explicit rather than requiring the absence of the
|
||||||
|
stage.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="forbid", frozen=True)
|
||||||
|
|
||||||
|
kind: Literal["none"]
|
||||||
|
|
||||||
|
|
||||||
|
AnalysisStage = Annotated[
|
||||||
|
DtwAnalysis | StatsAnalysis | NoAnalysis,
|
||||||
|
Field(discriminator="kind"),
|
||||||
|
]
|
||||||
|
"""Discriminated-union alias for the three analysis variants.
|
||||||
|
|
||||||
|
Pydantic dispatches on ``kind``. One of the three must always be
|
||||||
|
present in a valid config.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Output
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class OutputConfig(BaseModel):
|
||||||
|
"""Where :func:`run_analysis` should write its :class:`AnalysisReport`.
|
||||||
|
|
||||||
|
Kept as a sub-object rather than a bare path so downstream
|
||||||
|
extensions (figure paths, supplementary distance-matrix files)
|
||||||
|
can land here without changing the config's top-level shape.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="forbid", frozen=True)
|
||||||
|
|
||||||
|
report: Path
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Top-level config
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class AnalysisConfig(BaseModel):
|
||||||
|
"""Declarative description of a full analysis run.
|
||||||
|
|
||||||
|
Parsed from YAML (via :func:`load_config`) or JSON (via
|
||||||
|
:meth:`pydantic.BaseModel.model_validate`). Every field is
|
||||||
|
cross-validated at parse time so a typo in a nested sub-field
|
||||||
|
fails in milliseconds rather than after a multi-minute
|
||||||
|
predictions load.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
config_version
|
||||||
|
Schema version for the config itself. Only ``1`` is valid at
|
||||||
|
this release. Future config-format breaks bump this and a
|
||||||
|
sibling migration registry handles legacy YAML in place.
|
||||||
|
inputs
|
||||||
|
Predictions-file paths.
|
||||||
|
preprocessing
|
||||||
|
Per-input preprocessing (person-index selection today).
|
||||||
|
segmentation
|
||||||
|
Optional segmentation stage. ``None`` skips segmentation
|
||||||
|
entirely and analysis runs over each full trial as a single
|
||||||
|
"segment".
|
||||||
|
analysis
|
||||||
|
Required analysis stage. Exactly one of
|
||||||
|
:class:`DtwAnalysis` / :class:`StatsAnalysis` / :class:`NoAnalysis`.
|
||||||
|
output
|
||||||
|
Output paths.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="forbid", frozen=True)
|
||||||
|
|
||||||
|
config_version: Literal[1] = 1
|
||||||
|
inputs: InputsConfig
|
||||||
|
preprocessing: PreprocessingConfig = Field(default_factory=PreprocessingConfig)
|
||||||
|
segmentation: SegmentationStage | None = None
|
||||||
|
analysis: AnalysisStage
|
||||||
|
output: OutputConfig
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def _check_cross_stage_invariants(self) -> AnalysisConfig:
|
||||||
|
# DTW is comparative — it needs a reference input.
|
||||||
|
if isinstance(self.analysis, DtwAnalysis) and self.inputs.reference is None:
|
||||||
|
raise ValueError("analysis.kind='dtw' requires inputs.reference to be set")
|
||||||
|
# Stats is non-comparative — a reference without a use is
|
||||||
|
# almost certainly an operator error.
|
||||||
|
if isinstance(self.analysis, StatsAnalysis) and self.inputs.reference is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"analysis.kind='stats' operates on inputs.primary only; "
|
||||||
|
"remove inputs.reference or switch analysis.kind to 'dtw'"
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Report pieces
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class InputSummary(BaseModel):
|
||||||
|
"""Capsule of an input predictions file's headline metadata.
|
||||||
|
|
||||||
|
Stored in the :class:`AnalysisReport` so a reader of the report
|
||||||
|
can tell at a glance what was analysed without having to load the
|
||||||
|
underlying predictions JSONs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="forbid", frozen=True)
|
||||||
|
|
||||||
|
path: Path
|
||||||
|
frame_count: int = Field(ge=0)
|
||||||
|
fps: float = Field(ge=0.0)
|
||||||
|
provenance: Provenance | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class FeatureSummary(BaseModel):
|
||||||
|
"""Pydantic twin of :class:`~neuropose.analyzer.features.FeatureStatistics`.
|
||||||
|
|
||||||
|
The dataclass is used throughout the analyzer for ad-hoc Python
|
||||||
|
consumption; the report path needs a pydantic model for
|
||||||
|
round-tripping through JSON.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="forbid", frozen=True)
|
||||||
|
|
||||||
|
mean: float
|
||||||
|
std: float
|
||||||
|
min: float
|
||||||
|
max: float
|
||||||
|
range: float
|
||||||
|
|
||||||
|
|
||||||
|
class DtwResults(BaseModel):
|
||||||
|
"""DTW results attached to an :class:`AnalysisReport`.
|
||||||
|
|
||||||
|
``distances`` is parallel to ``segment_labels``. For an
|
||||||
|
unsegmented run the lists have length 1 and the label is
|
||||||
|
``"full_trial"``. For a segmented run each label takes the form
|
||||||
|
``"<segmentation_key>[<index>]"`` (e.g. ``"left_heel_strikes[3]"``).
|
||||||
|
``per_joint_distances`` carries a per-unit breakdown for
|
||||||
|
``method="dtw_per_joint"`` only; its outer length matches
|
||||||
|
``distances``, inner length matches either ``num_joints`` (coords)
|
||||||
|
or ``len(angle_triplets)`` (angles).
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="forbid", frozen=True)
|
||||||
|
|
||||||
|
kind: Literal["dtw"]
|
||||||
|
method: Literal["dtw_all", "dtw_per_joint", "dtw_relation"]
|
||||||
|
distances: list[float]
|
||||||
|
paths: list[list[tuple[int, int]]]
|
||||||
|
per_joint_distances: list[list[float]] | None = None
|
||||||
|
segment_labels: list[str]
|
||||||
|
summary: dict[str, float]
|
||||||
|
|
||||||
|
|
||||||
|
class StatsResults(BaseModel):
|
||||||
|
"""Feature-statistics results attached to an :class:`AnalysisReport`.
|
||||||
|
|
||||||
|
``statistics`` is parallel to ``segment_labels``; see
|
||||||
|
:class:`DtwResults` for the labelling convention.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="forbid", frozen=True)
|
||||||
|
|
||||||
|
kind: Literal["stats"]
|
||||||
|
statistics: list[FeatureSummary]
|
||||||
|
segment_labels: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
class NoResults(BaseModel):
|
||||||
|
"""Empty results payload for ``analysis.kind='none'`` runs."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="forbid", frozen=True)
|
||||||
|
|
||||||
|
kind: Literal["none"]
|
||||||
|
|
||||||
|
|
||||||
|
AnalysisResults = Annotated[
|
||||||
|
DtwResults | StatsResults | NoResults,
|
||||||
|
Field(discriminator="kind"),
|
||||||
|
]
|
||||||
|
"""Discriminated-union alias for the three analysis-result shapes.
|
||||||
|
|
||||||
|
Mirrors :data:`AnalysisStage` one-for-one: ``DtwAnalysis`` produces
|
||||||
|
:class:`DtwResults`, etc.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Top-level report
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class AnalysisReport(BaseModel):
|
||||||
|
"""Self-describing output artifact of :func:`run_analysis`.
|
||||||
|
|
||||||
|
Serialised to JSON on disk. Carries the originating config, the
|
||||||
|
:class:`~neuropose.io.Provenance` envelope (with the config
|
||||||
|
serialised into :attr:`~neuropose.io.Provenance.analysis_config`
|
||||||
|
so the report is self-describing even if the YAML is lost), each
|
||||||
|
input's headline metadata plus its own provenance if available,
|
||||||
|
any segmentations produced, and the analysis results themselves.
|
||||||
|
|
||||||
|
Lives in the schema-migration registry under ``"AnalysisReport"``
|
||||||
|
at ``CURRENT_VERSION``; see :mod:`neuropose.migrations`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="forbid", frozen=True)
|
||||||
|
|
||||||
|
schema_version: int = Field(default=CURRENT_VERSION, ge=1)
|
||||||
|
config: AnalysisConfig
|
||||||
|
provenance: Provenance | None = None
|
||||||
|
primary: InputSummary
|
||||||
|
reference: InputSummary | None = None
|
||||||
|
segmentations: dict[str, Segmentation] = Field(default_factory=dict)
|
||||||
|
results: AnalysisResults
|
||||||
|
|
||||||
|
|
||||||
|
def analysis_config_to_dict(config: AnalysisConfig) -> dict[str, Any]:
|
||||||
|
"""Serialise an :class:`AnalysisConfig` to a JSON-safe dict.
|
||||||
|
|
||||||
|
Returned shape is identical to what pydantic would produce via
|
||||||
|
:meth:`~pydantic.BaseModel.model_dump` in ``mode="json"`` — paths
|
||||||
|
become strings, tuples become lists, enums become their values.
|
||||||
|
Useful for stamping
|
||||||
|
:attr:`~neuropose.io.Provenance.analysis_config` on the
|
||||||
|
:class:`AnalysisReport`'s provenance envelope.
|
||||||
|
"""
|
||||||
|
return config.model_dump(mode="json")
|
||||||
|
|
@ -60,7 +60,9 @@ Version history
|
||||||
---------------
|
---------------
|
||||||
- **v1:** initial schema, pre-Phase-0.
|
- **v1:** initial schema, pre-Phase-0.
|
||||||
- **v2:** added optional ``provenance`` field to :class:`~neuropose.io.VideoPredictions`
|
- **v2:** added optional ``provenance`` field to :class:`~neuropose.io.VideoPredictions`
|
||||||
and :class:`~neuropose.io.BenchmarkResult` (Phase 0, Paper C reproducibility envelope)."""
|
and :class:`~neuropose.io.BenchmarkResult` (Phase 0, Paper C reproducibility envelope).
|
||||||
|
:class:`~neuropose.analyzer.pipeline.AnalysisReport` also enters the registry at v2
|
||||||
|
(no legacy v1 payloads ever existed for it, so no migration is registered)."""
|
||||||
|
|
||||||
|
|
||||||
class MigrationError(Exception):
|
class MigrationError(Exception):
|
||||||
|
|
@ -90,6 +92,7 @@ class MigrationNotFoundError(MigrationError):
|
||||||
# dict at ``source + 1``.
|
# dict at ``source + 1``.
|
||||||
_VIDEO_PREDICTIONS_MIGRATIONS: dict[int, Callable[[dict], dict]] = {}
|
_VIDEO_PREDICTIONS_MIGRATIONS: dict[int, Callable[[dict], dict]] = {}
|
||||||
_BENCHMARK_RESULT_MIGRATIONS: dict[int, Callable[[dict], dict]] = {}
|
_BENCHMARK_RESULT_MIGRATIONS: dict[int, Callable[[dict], dict]] = {}
|
||||||
|
_ANALYSIS_REPORT_MIGRATIONS: dict[int, Callable[[dict], dict]] = {}
|
||||||
|
|
||||||
|
|
||||||
def register_video_predictions_migration(
|
def register_video_predictions_migration(
|
||||||
|
|
@ -142,6 +145,29 @@ def register_benchmark_result_migration(
|
||||||
return wrap
|
return wrap
|
||||||
|
|
||||||
|
|
||||||
|
def register_analysis_report_migration(
|
||||||
|
from_version: int,
|
||||||
|
) -> Callable[[Callable[[dict], dict]], Callable[[dict], dict]]:
|
||||||
|
"""Register a :class:`~neuropose.analyzer.pipeline.AnalysisReport` migration.
|
||||||
|
|
||||||
|
See :func:`register_video_predictions_migration` for usage — this
|
||||||
|
is the same pattern for the analysis-report registry. Unlike the
|
||||||
|
other two schemas, :class:`AnalysisReport` first appeared at
|
||||||
|
:data:`CURRENT_VERSION = 2`, so no ``from_version=1`` migration
|
||||||
|
exists (and none is expected).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def wrap(fn: Callable[[dict], dict]) -> Callable[[dict], dict]:
|
||||||
|
if from_version in _ANALYSIS_REPORT_MIGRATIONS:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"analysis-report migration already registered from version {from_version}"
|
||||||
|
)
|
||||||
|
_ANALYSIS_REPORT_MIGRATIONS[from_version] = fn
|
||||||
|
return fn
|
||||||
|
|
||||||
|
return wrap
|
||||||
|
|
||||||
|
|
||||||
def migrate_video_predictions(payload: dict) -> dict:
|
def migrate_video_predictions(payload: dict) -> dict:
|
||||||
"""Migrate a raw :class:`~neuropose.io.VideoPredictions` dict to current.
|
"""Migrate a raw :class:`~neuropose.io.VideoPredictions` dict to current.
|
||||||
|
|
||||||
|
|
@ -179,6 +205,18 @@ def migrate_benchmark_result(payload: dict) -> dict:
|
||||||
return _migrate(payload, _BENCHMARK_RESULT_MIGRATIONS, schema_name="BenchmarkResult")
|
return _migrate(payload, _BENCHMARK_RESULT_MIGRATIONS, schema_name="BenchmarkResult")
|
||||||
|
|
||||||
|
|
||||||
|
def migrate_analysis_report(payload: dict) -> dict:
|
||||||
|
"""Migrate a raw :class:`~neuropose.analyzer.pipeline.AnalysisReport` dict.
|
||||||
|
|
||||||
|
See :func:`migrate_video_predictions` for semantics. Because
|
||||||
|
:class:`AnalysisReport` first shipped at schema_version 2, a
|
||||||
|
payload missing the key still defaults to 1 (and would require a
|
||||||
|
not-yet-registered v1 → v2 migration); this is only reachable for
|
||||||
|
deliberately malformed inputs.
|
||||||
|
"""
|
||||||
|
return _migrate(payload, _ANALYSIS_REPORT_MIGRATIONS, schema_name="AnalysisReport")
|
||||||
|
|
||||||
|
|
||||||
def migrate_job_results(payload: dict) -> dict:
|
def migrate_job_results(payload: dict) -> dict:
|
||||||
"""Migrate a :class:`~neuropose.io.JobResults` root dict to current.
|
"""Migrate a :class:`~neuropose.io.JobResults` root dict to current.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,434 @@
|
||||||
|
"""Tests for :mod:`neuropose.analyzer.pipeline`.
|
||||||
|
|
||||||
|
This file covers the schema half of the pipeline:
|
||||||
|
|
||||||
|
- :class:`AnalysisConfig` parsing, including the discriminated unions
|
||||||
|
for the segmentation and analysis stages, and the cross-field
|
||||||
|
invariants enforced at parse time.
|
||||||
|
- :class:`AnalysisReport` construction + JSON round-trip, including
|
||||||
|
the migration hook (schema_version defaults to CURRENT_VERSION).
|
||||||
|
- :func:`analysis_config_to_dict` JSON-safety.
|
||||||
|
|
||||||
|
The executor (``run_analysis``) gets its own test module.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from neuropose.analyzer.pipeline import (
|
||||||
|
AnalysisConfig,
|
||||||
|
AnalysisReport,
|
||||||
|
DtwAnalysis,
|
||||||
|
DtwResults,
|
||||||
|
ExtractorSegmentation,
|
||||||
|
FeatureSummary,
|
||||||
|
GaitCyclesBilateralSegmentation,
|
||||||
|
GaitCyclesSegmentation,
|
||||||
|
InputsConfig,
|
||||||
|
InputSummary,
|
||||||
|
NoAnalysis,
|
||||||
|
NoResults,
|
||||||
|
OutputConfig,
|
||||||
|
PreprocessingConfig,
|
||||||
|
StatsAnalysis,
|
||||||
|
StatsResults,
|
||||||
|
analysis_config_to_dict,
|
||||||
|
)
|
||||||
|
from neuropose.migrations import CURRENT_VERSION
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _minimal_dtw_config(tmp_path: Path) -> dict[str, Any]:
|
||||||
|
"""A minimal AnalysisConfig dict with dtw_all + reference."""
|
||||||
|
return {
|
||||||
|
"inputs": {
|
||||||
|
"primary": str(tmp_path / "primary.json"),
|
||||||
|
"reference": str(tmp_path / "reference.json"),
|
||||||
|
},
|
||||||
|
"analysis": {"kind": "dtw", "method": "dtw_all"},
|
||||||
|
"output": {"report": str(tmp_path / "report.json")},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _minimal_stats_config(tmp_path: Path) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"inputs": {"primary": str(tmp_path / "primary.json")},
|
||||||
|
"analysis": {
|
||||||
|
"kind": "stats",
|
||||||
|
"extractor": {"kind": "joint_axis", "joint": 32, "axis": 1, "invert": False},
|
||||||
|
},
|
||||||
|
"output": {"report": str(tmp_path / "report.json")},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# InputsConfig / PreprocessingConfig / OutputConfig
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestInputsConfig:
|
||||||
|
def test_primary_only(self, tmp_path: Path) -> None:
|
||||||
|
cfg = InputsConfig(primary=tmp_path / "a.json")
|
||||||
|
assert cfg.reference is None
|
||||||
|
|
||||||
|
def test_primary_and_reference(self, tmp_path: Path) -> None:
|
||||||
|
cfg = InputsConfig(
|
||||||
|
primary=tmp_path / "a.json",
|
||||||
|
reference=tmp_path / "b.json",
|
||||||
|
)
|
||||||
|
assert cfg.reference == tmp_path / "b.json"
|
||||||
|
|
||||||
|
def test_extra_field_rejected(self, tmp_path: Path) -> None:
|
||||||
|
with pytest.raises(ValidationError, match="Extra inputs"):
|
||||||
|
InputsConfig.model_validate(
|
||||||
|
{
|
||||||
|
"primary": str(tmp_path / "a.json"),
|
||||||
|
"extra": "field",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_frozen(self, tmp_path: Path) -> None:
|
||||||
|
cfg = InputsConfig(primary=tmp_path / "a.json")
|
||||||
|
with pytest.raises(ValidationError, match="frozen"):
|
||||||
|
cfg.primary = tmp_path / "b.json" # type: ignore[misc]
|
||||||
|
|
||||||
|
|
||||||
|
class TestPreprocessingConfig:
|
||||||
|
def test_default_person_index_zero(self) -> None:
|
||||||
|
cfg = PreprocessingConfig()
|
||||||
|
assert cfg.person_index == 0
|
||||||
|
|
||||||
|
def test_negative_person_index_rejected(self) -> None:
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
PreprocessingConfig(person_index=-1)
|
||||||
|
|
||||||
|
|
||||||
|
class TestOutputConfig:
|
||||||
|
def test_report_path(self, tmp_path: Path) -> None:
|
||||||
|
cfg = OutputConfig(report=tmp_path / "out.json")
|
||||||
|
assert cfg.report == tmp_path / "out.json"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Segmentation stage discriminated union
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestSegmentationStage:
|
||||||
|
def test_gait_cycles_parses_from_dict(self, tmp_path: Path) -> None:
|
||||||
|
config_dict = _minimal_dtw_config(tmp_path)
|
||||||
|
config_dict["segmentation"] = {
|
||||||
|
"kind": "gait_cycles",
|
||||||
|
"joint": "lhee",
|
||||||
|
"axis": "y",
|
||||||
|
"min_cycle_seconds": 0.5,
|
||||||
|
}
|
||||||
|
cfg = AnalysisConfig.model_validate(config_dict)
|
||||||
|
assert isinstance(cfg.segmentation, GaitCyclesSegmentation)
|
||||||
|
assert cfg.segmentation.joint == "lhee"
|
||||||
|
assert cfg.segmentation.min_cycle_seconds == 0.5
|
||||||
|
|
||||||
|
def test_bilateral_parses_from_dict(self, tmp_path: Path) -> None:
|
||||||
|
config_dict = _minimal_dtw_config(tmp_path)
|
||||||
|
config_dict["segmentation"] = {"kind": "gait_cycles_bilateral"}
|
||||||
|
cfg = AnalysisConfig.model_validate(config_dict)
|
||||||
|
assert isinstance(cfg.segmentation, GaitCyclesBilateralSegmentation)
|
||||||
|
|
||||||
|
def test_extractor_parses_from_dict(self, tmp_path: Path) -> None:
|
||||||
|
config_dict = _minimal_dtw_config(tmp_path)
|
||||||
|
config_dict["segmentation"] = {
|
||||||
|
"kind": "extractor",
|
||||||
|
"extractor": {
|
||||||
|
"kind": "joint_axis",
|
||||||
|
"joint": 15,
|
||||||
|
"axis": 1,
|
||||||
|
"invert": False,
|
||||||
|
},
|
||||||
|
"label": "wrist_cycles",
|
||||||
|
"min_distance_seconds": 0.5,
|
||||||
|
}
|
||||||
|
cfg = AnalysisConfig.model_validate(config_dict)
|
||||||
|
assert isinstance(cfg.segmentation, ExtractorSegmentation)
|
||||||
|
assert cfg.segmentation.label == "wrist_cycles"
|
||||||
|
|
||||||
|
def test_unknown_kind_rejected(self, tmp_path: Path) -> None:
|
||||||
|
config_dict = _minimal_dtw_config(tmp_path)
|
||||||
|
config_dict["segmentation"] = {"kind": "unknown_method"}
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
AnalysisConfig.model_validate(config_dict)
|
||||||
|
|
||||||
|
def test_segmentation_omitted_is_none(self, tmp_path: Path) -> None:
|
||||||
|
cfg = AnalysisConfig.model_validate(_minimal_dtw_config(tmp_path))
|
||||||
|
assert cfg.segmentation is None
|
||||||
|
|
||||||
|
def test_invalid_min_cycle_seconds_rejected(self, tmp_path: Path) -> None:
|
||||||
|
config_dict = _minimal_dtw_config(tmp_path)
|
||||||
|
config_dict["segmentation"] = {
|
||||||
|
"kind": "gait_cycles",
|
||||||
|
"min_cycle_seconds": 0.0, # must be > 0
|
||||||
|
}
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
AnalysisConfig.model_validate(config_dict)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Analysis stage discriminated union + cross-field invariants
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestDtwAnalysisValidation:
|
||||||
|
def test_dtw_relation_requires_joints(self) -> None:
|
||||||
|
with pytest.raises(ValidationError, match="joint_i and joint_j"):
|
||||||
|
DtwAnalysis(kind="dtw", method="dtw_relation")
|
||||||
|
|
||||||
|
def test_dtw_relation_rejects_angles_representation(self) -> None:
|
||||||
|
with pytest.raises(ValidationError, match="only supports representation='coords'"):
|
||||||
|
DtwAnalysis(
|
||||||
|
kind="dtw",
|
||||||
|
method="dtw_relation",
|
||||||
|
joint_i=0,
|
||||||
|
joint_j=1,
|
||||||
|
representation="angles",
|
||||||
|
angle_triplets=[(0, 1, 2)],
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_angles_requires_triplets(self) -> None:
|
||||||
|
with pytest.raises(ValidationError, match="angle_triplets"):
|
||||||
|
DtwAnalysis(
|
||||||
|
kind="dtw",
|
||||||
|
method="dtw_all",
|
||||||
|
representation="angles",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_angles_with_empty_triplets_rejected(self) -> None:
|
||||||
|
with pytest.raises(ValidationError, match="angle_triplets"):
|
||||||
|
DtwAnalysis(
|
||||||
|
kind="dtw",
|
||||||
|
method="dtw_all",
|
||||||
|
representation="angles",
|
||||||
|
angle_triplets=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_happy_path_dtw_all_coords(self) -> None:
|
||||||
|
analysis = DtwAnalysis(
|
||||||
|
kind="dtw",
|
||||||
|
method="dtw_all",
|
||||||
|
align="procrustes_per_sequence",
|
||||||
|
nan_policy="interpolate",
|
||||||
|
)
|
||||||
|
assert analysis.align == "procrustes_per_sequence"
|
||||||
|
|
||||||
|
def test_happy_path_dtw_all_angles(self) -> None:
|
||||||
|
analysis = DtwAnalysis(
|
||||||
|
kind="dtw",
|
||||||
|
method="dtw_all",
|
||||||
|
representation="angles",
|
||||||
|
angle_triplets=[(0, 1, 2), (3, 4, 5)],
|
||||||
|
)
|
||||||
|
assert len(analysis.angle_triplets or []) == 2
|
||||||
|
|
||||||
|
def test_happy_path_dtw_relation(self) -> None:
|
||||||
|
analysis = DtwAnalysis(
|
||||||
|
kind="dtw",
|
||||||
|
method="dtw_relation",
|
||||||
|
joint_i=15,
|
||||||
|
joint_j=23,
|
||||||
|
)
|
||||||
|
assert analysis.joint_i == 15
|
||||||
|
|
||||||
|
|
||||||
|
class TestAnalysisCrossStage:
|
||||||
|
def test_dtw_without_reference_rejected(self, tmp_path: Path) -> None:
|
||||||
|
config_dict = {
|
||||||
|
"inputs": {"primary": str(tmp_path / "a.json")},
|
||||||
|
"analysis": {"kind": "dtw", "method": "dtw_all"},
|
||||||
|
"output": {"report": str(tmp_path / "out.json")},
|
||||||
|
}
|
||||||
|
with pytest.raises(ValidationError, match=r"inputs\.reference"):
|
||||||
|
AnalysisConfig.model_validate(config_dict)
|
||||||
|
|
||||||
|
def test_stats_with_reference_rejected(self, tmp_path: Path) -> None:
|
||||||
|
config_dict = _minimal_stats_config(tmp_path)
|
||||||
|
config_dict["inputs"]["reference"] = str(tmp_path / "b.json")
|
||||||
|
with pytest.raises(ValidationError, match="primary only"):
|
||||||
|
AnalysisConfig.model_validate(config_dict)
|
||||||
|
|
||||||
|
def test_none_analysis_requires_no_reference(self, tmp_path: Path) -> None:
|
||||||
|
# NoAnalysis is fine with either reference present or absent.
|
||||||
|
config_dict = {
|
||||||
|
"inputs": {"primary": str(tmp_path / "a.json")},
|
||||||
|
"analysis": {"kind": "none"},
|
||||||
|
"output": {"report": str(tmp_path / "out.json")},
|
||||||
|
}
|
||||||
|
cfg = AnalysisConfig.model_validate(config_dict)
|
||||||
|
assert isinstance(cfg.analysis, NoAnalysis)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Top-level AnalysisConfig
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestAnalysisConfig:
|
||||||
|
def test_minimal_dtw_config_parses(self, tmp_path: Path) -> None:
|
||||||
|
cfg = AnalysisConfig.model_validate(_minimal_dtw_config(tmp_path))
|
||||||
|
assert cfg.config_version == 1
|
||||||
|
assert isinstance(cfg.analysis, DtwAnalysis)
|
||||||
|
assert cfg.preprocessing.person_index == 0 # default
|
||||||
|
|
||||||
|
def test_minimal_stats_config_parses(self, tmp_path: Path) -> None:
|
||||||
|
cfg = AnalysisConfig.model_validate(_minimal_stats_config(tmp_path))
|
||||||
|
assert isinstance(cfg.analysis, StatsAnalysis)
|
||||||
|
|
||||||
|
def test_config_version_must_be_1(self, tmp_path: Path) -> None:
|
||||||
|
config_dict = _minimal_dtw_config(tmp_path)
|
||||||
|
config_dict["config_version"] = 99
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
AnalysisConfig.model_validate(config_dict)
|
||||||
|
|
||||||
|
def test_round_trip_json(self, tmp_path: Path) -> None:
|
||||||
|
original = AnalysisConfig.model_validate(_minimal_dtw_config(tmp_path))
|
||||||
|
serialised = original.model_dump_json()
|
||||||
|
restored = AnalysisConfig.model_validate_json(serialised)
|
||||||
|
assert restored == original
|
||||||
|
|
||||||
|
def test_extra_top_level_field_rejected(self, tmp_path: Path) -> None:
|
||||||
|
config_dict = _minimal_dtw_config(tmp_path)
|
||||||
|
config_dict["unknown_key"] = "typo"
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
AnalysisConfig.model_validate(config_dict)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# analysis_config_to_dict
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestAnalysisConfigToDict:
|
||||||
|
def test_returns_json_safe_dict(self, tmp_path: Path) -> None:
|
||||||
|
cfg = AnalysisConfig.model_validate(_minimal_dtw_config(tmp_path))
|
||||||
|
dumped = analysis_config_to_dict(cfg)
|
||||||
|
# Paths must have become strings.
|
||||||
|
assert isinstance(dumped["inputs"]["primary"], str)
|
||||||
|
assert isinstance(dumped["output"]["report"], str)
|
||||||
|
|
||||||
|
def test_round_trips_through_dict(self, tmp_path: Path) -> None:
|
||||||
|
original = AnalysisConfig.model_validate(_minimal_dtw_config(tmp_path))
|
||||||
|
dumped = analysis_config_to_dict(original)
|
||||||
|
restored = AnalysisConfig.model_validate(dumped)
|
||||||
|
assert restored == original
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Result sub-schemas
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestDtwResults:
|
||||||
|
def test_minimal_construction(self) -> None:
|
||||||
|
res = DtwResults(
|
||||||
|
kind="dtw",
|
||||||
|
method="dtw_all",
|
||||||
|
distances=[0.5],
|
||||||
|
paths=[[(0, 0), (1, 1)]],
|
||||||
|
segment_labels=["full_trial"],
|
||||||
|
summary={"mean": 0.5},
|
||||||
|
)
|
||||||
|
assert res.kind == "dtw"
|
||||||
|
|
||||||
|
def test_per_joint_distances_shape_is_free(self) -> None:
|
||||||
|
# No validator enforces that per_joint_distances outer length
|
||||||
|
# matches distances — that's run-time semantics of the
|
||||||
|
# executor. Still, verify the field round-trips.
|
||||||
|
res = DtwResults(
|
||||||
|
kind="dtw",
|
||||||
|
method="dtw_per_joint",
|
||||||
|
distances=[0.1, 0.2],
|
||||||
|
paths=[[(0, 0)], [(0, 0)]],
|
||||||
|
per_joint_distances=[[0.05, 0.05], [0.1, 0.1]],
|
||||||
|
segment_labels=["rhee_cycles[0]", "rhee_cycles[1]"],
|
||||||
|
summary={"mean": 0.15},
|
||||||
|
)
|
||||||
|
assert res.per_joint_distances is not None
|
||||||
|
|
||||||
|
|
||||||
|
class TestStatsResults:
|
||||||
|
def test_round_trip(self) -> None:
|
||||||
|
res = StatsResults(
|
||||||
|
kind="stats",
|
||||||
|
statistics=[
|
||||||
|
FeatureSummary(mean=1.0, std=0.1, min=0.8, max=1.2, range=0.4),
|
||||||
|
FeatureSummary(mean=1.1, std=0.2, min=0.7, max=1.5, range=0.8),
|
||||||
|
],
|
||||||
|
segment_labels=["rhee_cycles[0]", "rhee_cycles[1]"],
|
||||||
|
)
|
||||||
|
dumped = res.model_dump_json()
|
||||||
|
restored = StatsResults.model_validate_json(dumped)
|
||||||
|
assert restored == res
|
||||||
|
|
||||||
|
|
||||||
|
class TestNoResults:
|
||||||
|
def test_construction(self) -> None:
|
||||||
|
res = NoResults(kind="none")
|
||||||
|
assert res.kind == "none"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# AnalysisReport
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _make_report(tmp_path: Path) -> AnalysisReport:
|
||||||
|
config = AnalysisConfig.model_validate(_minimal_dtw_config(tmp_path))
|
||||||
|
return AnalysisReport(
|
||||||
|
config=config,
|
||||||
|
primary=InputSummary(
|
||||||
|
path=tmp_path / "primary.json",
|
||||||
|
frame_count=300,
|
||||||
|
fps=30.0,
|
||||||
|
),
|
||||||
|
reference=InputSummary(
|
||||||
|
path=tmp_path / "reference.json",
|
||||||
|
frame_count=300,
|
||||||
|
fps=30.0,
|
||||||
|
),
|
||||||
|
results=DtwResults(
|
||||||
|
kind="dtw",
|
||||||
|
method="dtw_all",
|
||||||
|
distances=[0.42],
|
||||||
|
paths=[[(0, 0), (1, 1)]],
|
||||||
|
segment_labels=["full_trial"],
|
||||||
|
summary={"mean": 0.42, "p50": 0.42},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAnalysisReport:
|
||||||
|
def test_schema_version_defaults_to_current(self, tmp_path: Path) -> None:
|
||||||
|
report = _make_report(tmp_path)
|
||||||
|
assert report.schema_version == CURRENT_VERSION
|
||||||
|
|
||||||
|
def test_round_trip_json(self, tmp_path: Path) -> None:
|
||||||
|
report = _make_report(tmp_path)
|
||||||
|
serialised = report.model_dump_json()
|
||||||
|
restored = AnalysisReport.model_validate_json(serialised)
|
||||||
|
assert restored == report
|
||||||
|
|
||||||
|
def test_empty_segmentations_default(self, tmp_path: Path) -> None:
|
||||||
|
report = _make_report(tmp_path)
|
||||||
|
assert report.segmentations == {}
|
||||||
|
|
||||||
|
def test_extra_field_rejected(self, tmp_path: Path) -> None:
|
||||||
|
report = _make_report(tmp_path)
|
||||||
|
dumped = report.model_dump(mode="json")
|
||||||
|
dumped["mystery_field"] = 1
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
AnalysisReport.model_validate(dumped)
|
||||||
|
|
@ -38,6 +38,7 @@ from neuropose.migrations import (
|
||||||
FutureSchemaError,
|
FutureSchemaError,
|
||||||
MigrationError,
|
MigrationError,
|
||||||
MigrationNotFoundError,
|
MigrationNotFoundError,
|
||||||
|
migrate_analysis_report,
|
||||||
migrate_benchmark_result,
|
migrate_benchmark_result,
|
||||||
migrate_job_results,
|
migrate_job_results,
|
||||||
migrate_video_predictions,
|
migrate_video_predictions,
|
||||||
|
|
@ -282,8 +283,32 @@ class TestMigrateJobResults:
|
||||||
assert result["video_b.mp4"]["content_b"] is True
|
assert result["video_b.mp4"]["content_b"] is True
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# migrate_analysis_report
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestMigrateAnalysisReport:
|
||||||
|
def test_current_version_is_noop(self) -> None:
|
||||||
|
"""A payload already at CURRENT_VERSION passes through unchanged."""
|
||||||
|
payload = {"schema_version": CURRENT_VERSION, "foo": "bar"}
|
||||||
|
assert migrate_analysis_report(payload) == payload
|
||||||
|
|
||||||
|
def test_missing_version_defaults_to_v1_and_fails(self) -> None:
|
||||||
|
"""AnalysisReport first shipped at CURRENT_VERSION, so a payload
|
||||||
|
without schema_version (defaulting to 1) would require a
|
||||||
|
non-existent v1→v2 migration and fail with a clear error."""
|
||||||
|
with pytest.raises(MigrationNotFoundError, match="AnalysisReport"):
|
||||||
|
migrate_analysis_report({})
|
||||||
|
|
||||||
|
def test_future_version_rejected(self) -> None:
|
||||||
|
with pytest.raises(FutureSchemaError, match="AnalysisReport"):
|
||||||
|
migrate_analysis_report({"schema_version": CURRENT_VERSION + 5})
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# register_video_predictions_migration / register_benchmark_result_migration
|
# register_video_predictions_migration / register_benchmark_result_migration
|
||||||
|
# / register_analysis_report_migration
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -312,6 +337,21 @@ class TestRegistration:
|
||||||
assert _fn.__name__ == "_fn"
|
assert _fn.__name__ == "_fn"
|
||||||
assert _fn({"x": 1}) == {"x": 1}
|
assert _fn({"x": 1}) == {"x": 1}
|
||||||
|
|
||||||
|
def test_analysis_report_duplicate_registration_raises(
|
||||||
|
self, monkeypatch: pytest.MonkeyPatch
|
||||||
|
) -> None:
|
||||||
|
monkeypatch.setattr(migrations, "_ANALYSIS_REPORT_MIGRATIONS", {})
|
||||||
|
|
||||||
|
@migrations.register_analysis_report_migration(from_version=2)
|
||||||
|
def _first(p: dict) -> dict:
|
||||||
|
return p
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="already registered"):
|
||||||
|
|
||||||
|
@migrations.register_analysis_report_migration(from_version=2)
|
||||||
|
def _second(p: dict) -> dict:
|
||||||
|
return p
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Integration: load_* functions run migrations before validation
|
# Integration: load_* functions run migrations before validation
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue