neuropose/tests/unit/test_analyzer_pipeline.py

435 lines
15 KiB
Python

"""Tests for :mod:`neuropose.analyzer.pipeline`.
This file covers the schema half of the pipeline:
- :class:`AnalysisConfig` parsing, including the discriminated unions
for the segmentation and analysis stages, and the cross-field
invariants enforced at parse time.
- :class:`AnalysisReport` construction + JSON round-trip, including
the migration hook (schema_version defaults to CURRENT_VERSION).
- :func:`analysis_config_to_dict` JSON-safety.
The executor (``run_analysis``) gets its own test module.
"""
from __future__ import annotations
from pathlib import Path
from typing import Any
import pytest
from pydantic import ValidationError
from neuropose.analyzer.pipeline import (
AnalysisConfig,
AnalysisReport,
DtwAnalysis,
DtwResults,
ExtractorSegmentation,
FeatureSummary,
GaitCyclesBilateralSegmentation,
GaitCyclesSegmentation,
InputsConfig,
InputSummary,
NoAnalysis,
NoResults,
OutputConfig,
PreprocessingConfig,
StatsAnalysis,
StatsResults,
analysis_config_to_dict,
)
from neuropose.migrations import CURRENT_VERSION
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _minimal_dtw_config(tmp_path: Path) -> dict[str, Any]:
"""A minimal AnalysisConfig dict with dtw_all + reference."""
return {
"inputs": {
"primary": str(tmp_path / "primary.json"),
"reference": str(tmp_path / "reference.json"),
},
"analysis": {"kind": "dtw", "method": "dtw_all"},
"output": {"report": str(tmp_path / "report.json")},
}
def _minimal_stats_config(tmp_path: Path) -> dict[str, Any]:
return {
"inputs": {"primary": str(tmp_path / "primary.json")},
"analysis": {
"kind": "stats",
"extractor": {"kind": "joint_axis", "joint": 32, "axis": 1, "invert": False},
},
"output": {"report": str(tmp_path / "report.json")},
}
# ---------------------------------------------------------------------------
# InputsConfig / PreprocessingConfig / OutputConfig
# ---------------------------------------------------------------------------
class TestInputsConfig:
def test_primary_only(self, tmp_path: Path) -> None:
cfg = InputsConfig(primary=tmp_path / "a.json")
assert cfg.reference is None
def test_primary_and_reference(self, tmp_path: Path) -> None:
cfg = InputsConfig(
primary=tmp_path / "a.json",
reference=tmp_path / "b.json",
)
assert cfg.reference == tmp_path / "b.json"
def test_extra_field_rejected(self, tmp_path: Path) -> None:
with pytest.raises(ValidationError, match="Extra inputs"):
InputsConfig.model_validate(
{
"primary": str(tmp_path / "a.json"),
"extra": "field",
}
)
def test_frozen(self, tmp_path: Path) -> None:
cfg = InputsConfig(primary=tmp_path / "a.json")
with pytest.raises(ValidationError, match="frozen"):
cfg.primary = tmp_path / "b.json" # type: ignore[misc]
class TestPreprocessingConfig:
def test_default_person_index_zero(self) -> None:
cfg = PreprocessingConfig()
assert cfg.person_index == 0
def test_negative_person_index_rejected(self) -> None:
with pytest.raises(ValidationError):
PreprocessingConfig(person_index=-1)
class TestOutputConfig:
def test_report_path(self, tmp_path: Path) -> None:
cfg = OutputConfig(report=tmp_path / "out.json")
assert cfg.report == tmp_path / "out.json"
# ---------------------------------------------------------------------------
# Segmentation stage discriminated union
# ---------------------------------------------------------------------------
class TestSegmentationStage:
def test_gait_cycles_parses_from_dict(self, tmp_path: Path) -> None:
config_dict = _minimal_dtw_config(tmp_path)
config_dict["segmentation"] = {
"kind": "gait_cycles",
"joint": "lhee",
"axis": "y",
"min_cycle_seconds": 0.5,
}
cfg = AnalysisConfig.model_validate(config_dict)
assert isinstance(cfg.segmentation, GaitCyclesSegmentation)
assert cfg.segmentation.joint == "lhee"
assert cfg.segmentation.min_cycle_seconds == 0.5
def test_bilateral_parses_from_dict(self, tmp_path: Path) -> None:
config_dict = _minimal_dtw_config(tmp_path)
config_dict["segmentation"] = {"kind": "gait_cycles_bilateral"}
cfg = AnalysisConfig.model_validate(config_dict)
assert isinstance(cfg.segmentation, GaitCyclesBilateralSegmentation)
def test_extractor_parses_from_dict(self, tmp_path: Path) -> None:
config_dict = _minimal_dtw_config(tmp_path)
config_dict["segmentation"] = {
"kind": "extractor",
"extractor": {
"kind": "joint_axis",
"joint": 15,
"axis": 1,
"invert": False,
},
"label": "wrist_cycles",
"min_distance_seconds": 0.5,
}
cfg = AnalysisConfig.model_validate(config_dict)
assert isinstance(cfg.segmentation, ExtractorSegmentation)
assert cfg.segmentation.label == "wrist_cycles"
def test_unknown_kind_rejected(self, tmp_path: Path) -> None:
config_dict = _minimal_dtw_config(tmp_path)
config_dict["segmentation"] = {"kind": "unknown_method"}
with pytest.raises(ValidationError):
AnalysisConfig.model_validate(config_dict)
def test_segmentation_omitted_is_none(self, tmp_path: Path) -> None:
cfg = AnalysisConfig.model_validate(_minimal_dtw_config(tmp_path))
assert cfg.segmentation is None
def test_invalid_min_cycle_seconds_rejected(self, tmp_path: Path) -> None:
config_dict = _minimal_dtw_config(tmp_path)
config_dict["segmentation"] = {
"kind": "gait_cycles",
"min_cycle_seconds": 0.0, # must be > 0
}
with pytest.raises(ValidationError):
AnalysisConfig.model_validate(config_dict)
# ---------------------------------------------------------------------------
# Analysis stage discriminated union + cross-field invariants
# ---------------------------------------------------------------------------
class TestDtwAnalysisValidation:
def test_dtw_relation_requires_joints(self) -> None:
with pytest.raises(ValidationError, match="joint_i and joint_j"):
DtwAnalysis(kind="dtw", method="dtw_relation")
def test_dtw_relation_rejects_angles_representation(self) -> None:
with pytest.raises(ValidationError, match="only supports representation='coords'"):
DtwAnalysis(
kind="dtw",
method="dtw_relation",
joint_i=0,
joint_j=1,
representation="angles",
angle_triplets=[(0, 1, 2)],
)
def test_angles_requires_triplets(self) -> None:
with pytest.raises(ValidationError, match="angle_triplets"):
DtwAnalysis(
kind="dtw",
method="dtw_all",
representation="angles",
)
def test_angles_with_empty_triplets_rejected(self) -> None:
with pytest.raises(ValidationError, match="angle_triplets"):
DtwAnalysis(
kind="dtw",
method="dtw_all",
representation="angles",
angle_triplets=[],
)
def test_happy_path_dtw_all_coords(self) -> None:
analysis = DtwAnalysis(
kind="dtw",
method="dtw_all",
align="procrustes_per_sequence",
nan_policy="interpolate",
)
assert analysis.align == "procrustes_per_sequence"
def test_happy_path_dtw_all_angles(self) -> None:
analysis = DtwAnalysis(
kind="dtw",
method="dtw_all",
representation="angles",
angle_triplets=[(0, 1, 2), (3, 4, 5)],
)
assert len(analysis.angle_triplets or []) == 2
def test_happy_path_dtw_relation(self) -> None:
analysis = DtwAnalysis(
kind="dtw",
method="dtw_relation",
joint_i=15,
joint_j=23,
)
assert analysis.joint_i == 15
class TestAnalysisCrossStage:
def test_dtw_without_reference_rejected(self, tmp_path: Path) -> None:
config_dict = {
"inputs": {"primary": str(tmp_path / "a.json")},
"analysis": {"kind": "dtw", "method": "dtw_all"},
"output": {"report": str(tmp_path / "out.json")},
}
with pytest.raises(ValidationError, match=r"inputs\.reference"):
AnalysisConfig.model_validate(config_dict)
def test_stats_with_reference_rejected(self, tmp_path: Path) -> None:
config_dict = _minimal_stats_config(tmp_path)
config_dict["inputs"]["reference"] = str(tmp_path / "b.json")
with pytest.raises(ValidationError, match="primary only"):
AnalysisConfig.model_validate(config_dict)
def test_none_analysis_requires_no_reference(self, tmp_path: Path) -> None:
# NoAnalysis is fine with either reference present or absent.
config_dict = {
"inputs": {"primary": str(tmp_path / "a.json")},
"analysis": {"kind": "none"},
"output": {"report": str(tmp_path / "out.json")},
}
cfg = AnalysisConfig.model_validate(config_dict)
assert isinstance(cfg.analysis, NoAnalysis)
# ---------------------------------------------------------------------------
# Top-level AnalysisConfig
# ---------------------------------------------------------------------------
class TestAnalysisConfig:
def test_minimal_dtw_config_parses(self, tmp_path: Path) -> None:
cfg = AnalysisConfig.model_validate(_minimal_dtw_config(tmp_path))
assert cfg.config_version == 1
assert isinstance(cfg.analysis, DtwAnalysis)
assert cfg.preprocessing.person_index == 0 # default
def test_minimal_stats_config_parses(self, tmp_path: Path) -> None:
cfg = AnalysisConfig.model_validate(_minimal_stats_config(tmp_path))
assert isinstance(cfg.analysis, StatsAnalysis)
def test_config_version_must_be_1(self, tmp_path: Path) -> None:
config_dict = _minimal_dtw_config(tmp_path)
config_dict["config_version"] = 99
with pytest.raises(ValidationError):
AnalysisConfig.model_validate(config_dict)
def test_round_trip_json(self, tmp_path: Path) -> None:
original = AnalysisConfig.model_validate(_minimal_dtw_config(tmp_path))
serialised = original.model_dump_json()
restored = AnalysisConfig.model_validate_json(serialised)
assert restored == original
def test_extra_top_level_field_rejected(self, tmp_path: Path) -> None:
config_dict = _minimal_dtw_config(tmp_path)
config_dict["unknown_key"] = "typo"
with pytest.raises(ValidationError):
AnalysisConfig.model_validate(config_dict)
# ---------------------------------------------------------------------------
# analysis_config_to_dict
# ---------------------------------------------------------------------------
class TestAnalysisConfigToDict:
def test_returns_json_safe_dict(self, tmp_path: Path) -> None:
cfg = AnalysisConfig.model_validate(_minimal_dtw_config(tmp_path))
dumped = analysis_config_to_dict(cfg)
# Paths must have become strings.
assert isinstance(dumped["inputs"]["primary"], str)
assert isinstance(dumped["output"]["report"], str)
def test_round_trips_through_dict(self, tmp_path: Path) -> None:
original = AnalysisConfig.model_validate(_minimal_dtw_config(tmp_path))
dumped = analysis_config_to_dict(original)
restored = AnalysisConfig.model_validate(dumped)
assert restored == original
# ---------------------------------------------------------------------------
# Result sub-schemas
# ---------------------------------------------------------------------------
class TestDtwResults:
def test_minimal_construction(self) -> None:
res = DtwResults(
kind="dtw",
method="dtw_all",
distances=[0.5],
paths=[[(0, 0), (1, 1)]],
segment_labels=["full_trial"],
summary={"mean": 0.5},
)
assert res.kind == "dtw"
def test_per_joint_distances_shape_is_free(self) -> None:
# No validator enforces that per_joint_distances outer length
# matches distances — that's run-time semantics of the
# executor. Still, verify the field round-trips.
res = DtwResults(
kind="dtw",
method="dtw_per_joint",
distances=[0.1, 0.2],
paths=[[(0, 0)], [(0, 0)]],
per_joint_distances=[[0.05, 0.05], [0.1, 0.1]],
segment_labels=["rhee_cycles[0]", "rhee_cycles[1]"],
summary={"mean": 0.15},
)
assert res.per_joint_distances is not None
class TestStatsResults:
def test_round_trip(self) -> None:
res = StatsResults(
kind="stats",
statistics=[
FeatureSummary(mean=1.0, std=0.1, min=0.8, max=1.2, range=0.4),
FeatureSummary(mean=1.1, std=0.2, min=0.7, max=1.5, range=0.8),
],
segment_labels=["rhee_cycles[0]", "rhee_cycles[1]"],
)
dumped = res.model_dump_json()
restored = StatsResults.model_validate_json(dumped)
assert restored == res
class TestNoResults:
def test_construction(self) -> None:
res = NoResults(kind="none")
assert res.kind == "none"
# ---------------------------------------------------------------------------
# AnalysisReport
# ---------------------------------------------------------------------------
def _make_report(tmp_path: Path) -> AnalysisReport:
config = AnalysisConfig.model_validate(_minimal_dtw_config(tmp_path))
return AnalysisReport(
config=config,
primary=InputSummary(
path=tmp_path / "primary.json",
frame_count=300,
fps=30.0,
),
reference=InputSummary(
path=tmp_path / "reference.json",
frame_count=300,
fps=30.0,
),
results=DtwResults(
kind="dtw",
method="dtw_all",
distances=[0.42],
paths=[[(0, 0), (1, 1)]],
segment_labels=["full_trial"],
summary={"mean": 0.42, "p50": 0.42},
),
)
class TestAnalysisReport:
def test_schema_version_defaults_to_current(self, tmp_path: Path) -> None:
report = _make_report(tmp_path)
assert report.schema_version == CURRENT_VERSION
def test_round_trip_json(self, tmp_path: Path) -> None:
report = _make_report(tmp_path)
serialised = report.model_dump_json()
restored = AnalysisReport.model_validate_json(serialised)
assert restored == report
def test_empty_segmentations_default(self, tmp_path: Path) -> None:
report = _make_report(tmp_path)
assert report.segmentations == {}
def test_extra_field_rejected(self, tmp_path: Path) -> None:
report = _make_report(tmp_path)
dumped = report.model_dump(mode="json")
dumped["mystery_field"] = 1
with pytest.raises(ValidationError):
AnalysisReport.model_validate(dumped)