neuropose/tests/integration/test_joint_names_drift.py

61 lines
2.4 KiB
Python

"""Drift test for the hardcoded berkeley_mhad_43 joint-name constant.
:mod:`neuropose.analyzer.segment` ships ``JOINT_NAMES`` as a frozen
tuple of 43 strings so that post-processing callers can resolve
``"rwri"`` → index without loading a multi-gigabyte TensorFlow model.
That is only safe while the hardcoded tuple actually matches what the
pinned MeTRAbs SavedModel reports.
This test loads the real model via :func:`neuropose._model.load_metrabs_model`
and asserts that ``JOINT_NAMES`` is byte-identical to
``model.per_skeleton_joint_names["berkeley_mhad_43"].numpy().astype(str)``.
If MeTRAbs ever ships a new skeleton under the same name — or if we bump
the model pin to one whose ``berkeley_mhad_43`` skeleton is spelled
differently — this test is the drift detector.
Like every test under ``tests/integration/`` the file is marked
``@pytest.mark.slow`` and only runs under ``pytest --runslow``.
"""
from __future__ import annotations
from pathlib import Path
import pytest
from neuropose._model import load_metrabs_model
from neuropose.analyzer.segment import JOINT_NAMES
pytestmark = pytest.mark.slow
@pytest.fixture(scope="module")
def metrabs_model_cache_dir(tmp_path_factory: pytest.TempPathFactory) -> Path:
"""Module-scoped cache dir so the model downloads at most once per run.
Function scope would re-download on every test; session scope would
collide with the estimator smoke-test cache. Module scope is the
right middle ground for a file that only needs the model loaded
once.
"""
return tmp_path_factory.mktemp("neuropose_joint_names_model_cache")
def test_joint_names_match_pinned_model(metrabs_model_cache_dir: Path) -> None:
"""Hardcoded ``JOINT_NAMES`` must match the loaded MeTRAbs skeleton.
If this fails, the expected fix is:
1. Update :data:`neuropose.analyzer.segment.JOINT_NAMES` in the same
commit that bumps the model pin in :mod:`neuropose._model`.
2. Cross-check any CLI or docs that embed hardcoded joint names.
"""
model = load_metrabs_model(cache_dir=metrabs_model_cache_dir)
tensor = model.per_skeleton_joint_names["berkeley_mhad_43"]
model_names = tuple(tensor.numpy().astype(str).tolist())
assert model_names == JOINT_NAMES, (
"JOINT_NAMES drift detected — the hardcoded tuple in "
"neuropose.analyzer.segment no longer matches the MeTRAbs model. "
f"Model reports: {model_names}"
)