From 979beb1078216fcc4df61a62dffca8c8155e3ac2 Mon Sep 17 00:00:00 2001 From: Levi Neuwirth Date: Wed, 22 Apr 2026 11:13:36 -0400 Subject: [PATCH] add AnalysisConfig and AnalysisReport schemas MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit neuropose.analyzer.pipeline ships two top-level pydantic schemas: AnalysisConfig — what a user writes in YAML. Inputs (primary plus optional reference), preprocessing (person_index, room to grow), optional segmentation as a discriminated union of gait_cycles, gait_cycles_bilateral, and extractor, and a required analysis stage as a discriminated union of dtw, stats, none. AnalysisReport — runtime output with config, Provenance envelope, per-input summaries, produced segmentations, and a results payload whose shape mirrors the stage (DtwResults, StatsResults, NoResults). schema_version defaults to CURRENT_VERSION. Cross-field invariants enforced at parse time via model_validator: method='dtw_relation' requires joint_i/joint_j and refuses representation='angles'; representation='angles' requires non-empty angle_triplets; analysis.kind='dtw' requires inputs.reference; analysis.kind='stats' refuses a reference. Typos fail in milliseconds instead of after a multi-minute predictions load. neuropose.migrations gains a third registry for AnalysisReport (_ANALYSIS_REPORT_MIGRATIONS + register_analysis_report_migration + migrate_analysis_report), ready for future schema changes. No v1→v2 migration is registered because AnalysisReport first shipped at v2. Execution, CLI wiring, and example configs land in follow-up commits. --- CHANGELOG.md | 24 ++ src/neuropose/analyzer/__init__.py | 42 +++ src/neuropose/analyzer/pipeline.py | 478 +++++++++++++++++++++++++++ src/neuropose/migrations.py | 40 ++- tests/unit/test_analyzer_pipeline.py | 434 ++++++++++++++++++++++++ tests/unit/test_migrations.py | 40 +++ 6 files changed, 1057 insertions(+), 1 deletion(-) create mode 100644 src/neuropose/analyzer/pipeline.py create mode 100644 tests/unit/test_analyzer_pipeline.py diff --git a/CHANGELOG.md b/CHANGELOG.md index adde901..9e8278d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -257,6 +257,30 @@ be split into per-release sections once tagging begins. two-joint displacement DTW; users who prefer a unified API can express the same computation via `dtw_all` with an appropriate pair of angle triplets or run `dtw_relation` directly. +- **`neuropose.analyzer.pipeline`** (schemas) — declarative + analysis-pipeline configuration and output artifact, parseable from + YAML or JSON via pydantic. `AnalysisConfig` captures a full + experiment: inputs (primary + optional reference predictions + files), preprocessing (person index, with room to grow), + optional segmentation (`gait_cycles` / `gait_cycles_bilateral` / + `extractor` discriminated union), and a required analysis stage + (`dtw` / `stats` / `none` discriminated union). `AnalysisReport` + is the runtime output: carries the originating config, a + `Provenance` envelope with `analysis_config` populated, per-input + summaries, produced segmentations, and an analysis-result payload + that mirrors the stage choice (`DtwResults`, `StatsResults`, or + `NoResults`). Cross-field invariants — `method="dtw_relation"` + requires `joint_i`/`joint_j`, `representation="angles"` requires + non-empty `angle_triplets`, `analysis.kind="dtw"` requires + `inputs.reference`, `analysis.kind="stats"` refuses a reference — + are enforced at parse time via `model_validator` so typos fail in + milliseconds instead of after a multi-minute predictions load. + `AnalysisReport` carries a `schema_version` field defaulting to + `CURRENT_VERSION = 2`, with a new + `register_analysis_report_migration` decorator and + `migrate_analysis_report` driver in `neuropose.migrations` ready + for future schema changes. Pipeline execution lands in a + follow-up commit. - **`neuropose.analyzer.segment.segment_gait_cycles`** and **`segment_gait_cycles_bilateral`** — clinical convenience wrappers over `segment_predictions` that pre-fill a `joint_axis` diff --git a/src/neuropose/analyzer/__init__.py b/src/neuropose/analyzer/__init__.py index 2d04464..9b8c5a9 100644 --- a/src/neuropose/analyzer/__init__.py +++ b/src/neuropose/analyzer/__init__.py @@ -48,6 +48,28 @@ from neuropose.analyzer.features import ( predictions_to_numpy, procrustes_align, ) +from neuropose.analyzer.pipeline import ( + AnalysisConfig, + AnalysisReport, + AnalysisResults, + AnalysisStage, + DtwAnalysis, + DtwResults, + ExtractorSegmentation, + FeatureSummary, + GaitCyclesBilateralSegmentation, + GaitCyclesSegmentation, + InputsConfig, + InputSummary, + NoAnalysis, + NoResults, + OutputConfig, + PreprocessingConfig, + SegmentationStage, + StatsAnalysis, + StatsResults, + analysis_config_to_dict, +) from neuropose.analyzer.segment import ( JOINT_INDEX, JOINT_NAMES, @@ -70,12 +92,32 @@ __all__ = [ "JOINT_NAMES", "AlignMode", "AlignmentDiagnostics", + "AnalysisConfig", + "AnalysisReport", + "AnalysisResults", + "AnalysisStage", "AxisLetter", "DTWResult", + "DtwAnalysis", + "DtwResults", + "ExtractorSegmentation", "FeatureStatistics", + "FeatureSummary", + "GaitCyclesBilateralSegmentation", + "GaitCyclesSegmentation", + "InputSummary", + "InputsConfig", "NanPolicy", + "NoAnalysis", + "NoResults", + "OutputConfig", + "PreprocessingConfig", "ProcrustesMode", "Representation", + "SegmentationStage", + "StatsAnalysis", + "StatsResults", + "analysis_config_to_dict", "dtw_all", "dtw_per_joint", "dtw_relation", diff --git a/src/neuropose/analyzer/pipeline.py b/src/neuropose/analyzer/pipeline.py new file mode 100644 index 0000000..30a9b10 --- /dev/null +++ b/src/neuropose/analyzer/pipeline.py @@ -0,0 +1,478 @@ +"""YAML-configurable analysis pipeline. + +This module unifies the analyzer's individual primitives (Procrustes +alignment, gait-cycle segmentation, DTW on coords or angles, feature +statistics) behind a single declarative configuration object so an +experiment can be reproduced from one file that lives in a git +repository, carries through to the :class:`~neuropose.io.Provenance` +envelope on the output artifact, and can be cited unambiguously in +accompanying papers. + +Two top-level schemas live here: + +- :class:`AnalysisConfig` — what the user writes in YAML. Describes + the full pipeline: inputs, preprocessing, optional segmentation, + required analysis stage, and output path. +- :class:`AnalysisReport` — what :func:`run_analysis` emits. Carries + the config, a :class:`~neuropose.io.Provenance` envelope (with the + config serialised into :attr:`~neuropose.io.Provenance.analysis_config`), + per-input summaries, segmentation results, and the analysis results + themselves. + +Both schemas parse from (and serialise to) JSON via pydantic; the +config additionally parses from YAML via :func:`load_config`. Cross-field +invariants (for example, ``method="dtw_relation"`` requires ``joint_i`` +and ``joint_j``) are enforced at parse time so typo-laden configs fail +fast rather than after an expensive multi-minute load. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Annotated, Any, Literal + +from pydantic import BaseModel, ConfigDict, Field, model_validator + +from neuropose.analyzer.dtw import AlignMode, NanPolicy, Representation +from neuropose.analyzer.segment import AxisLetter +from neuropose.io import ExtractorSpec, Provenance, Segmentation +from neuropose.migrations import CURRENT_VERSION + +# --------------------------------------------------------------------------- +# Inputs +# --------------------------------------------------------------------------- + + +class InputsConfig(BaseModel): + """Predictions files consumed by the pipeline. + + Attributes + ---------- + primary + Path to a :class:`~neuropose.io.VideoPredictions` JSON file. + Always required. + reference + Optional second predictions file. When provided, + :class:`DtwAnalysis` runs comparative DTW between primary and + reference; when absent, analysis stages that require a + reference (i.e. DTW) raise a validation error at parse time. + """ + + model_config = ConfigDict(extra="forbid", frozen=True) + + primary: Path + reference: Path | None = None + + +# --------------------------------------------------------------------------- +# Preprocessing +# --------------------------------------------------------------------------- + + +class PreprocessingConfig(BaseModel): + """Per-input preprocessing applied before segmentation and analysis. + + Minimal today — just picks which detected person to extract from + each frame. Left as a named stage so future extensions (coordinate + normalisation, smoothing) can land here without reshuffling the + config shape. + + Attributes + ---------- + person_index + Which detected person to extract per frame. Defaults to ``0`` + (the first detected person), matching the single-subject + clinical case. + """ + + model_config = ConfigDict(extra="forbid", frozen=True) + + person_index: int = Field(default=0, ge=0) + + +# --------------------------------------------------------------------------- +# Segmentation stage +# --------------------------------------------------------------------------- + + +class GaitCyclesSegmentation(BaseModel): + """Single-heel gait-cycle segmentation via peak detection. + + Produces one :class:`~neuropose.io.Segmentation` keyed under the + joint name (e.g. ``"rhee_cycles"``). See + :func:`~neuropose.analyzer.segment.segment_gait_cycles` for the + underlying implementation. + """ + + model_config = ConfigDict(extra="forbid", frozen=True) + + kind: Literal["gait_cycles"] + joint: str = "rhee" + axis: AxisLetter = "y" + invert: bool = False + min_cycle_seconds: float = Field(default=0.4, gt=0.0) + min_prominence: float | None = None + + +class GaitCyclesBilateralSegmentation(BaseModel): + """Bilateral (both heels) gait-cycle segmentation. + + Produces two :class:`~neuropose.io.Segmentation` objects keyed as + ``"left_heel_strikes"`` and ``"right_heel_strikes"``. + """ + + model_config = ConfigDict(extra="forbid", frozen=True) + + kind: Literal["gait_cycles_bilateral"] + axis: AxisLetter = "y" + invert: bool = False + min_cycle_seconds: float = Field(default=0.4, gt=0.0) + min_prominence: float | None = None + + +class ExtractorSegmentation(BaseModel): + """Generic extractor-driven segmentation. + + Wraps :func:`~neuropose.analyzer.segment.segment_predictions` with + a caller-supplied :class:`~neuropose.io.ExtractorSpec`. Use this + when the signal of interest is not the vertical heel trace — e.g. + wrist-hip distance for a reach-and-grasp task, or elbow flexion + angle for a range-of-motion trial. + """ + + model_config = ConfigDict(extra="forbid", frozen=True) + + kind: Literal["extractor"] + extractor: ExtractorSpec + label: str = Field( + default="segmentation", + description="Key under which the resulting Segmentation is stored.", + ) + person_index: int | None = Field( + default=None, + description=( + "Overrides preprocessing.person_index for this stage. " + "None defers to the global preprocessing setting." + ), + ) + min_distance_seconds: float | None = Field(default=None, ge=0.0) + min_prominence: float | None = None + min_height: float | None = None + pad_seconds: float = Field(default=0.0, ge=0.0) + + +SegmentationStage = Annotated[ + GaitCyclesSegmentation | GaitCyclesBilateralSegmentation | ExtractorSegmentation, + Field(discriminator="kind"), +] +"""Discriminated-union alias for the three segmentation variants. + +Pydantic dispatches on the ``kind`` field. A config without a +``segmentation`` key at all skips this stage entirely +(see :class:`AnalysisConfig.segmentation`). +""" + + +# --------------------------------------------------------------------------- +# Analysis stage +# --------------------------------------------------------------------------- + + +class DtwAnalysis(BaseModel): + """Dynamic Time Warping between the primary and reference inputs. + + Dispatches to one of :func:`~neuropose.analyzer.dtw.dtw_all`, + :func:`~neuropose.analyzer.dtw.dtw_per_joint`, or + :func:`~neuropose.analyzer.dtw.dtw_relation` per the ``method`` + field. Cross-field invariants — ``method="dtw_relation"`` requires + ``joint_i`` and ``joint_j``, ``representation="angles"`` requires + a non-empty ``angle_triplets`` — are enforced at parse time. + """ + + model_config = ConfigDict(extra="forbid", frozen=True) + + kind: Literal["dtw"] + method: Literal["dtw_all", "dtw_per_joint", "dtw_relation"] = "dtw_all" + align: AlignMode = "none" + representation: Representation = "coords" + angle_triplets: list[tuple[int, int, int]] | None = None + joint_i: int | None = None + joint_j: int | None = None + nan_policy: NanPolicy = "propagate" + + @model_validator(mode="after") + def _check_method_fields(self) -> DtwAnalysis: + if self.method == "dtw_relation": + if self.joint_i is None or self.joint_j is None: + raise ValueError("method='dtw_relation' requires joint_i and joint_j") + if self.representation != "coords": + raise ValueError( + "method='dtw_relation' only supports representation='coords' " + "(a two-joint displacement is not a joint-angle signal)" + ) + if self.representation == "angles": + if not self.angle_triplets: + raise ValueError("representation='angles' requires a non-empty angle_triplets list") + if self.method == "dtw_relation": + # Guarded by the earlier branch, but make the invariant explicit. + raise ValueError( + "representation='angles' is incompatible with method='dtw_relation'" + ) + return self + + +class StatsAnalysis(BaseModel): + """Summary statistics over a scalar signal extracted from the primary input. + + Runs :func:`~neuropose.analyzer.segment.extract_signal` with the + caller-supplied :class:`~neuropose.io.ExtractorSpec`, then + computes :func:`~neuropose.analyzer.features.extract_feature_statistics` + on each segment (or on the full trial if no segmentation stage + runs). + """ + + model_config = ConfigDict(extra="forbid", frozen=True) + + kind: Literal["stats"] + extractor: ExtractorSpec + + +class NoAnalysis(BaseModel): + """Terminal stage placeholder; produces no per-segment results. + + Useful when the pipeline's goal is just to segment the input and + persist the :class:`~neuropose.io.Segmentation` plus an + :class:`AnalysisReport` with provenance — the ``none`` analysis + kind makes that explicit rather than requiring the absence of the + stage. + """ + + model_config = ConfigDict(extra="forbid", frozen=True) + + kind: Literal["none"] + + +AnalysisStage = Annotated[ + DtwAnalysis | StatsAnalysis | NoAnalysis, + Field(discriminator="kind"), +] +"""Discriminated-union alias for the three analysis variants. + +Pydantic dispatches on ``kind``. One of the three must always be +present in a valid config. +""" + + +# --------------------------------------------------------------------------- +# Output +# --------------------------------------------------------------------------- + + +class OutputConfig(BaseModel): + """Where :func:`run_analysis` should write its :class:`AnalysisReport`. + + Kept as a sub-object rather than a bare path so downstream + extensions (figure paths, supplementary distance-matrix files) + can land here without changing the config's top-level shape. + """ + + model_config = ConfigDict(extra="forbid", frozen=True) + + report: Path + + +# --------------------------------------------------------------------------- +# Top-level config +# --------------------------------------------------------------------------- + + +class AnalysisConfig(BaseModel): + """Declarative description of a full analysis run. + + Parsed from YAML (via :func:`load_config`) or JSON (via + :meth:`pydantic.BaseModel.model_validate`). Every field is + cross-validated at parse time so a typo in a nested sub-field + fails in milliseconds rather than after a multi-minute + predictions load. + + Attributes + ---------- + config_version + Schema version for the config itself. Only ``1`` is valid at + this release. Future config-format breaks bump this and a + sibling migration registry handles legacy YAML in place. + inputs + Predictions-file paths. + preprocessing + Per-input preprocessing (person-index selection today). + segmentation + Optional segmentation stage. ``None`` skips segmentation + entirely and analysis runs over each full trial as a single + "segment". + analysis + Required analysis stage. Exactly one of + :class:`DtwAnalysis` / :class:`StatsAnalysis` / :class:`NoAnalysis`. + output + Output paths. + """ + + model_config = ConfigDict(extra="forbid", frozen=True) + + config_version: Literal[1] = 1 + inputs: InputsConfig + preprocessing: PreprocessingConfig = Field(default_factory=PreprocessingConfig) + segmentation: SegmentationStage | None = None + analysis: AnalysisStage + output: OutputConfig + + @model_validator(mode="after") + def _check_cross_stage_invariants(self) -> AnalysisConfig: + # DTW is comparative — it needs a reference input. + if isinstance(self.analysis, DtwAnalysis) and self.inputs.reference is None: + raise ValueError("analysis.kind='dtw' requires inputs.reference to be set") + # Stats is non-comparative — a reference without a use is + # almost certainly an operator error. + if isinstance(self.analysis, StatsAnalysis) and self.inputs.reference is not None: + raise ValueError( + "analysis.kind='stats' operates on inputs.primary only; " + "remove inputs.reference or switch analysis.kind to 'dtw'" + ) + return self + + +# --------------------------------------------------------------------------- +# Report pieces +# --------------------------------------------------------------------------- + + +class InputSummary(BaseModel): + """Capsule of an input predictions file's headline metadata. + + Stored in the :class:`AnalysisReport` so a reader of the report + can tell at a glance what was analysed without having to load the + underlying predictions JSONs. + """ + + model_config = ConfigDict(extra="forbid", frozen=True) + + path: Path + frame_count: int = Field(ge=0) + fps: float = Field(ge=0.0) + provenance: Provenance | None = None + + +class FeatureSummary(BaseModel): + """Pydantic twin of :class:`~neuropose.analyzer.features.FeatureStatistics`. + + The dataclass is used throughout the analyzer for ad-hoc Python + consumption; the report path needs a pydantic model for + round-tripping through JSON. + """ + + model_config = ConfigDict(extra="forbid", frozen=True) + + mean: float + std: float + min: float + max: float + range: float + + +class DtwResults(BaseModel): + """DTW results attached to an :class:`AnalysisReport`. + + ``distances`` is parallel to ``segment_labels``. For an + unsegmented run the lists have length 1 and the label is + ``"full_trial"``. For a segmented run each label takes the form + ``"[]"`` (e.g. ``"left_heel_strikes[3]"``). + ``per_joint_distances`` carries a per-unit breakdown for + ``method="dtw_per_joint"`` only; its outer length matches + ``distances``, inner length matches either ``num_joints`` (coords) + or ``len(angle_triplets)`` (angles). + """ + + model_config = ConfigDict(extra="forbid", frozen=True) + + kind: Literal["dtw"] + method: Literal["dtw_all", "dtw_per_joint", "dtw_relation"] + distances: list[float] + paths: list[list[tuple[int, int]]] + per_joint_distances: list[list[float]] | None = None + segment_labels: list[str] + summary: dict[str, float] + + +class StatsResults(BaseModel): + """Feature-statistics results attached to an :class:`AnalysisReport`. + + ``statistics`` is parallel to ``segment_labels``; see + :class:`DtwResults` for the labelling convention. + """ + + model_config = ConfigDict(extra="forbid", frozen=True) + + kind: Literal["stats"] + statistics: list[FeatureSummary] + segment_labels: list[str] + + +class NoResults(BaseModel): + """Empty results payload for ``analysis.kind='none'`` runs.""" + + model_config = ConfigDict(extra="forbid", frozen=True) + + kind: Literal["none"] + + +AnalysisResults = Annotated[ + DtwResults | StatsResults | NoResults, + Field(discriminator="kind"), +] +"""Discriminated-union alias for the three analysis-result shapes. + +Mirrors :data:`AnalysisStage` one-for-one: ``DtwAnalysis`` produces +:class:`DtwResults`, etc. +""" + + +# --------------------------------------------------------------------------- +# Top-level report +# --------------------------------------------------------------------------- + + +class AnalysisReport(BaseModel): + """Self-describing output artifact of :func:`run_analysis`. + + Serialised to JSON on disk. Carries the originating config, the + :class:`~neuropose.io.Provenance` envelope (with the config + serialised into :attr:`~neuropose.io.Provenance.analysis_config` + so the report is self-describing even if the YAML is lost), each + input's headline metadata plus its own provenance if available, + any segmentations produced, and the analysis results themselves. + + Lives in the schema-migration registry under ``"AnalysisReport"`` + at ``CURRENT_VERSION``; see :mod:`neuropose.migrations`. + """ + + model_config = ConfigDict(extra="forbid", frozen=True) + + schema_version: int = Field(default=CURRENT_VERSION, ge=1) + config: AnalysisConfig + provenance: Provenance | None = None + primary: InputSummary + reference: InputSummary | None = None + segmentations: dict[str, Segmentation] = Field(default_factory=dict) + results: AnalysisResults + + +def analysis_config_to_dict(config: AnalysisConfig) -> dict[str, Any]: + """Serialise an :class:`AnalysisConfig` to a JSON-safe dict. + + Returned shape is identical to what pydantic would produce via + :meth:`~pydantic.BaseModel.model_dump` in ``mode="json"`` — paths + become strings, tuples become lists, enums become their values. + Useful for stamping + :attr:`~neuropose.io.Provenance.analysis_config` on the + :class:`AnalysisReport`'s provenance envelope. + """ + return config.model_dump(mode="json") diff --git a/src/neuropose/migrations.py b/src/neuropose/migrations.py index 2367c9a..fb777ab 100644 --- a/src/neuropose/migrations.py +++ b/src/neuropose/migrations.py @@ -60,7 +60,9 @@ Version history --------------- - **v1:** initial schema, pre-Phase-0. - **v2:** added optional ``provenance`` field to :class:`~neuropose.io.VideoPredictions` - and :class:`~neuropose.io.BenchmarkResult` (Phase 0, Paper C reproducibility envelope).""" + and :class:`~neuropose.io.BenchmarkResult` (Phase 0, Paper C reproducibility envelope). + :class:`~neuropose.analyzer.pipeline.AnalysisReport` also enters the registry at v2 + (no legacy v1 payloads ever existed for it, so no migration is registered).""" class MigrationError(Exception): @@ -90,6 +92,7 @@ class MigrationNotFoundError(MigrationError): # dict at ``source + 1``. _VIDEO_PREDICTIONS_MIGRATIONS: dict[int, Callable[[dict], dict]] = {} _BENCHMARK_RESULT_MIGRATIONS: dict[int, Callable[[dict], dict]] = {} +_ANALYSIS_REPORT_MIGRATIONS: dict[int, Callable[[dict], dict]] = {} def register_video_predictions_migration( @@ -142,6 +145,29 @@ def register_benchmark_result_migration( return wrap +def register_analysis_report_migration( + from_version: int, +) -> Callable[[Callable[[dict], dict]], Callable[[dict], dict]]: + """Register a :class:`~neuropose.analyzer.pipeline.AnalysisReport` migration. + + See :func:`register_video_predictions_migration` for usage — this + is the same pattern for the analysis-report registry. Unlike the + other two schemas, :class:`AnalysisReport` first appeared at + :data:`CURRENT_VERSION = 2`, so no ``from_version=1`` migration + exists (and none is expected). + """ + + def wrap(fn: Callable[[dict], dict]) -> Callable[[dict], dict]: + if from_version in _ANALYSIS_REPORT_MIGRATIONS: + raise RuntimeError( + f"analysis-report migration already registered from version {from_version}" + ) + _ANALYSIS_REPORT_MIGRATIONS[from_version] = fn + return fn + + return wrap + + def migrate_video_predictions(payload: dict) -> dict: """Migrate a raw :class:`~neuropose.io.VideoPredictions` dict to current. @@ -179,6 +205,18 @@ def migrate_benchmark_result(payload: dict) -> dict: return _migrate(payload, _BENCHMARK_RESULT_MIGRATIONS, schema_name="BenchmarkResult") +def migrate_analysis_report(payload: dict) -> dict: + """Migrate a raw :class:`~neuropose.analyzer.pipeline.AnalysisReport` dict. + + See :func:`migrate_video_predictions` for semantics. Because + :class:`AnalysisReport` first shipped at schema_version 2, a + payload missing the key still defaults to 1 (and would require a + not-yet-registered v1 → v2 migration); this is only reachable for + deliberately malformed inputs. + """ + return _migrate(payload, _ANALYSIS_REPORT_MIGRATIONS, schema_name="AnalysisReport") + + def migrate_job_results(payload: dict) -> dict: """Migrate a :class:`~neuropose.io.JobResults` root dict to current. diff --git a/tests/unit/test_analyzer_pipeline.py b/tests/unit/test_analyzer_pipeline.py new file mode 100644 index 0000000..987cd18 --- /dev/null +++ b/tests/unit/test_analyzer_pipeline.py @@ -0,0 +1,434 @@ +"""Tests for :mod:`neuropose.analyzer.pipeline`. + +This file covers the schema half of the pipeline: + +- :class:`AnalysisConfig` parsing, including the discriminated unions + for the segmentation and analysis stages, and the cross-field + invariants enforced at parse time. +- :class:`AnalysisReport` construction + JSON round-trip, including + the migration hook (schema_version defaults to CURRENT_VERSION). +- :func:`analysis_config_to_dict` JSON-safety. + +The executor (``run_analysis``) gets its own test module. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +import pytest +from pydantic import ValidationError + +from neuropose.analyzer.pipeline import ( + AnalysisConfig, + AnalysisReport, + DtwAnalysis, + DtwResults, + ExtractorSegmentation, + FeatureSummary, + GaitCyclesBilateralSegmentation, + GaitCyclesSegmentation, + InputsConfig, + InputSummary, + NoAnalysis, + NoResults, + OutputConfig, + PreprocessingConfig, + StatsAnalysis, + StatsResults, + analysis_config_to_dict, +) +from neuropose.migrations import CURRENT_VERSION + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _minimal_dtw_config(tmp_path: Path) -> dict[str, Any]: + """A minimal AnalysisConfig dict with dtw_all + reference.""" + return { + "inputs": { + "primary": str(tmp_path / "primary.json"), + "reference": str(tmp_path / "reference.json"), + }, + "analysis": {"kind": "dtw", "method": "dtw_all"}, + "output": {"report": str(tmp_path / "report.json")}, + } + + +def _minimal_stats_config(tmp_path: Path) -> dict[str, Any]: + return { + "inputs": {"primary": str(tmp_path / "primary.json")}, + "analysis": { + "kind": "stats", + "extractor": {"kind": "joint_axis", "joint": 32, "axis": 1, "invert": False}, + }, + "output": {"report": str(tmp_path / "report.json")}, + } + + +# --------------------------------------------------------------------------- +# InputsConfig / PreprocessingConfig / OutputConfig +# --------------------------------------------------------------------------- + + +class TestInputsConfig: + def test_primary_only(self, tmp_path: Path) -> None: + cfg = InputsConfig(primary=tmp_path / "a.json") + assert cfg.reference is None + + def test_primary_and_reference(self, tmp_path: Path) -> None: + cfg = InputsConfig( + primary=tmp_path / "a.json", + reference=tmp_path / "b.json", + ) + assert cfg.reference == tmp_path / "b.json" + + def test_extra_field_rejected(self, tmp_path: Path) -> None: + with pytest.raises(ValidationError, match="Extra inputs"): + InputsConfig.model_validate( + { + "primary": str(tmp_path / "a.json"), + "extra": "field", + } + ) + + def test_frozen(self, tmp_path: Path) -> None: + cfg = InputsConfig(primary=tmp_path / "a.json") + with pytest.raises(ValidationError, match="frozen"): + cfg.primary = tmp_path / "b.json" # type: ignore[misc] + + +class TestPreprocessingConfig: + def test_default_person_index_zero(self) -> None: + cfg = PreprocessingConfig() + assert cfg.person_index == 0 + + def test_negative_person_index_rejected(self) -> None: + with pytest.raises(ValidationError): + PreprocessingConfig(person_index=-1) + + +class TestOutputConfig: + def test_report_path(self, tmp_path: Path) -> None: + cfg = OutputConfig(report=tmp_path / "out.json") + assert cfg.report == tmp_path / "out.json" + + +# --------------------------------------------------------------------------- +# Segmentation stage discriminated union +# --------------------------------------------------------------------------- + + +class TestSegmentationStage: + def test_gait_cycles_parses_from_dict(self, tmp_path: Path) -> None: + config_dict = _minimal_dtw_config(tmp_path) + config_dict["segmentation"] = { + "kind": "gait_cycles", + "joint": "lhee", + "axis": "y", + "min_cycle_seconds": 0.5, + } + cfg = AnalysisConfig.model_validate(config_dict) + assert isinstance(cfg.segmentation, GaitCyclesSegmentation) + assert cfg.segmentation.joint == "lhee" + assert cfg.segmentation.min_cycle_seconds == 0.5 + + def test_bilateral_parses_from_dict(self, tmp_path: Path) -> None: + config_dict = _minimal_dtw_config(tmp_path) + config_dict["segmentation"] = {"kind": "gait_cycles_bilateral"} + cfg = AnalysisConfig.model_validate(config_dict) + assert isinstance(cfg.segmentation, GaitCyclesBilateralSegmentation) + + def test_extractor_parses_from_dict(self, tmp_path: Path) -> None: + config_dict = _minimal_dtw_config(tmp_path) + config_dict["segmentation"] = { + "kind": "extractor", + "extractor": { + "kind": "joint_axis", + "joint": 15, + "axis": 1, + "invert": False, + }, + "label": "wrist_cycles", + "min_distance_seconds": 0.5, + } + cfg = AnalysisConfig.model_validate(config_dict) + assert isinstance(cfg.segmentation, ExtractorSegmentation) + assert cfg.segmentation.label == "wrist_cycles" + + def test_unknown_kind_rejected(self, tmp_path: Path) -> None: + config_dict = _minimal_dtw_config(tmp_path) + config_dict["segmentation"] = {"kind": "unknown_method"} + with pytest.raises(ValidationError): + AnalysisConfig.model_validate(config_dict) + + def test_segmentation_omitted_is_none(self, tmp_path: Path) -> None: + cfg = AnalysisConfig.model_validate(_minimal_dtw_config(tmp_path)) + assert cfg.segmentation is None + + def test_invalid_min_cycle_seconds_rejected(self, tmp_path: Path) -> None: + config_dict = _minimal_dtw_config(tmp_path) + config_dict["segmentation"] = { + "kind": "gait_cycles", + "min_cycle_seconds": 0.0, # must be > 0 + } + with pytest.raises(ValidationError): + AnalysisConfig.model_validate(config_dict) + + +# --------------------------------------------------------------------------- +# Analysis stage discriminated union + cross-field invariants +# --------------------------------------------------------------------------- + + +class TestDtwAnalysisValidation: + def test_dtw_relation_requires_joints(self) -> None: + with pytest.raises(ValidationError, match="joint_i and joint_j"): + DtwAnalysis(kind="dtw", method="dtw_relation") + + def test_dtw_relation_rejects_angles_representation(self) -> None: + with pytest.raises(ValidationError, match="only supports representation='coords'"): + DtwAnalysis( + kind="dtw", + method="dtw_relation", + joint_i=0, + joint_j=1, + representation="angles", + angle_triplets=[(0, 1, 2)], + ) + + def test_angles_requires_triplets(self) -> None: + with pytest.raises(ValidationError, match="angle_triplets"): + DtwAnalysis( + kind="dtw", + method="dtw_all", + representation="angles", + ) + + def test_angles_with_empty_triplets_rejected(self) -> None: + with pytest.raises(ValidationError, match="angle_triplets"): + DtwAnalysis( + kind="dtw", + method="dtw_all", + representation="angles", + angle_triplets=[], + ) + + def test_happy_path_dtw_all_coords(self) -> None: + analysis = DtwAnalysis( + kind="dtw", + method="dtw_all", + align="procrustes_per_sequence", + nan_policy="interpolate", + ) + assert analysis.align == "procrustes_per_sequence" + + def test_happy_path_dtw_all_angles(self) -> None: + analysis = DtwAnalysis( + kind="dtw", + method="dtw_all", + representation="angles", + angle_triplets=[(0, 1, 2), (3, 4, 5)], + ) + assert len(analysis.angle_triplets or []) == 2 + + def test_happy_path_dtw_relation(self) -> None: + analysis = DtwAnalysis( + kind="dtw", + method="dtw_relation", + joint_i=15, + joint_j=23, + ) + assert analysis.joint_i == 15 + + +class TestAnalysisCrossStage: + def test_dtw_without_reference_rejected(self, tmp_path: Path) -> None: + config_dict = { + "inputs": {"primary": str(tmp_path / "a.json")}, + "analysis": {"kind": "dtw", "method": "dtw_all"}, + "output": {"report": str(tmp_path / "out.json")}, + } + with pytest.raises(ValidationError, match=r"inputs\.reference"): + AnalysisConfig.model_validate(config_dict) + + def test_stats_with_reference_rejected(self, tmp_path: Path) -> None: + config_dict = _minimal_stats_config(tmp_path) + config_dict["inputs"]["reference"] = str(tmp_path / "b.json") + with pytest.raises(ValidationError, match="primary only"): + AnalysisConfig.model_validate(config_dict) + + def test_none_analysis_requires_no_reference(self, tmp_path: Path) -> None: + # NoAnalysis is fine with either reference present or absent. + config_dict = { + "inputs": {"primary": str(tmp_path / "a.json")}, + "analysis": {"kind": "none"}, + "output": {"report": str(tmp_path / "out.json")}, + } + cfg = AnalysisConfig.model_validate(config_dict) + assert isinstance(cfg.analysis, NoAnalysis) + + +# --------------------------------------------------------------------------- +# Top-level AnalysisConfig +# --------------------------------------------------------------------------- + + +class TestAnalysisConfig: + def test_minimal_dtw_config_parses(self, tmp_path: Path) -> None: + cfg = AnalysisConfig.model_validate(_minimal_dtw_config(tmp_path)) + assert cfg.config_version == 1 + assert isinstance(cfg.analysis, DtwAnalysis) + assert cfg.preprocessing.person_index == 0 # default + + def test_minimal_stats_config_parses(self, tmp_path: Path) -> None: + cfg = AnalysisConfig.model_validate(_minimal_stats_config(tmp_path)) + assert isinstance(cfg.analysis, StatsAnalysis) + + def test_config_version_must_be_1(self, tmp_path: Path) -> None: + config_dict = _minimal_dtw_config(tmp_path) + config_dict["config_version"] = 99 + with pytest.raises(ValidationError): + AnalysisConfig.model_validate(config_dict) + + def test_round_trip_json(self, tmp_path: Path) -> None: + original = AnalysisConfig.model_validate(_minimal_dtw_config(tmp_path)) + serialised = original.model_dump_json() + restored = AnalysisConfig.model_validate_json(serialised) + assert restored == original + + def test_extra_top_level_field_rejected(self, tmp_path: Path) -> None: + config_dict = _minimal_dtw_config(tmp_path) + config_dict["unknown_key"] = "typo" + with pytest.raises(ValidationError): + AnalysisConfig.model_validate(config_dict) + + +# --------------------------------------------------------------------------- +# analysis_config_to_dict +# --------------------------------------------------------------------------- + + +class TestAnalysisConfigToDict: + def test_returns_json_safe_dict(self, tmp_path: Path) -> None: + cfg = AnalysisConfig.model_validate(_minimal_dtw_config(tmp_path)) + dumped = analysis_config_to_dict(cfg) + # Paths must have become strings. + assert isinstance(dumped["inputs"]["primary"], str) + assert isinstance(dumped["output"]["report"], str) + + def test_round_trips_through_dict(self, tmp_path: Path) -> None: + original = AnalysisConfig.model_validate(_minimal_dtw_config(tmp_path)) + dumped = analysis_config_to_dict(original) + restored = AnalysisConfig.model_validate(dumped) + assert restored == original + + +# --------------------------------------------------------------------------- +# Result sub-schemas +# --------------------------------------------------------------------------- + + +class TestDtwResults: + def test_minimal_construction(self) -> None: + res = DtwResults( + kind="dtw", + method="dtw_all", + distances=[0.5], + paths=[[(0, 0), (1, 1)]], + segment_labels=["full_trial"], + summary={"mean": 0.5}, + ) + assert res.kind == "dtw" + + def test_per_joint_distances_shape_is_free(self) -> None: + # No validator enforces that per_joint_distances outer length + # matches distances — that's run-time semantics of the + # executor. Still, verify the field round-trips. + res = DtwResults( + kind="dtw", + method="dtw_per_joint", + distances=[0.1, 0.2], + paths=[[(0, 0)], [(0, 0)]], + per_joint_distances=[[0.05, 0.05], [0.1, 0.1]], + segment_labels=["rhee_cycles[0]", "rhee_cycles[1]"], + summary={"mean": 0.15}, + ) + assert res.per_joint_distances is not None + + +class TestStatsResults: + def test_round_trip(self) -> None: + res = StatsResults( + kind="stats", + statistics=[ + FeatureSummary(mean=1.0, std=0.1, min=0.8, max=1.2, range=0.4), + FeatureSummary(mean=1.1, std=0.2, min=0.7, max=1.5, range=0.8), + ], + segment_labels=["rhee_cycles[0]", "rhee_cycles[1]"], + ) + dumped = res.model_dump_json() + restored = StatsResults.model_validate_json(dumped) + assert restored == res + + +class TestNoResults: + def test_construction(self) -> None: + res = NoResults(kind="none") + assert res.kind == "none" + + +# --------------------------------------------------------------------------- +# AnalysisReport +# --------------------------------------------------------------------------- + + +def _make_report(tmp_path: Path) -> AnalysisReport: + config = AnalysisConfig.model_validate(_minimal_dtw_config(tmp_path)) + return AnalysisReport( + config=config, + primary=InputSummary( + path=tmp_path / "primary.json", + frame_count=300, + fps=30.0, + ), + reference=InputSummary( + path=tmp_path / "reference.json", + frame_count=300, + fps=30.0, + ), + results=DtwResults( + kind="dtw", + method="dtw_all", + distances=[0.42], + paths=[[(0, 0), (1, 1)]], + segment_labels=["full_trial"], + summary={"mean": 0.42, "p50": 0.42}, + ), + ) + + +class TestAnalysisReport: + def test_schema_version_defaults_to_current(self, tmp_path: Path) -> None: + report = _make_report(tmp_path) + assert report.schema_version == CURRENT_VERSION + + def test_round_trip_json(self, tmp_path: Path) -> None: + report = _make_report(tmp_path) + serialised = report.model_dump_json() + restored = AnalysisReport.model_validate_json(serialised) + assert restored == report + + def test_empty_segmentations_default(self, tmp_path: Path) -> None: + report = _make_report(tmp_path) + assert report.segmentations == {} + + def test_extra_field_rejected(self, tmp_path: Path) -> None: + report = _make_report(tmp_path) + dumped = report.model_dump(mode="json") + dumped["mystery_field"] = 1 + with pytest.raises(ValidationError): + AnalysisReport.model_validate(dumped) diff --git a/tests/unit/test_migrations.py b/tests/unit/test_migrations.py index a7d8fef..64be5a5 100644 --- a/tests/unit/test_migrations.py +++ b/tests/unit/test_migrations.py @@ -38,6 +38,7 @@ from neuropose.migrations import ( FutureSchemaError, MigrationError, MigrationNotFoundError, + migrate_analysis_report, migrate_benchmark_result, migrate_job_results, migrate_video_predictions, @@ -282,8 +283,32 @@ class TestMigrateJobResults: assert result["video_b.mp4"]["content_b"] is True +# --------------------------------------------------------------------------- +# migrate_analysis_report +# --------------------------------------------------------------------------- + + +class TestMigrateAnalysisReport: + def test_current_version_is_noop(self) -> None: + """A payload already at CURRENT_VERSION passes through unchanged.""" + payload = {"schema_version": CURRENT_VERSION, "foo": "bar"} + assert migrate_analysis_report(payload) == payload + + def test_missing_version_defaults_to_v1_and_fails(self) -> None: + """AnalysisReport first shipped at CURRENT_VERSION, so a payload + without schema_version (defaulting to 1) would require a + non-existent v1→v2 migration and fail with a clear error.""" + with pytest.raises(MigrationNotFoundError, match="AnalysisReport"): + migrate_analysis_report({}) + + def test_future_version_rejected(self) -> None: + with pytest.raises(FutureSchemaError, match="AnalysisReport"): + migrate_analysis_report({"schema_version": CURRENT_VERSION + 5}) + + # --------------------------------------------------------------------------- # register_video_predictions_migration / register_benchmark_result_migration +# / register_analysis_report_migration # --------------------------------------------------------------------------- @@ -312,6 +337,21 @@ class TestRegistration: assert _fn.__name__ == "_fn" assert _fn({"x": 1}) == {"x": 1} + def test_analysis_report_duplicate_registration_raises( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.setattr(migrations, "_ANALYSIS_REPORT_MIGRATIONS", {}) + + @migrations.register_analysis_report_migration(from_version=2) + def _first(p: dict) -> dict: + return p + + with pytest.raises(RuntimeError, match="already registered"): + + @migrations.register_analysis_report_migration(from_version=2) + def _second(p: dict) -> dict: + return p + # --------------------------------------------------------------------------- # Integration: load_* functions run migrations before validation