435 lines
15 KiB
Python
435 lines
15 KiB
Python
"""Tests for :mod:`neuropose.analyzer.pipeline`.
|
|
|
|
This file covers the schema half of the pipeline:
|
|
|
|
- :class:`AnalysisConfig` parsing, including the discriminated unions
|
|
for the segmentation and analysis stages, and the cross-field
|
|
invariants enforced at parse time.
|
|
- :class:`AnalysisReport` construction + JSON round-trip, including
|
|
the migration hook (schema_version defaults to CURRENT_VERSION).
|
|
- :func:`analysis_config_to_dict` JSON-safety.
|
|
|
|
The executor (``run_analysis``) gets its own test module.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import pytest
|
|
from pydantic import ValidationError
|
|
|
|
from neuropose.analyzer.pipeline import (
|
|
AnalysisConfig,
|
|
AnalysisReport,
|
|
DtwAnalysis,
|
|
DtwResults,
|
|
ExtractorSegmentation,
|
|
FeatureSummary,
|
|
GaitCyclesBilateralSegmentation,
|
|
GaitCyclesSegmentation,
|
|
InputsConfig,
|
|
InputSummary,
|
|
NoAnalysis,
|
|
NoResults,
|
|
OutputConfig,
|
|
PreprocessingConfig,
|
|
StatsAnalysis,
|
|
StatsResults,
|
|
analysis_config_to_dict,
|
|
)
|
|
from neuropose.migrations import CURRENT_VERSION
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _minimal_dtw_config(tmp_path: Path) -> dict[str, Any]:
|
|
"""A minimal AnalysisConfig dict with dtw_all + reference."""
|
|
return {
|
|
"inputs": {
|
|
"primary": str(tmp_path / "primary.json"),
|
|
"reference": str(tmp_path / "reference.json"),
|
|
},
|
|
"analysis": {"kind": "dtw", "method": "dtw_all"},
|
|
"output": {"report": str(tmp_path / "report.json")},
|
|
}
|
|
|
|
|
|
def _minimal_stats_config(tmp_path: Path) -> dict[str, Any]:
|
|
return {
|
|
"inputs": {"primary": str(tmp_path / "primary.json")},
|
|
"analysis": {
|
|
"kind": "stats",
|
|
"extractor": {"kind": "joint_axis", "joint": 32, "axis": 1, "invert": False},
|
|
},
|
|
"output": {"report": str(tmp_path / "report.json")},
|
|
}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# InputsConfig / PreprocessingConfig / OutputConfig
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestInputsConfig:
|
|
def test_primary_only(self, tmp_path: Path) -> None:
|
|
cfg = InputsConfig(primary=tmp_path / "a.json")
|
|
assert cfg.reference is None
|
|
|
|
def test_primary_and_reference(self, tmp_path: Path) -> None:
|
|
cfg = InputsConfig(
|
|
primary=tmp_path / "a.json",
|
|
reference=tmp_path / "b.json",
|
|
)
|
|
assert cfg.reference == tmp_path / "b.json"
|
|
|
|
def test_extra_field_rejected(self, tmp_path: Path) -> None:
|
|
with pytest.raises(ValidationError, match="Extra inputs"):
|
|
InputsConfig.model_validate(
|
|
{
|
|
"primary": str(tmp_path / "a.json"),
|
|
"extra": "field",
|
|
}
|
|
)
|
|
|
|
def test_frozen(self, tmp_path: Path) -> None:
|
|
cfg = InputsConfig(primary=tmp_path / "a.json")
|
|
with pytest.raises(ValidationError, match="frozen"):
|
|
cfg.primary = tmp_path / "b.json" # type: ignore[misc]
|
|
|
|
|
|
class TestPreprocessingConfig:
|
|
def test_default_person_index_zero(self) -> None:
|
|
cfg = PreprocessingConfig()
|
|
assert cfg.person_index == 0
|
|
|
|
def test_negative_person_index_rejected(self) -> None:
|
|
with pytest.raises(ValidationError):
|
|
PreprocessingConfig(person_index=-1)
|
|
|
|
|
|
class TestOutputConfig:
|
|
def test_report_path(self, tmp_path: Path) -> None:
|
|
cfg = OutputConfig(report=tmp_path / "out.json")
|
|
assert cfg.report == tmp_path / "out.json"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Segmentation stage discriminated union
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestSegmentationStage:
|
|
def test_gait_cycles_parses_from_dict(self, tmp_path: Path) -> None:
|
|
config_dict = _minimal_dtw_config(tmp_path)
|
|
config_dict["segmentation"] = {
|
|
"kind": "gait_cycles",
|
|
"joint": "lhee",
|
|
"axis": "y",
|
|
"min_cycle_seconds": 0.5,
|
|
}
|
|
cfg = AnalysisConfig.model_validate(config_dict)
|
|
assert isinstance(cfg.segmentation, GaitCyclesSegmentation)
|
|
assert cfg.segmentation.joint == "lhee"
|
|
assert cfg.segmentation.min_cycle_seconds == 0.5
|
|
|
|
def test_bilateral_parses_from_dict(self, tmp_path: Path) -> None:
|
|
config_dict = _minimal_dtw_config(tmp_path)
|
|
config_dict["segmentation"] = {"kind": "gait_cycles_bilateral"}
|
|
cfg = AnalysisConfig.model_validate(config_dict)
|
|
assert isinstance(cfg.segmentation, GaitCyclesBilateralSegmentation)
|
|
|
|
def test_extractor_parses_from_dict(self, tmp_path: Path) -> None:
|
|
config_dict = _minimal_dtw_config(tmp_path)
|
|
config_dict["segmentation"] = {
|
|
"kind": "extractor",
|
|
"extractor": {
|
|
"kind": "joint_axis",
|
|
"joint": 15,
|
|
"axis": 1,
|
|
"invert": False,
|
|
},
|
|
"label": "wrist_cycles",
|
|
"min_distance_seconds": 0.5,
|
|
}
|
|
cfg = AnalysisConfig.model_validate(config_dict)
|
|
assert isinstance(cfg.segmentation, ExtractorSegmentation)
|
|
assert cfg.segmentation.label == "wrist_cycles"
|
|
|
|
def test_unknown_kind_rejected(self, tmp_path: Path) -> None:
|
|
config_dict = _minimal_dtw_config(tmp_path)
|
|
config_dict["segmentation"] = {"kind": "unknown_method"}
|
|
with pytest.raises(ValidationError):
|
|
AnalysisConfig.model_validate(config_dict)
|
|
|
|
def test_segmentation_omitted_is_none(self, tmp_path: Path) -> None:
|
|
cfg = AnalysisConfig.model_validate(_minimal_dtw_config(tmp_path))
|
|
assert cfg.segmentation is None
|
|
|
|
def test_invalid_min_cycle_seconds_rejected(self, tmp_path: Path) -> None:
|
|
config_dict = _minimal_dtw_config(tmp_path)
|
|
config_dict["segmentation"] = {
|
|
"kind": "gait_cycles",
|
|
"min_cycle_seconds": 0.0, # must be > 0
|
|
}
|
|
with pytest.raises(ValidationError):
|
|
AnalysisConfig.model_validate(config_dict)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Analysis stage discriminated union + cross-field invariants
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestDtwAnalysisValidation:
|
|
def test_dtw_relation_requires_joints(self) -> None:
|
|
with pytest.raises(ValidationError, match="joint_i and joint_j"):
|
|
DtwAnalysis(kind="dtw", method="dtw_relation")
|
|
|
|
def test_dtw_relation_rejects_angles_representation(self) -> None:
|
|
with pytest.raises(ValidationError, match="only supports representation='coords'"):
|
|
DtwAnalysis(
|
|
kind="dtw",
|
|
method="dtw_relation",
|
|
joint_i=0,
|
|
joint_j=1,
|
|
representation="angles",
|
|
angle_triplets=[(0, 1, 2)],
|
|
)
|
|
|
|
def test_angles_requires_triplets(self) -> None:
|
|
with pytest.raises(ValidationError, match="angle_triplets"):
|
|
DtwAnalysis(
|
|
kind="dtw",
|
|
method="dtw_all",
|
|
representation="angles",
|
|
)
|
|
|
|
def test_angles_with_empty_triplets_rejected(self) -> None:
|
|
with pytest.raises(ValidationError, match="angle_triplets"):
|
|
DtwAnalysis(
|
|
kind="dtw",
|
|
method="dtw_all",
|
|
representation="angles",
|
|
angle_triplets=[],
|
|
)
|
|
|
|
def test_happy_path_dtw_all_coords(self) -> None:
|
|
analysis = DtwAnalysis(
|
|
kind="dtw",
|
|
method="dtw_all",
|
|
align="procrustes_per_sequence",
|
|
nan_policy="interpolate",
|
|
)
|
|
assert analysis.align == "procrustes_per_sequence"
|
|
|
|
def test_happy_path_dtw_all_angles(self) -> None:
|
|
analysis = DtwAnalysis(
|
|
kind="dtw",
|
|
method="dtw_all",
|
|
representation="angles",
|
|
angle_triplets=[(0, 1, 2), (3, 4, 5)],
|
|
)
|
|
assert len(analysis.angle_triplets or []) == 2
|
|
|
|
def test_happy_path_dtw_relation(self) -> None:
|
|
analysis = DtwAnalysis(
|
|
kind="dtw",
|
|
method="dtw_relation",
|
|
joint_i=15,
|
|
joint_j=23,
|
|
)
|
|
assert analysis.joint_i == 15
|
|
|
|
|
|
class TestAnalysisCrossStage:
|
|
def test_dtw_without_reference_rejected(self, tmp_path: Path) -> None:
|
|
config_dict = {
|
|
"inputs": {"primary": str(tmp_path / "a.json")},
|
|
"analysis": {"kind": "dtw", "method": "dtw_all"},
|
|
"output": {"report": str(tmp_path / "out.json")},
|
|
}
|
|
with pytest.raises(ValidationError, match=r"inputs\.reference"):
|
|
AnalysisConfig.model_validate(config_dict)
|
|
|
|
def test_stats_with_reference_rejected(self, tmp_path: Path) -> None:
|
|
config_dict = _minimal_stats_config(tmp_path)
|
|
config_dict["inputs"]["reference"] = str(tmp_path / "b.json")
|
|
with pytest.raises(ValidationError, match="primary only"):
|
|
AnalysisConfig.model_validate(config_dict)
|
|
|
|
def test_none_analysis_requires_no_reference(self, tmp_path: Path) -> None:
|
|
# NoAnalysis is fine with either reference present or absent.
|
|
config_dict = {
|
|
"inputs": {"primary": str(tmp_path / "a.json")},
|
|
"analysis": {"kind": "none"},
|
|
"output": {"report": str(tmp_path / "out.json")},
|
|
}
|
|
cfg = AnalysisConfig.model_validate(config_dict)
|
|
assert isinstance(cfg.analysis, NoAnalysis)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Top-level AnalysisConfig
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestAnalysisConfig:
|
|
def test_minimal_dtw_config_parses(self, tmp_path: Path) -> None:
|
|
cfg = AnalysisConfig.model_validate(_minimal_dtw_config(tmp_path))
|
|
assert cfg.config_version == 1
|
|
assert isinstance(cfg.analysis, DtwAnalysis)
|
|
assert cfg.preprocessing.person_index == 0 # default
|
|
|
|
def test_minimal_stats_config_parses(self, tmp_path: Path) -> None:
|
|
cfg = AnalysisConfig.model_validate(_minimal_stats_config(tmp_path))
|
|
assert isinstance(cfg.analysis, StatsAnalysis)
|
|
|
|
def test_config_version_must_be_1(self, tmp_path: Path) -> None:
|
|
config_dict = _minimal_dtw_config(tmp_path)
|
|
config_dict["config_version"] = 99
|
|
with pytest.raises(ValidationError):
|
|
AnalysisConfig.model_validate(config_dict)
|
|
|
|
def test_round_trip_json(self, tmp_path: Path) -> None:
|
|
original = AnalysisConfig.model_validate(_minimal_dtw_config(tmp_path))
|
|
serialised = original.model_dump_json()
|
|
restored = AnalysisConfig.model_validate_json(serialised)
|
|
assert restored == original
|
|
|
|
def test_extra_top_level_field_rejected(self, tmp_path: Path) -> None:
|
|
config_dict = _minimal_dtw_config(tmp_path)
|
|
config_dict["unknown_key"] = "typo"
|
|
with pytest.raises(ValidationError):
|
|
AnalysisConfig.model_validate(config_dict)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# analysis_config_to_dict
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestAnalysisConfigToDict:
|
|
def test_returns_json_safe_dict(self, tmp_path: Path) -> None:
|
|
cfg = AnalysisConfig.model_validate(_minimal_dtw_config(tmp_path))
|
|
dumped = analysis_config_to_dict(cfg)
|
|
# Paths must have become strings.
|
|
assert isinstance(dumped["inputs"]["primary"], str)
|
|
assert isinstance(dumped["output"]["report"], str)
|
|
|
|
def test_round_trips_through_dict(self, tmp_path: Path) -> None:
|
|
original = AnalysisConfig.model_validate(_minimal_dtw_config(tmp_path))
|
|
dumped = analysis_config_to_dict(original)
|
|
restored = AnalysisConfig.model_validate(dumped)
|
|
assert restored == original
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Result sub-schemas
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestDtwResults:
|
|
def test_minimal_construction(self) -> None:
|
|
res = DtwResults(
|
|
kind="dtw",
|
|
method="dtw_all",
|
|
distances=[0.5],
|
|
paths=[[(0, 0), (1, 1)]],
|
|
segment_labels=["full_trial"],
|
|
summary={"mean": 0.5},
|
|
)
|
|
assert res.kind == "dtw"
|
|
|
|
def test_per_joint_distances_shape_is_free(self) -> None:
|
|
# No validator enforces that per_joint_distances outer length
|
|
# matches distances — that's run-time semantics of the
|
|
# executor. Still, verify the field round-trips.
|
|
res = DtwResults(
|
|
kind="dtw",
|
|
method="dtw_per_joint",
|
|
distances=[0.1, 0.2],
|
|
paths=[[(0, 0)], [(0, 0)]],
|
|
per_joint_distances=[[0.05, 0.05], [0.1, 0.1]],
|
|
segment_labels=["rhee_cycles[0]", "rhee_cycles[1]"],
|
|
summary={"mean": 0.15},
|
|
)
|
|
assert res.per_joint_distances is not None
|
|
|
|
|
|
class TestStatsResults:
|
|
def test_round_trip(self) -> None:
|
|
res = StatsResults(
|
|
kind="stats",
|
|
statistics=[
|
|
FeatureSummary(mean=1.0, std=0.1, min=0.8, max=1.2, range=0.4),
|
|
FeatureSummary(mean=1.1, std=0.2, min=0.7, max=1.5, range=0.8),
|
|
],
|
|
segment_labels=["rhee_cycles[0]", "rhee_cycles[1]"],
|
|
)
|
|
dumped = res.model_dump_json()
|
|
restored = StatsResults.model_validate_json(dumped)
|
|
assert restored == res
|
|
|
|
|
|
class TestNoResults:
|
|
def test_construction(self) -> None:
|
|
res = NoResults(kind="none")
|
|
assert res.kind == "none"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# AnalysisReport
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _make_report(tmp_path: Path) -> AnalysisReport:
|
|
config = AnalysisConfig.model_validate(_minimal_dtw_config(tmp_path))
|
|
return AnalysisReport(
|
|
config=config,
|
|
primary=InputSummary(
|
|
path=tmp_path / "primary.json",
|
|
frame_count=300,
|
|
fps=30.0,
|
|
),
|
|
reference=InputSummary(
|
|
path=tmp_path / "reference.json",
|
|
frame_count=300,
|
|
fps=30.0,
|
|
),
|
|
results=DtwResults(
|
|
kind="dtw",
|
|
method="dtw_all",
|
|
distances=[0.42],
|
|
paths=[[(0, 0), (1, 1)]],
|
|
segment_labels=["full_trial"],
|
|
summary={"mean": 0.42, "p50": 0.42},
|
|
),
|
|
)
|
|
|
|
|
|
class TestAnalysisReport:
|
|
def test_schema_version_defaults_to_current(self, tmp_path: Path) -> None:
|
|
report = _make_report(tmp_path)
|
|
assert report.schema_version == CURRENT_VERSION
|
|
|
|
def test_round_trip_json(self, tmp_path: Path) -> None:
|
|
report = _make_report(tmp_path)
|
|
serialised = report.model_dump_json()
|
|
restored = AnalysisReport.model_validate_json(serialised)
|
|
assert restored == report
|
|
|
|
def test_empty_segmentations_default(self, tmp_path: Path) -> None:
|
|
report = _make_report(tmp_path)
|
|
assert report.segmentations == {}
|
|
|
|
def test_extra_field_rejected(self, tmp_path: Path) -> None:
|
|
report = _make_report(tmp_path)
|
|
dumped = report.model_dump(mode="json")
|
|
dumped["mystery_field"] = 1
|
|
with pytest.raises(ValidationError):
|
|
AnalysisReport.model_validate(dumped)
|