From a1c495b2fd51d9ef0f6d90dbb3a05269022fee4e Mon Sep 17 00:00:00 2001 From: Levi Neuwirth Date: Sat, 18 Apr 2026 17:50:21 -0400 Subject: [PATCH] add gait-cycle segmentation to analyzer segment_gait_cycles wraps segment_predictions with a joint_axis extractor and gait-appropriate defaults (joint="rhee", axis="y", min_cycle_seconds=0.4). The joint name resolves through joint_index; axis is a "x"/"y"/"z" string literal converted to the numeric index internally. An invert flag flips peaks and valleys for recording conventions where a heel-strike appears as a local minimum. segment_gait_cycles_bilateral composes the single-side function twice and returns a {"left_heel_strikes", "right_heel_strikes"} dict shape-compatible with VideoPredictions.segmentations, so the caller can merge it directly into a predictions object. Pathological gaits (shuffling, walker-assisted) degrade to an empty segments list rather than raising, inherited from segment_by_peaks' peak-not-found behaviour. Closes the gait-cycle segmentation item in TECHNICAL.md Phase 0. --- CHANGELOG.md | 22 +++- src/neuropose/analyzer/__init__.py | 6 ++ src/neuropose/analyzer/segment.py | 162 ++++++++++++++++++++++++++++ tests/unit/test_analyzer_segment.py | 138 ++++++++++++++++++++++++ 4 files changed, 327 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6cf4d35..a887ff7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -239,6 +239,22 @@ be split into per-release sections once tagging begins. distance is rotation- and translation-invariant. Paper C's pipeline is expected to set `align="procrustes_per_sequence"`; see `TECHNICAL.md` Phase 0. +- **`neuropose.analyzer.segment.segment_gait_cycles`** and + **`segment_gait_cycles_bilateral`** — clinical convenience + wrappers over `segment_predictions` that pre-fill a `joint_axis` + extractor with gait-appropriate defaults (`joint="rhee"`, + `axis="y"`, `min_cycle_seconds=0.4`). The single-side entry point + accepts any berkeley_mhad_43 joint name and any spatial axis as a + string literal `"x" | "y" | "z"`, plus an `invert` flag for + recordings whose vertical axis runs opposite to MeTRAbs's + Y-down world-coordinate convention. The bilateral wrapper runs + the detection on both `lhee` and `rhee` and returns the two + results under `"left_heel_strikes"` / `"right_heel_strikes"` + keys — shape-compatible with `VideoPredictions.segmentations` so + the dict can be merged in directly. Degrades gracefully on + pathological gaits (shuffling, walker-assisted) by returning an + empty segments list rather than raising. Closes the gait-cycle + segmentation item in `TECHNICAL.md` Phase 0. - **`neuropose.io.Provenance`** — reproducibility envelope for every inference run. Populated automatically by `Estimator.process_video` when the model was loaded via `load_model` (the production path) @@ -316,7 +332,11 @@ be split into per-release sections once tagging begins. time-based parameters to frame counts via `metadata.fps`), and `slice_predictions` (split a `VideoPredictions` into one per detected repetition with re-keyed frame names and a rewritten - `frame_count`). Ships four extractor factories — + `frame_count`). Gait-specific convenience wrappers + `segment_gait_cycles` (single heel) and + `segment_gait_cycles_bilateral` (both heels, returning a dict + keyed by `"left_heel_strikes"` / `"right_heel_strikes"`) sit + above `segment_predictions` with clinical defaults. Ships four extractor factories — `joint_axis`, `joint_pair_distance`, `joint_speed`, and `joint_angle` — plus a `JOINT_NAMES` constant for the berkeley_mhad_43 skeleton with a `joint_index(name)` lookup, diff --git a/src/neuropose/analyzer/__init__.py b/src/neuropose/analyzer/__init__.py index 98b6dbe..fb0c97d 100644 --- a/src/neuropose/analyzer/__init__.py +++ b/src/neuropose/analyzer/__init__.py @@ -49,6 +49,7 @@ from neuropose.analyzer.features import ( from neuropose.analyzer.segment import ( JOINT_INDEX, JOINT_NAMES, + AxisLetter, extract_signal, joint_angle, joint_axis, @@ -56,6 +57,8 @@ from neuropose.analyzer.segment import ( joint_pair_distance, joint_speed, segment_by_peaks, + segment_gait_cycles, + segment_gait_cycles_bilateral, segment_predictions, slice_predictions, ) @@ -65,6 +68,7 @@ __all__ = [ "JOINT_NAMES", "AlignMode", "AlignmentDiagnostics", + "AxisLetter", "DTWResult", "FeatureStatistics", "ProcrustesMode", @@ -85,6 +89,8 @@ __all__ = [ "predictions_to_numpy", "procrustes_align", "segment_by_peaks", + "segment_gait_cycles", + "segment_gait_cycles_bilateral", "segment_predictions", "slice_predictions", ] diff --git a/src/neuropose/analyzer/segment.py b/src/neuropose/analyzer/segment.py index 232a725..c97c6ec 100644 --- a/src/neuropose/analyzer/segment.py +++ b/src/neuropose/analyzer/segment.py @@ -30,6 +30,12 @@ Three layers of API are provided, in increasing order of convenience: :class:`~neuropose.io.ExtractorSpec`, converts time-based parameters to frame counts using ``metadata.fps``, and returns a full :class:`~neuropose.io.Segmentation` ready to attach to the predictions. +- :func:`segment_gait_cycles` and :func:`segment_gait_cycles_bilateral` + — clinical convenience wrappers over :func:`segment_predictions` + that pre-fill a :func:`joint_axis` extractor with gait-appropriate + defaults (heel joint, Y axis, 0.4 s minimum cycle). The bilateral + variant returns both sides under ``"left_heel_strikes"`` and + ``"right_heel_strikes"`` keys. - :func:`slice_predictions` — split a :class:`~neuropose.io.VideoPredictions` into one per-repetition :class:`~neuropose.io.VideoPredictions`, useful when downstream code wants per-rep objects rather than windows @@ -66,6 +72,7 @@ installed; a clear :class:`ImportError` surfaces at the first call to from __future__ import annotations from collections.abc import Sequence +from typing import Literal import numpy as np @@ -83,6 +90,11 @@ from neuropose.io import ( VideoPredictions, ) +AxisLetter = Literal["x", "y", "z"] +"""Axis selector used by gait-cycle segmentation helpers.""" + +_AXIS_INDICES: dict[AxisLetter, int] = {"x": 0, "y": 1, "z": 2} + # --------------------------------------------------------------------------- # berkeley_mhad_43 joint names # --------------------------------------------------------------------------- @@ -522,6 +534,156 @@ def segment_predictions( return Segmentation(config=config, segments=segments) +# --------------------------------------------------------------------------- +# Gait-cycle segmentation +# --------------------------------------------------------------------------- + + +def segment_gait_cycles( + predictions: VideoPredictions, + *, + joint: str = "rhee", + axis: AxisLetter = "y", + invert: bool = False, + min_cycle_seconds: float = 0.4, + min_prominence: float | None = None, +) -> Segmentation: + """Segment gait cycles from a single heel's vertical trace. + + Runs valley-to-valley peak detection (the same engine used by + :func:`segment_predictions`) on the chosen joint's coordinate along + the chosen spatial axis. By default, each detected peak corresponds + to one heel-strike — the frame where the heel reaches its lowest + point on the Y-down MeTRAbs world-coordinate convention — and the + returned :class:`~neuropose.io.Segment` windows span one full gait + cycle from the preceding toe-off valley to the following toe-off + valley. + + The function is a **thin wrapper** over :func:`segment_predictions` + with a :func:`joint_axis` extractor; it exists to give clinical + callers a gait-specific entry point with meaningful defaults + (``joint="rhee"``, ``axis="y"``, ``min_cycle_seconds=0.4``) + rather than forcing them to construct the extractor by hand. + + Parameters + ---------- + predictions + Per-video predictions to segment. ``metadata.fps`` is used to + translate ``min_cycle_seconds`` into a sample-count distance + threshold. + joint + Joint name in the berkeley_mhad_43 skeleton — typically + ``"rhee"`` (right heel) or ``"lhee"`` (left heel). Resolved + via :func:`joint_index`. + axis + Spatial axis to track, as ``"x"``, ``"y"``, or ``"z"``. The + default ``"y"`` matches the vertical axis in MeTRAbs's output + (Y-down world coordinates). + invert + If ``True``, negate the extracted signal so that minima + become peaks. Needed when the recording convention makes a + heel-strike appear as a *decrease* in the chosen coordinate + — for example, a camera orientation where the vertical axis + runs bottom-to-top instead of MeTRAbs's default top-to-bottom. + min_cycle_seconds + Minimum gait-cycle duration. Used as scipy's + ``find_peaks(distance=...)`` parameter after conversion to + frame count via ``metadata.fps``. Defaults to ``0.4`` seconds, + which rejects noise peaks on even the fastest human gaits + (~120 strides/min) while retaining every real cadence. + min_prominence + Forwarded to :func:`segment_by_peaks` to filter out shallow + local maxima that aren't real heel-strikes. In MeTRAbs units + (millimetres) a threshold of 20 to 50 mm is typical for + able-bodied gait; leave ``None`` to accept every peak scipy + identifies. + + Returns + ------- + Segmentation + A :class:`~neuropose.io.Segmentation` paired with the full + :class:`~neuropose.io.SegmentationConfig` that produced it, so + the output is self-describing when persisted. The segments + list is **empty** rather than an exception when no peaks are + detected — a common outcome for shuffling gaits or + walker-assisted trials. + + Raises + ------ + KeyError + If ``joint`` is not a known berkeley_mhad_43 joint name. + ValueError + If ``axis`` is not one of ``"x"``, ``"y"``, ``"z"``, or if + ``predictions`` has zero frames, or if ``metadata.fps`` is + non-positive. + ImportError + If :mod:`scipy` is not installed. + """ + if axis not in _AXIS_INDICES: + raise ValueError(f"axis must be one of 'x', 'y', 'z'; got {axis!r}") + joint_idx = joint_index(joint) + axis_idx = _AXIS_INDICES[axis] + extractor = joint_axis(joint_idx, axis_idx, invert=invert) + return segment_predictions( + predictions, + extractor, + min_distance_seconds=min_cycle_seconds, + min_prominence=min_prominence, + ) + + +def segment_gait_cycles_bilateral( + predictions: VideoPredictions, + *, + axis: AxisLetter = "y", + invert: bool = False, + min_cycle_seconds: float = 0.4, + min_prominence: float | None = None, +) -> dict[str, Segmentation]: + """Segment gait cycles for both heels. + + Runs :func:`segment_gait_cycles` twice — once with ``joint="lhee"`` + and once with ``joint="rhee"`` — and returns the two results under + the keys ``"left_heel_strikes"`` and ``"right_heel_strikes"``. The + returned dict is shape-compatible with + :class:`~neuropose.io.VideoPredictions.segmentations` so it can be + merged directly into a predictions object and persisted to + ``results.json`` via the usual save path. + + Parameters + ---------- + predictions, axis, invert, min_cycle_seconds, min_prominence + Forwarded to :func:`segment_gait_cycles`; see that function's + docstring for details. + + Returns + ------- + dict[str, Segmentation] + Two-keyed mapping with the left and right heel segmentations + under ``"left_heel_strikes"`` and ``"right_heel_strikes"``. + Either side may carry an empty segments list if its heel's + trace contained no detectable strikes. + """ + return { + "left_heel_strikes": segment_gait_cycles( + predictions, + joint="lhee", + axis=axis, + invert=invert, + min_cycle_seconds=min_cycle_seconds, + min_prominence=min_prominence, + ), + "right_heel_strikes": segment_gait_cycles( + predictions, + joint="rhee", + axis=axis, + invert=invert, + min_cycle_seconds=min_cycle_seconds, + min_prominence=min_prominence, + ), + } + + # --------------------------------------------------------------------------- # Slicing: one VideoPredictions per segment # --------------------------------------------------------------------------- diff --git a/tests/unit/test_analyzer_segment.py b/tests/unit/test_analyzer_segment.py index 76078f5..bc48fbc 100644 --- a/tests/unit/test_analyzer_segment.py +++ b/tests/unit/test_analyzer_segment.py @@ -33,6 +33,8 @@ from neuropose.analyzer.segment import ( joint_pair_distance, joint_speed, segment_by_peaks, + segment_gait_cycles, + segment_gait_cycles_bilateral, segment_predictions, slice_predictions, ) @@ -429,3 +431,139 @@ class TestSlicePredictions: sliced = slice_predictions(preds, segments)[0] # frame_000000 of the slice must equal frame_000050 of the source assert sliced["frame_000000"].poses3d == preds["frame_000050"].poses3d + + +# --------------------------------------------------------------------------- +# segment_gait_cycles / segment_gait_cycles_bilateral +# --------------------------------------------------------------------------- + + +def _heel_signal(num_cycles: int, frames_per_cycle: int) -> np.ndarray: + """A clean sinusoid standing in for a heel's vertical trace.""" + total = num_cycles * frames_per_cycle + t = np.linspace(0.0, num_cycles * 2.0 * math.pi, total, endpoint=False) + # Amplitude chosen so min_prominence tests have a non-trivial range. + return (np.sin(t) * 100.0 + 100.0).astype(float) + + +class TestSegmentGaitCycles: + def test_detects_expected_number_of_cycles(self) -> None: + # 5 cycles at 30 fps, 30 frames per cycle = 1.0 s per stride. + # Well inside the default min_cycle_seconds=0.4 gate. + signal = _heel_signal(num_cycles=5, frames_per_cycle=30) + preds = _make_predictions(signal, joint=JOINT_INDEX["rhee"]) + seg = segment_gait_cycles(preds, joint="rhee", axis="y") + assert len(seg.segments) == 5 + + def test_config_records_inputs(self) -> None: + signal = _heel_signal(num_cycles=3, frames_per_cycle=30) + preds = _make_predictions(signal, joint=JOINT_INDEX["rhee"]) + seg = segment_gait_cycles( + preds, + joint="rhee", + axis="y", + min_cycle_seconds=0.5, + min_prominence=10.0, + ) + assert isinstance(seg, Segmentation) + assert isinstance(seg.config.extractor, JointAxisExtractor) + assert seg.config.extractor.joint == JOINT_INDEX["rhee"] + assert seg.config.extractor.axis == 1 # "y" → 1 + assert seg.config.extractor.invert is False + assert seg.config.min_distance_seconds == 0.5 + assert seg.config.min_prominence == 10.0 + + def test_axis_selection(self) -> None: + # Put the signal on the X axis instead of Y. + signal = _heel_signal(num_cycles=4, frames_per_cycle=30) + preds = _make_predictions(signal, joint=JOINT_INDEX["rhee"], axis=0) + seg_y = segment_gait_cycles(preds, joint="rhee", axis="y") + seg_x = segment_gait_cycles(preds, joint="rhee", axis="x") + # Y is all-zeros (flat → no peaks), X carries the signal. + assert len(seg_y.segments) == 0 + assert len(seg_x.segments) == 4 + + def test_invert_flips_peaks_and_valleys(self) -> None: + # Invert the heel trace; with invert=True, the original valleys + # become the peaks detected as heel-strikes. + signal = _heel_signal(num_cycles=4, frames_per_cycle=30) + preds = _make_predictions(signal, joint=JOINT_INDEX["rhee"]) + seg_plain = segment_gait_cycles(preds, joint="rhee", axis="y", invert=False) + seg_inverted = segment_gait_cycles(preds, joint="rhee", axis="y", invert=True) + # Both detect four distinct events (peaks in either the signal + # or its negation). Peaks differ by roughly half a cycle. + assert len(seg_plain.segments) == 4 + assert len(seg_inverted.segments) == 4 + plain_peaks = [s.peak for s in seg_plain.segments] + inverted_peaks = [s.peak for s in seg_inverted.segments] + assert plain_peaks != inverted_peaks + + def test_pathological_flat_signal_returns_empty(self) -> None: + # A subject whose heel never leaves the ground — no peaks. + signal = np.zeros(120) + preds = _make_predictions(signal, joint=JOINT_INDEX["rhee"]) + seg = segment_gait_cycles(preds, joint="rhee", axis="y") + assert seg.segments == [] + + def test_min_cycle_seconds_rejects_close_peaks(self) -> None: + # 10 cycles in 60 frames @ 30 fps = 0.2 s per cycle. + # min_cycle_seconds=0.4 should reject all but every-other peak. + signal = _heel_signal(num_cycles=10, frames_per_cycle=6) + preds = _make_predictions(signal, joint=JOINT_INDEX["rhee"]) + seg_permissive = segment_gait_cycles(preds, joint="rhee", min_cycle_seconds=0.0) + seg_strict = segment_gait_cycles(preds, joint="rhee", min_cycle_seconds=0.4) + # Strict mode drops peaks that are too close together. + assert len(seg_strict.segments) < len(seg_permissive.segments) + + def test_unknown_joint_raises_key_error(self) -> None: + signal = _heel_signal(num_cycles=3, frames_per_cycle=30) + preds = _make_predictions(signal, joint=JOINT_INDEX["rhee"]) + with pytest.raises(KeyError, match="unknown joint"): + segment_gait_cycles(preds, joint="left_heel") # wrong name + + def test_invalid_axis_raises_value_error(self) -> None: + signal = _heel_signal(num_cycles=3, frames_per_cycle=30) + preds = _make_predictions(signal, joint=JOINT_INDEX["rhee"]) + with pytest.raises(ValueError, match="axis must be one of"): + segment_gait_cycles(preds, joint="rhee", axis="w") # type: ignore[arg-type] + + +class TestSegmentGaitCyclesBilateral: + def test_returns_both_keys(self) -> None: + signal = _heel_signal(num_cycles=3, frames_per_cycle=30) + # Put the same signal on both heels so both sides find cycles. + preds = _make_predictions(signal, joint=JOINT_INDEX["rhee"]) + # Rebuild predictions with lhee populated too. + frames = {} + for i, value in enumerate(signal): + poses = [[[0.0, 0.0, 0.0] for _ in range(NUM_JOINTS)]] + poses[0][JOINT_INDEX["lhee"]][1] = float(value) + poses[0][JOINT_INDEX["rhee"]][1] = float(value) + frames[f"frame_{i:06d}"] = { + "boxes": [[0.0, 0.0, 1.0, 1.0, 0.9]], + "poses3d": poses, + "poses2d": [[[0.0, 0.0]] * NUM_JOINTS], + } + preds = VideoPredictions.model_validate( + { + "metadata": { + "frame_count": len(signal), + "fps": 30.0, + "width": 640, + "height": 480, + }, + "frames": frames, + } + ) + result = segment_gait_cycles_bilateral(preds) + assert set(result.keys()) == {"left_heel_strikes", "right_heel_strikes"} + assert len(result["left_heel_strikes"].segments) == 3 + assert len(result["right_heel_strikes"].segments) == 3 + + def test_pathological_one_side_returns_empty_for_that_side(self) -> None: + # Only the right heel carries a signal; left heel is flat. + signal = _heel_signal(num_cycles=3, frames_per_cycle=30) + preds = _make_predictions(signal, joint=JOINT_INDEX["rhee"]) + result = segment_gait_cycles_bilateral(preds) + assert len(result["right_heel_strikes"].segments) == 3 + assert result["left_heel_strikes"].segments == []