add AnalysisConfig and AnalysisReport schemas
neuropose.analyzer.pipeline ships two top-level pydantic schemas: AnalysisConfig — what a user writes in YAML. Inputs (primary plus optional reference), preprocessing (person_index, room to grow), optional segmentation as a discriminated union of gait_cycles, gait_cycles_bilateral, and extractor, and a required analysis stage as a discriminated union of dtw, stats, none. AnalysisReport — runtime output with config, Provenance envelope, per-input summaries, produced segmentations, and a results payload whose shape mirrors the stage (DtwResults, StatsResults, NoResults). schema_version defaults to CURRENT_VERSION. Cross-field invariants enforced at parse time via model_validator: method='dtw_relation' requires joint_i/joint_j and refuses representation='angles'; representation='angles' requires non-empty angle_triplets; analysis.kind='dtw' requires inputs.reference; analysis.kind='stats' refuses a reference. Typos fail in milliseconds instead of after a multi-minute predictions load. neuropose.migrations gains a third registry for AnalysisReport (_ANALYSIS_REPORT_MIGRATIONS + register_analysis_report_migration + migrate_analysis_report), ready for future schema changes. No v1→v2 migration is registered because AnalysisReport first shipped at v2. Execution, CLI wiring, and example configs land in follow-up commits.
This commit is contained in:
parent
87461a17d0
commit
979beb1078
24
CHANGELOG.md
24
CHANGELOG.md
|
|
@ -257,6 +257,30 @@ be split into per-release sections once tagging begins.
|
|||
two-joint displacement DTW; users who prefer a unified API can
|
||||
express the same computation via `dtw_all` with an appropriate
|
||||
pair of angle triplets or run `dtw_relation` directly.
|
||||
- **`neuropose.analyzer.pipeline`** (schemas) — declarative
|
||||
analysis-pipeline configuration and output artifact, parseable from
|
||||
YAML or JSON via pydantic. `AnalysisConfig` captures a full
|
||||
experiment: inputs (primary + optional reference predictions
|
||||
files), preprocessing (person index, with room to grow),
|
||||
optional segmentation (`gait_cycles` / `gait_cycles_bilateral` /
|
||||
`extractor` discriminated union), and a required analysis stage
|
||||
(`dtw` / `stats` / `none` discriminated union). `AnalysisReport`
|
||||
is the runtime output: carries the originating config, a
|
||||
`Provenance` envelope with `analysis_config` populated, per-input
|
||||
summaries, produced segmentations, and an analysis-result payload
|
||||
that mirrors the stage choice (`DtwResults`, `StatsResults`, or
|
||||
`NoResults`). Cross-field invariants — `method="dtw_relation"`
|
||||
requires `joint_i`/`joint_j`, `representation="angles"` requires
|
||||
non-empty `angle_triplets`, `analysis.kind="dtw"` requires
|
||||
`inputs.reference`, `analysis.kind="stats"` refuses a reference —
|
||||
are enforced at parse time via `model_validator` so typos fail in
|
||||
milliseconds instead of after a multi-minute predictions load.
|
||||
`AnalysisReport` carries a `schema_version` field defaulting to
|
||||
`CURRENT_VERSION = 2`, with a new
|
||||
`register_analysis_report_migration` decorator and
|
||||
`migrate_analysis_report` driver in `neuropose.migrations` ready
|
||||
for future schema changes. Pipeline execution lands in a
|
||||
follow-up commit.
|
||||
- **`neuropose.analyzer.segment.segment_gait_cycles`** and
|
||||
**`segment_gait_cycles_bilateral`** — clinical convenience
|
||||
wrappers over `segment_predictions` that pre-fill a `joint_axis`
|
||||
|
|
|
|||
|
|
@ -48,6 +48,28 @@ from neuropose.analyzer.features import (
|
|||
predictions_to_numpy,
|
||||
procrustes_align,
|
||||
)
|
||||
from neuropose.analyzer.pipeline import (
|
||||
AnalysisConfig,
|
||||
AnalysisReport,
|
||||
AnalysisResults,
|
||||
AnalysisStage,
|
||||
DtwAnalysis,
|
||||
DtwResults,
|
||||
ExtractorSegmentation,
|
||||
FeatureSummary,
|
||||
GaitCyclesBilateralSegmentation,
|
||||
GaitCyclesSegmentation,
|
||||
InputsConfig,
|
||||
InputSummary,
|
||||
NoAnalysis,
|
||||
NoResults,
|
||||
OutputConfig,
|
||||
PreprocessingConfig,
|
||||
SegmentationStage,
|
||||
StatsAnalysis,
|
||||
StatsResults,
|
||||
analysis_config_to_dict,
|
||||
)
|
||||
from neuropose.analyzer.segment import (
|
||||
JOINT_INDEX,
|
||||
JOINT_NAMES,
|
||||
|
|
@ -70,12 +92,32 @@ __all__ = [
|
|||
"JOINT_NAMES",
|
||||
"AlignMode",
|
||||
"AlignmentDiagnostics",
|
||||
"AnalysisConfig",
|
||||
"AnalysisReport",
|
||||
"AnalysisResults",
|
||||
"AnalysisStage",
|
||||
"AxisLetter",
|
||||
"DTWResult",
|
||||
"DtwAnalysis",
|
||||
"DtwResults",
|
||||
"ExtractorSegmentation",
|
||||
"FeatureStatistics",
|
||||
"FeatureSummary",
|
||||
"GaitCyclesBilateralSegmentation",
|
||||
"GaitCyclesSegmentation",
|
||||
"InputSummary",
|
||||
"InputsConfig",
|
||||
"NanPolicy",
|
||||
"NoAnalysis",
|
||||
"NoResults",
|
||||
"OutputConfig",
|
||||
"PreprocessingConfig",
|
||||
"ProcrustesMode",
|
||||
"Representation",
|
||||
"SegmentationStage",
|
||||
"StatsAnalysis",
|
||||
"StatsResults",
|
||||
"analysis_config_to_dict",
|
||||
"dtw_all",
|
||||
"dtw_per_joint",
|
||||
"dtw_relation",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,478 @@
|
|||
"""YAML-configurable analysis pipeline.
|
||||
|
||||
This module unifies the analyzer's individual primitives (Procrustes
|
||||
alignment, gait-cycle segmentation, DTW on coords or angles, feature
|
||||
statistics) behind a single declarative configuration object so an
|
||||
experiment can be reproduced from one file that lives in a git
|
||||
repository, carries through to the :class:`~neuropose.io.Provenance`
|
||||
envelope on the output artifact, and can be cited unambiguously in
|
||||
accompanying papers.
|
||||
|
||||
Two top-level schemas live here:
|
||||
|
||||
- :class:`AnalysisConfig` — what the user writes in YAML. Describes
|
||||
the full pipeline: inputs, preprocessing, optional segmentation,
|
||||
required analysis stage, and output path.
|
||||
- :class:`AnalysisReport` — what :func:`run_analysis` emits. Carries
|
||||
the config, a :class:`~neuropose.io.Provenance` envelope (with the
|
||||
config serialised into :attr:`~neuropose.io.Provenance.analysis_config`),
|
||||
per-input summaries, segmentation results, and the analysis results
|
||||
themselves.
|
||||
|
||||
Both schemas parse from (and serialise to) JSON via pydantic; the
|
||||
config additionally parses from YAML via :func:`load_config`. Cross-field
|
||||
invariants (for example, ``method="dtw_relation"`` requires ``joint_i``
|
||||
and ``joint_j``) are enforced at parse time so typo-laden configs fail
|
||||
fast rather than after an expensive multi-minute load.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
from neuropose.analyzer.dtw import AlignMode, NanPolicy, Representation
|
||||
from neuropose.analyzer.segment import AxisLetter
|
||||
from neuropose.io import ExtractorSpec, Provenance, Segmentation
|
||||
from neuropose.migrations import CURRENT_VERSION
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Inputs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class InputsConfig(BaseModel):
|
||||
"""Predictions files consumed by the pipeline.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
primary
|
||||
Path to a :class:`~neuropose.io.VideoPredictions` JSON file.
|
||||
Always required.
|
||||
reference
|
||||
Optional second predictions file. When provided,
|
||||
:class:`DtwAnalysis` runs comparative DTW between primary and
|
||||
reference; when absent, analysis stages that require a
|
||||
reference (i.e. DTW) raise a validation error at parse time.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid", frozen=True)
|
||||
|
||||
primary: Path
|
||||
reference: Path | None = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Preprocessing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class PreprocessingConfig(BaseModel):
|
||||
"""Per-input preprocessing applied before segmentation and analysis.
|
||||
|
||||
Minimal today — just picks which detected person to extract from
|
||||
each frame. Left as a named stage so future extensions (coordinate
|
||||
normalisation, smoothing) can land here without reshuffling the
|
||||
config shape.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
person_index
|
||||
Which detected person to extract per frame. Defaults to ``0``
|
||||
(the first detected person), matching the single-subject
|
||||
clinical case.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid", frozen=True)
|
||||
|
||||
person_index: int = Field(default=0, ge=0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Segmentation stage
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class GaitCyclesSegmentation(BaseModel):
|
||||
"""Single-heel gait-cycle segmentation via peak detection.
|
||||
|
||||
Produces one :class:`~neuropose.io.Segmentation` keyed under the
|
||||
joint name (e.g. ``"rhee_cycles"``). See
|
||||
:func:`~neuropose.analyzer.segment.segment_gait_cycles` for the
|
||||
underlying implementation.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid", frozen=True)
|
||||
|
||||
kind: Literal["gait_cycles"]
|
||||
joint: str = "rhee"
|
||||
axis: AxisLetter = "y"
|
||||
invert: bool = False
|
||||
min_cycle_seconds: float = Field(default=0.4, gt=0.0)
|
||||
min_prominence: float | None = None
|
||||
|
||||
|
||||
class GaitCyclesBilateralSegmentation(BaseModel):
|
||||
"""Bilateral (both heels) gait-cycle segmentation.
|
||||
|
||||
Produces two :class:`~neuropose.io.Segmentation` objects keyed as
|
||||
``"left_heel_strikes"`` and ``"right_heel_strikes"``.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid", frozen=True)
|
||||
|
||||
kind: Literal["gait_cycles_bilateral"]
|
||||
axis: AxisLetter = "y"
|
||||
invert: bool = False
|
||||
min_cycle_seconds: float = Field(default=0.4, gt=0.0)
|
||||
min_prominence: float | None = None
|
||||
|
||||
|
||||
class ExtractorSegmentation(BaseModel):
|
||||
"""Generic extractor-driven segmentation.
|
||||
|
||||
Wraps :func:`~neuropose.analyzer.segment.segment_predictions` with
|
||||
a caller-supplied :class:`~neuropose.io.ExtractorSpec`. Use this
|
||||
when the signal of interest is not the vertical heel trace — e.g.
|
||||
wrist-hip distance for a reach-and-grasp task, or elbow flexion
|
||||
angle for a range-of-motion trial.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid", frozen=True)
|
||||
|
||||
kind: Literal["extractor"]
|
||||
extractor: ExtractorSpec
|
||||
label: str = Field(
|
||||
default="segmentation",
|
||||
description="Key under which the resulting Segmentation is stored.",
|
||||
)
|
||||
person_index: int | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Overrides preprocessing.person_index for this stage. "
|
||||
"None defers to the global preprocessing setting."
|
||||
),
|
||||
)
|
||||
min_distance_seconds: float | None = Field(default=None, ge=0.0)
|
||||
min_prominence: float | None = None
|
||||
min_height: float | None = None
|
||||
pad_seconds: float = Field(default=0.0, ge=0.0)
|
||||
|
||||
|
||||
SegmentationStage = Annotated[
|
||||
GaitCyclesSegmentation | GaitCyclesBilateralSegmentation | ExtractorSegmentation,
|
||||
Field(discriminator="kind"),
|
||||
]
|
||||
"""Discriminated-union alias for the three segmentation variants.
|
||||
|
||||
Pydantic dispatches on the ``kind`` field. A config without a
|
||||
``segmentation`` key at all skips this stage entirely
|
||||
(see :class:`AnalysisConfig.segmentation`).
|
||||
"""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Analysis stage
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class DtwAnalysis(BaseModel):
|
||||
"""Dynamic Time Warping between the primary and reference inputs.
|
||||
|
||||
Dispatches to one of :func:`~neuropose.analyzer.dtw.dtw_all`,
|
||||
:func:`~neuropose.analyzer.dtw.dtw_per_joint`, or
|
||||
:func:`~neuropose.analyzer.dtw.dtw_relation` per the ``method``
|
||||
field. Cross-field invariants — ``method="dtw_relation"`` requires
|
||||
``joint_i`` and ``joint_j``, ``representation="angles"`` requires
|
||||
a non-empty ``angle_triplets`` — are enforced at parse time.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid", frozen=True)
|
||||
|
||||
kind: Literal["dtw"]
|
||||
method: Literal["dtw_all", "dtw_per_joint", "dtw_relation"] = "dtw_all"
|
||||
align: AlignMode = "none"
|
||||
representation: Representation = "coords"
|
||||
angle_triplets: list[tuple[int, int, int]] | None = None
|
||||
joint_i: int | None = None
|
||||
joint_j: int | None = None
|
||||
nan_policy: NanPolicy = "propagate"
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _check_method_fields(self) -> DtwAnalysis:
|
||||
if self.method == "dtw_relation":
|
||||
if self.joint_i is None or self.joint_j is None:
|
||||
raise ValueError("method='dtw_relation' requires joint_i and joint_j")
|
||||
if self.representation != "coords":
|
||||
raise ValueError(
|
||||
"method='dtw_relation' only supports representation='coords' "
|
||||
"(a two-joint displacement is not a joint-angle signal)"
|
||||
)
|
||||
if self.representation == "angles":
|
||||
if not self.angle_triplets:
|
||||
raise ValueError("representation='angles' requires a non-empty angle_triplets list")
|
||||
if self.method == "dtw_relation":
|
||||
# Guarded by the earlier branch, but make the invariant explicit.
|
||||
raise ValueError(
|
||||
"representation='angles' is incompatible with method='dtw_relation'"
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class StatsAnalysis(BaseModel):
|
||||
"""Summary statistics over a scalar signal extracted from the primary input.
|
||||
|
||||
Runs :func:`~neuropose.analyzer.segment.extract_signal` with the
|
||||
caller-supplied :class:`~neuropose.io.ExtractorSpec`, then
|
||||
computes :func:`~neuropose.analyzer.features.extract_feature_statistics`
|
||||
on each segment (or on the full trial if no segmentation stage
|
||||
runs).
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid", frozen=True)
|
||||
|
||||
kind: Literal["stats"]
|
||||
extractor: ExtractorSpec
|
||||
|
||||
|
||||
class NoAnalysis(BaseModel):
|
||||
"""Terminal stage placeholder; produces no per-segment results.
|
||||
|
||||
Useful when the pipeline's goal is just to segment the input and
|
||||
persist the :class:`~neuropose.io.Segmentation` plus an
|
||||
:class:`AnalysisReport` with provenance — the ``none`` analysis
|
||||
kind makes that explicit rather than requiring the absence of the
|
||||
stage.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid", frozen=True)
|
||||
|
||||
kind: Literal["none"]
|
||||
|
||||
|
||||
AnalysisStage = Annotated[
|
||||
DtwAnalysis | StatsAnalysis | NoAnalysis,
|
||||
Field(discriminator="kind"),
|
||||
]
|
||||
"""Discriminated-union alias for the three analysis variants.
|
||||
|
||||
Pydantic dispatches on ``kind``. One of the three must always be
|
||||
present in a valid config.
|
||||
"""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Output
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class OutputConfig(BaseModel):
|
||||
"""Where :func:`run_analysis` should write its :class:`AnalysisReport`.
|
||||
|
||||
Kept as a sub-object rather than a bare path so downstream
|
||||
extensions (figure paths, supplementary distance-matrix files)
|
||||
can land here without changing the config's top-level shape.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid", frozen=True)
|
||||
|
||||
report: Path
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Top-level config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class AnalysisConfig(BaseModel):
|
||||
"""Declarative description of a full analysis run.
|
||||
|
||||
Parsed from YAML (via :func:`load_config`) or JSON (via
|
||||
:meth:`pydantic.BaseModel.model_validate`). Every field is
|
||||
cross-validated at parse time so a typo in a nested sub-field
|
||||
fails in milliseconds rather than after a multi-minute
|
||||
predictions load.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
config_version
|
||||
Schema version for the config itself. Only ``1`` is valid at
|
||||
this release. Future config-format breaks bump this and a
|
||||
sibling migration registry handles legacy YAML in place.
|
||||
inputs
|
||||
Predictions-file paths.
|
||||
preprocessing
|
||||
Per-input preprocessing (person-index selection today).
|
||||
segmentation
|
||||
Optional segmentation stage. ``None`` skips segmentation
|
||||
entirely and analysis runs over each full trial as a single
|
||||
"segment".
|
||||
analysis
|
||||
Required analysis stage. Exactly one of
|
||||
:class:`DtwAnalysis` / :class:`StatsAnalysis` / :class:`NoAnalysis`.
|
||||
output
|
||||
Output paths.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid", frozen=True)
|
||||
|
||||
config_version: Literal[1] = 1
|
||||
inputs: InputsConfig
|
||||
preprocessing: PreprocessingConfig = Field(default_factory=PreprocessingConfig)
|
||||
segmentation: SegmentationStage | None = None
|
||||
analysis: AnalysisStage
|
||||
output: OutputConfig
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _check_cross_stage_invariants(self) -> AnalysisConfig:
|
||||
# DTW is comparative — it needs a reference input.
|
||||
if isinstance(self.analysis, DtwAnalysis) and self.inputs.reference is None:
|
||||
raise ValueError("analysis.kind='dtw' requires inputs.reference to be set")
|
||||
# Stats is non-comparative — a reference without a use is
|
||||
# almost certainly an operator error.
|
||||
if isinstance(self.analysis, StatsAnalysis) and self.inputs.reference is not None:
|
||||
raise ValueError(
|
||||
"analysis.kind='stats' operates on inputs.primary only; "
|
||||
"remove inputs.reference or switch analysis.kind to 'dtw'"
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Report pieces
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class InputSummary(BaseModel):
|
||||
"""Capsule of an input predictions file's headline metadata.
|
||||
|
||||
Stored in the :class:`AnalysisReport` so a reader of the report
|
||||
can tell at a glance what was analysed without having to load the
|
||||
underlying predictions JSONs.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid", frozen=True)
|
||||
|
||||
path: Path
|
||||
frame_count: int = Field(ge=0)
|
||||
fps: float = Field(ge=0.0)
|
||||
provenance: Provenance | None = None
|
||||
|
||||
|
||||
class FeatureSummary(BaseModel):
|
||||
"""Pydantic twin of :class:`~neuropose.analyzer.features.FeatureStatistics`.
|
||||
|
||||
The dataclass is used throughout the analyzer for ad-hoc Python
|
||||
consumption; the report path needs a pydantic model for
|
||||
round-tripping through JSON.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid", frozen=True)
|
||||
|
||||
mean: float
|
||||
std: float
|
||||
min: float
|
||||
max: float
|
||||
range: float
|
||||
|
||||
|
||||
class DtwResults(BaseModel):
|
||||
"""DTW results attached to an :class:`AnalysisReport`.
|
||||
|
||||
``distances`` is parallel to ``segment_labels``. For an
|
||||
unsegmented run the lists have length 1 and the label is
|
||||
``"full_trial"``. For a segmented run each label takes the form
|
||||
``"<segmentation_key>[<index>]"`` (e.g. ``"left_heel_strikes[3]"``).
|
||||
``per_joint_distances`` carries a per-unit breakdown for
|
||||
``method="dtw_per_joint"`` only; its outer length matches
|
||||
``distances``, inner length matches either ``num_joints`` (coords)
|
||||
or ``len(angle_triplets)`` (angles).
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid", frozen=True)
|
||||
|
||||
kind: Literal["dtw"]
|
||||
method: Literal["dtw_all", "dtw_per_joint", "dtw_relation"]
|
||||
distances: list[float]
|
||||
paths: list[list[tuple[int, int]]]
|
||||
per_joint_distances: list[list[float]] | None = None
|
||||
segment_labels: list[str]
|
||||
summary: dict[str, float]
|
||||
|
||||
|
||||
class StatsResults(BaseModel):
|
||||
"""Feature-statistics results attached to an :class:`AnalysisReport`.
|
||||
|
||||
``statistics`` is parallel to ``segment_labels``; see
|
||||
:class:`DtwResults` for the labelling convention.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid", frozen=True)
|
||||
|
||||
kind: Literal["stats"]
|
||||
statistics: list[FeatureSummary]
|
||||
segment_labels: list[str]
|
||||
|
||||
|
||||
class NoResults(BaseModel):
|
||||
"""Empty results payload for ``analysis.kind='none'`` runs."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid", frozen=True)
|
||||
|
||||
kind: Literal["none"]
|
||||
|
||||
|
||||
AnalysisResults = Annotated[
|
||||
DtwResults | StatsResults | NoResults,
|
||||
Field(discriminator="kind"),
|
||||
]
|
||||
"""Discriminated-union alias for the three analysis-result shapes.
|
||||
|
||||
Mirrors :data:`AnalysisStage` one-for-one: ``DtwAnalysis`` produces
|
||||
:class:`DtwResults`, etc.
|
||||
"""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Top-level report
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class AnalysisReport(BaseModel):
|
||||
"""Self-describing output artifact of :func:`run_analysis`.
|
||||
|
||||
Serialised to JSON on disk. Carries the originating config, the
|
||||
:class:`~neuropose.io.Provenance` envelope (with the config
|
||||
serialised into :attr:`~neuropose.io.Provenance.analysis_config`
|
||||
so the report is self-describing even if the YAML is lost), each
|
||||
input's headline metadata plus its own provenance if available,
|
||||
any segmentations produced, and the analysis results themselves.
|
||||
|
||||
Lives in the schema-migration registry under ``"AnalysisReport"``
|
||||
at ``CURRENT_VERSION``; see :mod:`neuropose.migrations`.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid", frozen=True)
|
||||
|
||||
schema_version: int = Field(default=CURRENT_VERSION, ge=1)
|
||||
config: AnalysisConfig
|
||||
provenance: Provenance | None = None
|
||||
primary: InputSummary
|
||||
reference: InputSummary | None = None
|
||||
segmentations: dict[str, Segmentation] = Field(default_factory=dict)
|
||||
results: AnalysisResults
|
||||
|
||||
|
||||
def analysis_config_to_dict(config: AnalysisConfig) -> dict[str, Any]:
|
||||
"""Serialise an :class:`AnalysisConfig` to a JSON-safe dict.
|
||||
|
||||
Returned shape is identical to what pydantic would produce via
|
||||
:meth:`~pydantic.BaseModel.model_dump` in ``mode="json"`` — paths
|
||||
become strings, tuples become lists, enums become their values.
|
||||
Useful for stamping
|
||||
:attr:`~neuropose.io.Provenance.analysis_config` on the
|
||||
:class:`AnalysisReport`'s provenance envelope.
|
||||
"""
|
||||
return config.model_dump(mode="json")
|
||||
|
|
@ -60,7 +60,9 @@ Version history
|
|||
---------------
|
||||
- **v1:** initial schema, pre-Phase-0.
|
||||
- **v2:** added optional ``provenance`` field to :class:`~neuropose.io.VideoPredictions`
|
||||
and :class:`~neuropose.io.BenchmarkResult` (Phase 0, Paper C reproducibility envelope)."""
|
||||
and :class:`~neuropose.io.BenchmarkResult` (Phase 0, Paper C reproducibility envelope).
|
||||
:class:`~neuropose.analyzer.pipeline.AnalysisReport` also enters the registry at v2
|
||||
(no legacy v1 payloads ever existed for it, so no migration is registered)."""
|
||||
|
||||
|
||||
class MigrationError(Exception):
|
||||
|
|
@ -90,6 +92,7 @@ class MigrationNotFoundError(MigrationError):
|
|||
# dict at ``source + 1``.
|
||||
_VIDEO_PREDICTIONS_MIGRATIONS: dict[int, Callable[[dict], dict]] = {}
|
||||
_BENCHMARK_RESULT_MIGRATIONS: dict[int, Callable[[dict], dict]] = {}
|
||||
_ANALYSIS_REPORT_MIGRATIONS: dict[int, Callable[[dict], dict]] = {}
|
||||
|
||||
|
||||
def register_video_predictions_migration(
|
||||
|
|
@ -142,6 +145,29 @@ def register_benchmark_result_migration(
|
|||
return wrap
|
||||
|
||||
|
||||
def register_analysis_report_migration(
|
||||
from_version: int,
|
||||
) -> Callable[[Callable[[dict], dict]], Callable[[dict], dict]]:
|
||||
"""Register a :class:`~neuropose.analyzer.pipeline.AnalysisReport` migration.
|
||||
|
||||
See :func:`register_video_predictions_migration` for usage — this
|
||||
is the same pattern for the analysis-report registry. Unlike the
|
||||
other two schemas, :class:`AnalysisReport` first appeared at
|
||||
:data:`CURRENT_VERSION = 2`, so no ``from_version=1`` migration
|
||||
exists (and none is expected).
|
||||
"""
|
||||
|
||||
def wrap(fn: Callable[[dict], dict]) -> Callable[[dict], dict]:
|
||||
if from_version in _ANALYSIS_REPORT_MIGRATIONS:
|
||||
raise RuntimeError(
|
||||
f"analysis-report migration already registered from version {from_version}"
|
||||
)
|
||||
_ANALYSIS_REPORT_MIGRATIONS[from_version] = fn
|
||||
return fn
|
||||
|
||||
return wrap
|
||||
|
||||
|
||||
def migrate_video_predictions(payload: dict) -> dict:
|
||||
"""Migrate a raw :class:`~neuropose.io.VideoPredictions` dict to current.
|
||||
|
||||
|
|
@ -179,6 +205,18 @@ def migrate_benchmark_result(payload: dict) -> dict:
|
|||
return _migrate(payload, _BENCHMARK_RESULT_MIGRATIONS, schema_name="BenchmarkResult")
|
||||
|
||||
|
||||
def migrate_analysis_report(payload: dict) -> dict:
|
||||
"""Migrate a raw :class:`~neuropose.analyzer.pipeline.AnalysisReport` dict.
|
||||
|
||||
See :func:`migrate_video_predictions` for semantics. Because
|
||||
:class:`AnalysisReport` first shipped at schema_version 2, a
|
||||
payload missing the key still defaults to 1 (and would require a
|
||||
not-yet-registered v1 → v2 migration); this is only reachable for
|
||||
deliberately malformed inputs.
|
||||
"""
|
||||
return _migrate(payload, _ANALYSIS_REPORT_MIGRATIONS, schema_name="AnalysisReport")
|
||||
|
||||
|
||||
def migrate_job_results(payload: dict) -> dict:
|
||||
"""Migrate a :class:`~neuropose.io.JobResults` root dict to current.
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,434 @@
|
|||
"""Tests for :mod:`neuropose.analyzer.pipeline`.
|
||||
|
||||
This file covers the schema half of the pipeline:
|
||||
|
||||
- :class:`AnalysisConfig` parsing, including the discriminated unions
|
||||
for the segmentation and analysis stages, and the cross-field
|
||||
invariants enforced at parse time.
|
||||
- :class:`AnalysisReport` construction + JSON round-trip, including
|
||||
the migration hook (schema_version defaults to CURRENT_VERSION).
|
||||
- :func:`analysis_config_to_dict` JSON-safety.
|
||||
|
||||
The executor (``run_analysis``) gets its own test module.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from neuropose.analyzer.pipeline import (
|
||||
AnalysisConfig,
|
||||
AnalysisReport,
|
||||
DtwAnalysis,
|
||||
DtwResults,
|
||||
ExtractorSegmentation,
|
||||
FeatureSummary,
|
||||
GaitCyclesBilateralSegmentation,
|
||||
GaitCyclesSegmentation,
|
||||
InputsConfig,
|
||||
InputSummary,
|
||||
NoAnalysis,
|
||||
NoResults,
|
||||
OutputConfig,
|
||||
PreprocessingConfig,
|
||||
StatsAnalysis,
|
||||
StatsResults,
|
||||
analysis_config_to_dict,
|
||||
)
|
||||
from neuropose.migrations import CURRENT_VERSION
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _minimal_dtw_config(tmp_path: Path) -> dict[str, Any]:
|
||||
"""A minimal AnalysisConfig dict with dtw_all + reference."""
|
||||
return {
|
||||
"inputs": {
|
||||
"primary": str(tmp_path / "primary.json"),
|
||||
"reference": str(tmp_path / "reference.json"),
|
||||
},
|
||||
"analysis": {"kind": "dtw", "method": "dtw_all"},
|
||||
"output": {"report": str(tmp_path / "report.json")},
|
||||
}
|
||||
|
||||
|
||||
def _minimal_stats_config(tmp_path: Path) -> dict[str, Any]:
|
||||
return {
|
||||
"inputs": {"primary": str(tmp_path / "primary.json")},
|
||||
"analysis": {
|
||||
"kind": "stats",
|
||||
"extractor": {"kind": "joint_axis", "joint": 32, "axis": 1, "invert": False},
|
||||
},
|
||||
"output": {"report": str(tmp_path / "report.json")},
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# InputsConfig / PreprocessingConfig / OutputConfig
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestInputsConfig:
|
||||
def test_primary_only(self, tmp_path: Path) -> None:
|
||||
cfg = InputsConfig(primary=tmp_path / "a.json")
|
||||
assert cfg.reference is None
|
||||
|
||||
def test_primary_and_reference(self, tmp_path: Path) -> None:
|
||||
cfg = InputsConfig(
|
||||
primary=tmp_path / "a.json",
|
||||
reference=tmp_path / "b.json",
|
||||
)
|
||||
assert cfg.reference == tmp_path / "b.json"
|
||||
|
||||
def test_extra_field_rejected(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(ValidationError, match="Extra inputs"):
|
||||
InputsConfig.model_validate(
|
||||
{
|
||||
"primary": str(tmp_path / "a.json"),
|
||||
"extra": "field",
|
||||
}
|
||||
)
|
||||
|
||||
def test_frozen(self, tmp_path: Path) -> None:
|
||||
cfg = InputsConfig(primary=tmp_path / "a.json")
|
||||
with pytest.raises(ValidationError, match="frozen"):
|
||||
cfg.primary = tmp_path / "b.json" # type: ignore[misc]
|
||||
|
||||
|
||||
class TestPreprocessingConfig:
|
||||
def test_default_person_index_zero(self) -> None:
|
||||
cfg = PreprocessingConfig()
|
||||
assert cfg.person_index == 0
|
||||
|
||||
def test_negative_person_index_rejected(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
PreprocessingConfig(person_index=-1)
|
||||
|
||||
|
||||
class TestOutputConfig:
|
||||
def test_report_path(self, tmp_path: Path) -> None:
|
||||
cfg = OutputConfig(report=tmp_path / "out.json")
|
||||
assert cfg.report == tmp_path / "out.json"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Segmentation stage discriminated union
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSegmentationStage:
|
||||
def test_gait_cycles_parses_from_dict(self, tmp_path: Path) -> None:
|
||||
config_dict = _minimal_dtw_config(tmp_path)
|
||||
config_dict["segmentation"] = {
|
||||
"kind": "gait_cycles",
|
||||
"joint": "lhee",
|
||||
"axis": "y",
|
||||
"min_cycle_seconds": 0.5,
|
||||
}
|
||||
cfg = AnalysisConfig.model_validate(config_dict)
|
||||
assert isinstance(cfg.segmentation, GaitCyclesSegmentation)
|
||||
assert cfg.segmentation.joint == "lhee"
|
||||
assert cfg.segmentation.min_cycle_seconds == 0.5
|
||||
|
||||
def test_bilateral_parses_from_dict(self, tmp_path: Path) -> None:
|
||||
config_dict = _minimal_dtw_config(tmp_path)
|
||||
config_dict["segmentation"] = {"kind": "gait_cycles_bilateral"}
|
||||
cfg = AnalysisConfig.model_validate(config_dict)
|
||||
assert isinstance(cfg.segmentation, GaitCyclesBilateralSegmentation)
|
||||
|
||||
def test_extractor_parses_from_dict(self, tmp_path: Path) -> None:
|
||||
config_dict = _minimal_dtw_config(tmp_path)
|
||||
config_dict["segmentation"] = {
|
||||
"kind": "extractor",
|
||||
"extractor": {
|
||||
"kind": "joint_axis",
|
||||
"joint": 15,
|
||||
"axis": 1,
|
||||
"invert": False,
|
||||
},
|
||||
"label": "wrist_cycles",
|
||||
"min_distance_seconds": 0.5,
|
||||
}
|
||||
cfg = AnalysisConfig.model_validate(config_dict)
|
||||
assert isinstance(cfg.segmentation, ExtractorSegmentation)
|
||||
assert cfg.segmentation.label == "wrist_cycles"
|
||||
|
||||
def test_unknown_kind_rejected(self, tmp_path: Path) -> None:
|
||||
config_dict = _minimal_dtw_config(tmp_path)
|
||||
config_dict["segmentation"] = {"kind": "unknown_method"}
|
||||
with pytest.raises(ValidationError):
|
||||
AnalysisConfig.model_validate(config_dict)
|
||||
|
||||
def test_segmentation_omitted_is_none(self, tmp_path: Path) -> None:
|
||||
cfg = AnalysisConfig.model_validate(_minimal_dtw_config(tmp_path))
|
||||
assert cfg.segmentation is None
|
||||
|
||||
def test_invalid_min_cycle_seconds_rejected(self, tmp_path: Path) -> None:
|
||||
config_dict = _minimal_dtw_config(tmp_path)
|
||||
config_dict["segmentation"] = {
|
||||
"kind": "gait_cycles",
|
||||
"min_cycle_seconds": 0.0, # must be > 0
|
||||
}
|
||||
with pytest.raises(ValidationError):
|
||||
AnalysisConfig.model_validate(config_dict)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Analysis stage discriminated union + cross-field invariants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDtwAnalysisValidation:
|
||||
def test_dtw_relation_requires_joints(self) -> None:
|
||||
with pytest.raises(ValidationError, match="joint_i and joint_j"):
|
||||
DtwAnalysis(kind="dtw", method="dtw_relation")
|
||||
|
||||
def test_dtw_relation_rejects_angles_representation(self) -> None:
|
||||
with pytest.raises(ValidationError, match="only supports representation='coords'"):
|
||||
DtwAnalysis(
|
||||
kind="dtw",
|
||||
method="dtw_relation",
|
||||
joint_i=0,
|
||||
joint_j=1,
|
||||
representation="angles",
|
||||
angle_triplets=[(0, 1, 2)],
|
||||
)
|
||||
|
||||
def test_angles_requires_triplets(self) -> None:
|
||||
with pytest.raises(ValidationError, match="angle_triplets"):
|
||||
DtwAnalysis(
|
||||
kind="dtw",
|
||||
method="dtw_all",
|
||||
representation="angles",
|
||||
)
|
||||
|
||||
def test_angles_with_empty_triplets_rejected(self) -> None:
|
||||
with pytest.raises(ValidationError, match="angle_triplets"):
|
||||
DtwAnalysis(
|
||||
kind="dtw",
|
||||
method="dtw_all",
|
||||
representation="angles",
|
||||
angle_triplets=[],
|
||||
)
|
||||
|
||||
def test_happy_path_dtw_all_coords(self) -> None:
|
||||
analysis = DtwAnalysis(
|
||||
kind="dtw",
|
||||
method="dtw_all",
|
||||
align="procrustes_per_sequence",
|
||||
nan_policy="interpolate",
|
||||
)
|
||||
assert analysis.align == "procrustes_per_sequence"
|
||||
|
||||
def test_happy_path_dtw_all_angles(self) -> None:
|
||||
analysis = DtwAnalysis(
|
||||
kind="dtw",
|
||||
method="dtw_all",
|
||||
representation="angles",
|
||||
angle_triplets=[(0, 1, 2), (3, 4, 5)],
|
||||
)
|
||||
assert len(analysis.angle_triplets or []) == 2
|
||||
|
||||
def test_happy_path_dtw_relation(self) -> None:
|
||||
analysis = DtwAnalysis(
|
||||
kind="dtw",
|
||||
method="dtw_relation",
|
||||
joint_i=15,
|
||||
joint_j=23,
|
||||
)
|
||||
assert analysis.joint_i == 15
|
||||
|
||||
|
||||
class TestAnalysisCrossStage:
|
||||
def test_dtw_without_reference_rejected(self, tmp_path: Path) -> None:
|
||||
config_dict = {
|
||||
"inputs": {"primary": str(tmp_path / "a.json")},
|
||||
"analysis": {"kind": "dtw", "method": "dtw_all"},
|
||||
"output": {"report": str(tmp_path / "out.json")},
|
||||
}
|
||||
with pytest.raises(ValidationError, match=r"inputs\.reference"):
|
||||
AnalysisConfig.model_validate(config_dict)
|
||||
|
||||
def test_stats_with_reference_rejected(self, tmp_path: Path) -> None:
|
||||
config_dict = _minimal_stats_config(tmp_path)
|
||||
config_dict["inputs"]["reference"] = str(tmp_path / "b.json")
|
||||
with pytest.raises(ValidationError, match="primary only"):
|
||||
AnalysisConfig.model_validate(config_dict)
|
||||
|
||||
def test_none_analysis_requires_no_reference(self, tmp_path: Path) -> None:
|
||||
# NoAnalysis is fine with either reference present or absent.
|
||||
config_dict = {
|
||||
"inputs": {"primary": str(tmp_path / "a.json")},
|
||||
"analysis": {"kind": "none"},
|
||||
"output": {"report": str(tmp_path / "out.json")},
|
||||
}
|
||||
cfg = AnalysisConfig.model_validate(config_dict)
|
||||
assert isinstance(cfg.analysis, NoAnalysis)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Top-level AnalysisConfig
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAnalysisConfig:
|
||||
def test_minimal_dtw_config_parses(self, tmp_path: Path) -> None:
|
||||
cfg = AnalysisConfig.model_validate(_minimal_dtw_config(tmp_path))
|
||||
assert cfg.config_version == 1
|
||||
assert isinstance(cfg.analysis, DtwAnalysis)
|
||||
assert cfg.preprocessing.person_index == 0 # default
|
||||
|
||||
def test_minimal_stats_config_parses(self, tmp_path: Path) -> None:
|
||||
cfg = AnalysisConfig.model_validate(_minimal_stats_config(tmp_path))
|
||||
assert isinstance(cfg.analysis, StatsAnalysis)
|
||||
|
||||
def test_config_version_must_be_1(self, tmp_path: Path) -> None:
|
||||
config_dict = _minimal_dtw_config(tmp_path)
|
||||
config_dict["config_version"] = 99
|
||||
with pytest.raises(ValidationError):
|
||||
AnalysisConfig.model_validate(config_dict)
|
||||
|
||||
def test_round_trip_json(self, tmp_path: Path) -> None:
|
||||
original = AnalysisConfig.model_validate(_minimal_dtw_config(tmp_path))
|
||||
serialised = original.model_dump_json()
|
||||
restored = AnalysisConfig.model_validate_json(serialised)
|
||||
assert restored == original
|
||||
|
||||
def test_extra_top_level_field_rejected(self, tmp_path: Path) -> None:
|
||||
config_dict = _minimal_dtw_config(tmp_path)
|
||||
config_dict["unknown_key"] = "typo"
|
||||
with pytest.raises(ValidationError):
|
||||
AnalysisConfig.model_validate(config_dict)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# analysis_config_to_dict
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAnalysisConfigToDict:
|
||||
def test_returns_json_safe_dict(self, tmp_path: Path) -> None:
|
||||
cfg = AnalysisConfig.model_validate(_minimal_dtw_config(tmp_path))
|
||||
dumped = analysis_config_to_dict(cfg)
|
||||
# Paths must have become strings.
|
||||
assert isinstance(dumped["inputs"]["primary"], str)
|
||||
assert isinstance(dumped["output"]["report"], str)
|
||||
|
||||
def test_round_trips_through_dict(self, tmp_path: Path) -> None:
|
||||
original = AnalysisConfig.model_validate(_minimal_dtw_config(tmp_path))
|
||||
dumped = analysis_config_to_dict(original)
|
||||
restored = AnalysisConfig.model_validate(dumped)
|
||||
assert restored == original
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Result sub-schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDtwResults:
|
||||
def test_minimal_construction(self) -> None:
|
||||
res = DtwResults(
|
||||
kind="dtw",
|
||||
method="dtw_all",
|
||||
distances=[0.5],
|
||||
paths=[[(0, 0), (1, 1)]],
|
||||
segment_labels=["full_trial"],
|
||||
summary={"mean": 0.5},
|
||||
)
|
||||
assert res.kind == "dtw"
|
||||
|
||||
def test_per_joint_distances_shape_is_free(self) -> None:
|
||||
# No validator enforces that per_joint_distances outer length
|
||||
# matches distances — that's run-time semantics of the
|
||||
# executor. Still, verify the field round-trips.
|
||||
res = DtwResults(
|
||||
kind="dtw",
|
||||
method="dtw_per_joint",
|
||||
distances=[0.1, 0.2],
|
||||
paths=[[(0, 0)], [(0, 0)]],
|
||||
per_joint_distances=[[0.05, 0.05], [0.1, 0.1]],
|
||||
segment_labels=["rhee_cycles[0]", "rhee_cycles[1]"],
|
||||
summary={"mean": 0.15},
|
||||
)
|
||||
assert res.per_joint_distances is not None
|
||||
|
||||
|
||||
class TestStatsResults:
|
||||
def test_round_trip(self) -> None:
|
||||
res = StatsResults(
|
||||
kind="stats",
|
||||
statistics=[
|
||||
FeatureSummary(mean=1.0, std=0.1, min=0.8, max=1.2, range=0.4),
|
||||
FeatureSummary(mean=1.1, std=0.2, min=0.7, max=1.5, range=0.8),
|
||||
],
|
||||
segment_labels=["rhee_cycles[0]", "rhee_cycles[1]"],
|
||||
)
|
||||
dumped = res.model_dump_json()
|
||||
restored = StatsResults.model_validate_json(dumped)
|
||||
assert restored == res
|
||||
|
||||
|
||||
class TestNoResults:
|
||||
def test_construction(self) -> None:
|
||||
res = NoResults(kind="none")
|
||||
assert res.kind == "none"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AnalysisReport
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_report(tmp_path: Path) -> AnalysisReport:
|
||||
config = AnalysisConfig.model_validate(_minimal_dtw_config(tmp_path))
|
||||
return AnalysisReport(
|
||||
config=config,
|
||||
primary=InputSummary(
|
||||
path=tmp_path / "primary.json",
|
||||
frame_count=300,
|
||||
fps=30.0,
|
||||
),
|
||||
reference=InputSummary(
|
||||
path=tmp_path / "reference.json",
|
||||
frame_count=300,
|
||||
fps=30.0,
|
||||
),
|
||||
results=DtwResults(
|
||||
kind="dtw",
|
||||
method="dtw_all",
|
||||
distances=[0.42],
|
||||
paths=[[(0, 0), (1, 1)]],
|
||||
segment_labels=["full_trial"],
|
||||
summary={"mean": 0.42, "p50": 0.42},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class TestAnalysisReport:
|
||||
def test_schema_version_defaults_to_current(self, tmp_path: Path) -> None:
|
||||
report = _make_report(tmp_path)
|
||||
assert report.schema_version == CURRENT_VERSION
|
||||
|
||||
def test_round_trip_json(self, tmp_path: Path) -> None:
|
||||
report = _make_report(tmp_path)
|
||||
serialised = report.model_dump_json()
|
||||
restored = AnalysisReport.model_validate_json(serialised)
|
||||
assert restored == report
|
||||
|
||||
def test_empty_segmentations_default(self, tmp_path: Path) -> None:
|
||||
report = _make_report(tmp_path)
|
||||
assert report.segmentations == {}
|
||||
|
||||
def test_extra_field_rejected(self, tmp_path: Path) -> None:
|
||||
report = _make_report(tmp_path)
|
||||
dumped = report.model_dump(mode="json")
|
||||
dumped["mystery_field"] = 1
|
||||
with pytest.raises(ValidationError):
|
||||
AnalysisReport.model_validate(dumped)
|
||||
|
|
@ -38,6 +38,7 @@ from neuropose.migrations import (
|
|||
FutureSchemaError,
|
||||
MigrationError,
|
||||
MigrationNotFoundError,
|
||||
migrate_analysis_report,
|
||||
migrate_benchmark_result,
|
||||
migrate_job_results,
|
||||
migrate_video_predictions,
|
||||
|
|
@ -282,8 +283,32 @@ class TestMigrateJobResults:
|
|||
assert result["video_b.mp4"]["content_b"] is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# migrate_analysis_report
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMigrateAnalysisReport:
|
||||
def test_current_version_is_noop(self) -> None:
|
||||
"""A payload already at CURRENT_VERSION passes through unchanged."""
|
||||
payload = {"schema_version": CURRENT_VERSION, "foo": "bar"}
|
||||
assert migrate_analysis_report(payload) == payload
|
||||
|
||||
def test_missing_version_defaults_to_v1_and_fails(self) -> None:
|
||||
"""AnalysisReport first shipped at CURRENT_VERSION, so a payload
|
||||
without schema_version (defaulting to 1) would require a
|
||||
non-existent v1→v2 migration and fail with a clear error."""
|
||||
with pytest.raises(MigrationNotFoundError, match="AnalysisReport"):
|
||||
migrate_analysis_report({})
|
||||
|
||||
def test_future_version_rejected(self) -> None:
|
||||
with pytest.raises(FutureSchemaError, match="AnalysisReport"):
|
||||
migrate_analysis_report({"schema_version": CURRENT_VERSION + 5})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# register_video_predictions_migration / register_benchmark_result_migration
|
||||
# / register_analysis_report_migration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
|
|
@ -312,6 +337,21 @@ class TestRegistration:
|
|||
assert _fn.__name__ == "_fn"
|
||||
assert _fn({"x": 1}) == {"x": 1}
|
||||
|
||||
def test_analysis_report_duplicate_registration_raises(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
monkeypatch.setattr(migrations, "_ANALYSIS_REPORT_MIGRATIONS", {})
|
||||
|
||||
@migrations.register_analysis_report_migration(from_version=2)
|
||||
def _first(p: dict) -> dict:
|
||||
return p
|
||||
|
||||
with pytest.raises(RuntimeError, match="already registered"):
|
||||
|
||||
@migrations.register_analysis_report_migration(from_version=2)
|
||||
def _second(p: dict) -> dict:
|
||||
return p
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration: load_* functions run migrations before validation
|
||||
|
|
|
|||
Loading…
Reference in New Issue