pin tensorflow, comprehensiveness
This commit is contained in:
parent
3720911dac
commit
62d0b0789c
|
|
@ -0,0 +1,76 @@
|
|||
# ---------------------------------------------------------------------------
|
||||
# Files excluded from the Docker build context.
|
||||
#
|
||||
# The goal is to keep `docker build` fast (small context) and to avoid
|
||||
# accidentally baking developer tooling, caches, test data, or
|
||||
# IRB-sensitive artifacts into the image. Anything the runtime image
|
||||
# actually needs (pyproject.toml, README.md, LICENSE, src/) is copied
|
||||
# explicitly in the Dockerfile, so this ignore list can afford to be
|
||||
# aggressive.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Version control
|
||||
.git
|
||||
.gitignore
|
||||
.gitattributes
|
||||
|
||||
# CI / developer configuration
|
||||
.github
|
||||
.pre-commit-config.yaml
|
||||
.python-version
|
||||
|
||||
# Python caches and virtual environments
|
||||
.venv
|
||||
.venv-*
|
||||
venv
|
||||
env
|
||||
ENV
|
||||
**/__pycache__
|
||||
**/*.pyc
|
||||
**/*.pyo
|
||||
**/*.pyd
|
||||
*.egg-info
|
||||
|
||||
# Lint / type / test caches
|
||||
.pytest_cache
|
||||
.ruff_cache
|
||||
.mypy_cache
|
||||
.pyright
|
||||
.coverage
|
||||
.coverage.*
|
||||
htmlcov
|
||||
.tox
|
||||
|
||||
# Build outputs
|
||||
build
|
||||
dist
|
||||
site
|
||||
|
||||
# Editor metadata
|
||||
.vscode
|
||||
.idea
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
# Documentation sources (mkdocs builds these on the docs CI job, not
|
||||
# inside the runtime image) and the mkdocs config itself.
|
||||
docs
|
||||
mkdocs.yml
|
||||
notebooks
|
||||
|
||||
# Ancillary developer scripts — the model downloader is useful in a
|
||||
# dev context but is redundant inside the image because the daemon
|
||||
# pulls the model on first start.
|
||||
scripts
|
||||
|
||||
# Research / planning artifacts — not needed at runtime.
|
||||
RESEARCH.md
|
||||
audit.md
|
||||
|
||||
# Tests — the image is a runtime artifact, not a test environment.
|
||||
tests
|
||||
|
||||
# OS / tooling noise
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
|
@ -0,0 +1,336 @@
|
|||
# Changelog
|
||||
|
||||
All notable changes to NeuroPose are recorded in this file.
|
||||
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
This section covers the ground-up rewrite of NeuroPose. The entries
|
||||
below describe the difference between the previous internal prototype
|
||||
and the state of the repository at the first tagged release, and will
|
||||
be split into per-release sections once tagging begins.
|
||||
|
||||
### Added
|
||||
|
||||
#### Package structure and tooling
|
||||
|
||||
- `src/neuropose/` package layout with `py.typed` marker, MIT `LICENSE`,
|
||||
policy-enforcing `.gitignore`, pinned Python 3.11 (`.python-version`),
|
||||
and `pyproject.toml` with full project metadata, classifiers, and
|
||||
URL pointers. The runtime TensorFlow dependency is pinned to
|
||||
`tensorflow>=2.16,<3.0` — see *Changed* below for the rationale.
|
||||
- `[project.optional-dependencies].analysis` extra for fastdtw, scipy,
|
||||
scikit-learn, and sktime — install via `pip install neuropose[analysis]`.
|
||||
- `[project.optional-dependencies].metal` extra pulling
|
||||
`tensorflow-metal>=1.2,<2` under `sys_platform == 'darwin' and
|
||||
platform_machine == 'arm64'` environment markers. Opt-in only via
|
||||
`pip install 'neuropose[metal]'` or `uv sync --extra metal`; silently
|
||||
no-op on every non-Apple-Silicon platform. The Metal path is **not**
|
||||
exercised in CI and is documented as experimental in
|
||||
`docs/getting-started.md` — users enabling it are expected to
|
||||
spot-check numerics against the CPU path before trusting results
|
||||
downstream.
|
||||
- `[dependency-groups].dev` (PEP 735) with the full dev + docs + analyzer
|
||||
toolchain: pytest, pytest-cov, ruff, pyright, pre-commit,
|
||||
mkdocs-material, mkdocstrings, fastdtw, and scipy. `uv sync --group dev`
|
||||
gives contributors everything needed to run the whole suite.
|
||||
- `AUTHORS.md`, `CITATION.cff` (with a MeTRAbs upstream `references:`
|
||||
entry), and a MIT-licensed `LICENSE` with an explicit MeTRAbs
|
||||
attribution paragraph.
|
||||
- Pre-commit configuration (`.pre-commit-config.yaml`) running ruff,
|
||||
ruff-format, gitleaks (secret scanning), a 500 KB-limit
|
||||
large-files hook, end-of-file fixers, trailing-whitespace fixers,
|
||||
and YAML/TOML/JSON validators. Pyright is deliberately **not** in
|
||||
pre-commit — it runs in CI only, so pre-commit stays fast.
|
||||
- Ruff configuration in `pyproject.toml` with a deliberately broad
|
||||
rule selection (pycodestyle, pyflakes, isort, bugbear, pyupgrade,
|
||||
simplify, ruff-specific, pep8-naming, comprehensions, pathlib,
|
||||
pytest-style, tidy-imports, numpy-specific, pydocstyle with numpy
|
||||
convention). Per-file ignores for tests and private modules.
|
||||
- Pyright configuration in `standard` mode (not `strict` — TF/OpenCV
|
||||
stubs would otherwise drown the signal). Unknown-type reports are
|
||||
explicitly silenced until the TensorFlow version pin is settled.
|
||||
- Pytest configuration with strict markers, an opt-in `slow` marker,
|
||||
and a `--runslow` CLI flag implemented in
|
||||
`tests/conftest.py::pytest_collection_modifyitems` so integration
|
||||
tests stay out of the default run.
|
||||
|
||||
#### CI / infrastructure
|
||||
|
||||
- GitHub Actions workflow `.github/workflows/ci.yml` running three
|
||||
parallel jobs — **lint** (ruff), **typecheck** (pyright), and
|
||||
**test** (pytest) — on every push and PR to `main`. Uses `uv` with
|
||||
a pinned version (`0.9.16`) and cache-enabled setup for fast reruns.
|
||||
Concurrency control cancels superseded runs on the same branch.
|
||||
- GitHub Actions workflow `.github/workflows/docs.yml` that builds the
|
||||
mkdocs-material site on every relevant push and uploads the rendered
|
||||
site as a 14-day workflow artifact. GitHub Pages deployment is
|
||||
intentionally not wired up yet; the workflow header comment
|
||||
describes what to add when the repo flips public.
|
||||
|
||||
#### Runtime modules
|
||||
|
||||
- **`neuropose.config`** — `Settings` class built on
|
||||
`pydantic-settings`. Field-level validation for `device`,
|
||||
`poll_interval_seconds`, and `default_fov_degrees`; explicit
|
||||
`from_yaml()` classmethod (no implicit config-file discovery); XDG
|
||||
defaults for `data_dir` and `model_cache_dir` (`~/.local/share/neuropose/…`)
|
||||
so runtime data never lives inside the repository; `ensure_dirs()`
|
||||
as an explicit method so construction remains filesystem-side-effect-free.
|
||||
- **`neuropose.io`** — validated prediction schemas:
|
||||
`FramePrediction` (frozen), `VideoMetadata` (frame count, fps,
|
||||
width, height), `VideoPredictions` (metadata envelope + frames
|
||||
mapping), `JobResults`, `JobStatus` enum, `JobStatusEntry` (with a
|
||||
structured `error` field), and `StatusFile`. Load and save helpers
|
||||
with an atomic tmp-file-then-rename pattern for every state file.
|
||||
`load_status` is deliberately crash-resilient: missing, corrupt,
|
||||
or non-mapping JSON returns an empty `StatusFile` rather than
|
||||
raising.
|
||||
- **`neuropose.estimator`** — `Estimator` class that streams frames
|
||||
directly from OpenCV into the model, with no intermediate write-to-
|
||||
disk-then-read-back-as-PNG round trip. Returns a typed
|
||||
`ProcessVideoResult` containing a validated `VideoPredictions`
|
||||
object; does not touch the filesystem. Constructor accepts an
|
||||
injected model for testability; `load_model()` delegates to
|
||||
`neuropose._model.load_metrabs_model()`. Typed exception hierarchy:
|
||||
`EstimatorError`, `ModelNotLoadedError`, `VideoDecodeError`.
|
||||
Optional per-frame `progress` callback for long videos. Frame
|
||||
identifier convention is `frame_000000` (six-digit zero-pad, no
|
||||
extension — no file is implied).
|
||||
- **`neuropose.visualize`** — `visualize_predictions()` for per-frame
|
||||
2D + 3D overlay rendering. `matplotlib.use("Agg")` is called inside
|
||||
the function rather than at module import, so `import neuropose.visualize`
|
||||
has no global side effect. Explicit deep-copy of `poses3d` before
|
||||
axis rotation to prevent the aliasing bug from the previous
|
||||
prototype. Supports `frame_indices` for rendering a subset of
|
||||
frames.
|
||||
- **`neuropose.interfacer`** — `Interfacer` job-lifecycle daemon with
|
||||
dependency-injected `Settings` and `Estimator`. Single-instance
|
||||
enforcement via `fcntl.flock` on `data_dir/.neuropose.lock`.
|
||||
Crash-recovery `recover_stuck_jobs()` that marks any status entries
|
||||
left in `processing` state as failed with an "interrupted"
|
||||
message and quarantines their inputs. Graceful shutdown on SIGINT/
|
||||
SIGTERM with an interruptible sleep. Structured error fields on
|
||||
every failed job. `run_once()` factored out of the main loop so
|
||||
tests can drive single iterations without threading. Quarantine
|
||||
collision handling (`job_a.1`, `job_a.2`, …) and empty-directory
|
||||
silent-skip heuristic (mid-copy directories are not marked failed).
|
||||
- **`neuropose._model`** — MeTRAbs model loader. Downloads the pinned
|
||||
tarball from the upstream RWTH Aachen URL
|
||||
(`metrabs_eff2l_y4_384px_800k_28ds.tar.gz`), verifies its SHA-256
|
||||
checksum, atomically extracts to a staging directory and renames
|
||||
into place, and loads via `tf.saved_model.load`. Streams the
|
||||
download and hash computation in 1 MB chunks so memory is flat.
|
||||
One automatic retry on SHA-256 mismatch (in case the previous
|
||||
download was truncated). Post-load interface check for
|
||||
`detect_poses`, `per_skeleton_joint_names`, and
|
||||
`per_skeleton_joint_edges`.
|
||||
- **`neuropose.analyzer`** — post-processing subpackage with lazy
|
||||
imports for the heavy dependencies:
|
||||
- `analyzer.dtw` — three DTW entry points (`dtw_all`,
|
||||
`dtw_per_joint`, `dtw_relation`) over fastdtw, with a frozen
|
||||
`DTWResult` dataclass. See `RESEARCH.md` for the ongoing
|
||||
methodology investigation.
|
||||
- `analyzer.features` — `predictions_to_numpy`,
|
||||
`normalize_pose_sequence` (uniform and axis-wise),
|
||||
`pad_sequences` (edge-padding), `extract_joint_angles` (NaN on
|
||||
degenerate vectors), `extract_feature_statistics`
|
||||
(`FeatureStatistics` frozen dataclass), and a `find_peaks` thin
|
||||
wrapper around `scipy.signal.find_peaks`.
|
||||
- **`neuropose.cli`** — Typer-based command-line interface with
|
||||
three subcommands: `watch` (run the daemon), `process <video>`
|
||||
(run the estimator on a single video), and `analyze <results>`
|
||||
(stub). Global options `--config/-c`, `--verbose/-v`, `--quiet/-q`,
|
||||
`--version`. Structured error handling turns expected exceptions
|
||||
(`FileNotFoundError` on config, `ValidationError`, `AlreadyRunningError`,
|
||||
`NotImplementedError`, `KeyboardInterrupt`) into clear stderr
|
||||
messages and distinct exit codes (`EXIT_OK=0`, `EXIT_USAGE=2`,
|
||||
`EXIT_PENDING=3`, `EXIT_INTERRUPTED=130`). The CLI entry point is
|
||||
wired in `[project.scripts]` as `neuropose = "neuropose.cli:run"`.
|
||||
|
||||
#### Documentation
|
||||
|
||||
- **mkdocs-material documentation site** under `docs/` with the full
|
||||
theme configuration (light/dark toggle, tabs navigation, search),
|
||||
`mkdocstrings` Python handler set to numpy docstring style with
|
||||
source links, and a `pymdownx` extension set for admonitions,
|
||||
tabbed content, collapsible details, and syntax-highlighted code
|
||||
blocks. Nav: Home → Getting Started → Architecture → API Reference
|
||||
(auto-generated from module docstrings) → Development → Deployment.
|
||||
- Prose documentation pages: `docs/index.md` (public landing page),
|
||||
`docs/getting-started.md` (install, CLI, output schema, Python API,
|
||||
visualization, troubleshooting), `docs/architecture.md` (three-stage
|
||||
pipeline, data flow, runtime directory layout, design principles),
|
||||
`docs/development.md` (contributor setup, tests, lint/type,
|
||||
commit hygiene, release process stub), and `docs/deployment.md`
|
||||
(systemd user unit, Docker pointer, GPU notes, backup guidance).
|
||||
- API reference stubs `docs/api/{config,estimator,interfacer,io,visualize}.md`
|
||||
— each is a two-line file containing a `:::` mkdocstrings directive,
|
||||
so the API documentation is generated from the source docstrings
|
||||
at build time and cannot drift out of sync.
|
||||
- `RESEARCH.md` at the repo root: a living R&D log for DTW
|
||||
methodology alternatives and MeTRAbs self-hosting / fine-tuning
|
||||
plans. Not user-facing documentation; not linked from the mkdocs
|
||||
nav.
|
||||
|
||||
#### Tests
|
||||
|
||||
- `tests/unit/` covering configuration (defaults, validation, YAML
|
||||
loading, env overrides, `ensure_dirs`), IO schema and helpers
|
||||
(roundtrip, atomic save, frozen-model guarantees, corruption
|
||||
tolerance), the estimator (construction, model-guard, process path
|
||||
with fake MeTRAbs model, error paths), the visualize module
|
||||
(smoke tests + an anti-regression check for the audit §6 aliasing
|
||||
bug), the interfacer (construction, discovery, process-job happy
|
||||
and failure paths, stuck-job recovery, lock, run_once,
|
||||
interruptible sleep), the CLI (top-level options, config handling,
|
||||
each subcommand's error path), the analyzer DTW helpers, and the
|
||||
analyzer features helpers.
|
||||
- `tests/conftest.py` with an autouse `_isolate_environment` fixture
|
||||
that redirects `$HOME` and `$XDG_DATA_HOME` at a per-test temp
|
||||
directory so no test can accidentally write to the developer's real
|
||||
machine, and clears any `NEUROPOSE_*` env vars. Adds a
|
||||
`synthetic_video` fixture (cv2-generated 5-frame MJPG AVI sized
|
||||
for most unit tests) and a `fake_metrabs_model` fixture.
|
||||
- `tests/integration/test_estimator_smoke.py` — end-to-end model
|
||||
loader + estimator smoke test against the real MeTRAbs tarball,
|
||||
marked `@pytest.mark.slow`, skipped by default, opt-in via
|
||||
`--runslow`. Uses a session-scoped model cache so the download
|
||||
happens at most once per run.
|
||||
|
||||
#### Operations
|
||||
|
||||
- `Dockerfile` — CPU image based on `python:3.11-slim-bookworm`.
|
||||
Installs the package with the `analysis` extra, runs as non-root
|
||||
user `neuropose` (UID 1000), exposes `/data` as a volume, sets
|
||||
`NEUROPOSE_DATA_DIR` and `NEUROPOSE_MODEL_CACHE_DIR` to point at
|
||||
the mounted volume, and uses `ENTRYPOINT ["neuropose"]` with
|
||||
`CMD ["watch"]` so the default is the daemon and overrides are
|
||||
ergonomic.
|
||||
- `.dockerignore` that aggressively excludes developer tooling,
|
||||
caches, tests, documentation sources, research notes, and
|
||||
ancillary scripts from the build context.
|
||||
- `scripts/download_model.py` — standalone pre-warm script that
|
||||
invokes `load_metrabs_model()` with an optional `--cache-dir`
|
||||
override. Useful for seeding a deployment's cache before cutting
|
||||
off network access.
|
||||
|
||||
### Changed
|
||||
|
||||
- **Relicensed from AGPL-3.0 (used in the prior internal prototype)
|
||||
to MIT.** The prior license was copied from precedent rather than
|
||||
chosen deliberately; the MIT relicense better matches both the
|
||||
project's "research software others can build on" intent and the
|
||||
upstream MeTRAbs license.
|
||||
- Reorganised from the prior `backend/` + runtime-data layout into
|
||||
a `src/neuropose/` Python package. Runtime data now lives outside
|
||||
the repository by default (under `$XDG_DATA_HOME/neuropose/`) so
|
||||
subject-identifying inputs cannot accidentally end up in a
|
||||
`git add`.
|
||||
- Frame identifier convention changed from `frame_0000.png` (old,
|
||||
misleading — no PNG file exists) to `frame_000000` (six-digit
|
||||
zero-pad, no extension, pure identifier).
|
||||
- Estimator API: `process_video()` now returns a typed
|
||||
`ProcessVideoResult` containing a validated `VideoPredictions`
|
||||
object, instead of a stringly-typed dict with `results_path` and
|
||||
`frame_count`. The estimator no longer owns filesystem
|
||||
destinations — the caller decides where to save.
|
||||
- `VideoPredictions` schema now carries a `VideoMetadata` envelope
|
||||
(frame count, fps, width, height) alongside the per-frame
|
||||
predictions. Downstream analysis can convert frame indices to
|
||||
real time without needing access to the original video.
|
||||
- Interfacer uses `datetime.now(UTC)` instead of the deprecated
|
||||
`datetime.utcnow()`, addresses the "no-videos"-vs-exception-path
|
||||
inconsistency (both now quarantine), and persists a structured
|
||||
`error` string on every failure for grep-friendly diagnostics.
|
||||
- **TensorFlow pin tightened to `tensorflow>=2.16,<3.0`.** The 2.16
|
||||
floor is the first release with native `darwin/arm64` wheels under
|
||||
the `tensorflow` package name on PyPI, so a single dependency line
|
||||
works across Linux x86_64, Linux arm64, and Apple Silicon macOS
|
||||
without platform markers or a separate `tensorflow-macos` package.
|
||||
Empirical verification: the pinned MeTRAbs SavedModel
|
||||
(`metrabs_eff2l_y4_384px_800k_28ds`, serialized with TF 2.10)
|
||||
loads and runs `detect_poses` end-to-end on TF 2.21 + Keras 3 with
|
||||
no errors, and exposes only stock TensorFlow ops (zero MeTRAbs
|
||||
custom kernels). Full test matrix and op inventory in
|
||||
`RESEARCH.md`.
|
||||
- Operating-system classifiers in `pyproject.toml` extended from
|
||||
Linux-only to `POSIX` + `POSIX :: Linux` + `MacOS`, reflecting the
|
||||
Apple Silicon support that the TF 2.16 floor makes real.
|
||||
|
||||
### Removed
|
||||
|
||||
- The previous `backend/analyzer.py` and `backend/validator.py`
|
||||
stubs, which were non-functional and had never been run
|
||||
successfully. `analyzer.py` is reintroduced as a pure-function
|
||||
subpackage (`neuropose.analyzer`) rewritten from the prior
|
||||
code's design intent. `validator.py` is reintroduced as a real
|
||||
pytest suite (`tests/unit/` and `tests/integration/`).
|
||||
- The previous `reconstruct_from_frames` helper on the `Estimator`
|
||||
— dead code, broken (dereferenced `self.OUTPUT_PATH`, which did
|
||||
not exist), hardcoded 10 fps, never called. ffmpeg is a better
|
||||
tool for this and can be invoked directly.
|
||||
- The previous `__main__` placeholder (`print("in main"); sys.exit()`)
|
||||
on `estimator.py`. The real CLI now lives in `neuropose.cli`.
|
||||
- Every file under `docs/` in the previous prototype. All of the
|
||||
pydoc-generated HTML, Org-mode sources, and handwritten markdown
|
||||
described an older version of the API with methods
|
||||
(`bind_and_block`, `construct_paths`, `toggle_visualization`,
|
||||
`propagate_fatal_error`, etc.) that no longer exist. The docs are
|
||||
now auto-generated from source docstrings via mkdocstrings so
|
||||
drift is mechanically impossible.
|
||||
- The previous Dockerfile, which referenced a non-existent
|
||||
`backend/requirements.txt`, attempted to `COPY ./model /app/model`
|
||||
(no such directory), and set `CMD ["uvicorn", "main:app"]` for a
|
||||
FastAPI app that never existed.
|
||||
- The previous `install/install.sh`, `install/#install.sh#` (an
|
||||
Emacs autosave file), `install/install.sh~` (an Emacs backup file),
|
||||
and `install/environment.yml`. The conda + `git+https` install
|
||||
story is replaced by `uv` + a single `pyproject.toml`.
|
||||
- The previous `bit.ly/metrabs_1` URL shortener for the model
|
||||
download, replaced by a pinned canonical URL on the upstream
|
||||
RWTH Aachen "omnomnom" host, with SHA-256 verification on
|
||||
download. See `RESEARCH.md` for the plan to mirror to
|
||||
self-hosted storage.
|
||||
|
||||
### Security
|
||||
|
||||
- Large-files pre-commit hook (`check-added-large-files` with a
|
||||
500 KB limit) blocks accidental commits of subject data or model
|
||||
weights.
|
||||
- Gitleaks pre-commit hook scans every staged change for secret
|
||||
material.
|
||||
- Dockerfile runs as a non-root user (UID 1000, `neuropose`) by
|
||||
default.
|
||||
- Tarfile extraction uses the `filter="data"` option to block path
|
||||
traversal and other tar-bomb attacks during MeTRAbs model
|
||||
extraction.
|
||||
- SHA-256 pinning of the MeTRAbs model artifact. A change to the
|
||||
upstream tarball contents fails the checksum verification and
|
||||
requires a human-reviewed diff before the new artifact is
|
||||
trusted.
|
||||
|
||||
### Known limitations
|
||||
|
||||
- Apple Silicon support is established by-construction (TF 2.16+
|
||||
publishes native `darwin/arm64` wheels and the MeTRAbs SavedModel
|
||||
uses only stock ops verified portable on TF 2.21) but has not yet
|
||||
been exercised on real Apple Silicon hardware. A `macos-14` CI
|
||||
matrix entry covering the unit tests is the cheapest way to catch
|
||||
any regression and is planned as a follow-up.
|
||||
- Classification wrappers on top of sktime are deliberately **not**
|
||||
included in `neuropose.analyzer` for this release. See `RESEARCH.md`
|
||||
for the reasoning and the plan.
|
||||
- GPU support in Docker is not yet shipped (`Dockerfile.gpu` is
|
||||
planned). The existing `Dockerfile` runs CPU-only.
|
||||
- `neuropose analyze` is a CLI stub that exits with a pending
|
||||
message. The analyzer subpackage is usable from Python directly;
|
||||
the CLI wrapper will follow once the analysis pipeline has a
|
||||
concrete shape worth wrapping.
|
||||
- The data-handling policy referenced from `docs/deployment.md` and
|
||||
`docs/index.md` (`docs/data-policy.md`) is being authored
|
||||
separately and is not part of this changelog entry.
|
||||
|
||||
[Unreleased]: https://git.levineuwirth.org/neuwirth/neuropose/compare/initial...HEAD
|
||||
|
|
@ -0,0 +1,89 @@
|
|||
# syntax=docker/dockerfile:1.7
|
||||
# ---------------------------------------------------------------------------
|
||||
# NeuroPose — CPU Docker image.
|
||||
#
|
||||
# Builds a minimal CPU-only image running the NeuroPose job-processing
|
||||
# daemon (`neuropose watch`). The GPU variant is planned but not
|
||||
# shipped in v0.1 — it will live in a separate `Dockerfile.gpu` based
|
||||
# on a CUDA-enabled TensorFlow base image.
|
||||
#
|
||||
# Build:
|
||||
# docker build -t neuropose:latest .
|
||||
#
|
||||
# Run (daemon mode, the default):
|
||||
# docker run -d \
|
||||
# -v /srv/neuropose/jobs:/data/jobs \
|
||||
# -v /srv/neuropose/models:/data/models \
|
||||
# --name neuropose \
|
||||
# neuropose:latest
|
||||
#
|
||||
# Run (single-video mode, overriding the default command):
|
||||
# docker run --rm \
|
||||
# -v /srv/neuropose/jobs:/data/jobs \
|
||||
# -v /srv/neuropose/models:/data/models \
|
||||
# -v $PWD/video.mp4:/input.mp4:ro \
|
||||
# neuropose:latest process /input.mp4 --output /data/jobs/result.json
|
||||
#
|
||||
# Notes:
|
||||
# - The Python base image (`python:3.11-slim-bookworm`) is deliberately
|
||||
# not pinned to a specific patch version in this commit. A sha256
|
||||
# digest pin is an easy follow-up once we are ready to commit to a
|
||||
# reproducible build chain.
|
||||
# - TensorFlow will be downloaded by pip during the build (~500 MB). The
|
||||
# final image is correspondingly large; optimisation is a future
|
||||
# concern.
|
||||
# - The MeTRAbs model itself is NOT baked into the image. It downloads
|
||||
# on first daemon startup into /data/models, which must be mounted
|
||||
# from the host to avoid repeating the download on every container
|
||||
# start.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
FROM python:3.11-slim-bookworm AS runtime
|
||||
|
||||
# System dependencies:
|
||||
# - ca-certificates: HTTPS for the MeTRAbs model download.
|
||||
# - ffmpeg: video I/O backend OpenCV calls into.
|
||||
# - libgl1, libglib2.0-0: transitive requirements of
|
||||
# opencv-python-headless on slim images.
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
ca-certificates \
|
||||
ffmpeg \
|
||||
libgl1 \
|
||||
libglib2.0-0 \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy package metadata first so that source-only edits do not bust
|
||||
# the pip cache layer on rebuilds.
|
||||
COPY pyproject.toml README.md LICENSE ./
|
||||
COPY src/ ./src/
|
||||
|
||||
# Install the package system-wide. The `[analysis]` extra pulls in
|
||||
# fastdtw / scipy / scikit-learn / sktime so `neuropose analyze`
|
||||
# works out of the box inside the container.
|
||||
RUN pip install --no-cache-dir ".[analysis]"
|
||||
|
||||
# Non-root runtime user. Using UID 1000 to line up with most host-side
|
||||
# /data directories; override with `docker run --user` if needed.
|
||||
RUN useradd --create-home --uid 1000 neuropose \
|
||||
&& mkdir -p /data/jobs /data/models \
|
||||
&& chown -R neuropose:neuropose /data
|
||||
|
||||
# Point NeuroPose's Settings at the mounted /data volume. These are
|
||||
# read by pydantic-settings via the NEUROPOSE_ prefix and override the
|
||||
# XDG defaults that would otherwise resolve inside the container's
|
||||
# ephemeral filesystem.
|
||||
ENV NEUROPOSE_DATA_DIR=/data/jobs \
|
||||
NEUROPOSE_MODEL_CACHE_DIR=/data/models \
|
||||
PYTHONUNBUFFERED=1
|
||||
|
||||
USER neuropose
|
||||
VOLUME ["/data"]
|
||||
|
||||
# The entrypoint is the `neuropose` CLI; the default command is
|
||||
# `watch` (daemon mode). Override at `docker run` time to invoke
|
||||
# `process` or `analyze` instead.
|
||||
ENTRYPOINT ["neuropose"]
|
||||
CMD ["watch"]
|
||||
130
README.md
130
README.md
|
|
@ -1,71 +1,119 @@
|
|||
# NeuroPose (rewrite)
|
||||
|
||||
Ground-up rewrite of the prior NeuroPose internal prototype.
|
||||
Ground-up rewrite of the prior NeuroPose internal prototype. The repository
|
||||
is private while the IRB data-handling policy is being authored; this README
|
||||
is aimed at contributors working on the rewrite, not external users.
|
||||
|
||||
## Target layout
|
||||
## Layout
|
||||
|
||||
```text
|
||||
neuropose/
|
||||
├── .github/workflows/ # CI on the GitHub mirror (ruff + pyright + pytest)
|
||||
├── .github/workflows/ # CI: ruff + pyright + pytest (ci.yml), mkdocs (docs.yml)
|
||||
├── src/neuropose/
|
||||
│ ├── __init__.py # version only
|
||||
│ ├── config.py # pydantic-settings Settings class
|
||||
│ ├── io.py # prediction schema, load/save helpers
|
||||
│ ├── estimator.py # per-video MeTRAbs worker (ported, cleaned)
|
||||
│ ├── interfacer.py # filesystem-polling daemon (ported, cleaned)
|
||||
│ ├── cli.py # typer app (`neuropose run|watch|analyze`)
|
||||
│ ├── _model.py # MeTRAbs model load + local cache
|
||||
│ └── analyzer/ # rewrite of the prior analyzer.py
|
||||
│ ├── estimator.py # per-video MeTRAbs worker
|
||||
│ ├── interfacer.py # filesystem-polling daemon
|
||||
│ ├── visualize.py # per-frame 2D + 3D overlay rendering
|
||||
│ ├── cli.py # typer app (watch | process | analyze)
|
||||
│ ├── _model.py # MeTRAbs download + SHA-256 verify + load
|
||||
│ └── analyzer/ # post-processing subpackage
|
||||
│ ├── dtw.py # FastDTW helpers
|
||||
│ ├── features.py # normalization, padding, joint angles
|
||||
│ └── classification.py # sktime classifier wrappers
|
||||
│ └── features.py # normalization, padding, joint angles, stats
|
||||
├── tests/
|
||||
│ ├── conftest.py # env isolation, synthetic video, fake model
|
||||
│ ├── unit/ # fast, no model download
|
||||
│ ├── integration/ # marked @slow, downloads the model
|
||||
│ └── fixtures/ # synthetic video + reference predictions
|
||||
├── docs/ # mkdocs-material site
|
||||
├── notebooks/ # getting_started.ipynb, tested in CI via nbval
|
||||
├── config/default.yaml # example runtime config
|
||||
├── scripts/download_model.py
|
||||
├── pyproject.toml # hatchling build, typer + pydantic + TF stack
|
||||
├── Dockerfile # CPU, pinned deps
|
||||
│ └── integration/ # marked @slow, downloads the real MeTRAbs model
|
||||
├── docs/ # mkdocs-material site (mkdocs.yml at repo root)
|
||||
├── scripts/download_model.py # pre-warm the model cache
|
||||
├── pyproject.toml # hatchling build, dev group (PEP 735)
|
||||
├── Dockerfile # CPU image, non-root, /data volume
|
||||
├── CHANGELOG.md # Keep a Changelog format
|
||||
├── RESEARCH.md # DTW methodology + MeTRAbs self-hosting R&D log
|
||||
├── AUTHORS.md
|
||||
├── CITATION.cff
|
||||
└── LICENSE # MIT
|
||||
```
|
||||
|
||||
## Architecture
|
||||
|
||||
Three stages, one module each:
|
||||
|
||||
- **`estimator`** — per-video worker. Extracts frames from an input video,
|
||||
runs MeTRAbs on each frame, and writes per-frame predictions (`boxes`,
|
||||
`poses3d`, `poses2d`) to JSON. No daemon logic; usable directly from Python.
|
||||
- **`estimator`** — per-video worker. Streams frames from an input video via
|
||||
OpenCV, runs MeTRAbs on each frame, and returns a validated
|
||||
`VideoPredictions` (per-frame `boxes`, `poses3d`, `poses2d` plus a
|
||||
`VideoMetadata` envelope with frame count, fps, and resolution). Pure
|
||||
library — no filesystem semantics.
|
||||
- **`interfacer`** — filesystem-polling daemon. Watches the configured input
|
||||
directory for new job subdirectories, dispatches each to an `Estimator`,
|
||||
and persists job state (`status.json`) across crashes and restarts. Owns
|
||||
the input → output → failed directory lifecycle.
|
||||
and persists job state (`status.json`) across crashes and restarts. Single
|
||||
instance enforced via `fcntl.flock`. Owns the input → output → failed
|
||||
directory lifecycle.
|
||||
- **`analyzer`** — post-processing subpackage. FastDTW-based motion
|
||||
comparison, joint-angle feature extraction, and sktime classifier wrappers.
|
||||
Operates on the JSON output of the estimator.
|
||||
comparison (`dtw_all`, `dtw_per_joint`, `dtw_relation`) and joint-angle /
|
||||
feature-statistics helpers. Pure functions operating on `VideoPredictions`.
|
||||
Heavy dependencies (fastdtw, scipy) are lazy-imported so
|
||||
`import neuropose.analyzer` works without the `analysis` extra.
|
||||
|
||||
Configuration is centralized in `src/neuropose/config.py` (a
|
||||
pydantic-settings `Settings` class). The runtime data directory defaults to
|
||||
`$XDG_DATA_HOME/neuropose/jobs/` and never lives inside the repository.
|
||||
|
||||
## Commit plan
|
||||
## Development setup
|
||||
|
||||
| # | Scope | State |
|
||||
|---|---|---|
|
||||
| 1 | Scaffolding: package layout, MIT license, authors, citation, policy-aware `.gitignore` | in review |
|
||||
| 2 | Dev tooling: pre-commit, ruff, pyright, gitleaks | planned |
|
||||
| 3 | CI workflow on the GitHub mirror | planned |
|
||||
| 4 | `config.py`, `io.py`, unit tests | planned |
|
||||
| 5 | Port `estimator.py` with typing and audit §6 fixes | planned |
|
||||
| 6 | Port `interfacer.py` with audit §7 fixes | planned |
|
||||
| 7 | Typer CLI (`neuropose run|watch|analyze`) | planned |
|
||||
| 8 | mkdocs-material docs site | planned |
|
||||
| 9 | Data-handling policy (gates going public) | blocked on IRB prose |
|
||||
| 10 | `analyzer/` subpackage rewrite | planned |
|
||||
| 11 | MeTRAbs model loader + integration smoke test | blocked on upstream URL + TF pin |
|
||||
| 12 | Dockerfile | blocked on 11 |
|
||||
| 13 | Comprehensive `CHANGELOG.md` retroactive entry | blocked on 12 |
|
||||
Requires Python 3.11 and [`uv`](https://github.com/astral-sh/uv).
|
||||
|
||||
```bash
|
||||
git clone https://git.levineuwirth.org/neuwirth/neuropose.git
|
||||
cd neuropose
|
||||
uv sync --group dev
|
||||
```
|
||||
|
||||
`uv sync --group dev` creates `.venv/` automatically and installs the
|
||||
runtime stack (pydantic, typer, OpenCV, TensorFlow, matplotlib) plus the
|
||||
full dev toolchain (pytest, ruff, pyright, pre-commit, mkdocs-material,
|
||||
fastdtw, scipy). First sync downloads ~600 MB of TensorFlow; subsequent
|
||||
runs hit the uv cache.
|
||||
|
||||
Install the pre-commit hooks:
|
||||
|
||||
```bash
|
||||
uv run pre-commit install
|
||||
```
|
||||
|
||||
### Running tests
|
||||
|
||||
```bash
|
||||
uv run pytest # unit tests only (default)
|
||||
uv run pytest --runslow # unit + integration; downloads ~2 GB MeTRAbs model
|
||||
uv run pytest -m "not slow" # explicitly exclude slow tests
|
||||
```
|
||||
|
||||
Integration tests live under `tests/integration/` and are gated behind
|
||||
`@pytest.mark.slow` plus a custom `--runslow` flag implemented in
|
||||
`tests/conftest.py`. Without the flag, slow tests are skipped at collection
|
||||
time. The first `--runslow` run downloads the pinned MeTRAbs tarball
|
||||
(~2 GB) into a session-scoped temp cache; subsequent tests in the same run
|
||||
reuse it.
|
||||
|
||||
### Lint and type-check
|
||||
|
||||
```bash
|
||||
uv run ruff check .
|
||||
uv run ruff format .
|
||||
uv run pyright
|
||||
```
|
||||
|
||||
CI runs lint, typecheck, and test as three parallel jobs on every push and
|
||||
PR to `main` — see `.github/workflows/ci.yml`.
|
||||
|
||||
### Docs
|
||||
|
||||
```bash
|
||||
uv run mkdocs serve # live-reload preview at http://127.0.0.1:8000
|
||||
uv run mkdocs build --strict # same check CI runs
|
||||
```
|
||||
|
||||
The API reference pages under `docs/api/` are auto-generated from source
|
||||
docstrings via mkdocstrings, so they cannot drift out of sync.
|
||||
|
|
|
|||
|
|
@ -0,0 +1,762 @@
|
|||
# NeuroPose Research and Ideation Notes
|
||||
|
||||
A living R&D log for open design questions, speculative directions, and
|
||||
planned experiments that are larger in scope than individual commits.
|
||||
This is **not** user-facing documentation — items in here are
|
||||
*candidates* for future work, and inclusion does not imply commitment.
|
||||
|
||||
## How to use this document
|
||||
|
||||
- Add a section when you start thinking about a new area of investigation.
|
||||
- Each section should end with an **Open questions** or **Next steps**
|
||||
block so it's obvious to a future you (or a new contributor) what the
|
||||
active threads are.
|
||||
- When something in here is decided and implemented, move it to the
|
||||
relevant place in `docs/` or in the code itself and leave a short
|
||||
pointer behind ("*See `docs/architecture.md` for the resolved design.*").
|
||||
- Consider the audience: yourself, Dr. Shu, David, Praneeth, and future
|
||||
contributors. Assume they know pose estimation at a grad-student level
|
||||
but may not have followed every prior conversation.
|
||||
|
||||
## Contents
|
||||
|
||||
- [DTW methodology](#dtw-methodology)
|
||||
- [TensorFlow version compatibility](#tensorflow-version-compatibility)
|
||||
- [MeTRAbs hosting and extensibility](#metrabs-hosting-and-extensibility)
|
||||
|
||||
---
|
||||
|
||||
## DTW methodology
|
||||
|
||||
### Current implementation (v0.1, commit 10)
|
||||
|
||||
`neuropose.analyzer.dtw` ships three entry points, all built on top of
|
||||
[`fastdtw`](https://github.com/slaypni/fastdtw) with
|
||||
`scipy.spatial.distance.euclidean` as the point-distance function:
|
||||
|
||||
- **`dtw_all(a, b)`** — single DTW on flattened `(frames, joints × 3)`
|
||||
vectors. One scalar distance for the whole sequence.
|
||||
- **`dtw_per_joint(a, b)`** — one DTW call per joint, returning a list
|
||||
of per-joint distances and warping paths. Preserves per-joint
|
||||
temporal alignment at J× the cost.
|
||||
- **`dtw_relation(a, b, joint_i, joint_j)`** — DTW on the per-frame
|
||||
displacement vector between two specific joints. The intent here is
|
||||
to capture "how does the relationship between these two joints change
|
||||
over time", which is translation-invariant and so immune to raw
|
||||
camera-frame changes.
|
||||
|
||||
These three correspond directly to the three helpers that existed
|
||||
(broken) in the previous prototype's `analyzer.py`, ported forward with
|
||||
bug fixes, types, and tests. **The port was mechanical — not a
|
||||
methodological choice.** We inherited the FastDTW + Euclidean defaults
|
||||
without validating them against the clinical research use cases, and
|
||||
that validation is overdue.
|
||||
|
||||
### Known limitations of the v0.1 approach
|
||||
|
||||
#### FastDTW is an approximation, not exact DTW
|
||||
|
||||
[FastDTW](https://cs.fit.edu/~pkc/papers/tdm04.pdf) is a multi-scale
|
||||
approximation that runs in linear time by recursively refining a coarse
|
||||
alignment. For the radius-based implementation in
|
||||
`slaypni/fastdtw`, the distance is not guaranteed to match exact DTW,
|
||||
and in pathological cases the error can be significant. For a research
|
||||
codebase where the DTW distance is going to show up in a figure, that
|
||||
matters.
|
||||
|
||||
**Candidate exact alternatives** (all pip-installable):
|
||||
|
||||
- [`dtaidistance`](https://github.com/wannesm/dtaidistance) — C-based,
|
||||
supports both exact DTW and a `fast=True` approximation; also
|
||||
supports shape-DTW and various constraint bands. Actively maintained,
|
||||
and the underlying algorithms match the textbook.
|
||||
- [`tslearn`](https://tslearn.readthedocs.io/) — ML-flavored toolkit
|
||||
with exact DTW, soft-DTW (differentiable), Sakoe-Chiba banding, and
|
||||
kernel-DTW. Good fit if we ever want to feed DTW distances into an
|
||||
sklearn/PyTorch pipeline.
|
||||
- [`cdtw`](https://github.com/statefb/dtw-python) / `dtw-python` —
|
||||
Python port of the R `dtw` package; exhaustive options for windowing,
|
||||
step patterns, and open-ended alignment. Less friendly API but the
|
||||
most rigorously documented.
|
||||
|
||||
#### Euclidean is a choice, not a default
|
||||
|
||||
Treating `(x, y, z)` joint positions as a point in R³ and taking
|
||||
Euclidean distances implicitly assumes the three axes are commensurable
|
||||
in the same units, which is fine for MeTRAbs (mm) but throws away prior
|
||||
knowledge about human motion. Alternatives worth considering:
|
||||
|
||||
- **Angular distance on joint angles.** Compute joint angles per frame
|
||||
(`extract_joint_angles` already exists) and run DTW on the angle
|
||||
sequences rather than raw coordinates. Translation- and
|
||||
scale-invariant by construction; well-matched to clinical metrics
|
||||
like knee flexion angle.
|
||||
- **Geodesic distance on SO(3)** for local joint rotations. Requires a
|
||||
skeleton-rooted rotation parameterization; more work to set up but
|
||||
the right metric for "how different are these two poses?" in a
|
||||
biomechanics sense.
|
||||
- **Mahalanobis distance** against a learned pose prior. This is the
|
||||
"machine learning" answer — fit a covariance to a reference corpus
|
||||
(normal gait from a healthy cohort), then measure distances in the
|
||||
whitened space. Requires enough data to fit the prior without
|
||||
overfitting, but makes "is this gait abnormal?" a calibrated question.
|
||||
|
||||
#### Preprocessing: what invariance do we want?
|
||||
|
||||
The v0.1 implementation is not invariant to anything. Two videos of the
|
||||
same subject with a different camera position will give a different
|
||||
DTW distance, which is almost certainly not what a clinician wants.
|
||||
Candidate preprocessing steps:
|
||||
|
||||
- **Translation invariance**: subtract the root joint (pelvis or torso
|
||||
centroid) from every joint per frame, so all poses are expressed in a
|
||||
body-relative coordinate frame. Cheap and almost always desired.
|
||||
- **Scale invariance**: divide by a reference length (e.g., torso
|
||||
length, or total skeleton span) so tall and short subjects produce
|
||||
comparable distances. Important for comparing across subjects.
|
||||
- **Rotation invariance**: align to a canonical frame (e.g., hip-to-hip
|
||||
vector = x-axis, hip-to-shoulder = z-axis) per frame. Required if the
|
||||
subject's orientation relative to the camera varies between trials.
|
||||
- **Procrustes alignment per frame**: fit the best rigid transform
|
||||
(rotation + translation) between pose A's frame and pose B's frame
|
||||
before computing distance. The closed-form
|
||||
[Kabsch algorithm](https://en.wikipedia.org/wiki/Kabsch_algorithm) is
|
||||
fast and exact. This is likely the *right* thing for most comparison
|
||||
use cases but has never been wired up.
|
||||
|
||||
The `dtw_relation` helper is translation- and (for unit-vector
|
||||
displacements) scale-invariant by construction, which is why it ends up
|
||||
being the most useful of the three existing entry points in practice.
|
||||
|
||||
#### Representation: coordinates, angles, velocities, or dual?
|
||||
|
||||
The v0.1 DTW operates on **3D joint coordinates** (translation-dependent)
|
||||
or **joint-pair displacements** (`dtw_relation`). Other representations
|
||||
worth comparing:
|
||||
|
||||
- **Joint angles.** Using `extract_joint_angles` output as the DTW
|
||||
input gives a rotation-and-translation-invariant comparison that's
|
||||
also directly interpretable in clinical terms.
|
||||
- **Joint velocities.** Temporal derivatives of position. Emphasizes
|
||||
*how the pose changes* rather than *what it is* — good for
|
||||
discriminating smooth from jerky motion in gait.
|
||||
- **Dual (position + angle).** Concatenate normalized position and
|
||||
angle features into a single per-frame vector. More expressive but
|
||||
requires tuning the relative weights.
|
||||
- **Learned embeddings.** Feed each frame through a pretrained
|
||||
pose-representation network (there are a few) and DTW on the
|
||||
embedding space. Expensive and opaque but may capture
|
||||
higher-order structure.
|
||||
|
||||
#### Multi-scale approaches
|
||||
|
||||
FastDTW is already multi-scale internally. Other ideas:
|
||||
|
||||
- **Coarse-to-fine DTW.** Downsample aggressively, run exact DTW on
|
||||
the coarse version to get a sub-quadratic alignment, then refine
|
||||
locally. This is essentially what FastDTW does, but with an explicit
|
||||
signal-processing hat on.
|
||||
- **Wavelet-decomposed DTW.** Decompose each joint's trajectory into
|
||||
wavelet coefficients and run DTW on the low-frequency coefficients.
|
||||
Unclear whether this actually helps; interesting because it separates
|
||||
posture (low-frequency) from tremor / micro-motion (high-frequency).
|
||||
|
||||
#### Clinical gait: cycle-aware DTW
|
||||
|
||||
Gait is approximately periodic, and "the 4th heel-strike of trial A"
|
||||
is the clinically meaningful comparison point to "the 4th heel-strike
|
||||
of trial B", not "frame 120 of A vs frame 120 of B". A natural two-stage
|
||||
approach:
|
||||
|
||||
1. **Cycle detection.** Find heel-strikes (or other gait events) via
|
||||
peak detection on a joint's vertical coordinate, and segment each
|
||||
trial into individual cycles.
|
||||
2. **Per-cycle DTW.** Time-warp within each cycle independently to
|
||||
normalize cycle duration. The distance between trials is then the
|
||||
sum / mean of per-cycle distances.
|
||||
|
||||
This is standard in the biomechanics literature
|
||||
([Sadeghi et al. 2000](https://doi.org/10.1016/S0966-6362(00)00074-3)
|
||||
and descendants) and is almost certainly a better fit for clinical
|
||||
comparison than the naive full-trial DTW we ship today.
|
||||
|
||||
#### Soft-DTW for learning applications
|
||||
|
||||
[Soft-DTW](https://arxiv.org/abs/1703.01541) is a differentiable
|
||||
relaxation of DTW, which means gradients can flow through it. This
|
||||
matters if we ever want to train a network to *learn* a distance
|
||||
metric or an embedding under a DTW objective — for example, a pose
|
||||
encoder whose output space is calibrated to gait similarity. Worth
|
||||
keeping on the radar even if we're not training anything today.
|
||||
`tslearn` implements it.
|
||||
|
||||
### Evaluation strategy
|
||||
|
||||
Validating a DTW implementation is harder than validating most things.
|
||||
Some ideas for how to know we got it right:
|
||||
|
||||
- **Synthetic perturbations.** Take a reference sequence and apply
|
||||
known perturbations (time stretch, added noise, spatial offset) and
|
||||
verify that distance scales monotonically with perturbation magnitude
|
||||
and that invariance properties are honored.
|
||||
- **Reference implementation parity.** For a small set of hand-picked
|
||||
pairs, compute DTW distance using `dtaidistance` exact DTW and
|
||||
our implementation, and verify the approximation error is below a
|
||||
documented threshold.
|
||||
- **Inter-rater clinical benchmark.** When we have labeled clinical
|
||||
data, measure how well DTW distance correlates with clinician
|
||||
ratings of gait similarity. This is the real test but is gated on
|
||||
having data we can use.
|
||||
- **Pathology discrimination.** Can DTW distance separate healthy
|
||||
from impaired gait in a held-out set? This is the usefulness test.
|
||||
|
||||
### Open questions
|
||||
|
||||
1. Is FastDTW good enough, or should we move to `dtaidistance` exact
|
||||
DTW as the default? (First concrete experiment: pick 20 pairs from
|
||||
whatever reference data we can source, compute distance both ways,
|
||||
see if the approximation error is acceptable.)
|
||||
2. What's the right representation for clinical gait DTW — raw
|
||||
coordinates, joint angles, or per-pair displacements?
|
||||
3. Should we implement Procrustes alignment as a preprocessing step
|
||||
before any DTW call? (If yes, it belongs in `neuropose.analyzer.features`.)
|
||||
4. Should the clinical pipeline use cycle-segmented DTW instead of
|
||||
full-trial DTW? This is a methodological choice with real
|
||||
downstream implications.
|
||||
5. Is soft-DTW useful to us, or is it a solution looking for a
|
||||
problem we don't have?
|
||||
6. What reference corpus do we use to develop and validate any of this?
|
||||
|
||||
### Reading list
|
||||
|
||||
- Sakoe, H. & Chiba, S. (1978). "Dynamic programming algorithm
|
||||
optimization for spoken word recognition." The original DTW paper.
|
||||
- Salvador, S. & Chan, P. (2007). "Toward accurate dynamic time
|
||||
warping in linear time and space."
|
||||
[PDF](https://cs.fit.edu/~pkc/papers/tdm04.pdf). The FastDTW paper.
|
||||
- Cuturi, M. & Blondel, M. (2017). "Soft-DTW: a Differentiable Loss
|
||||
Function for Time-Series." [arXiv 1703.01541](https://arxiv.org/abs/1703.01541).
|
||||
- Sadeghi, H. et al. (2000). "Symmetry and limb dominance in able-bodied
|
||||
gait: a review." Biomechanics reference for cycle-aware analysis.
|
||||
- `dtaidistance` documentation —
|
||||
<https://dtaidistance.readthedocs.io/>. Worth reading even if we
|
||||
don't switch, for the overview of DTW variants and constraints.
|
||||
|
||||
### Next steps
|
||||
|
||||
- [ ] Pick 10–20 reference pose-sequence pairs and run both FastDTW and
|
||||
exact DTW on them to quantify the approximation error.
|
||||
- [ ] Prototype a Procrustes-aligned preprocessing wrapper and
|
||||
re-run the same pairs.
|
||||
- [ ] Sketch a cycle-aware DTW pipeline against a gait dataset we can
|
||||
actually use (identity- and IRB-safe).
|
||||
- [ ] Decide whether to keep FastDTW as the default or replace it.
|
||||
- [ ] If we replace it: migrate `neuropose.analyzer.dtw` to the new
|
||||
backend in a single commit with no API change.
|
||||
|
||||
---
|
||||
|
||||
## TensorFlow version compatibility
|
||||
|
||||
### The question
|
||||
|
||||
The pinned MeTRAbs model artifact
|
||||
(`metrabs_eff2l_y4_384px_800k_28ds.tar.gz`) is a TensorFlow SavedModel.
|
||||
SavedModels embed a producer TF version and depend on a set of TF op
|
||||
kernels. Picking a TF version pin that is too low risks Apple Silicon
|
||||
install pain (pre-2.16 has no native `darwin/arm64` wheel under the
|
||||
`tensorflow` package name); picking one that is too high risks loading
|
||||
or runtime failures if MeTRAbs uses ops that have been renamed,
|
||||
deprecated, or removed. The goal of this investigation was to find the
|
||||
**minimum** pin that works on Linux x86_64, Linux arm64, and macOS arm64
|
||||
without forcing platform-conditional dependencies or shipping
|
||||
`tensorflow-metal` as a default.
|
||||
|
||||
### Method
|
||||
|
||||
Phase 0 of the procedure laid out earlier in this document was to
|
||||
inspect the SavedModel directly and run `detect_poses` end-to-end on a
|
||||
synthetic input. The probe script (`test.py` at the repo root, kept
|
||||
during the investigation and removed in the same commit that landed the
|
||||
pin) did three things:
|
||||
|
||||
1. Parsed `saved_model.pb` with `saved_model_pb2.SavedModel` and read
|
||||
the `tensorflow_version` and `tensorflow_git_version` fields out of
|
||||
each `meta_info_def` to establish the **producer** version.
|
||||
2. Walked every `node.op` and `library.function[*].node_def[*].op` in
|
||||
the graph to enumerate the **complete set of ops** the model relies
|
||||
on. This is the binary-compatibility surface — anything in this set
|
||||
that gets removed in a future TF release breaks the model.
|
||||
3. Called `tf.saved_model.load(MODEL_DIR)`, accessed
|
||||
`per_skeleton_joint_names["berkeley_mhad_43"]`, and invoked
|
||||
`model.detect_poses(image, intrinsic_matrix=..., skeleton="berkeley_mhad_43")`
|
||||
on a 288×384 black frame to confirm the consumer TF version actually
|
||||
*runs* the model (not just loads it — these are different failure
|
||||
modes).
|
||||
|
||||
The probe ran on Linux x86_64 against whatever `uv sync --group dev`
|
||||
resolved at the time, which was **TensorFlow 2.21.0** with **Keras
|
||||
3.14.0** — i.e. the most recent TF release as of 2026-04 and a version
|
||||
that crosses the Keras-3 cutover at TF 2.16.
|
||||
|
||||
### Result
|
||||
|
||||
- **Producer version:** `tf version: 2.10.0`,
|
||||
`producer: v2.10.0-0-g359c3cdfc5f`. The model was serialized in
|
||||
September 2022, consistent with the file mtimes in the extracted
|
||||
tarball.
|
||||
- **Custom ops:** **zero**. `tf.raw_ops.__dict__` filtered for
|
||||
`"metrabs"` returned `[]`. Every op in the SavedModel is a stock
|
||||
TensorFlow kernel that has been stable since at least TF 2.4.
|
||||
- **Op inventory** (recorded for posterity so a future contributor can
|
||||
diff against a newer MeTRAbs release without re-running the probe):
|
||||
|
||||
```
|
||||
Abs, Add, AddV2, All, Any, Assert, AssignVariableOp, AvgPool,
|
||||
BatchMatMulV2, BiasAdd, Bitcast, BroadcastArgs, BroadcastTo, Cast,
|
||||
Ceil, Cholesky, CombinedNonMaxSuppression, ConcatV2, Const, Conv2D,
|
||||
Cos, Cross, Cumsum, DepthwiseConv2dNative, Einsum, EnsureShape, Equal,
|
||||
Exp, ExpandDims, Fill, Floor, FloorDiv, FloorMod, FusedBatchNormV3,
|
||||
GatherV2, Greater, GreaterEqual, Identity, IdentityN, If,
|
||||
ImageProjectiveTransformV3, LeakyRelu, Less, LessEqual, Log,
|
||||
LogicalAnd, LogicalNot, LogicalOr, LookupTableExportV2,
|
||||
LookupTableFindV2, LookupTableImportV2, MatMul, MatrixDiagV3,
|
||||
MatrixInverse, MatrixSolveLs, MatrixTriangularSolve, Max, MaxPool,
|
||||
Maximum, Mean, MergeV2Checkpoints, Min, Minimum, Mul,
|
||||
MutableDenseHashTableV2, Neg, NoOp, NonMaxSuppressionWithOverlaps,
|
||||
NotEqual, Pack, Pad, PadV2, PartitionedCall, Placeholder, Pow, Prod,
|
||||
RaggedRange, RaggedTensorFromVariant, RaggedTensorToTensor,
|
||||
RaggedTensorToVariant, Range, Rank, ReadVariableOp, RealDiv, Relu,
|
||||
Reshape, ResizeArea, ResizeBilinear, RestoreV2, ReverseV2,
|
||||
RngReadAndSkip, SaveV2, Select, SelectV2, Shape, ShardedFilename,
|
||||
Sigmoid, Sin, Size, Slice, Softplus, Split, SplitV, Sqrt, Square,
|
||||
Squeeze, StatefulPartitionedCall, StatelessIf,
|
||||
StatelessRandomUniformV2, StatelessWhile, StaticRegexFullMatch,
|
||||
StridedSlice, StringJoin, Sub, Sum, Tan, Tanh, TensorListConcatV2,
|
||||
TensorListFromTensor, TensorListGetItem, TensorListReserve,
|
||||
TensorListSetItem, TensorListStack, Tile, TopKV2, Transpose, Unpack,
|
||||
VarHandleOp, Where, While, ZerosLike
|
||||
```
|
||||
|
||||
- **Load:** `tf.saved_model.load` returned a `_UserObject` with
|
||||
`detect_poses` exposed. No warnings about deprecated kernels, no
|
||||
errors. The 11-minor-version forward jump from producer 2.10 to
|
||||
consumer 2.21 was a non-event, including the Keras 3 cutover at 2.16.
|
||||
- **Skeleton check:** `per_skeleton_joint_names["berkeley_mhad_43"]` had
|
||||
shape `(43,)` and `per_skeleton_joint_edges["berkeley_mhad_43"]` had
|
||||
shape `(42, 2)`, exactly matching what
|
||||
`tests/integration/test_estimator_smoke.py` asserts.
|
||||
- **End-to-end inference:** `model.detect_poses` on a black 288×384
|
||||
frame returned `{'poses3d': (0, 43, 3), 'boxes': (0, 5),
|
||||
'poses2d': (0, 43, 2)}`, all `float32`. Zero detections is the
|
||||
correct output for a black frame — the important signal is that the
|
||||
shapes, dtypes, and key names exactly match what `FramePrediction` in
|
||||
`neuropose.io` is built to ingest, so the entire estimator pipeline
|
||||
is wire-compatible with this TF version.
|
||||
|
||||
### Decision
|
||||
|
||||
Pin `tensorflow>=2.16,<3.0`. Reasoning:
|
||||
|
||||
1. **2.16 is the Apple Silicon floor that matters.** TF 2.16 is the
|
||||
first release with native `darwin/arm64` wheels published on PyPI
|
||||
under the `tensorflow` package name. Below 2.16, Mac users would
|
||||
need `tensorflow-macos` (a separate Apple-maintained package), which
|
||||
forces ugly platform markers in `pyproject.toml` and means Linux and
|
||||
Mac users run subtly different codebases. Above 2.16, the same
|
||||
single dependency line installs cleanly on every supported platform.
|
||||
2. **MeTRAbs imposes no upper bound below 3.0.** Producer 2.10 → consumer
|
||||
2.21 (an 11-minor-version jump across the Keras 3 boundary) loaded
|
||||
and ran without a single complaint. The op inventory is 100% stock,
|
||||
so future TF 2.x releases would only break this if they removed
|
||||
stable kernels — which would itself be a TF 2.x SemVer violation.
|
||||
3. **`tensorflow-metal` is an opt-in extra, not a default.**
|
||||
`tensorflow-metal` is a PluggableDevice that Apple ships separately
|
||||
to add a Metal-backed `/GPU:0`. It has its own version-compatibility
|
||||
table (Apple maintains it at
|
||||
`developer.apple.com/metal/tensorflow-plugin/`), has a documented
|
||||
history of producing silently-wrong numerics on specific TF ops,
|
||||
and breaks intermittently on Keras 3. For a clinical-research
|
||||
pipeline where reproducibility matters more than inference latency,
|
||||
CPU inference on Mac is the right default. We do ship a
|
||||
`[project.optional-dependencies].metal` extra that pulls
|
||||
`tensorflow-metal>=1.2,<2` under darwin/arm64 platform markers, so
|
||||
users who want the speedup can opt in via
|
||||
`pip install 'neuropose[metal]'` — but the Metal path is not
|
||||
exercised in CI, is documented as experimental in
|
||||
`docs/getting-started.md`, and users are expected to spot-check
|
||||
`poses3d` output against the CPU path before trusting it for any
|
||||
clinical measurement.
|
||||
|
||||
### What is **not** yet verified
|
||||
|
||||
- The probe ran on Linux x86_64 only. macOS arm64 has not been exercised
|
||||
on real hardware. The argument that it should work is by construction
|
||||
— `tensorflow==2.16+` ships native arm64 macOS wheels, the SavedModel
|
||||
uses zero custom ops, and there is no MeTRAbs-side platform code — but
|
||||
empirical confirmation is still pending.
|
||||
- Linux arm64 has likewise not been exercised. Same by-construction
|
||||
argument applies.
|
||||
- A `macos-14` GitHub Actions matrix entry (which would run the unit
|
||||
tests on Apple Silicon hardware) is the cheapest way to catch any
|
||||
regression and is the intended follow-up.
|
||||
- Inference-output numerics have not been compared across platforms.
|
||||
This is the next layer of rigor below "does it run" — we expect
|
||||
fp32 results to match within ~1e-3 mm on `poses3d`, but a real
|
||||
cross-platform diff against a reference set has not been done.
|
||||
- The `[metal]` optional-dependencies extra exists in `pyproject.toml`
|
||||
but the Metal code path has never been exercised against the
|
||||
pinned MeTRAbs SavedModel. Enabling it is a pure opt-in and comes
|
||||
with a documented "verify your own numerics" caveat in
|
||||
`docs/getting-started.md`. Whether it actually produces a speedup
|
||||
on EfficientNetV2-L-based inference on real clinical videos —
|
||||
and whether that speedup is worth the numerical-divergence risk
|
||||
— is unknown.
|
||||
|
||||
### Open questions
|
||||
|
||||
1. Does the same `detect_poses` call produce numerically equivalent
|
||||
`poses3d` on macOS arm64 as on Linux x86_64 against a real (non-black)
|
||||
reference image? Within what tolerance?
|
||||
2. If a future MeTRAbs release introduces a custom op (e.g. for a new
|
||||
detector head), how do we want the loader to fail? Currently the
|
||||
`_REQUIRED_MODEL_ATTRS` interface check would still pass; the failure
|
||||
would surface at first `detect_poses` call, which is late.
|
||||
3. Does it make sense to upper-bound the pin more tightly than `<3.0`
|
||||
(e.g. `<2.22` to bound to tested versions), or is the SemVer guard
|
||||
sufficient given the all-stock-ops result?
|
||||
|
||||
### Next steps
|
||||
|
||||
- [ ] Run the same probe on real macOS arm64 hardware and log the
|
||||
result (load success, detect_poses success, output numerics
|
||||
diff against the Linux baseline).
|
||||
- [ ] Add a `macos-14` matrix entry to `.github/workflows/ci.yml` for
|
||||
the unit tests. Slow tests stay Linux-only to avoid doubling the
|
||||
MeTRAbs download cost in CI.
|
||||
- [ ] Re-run the probe whenever MeTRAbs upstream publishes a new model
|
||||
tarball, and diff the op inventory above. Any new op that is not
|
||||
in the list above is a flag worth investigating before raising
|
||||
the pin.
|
||||
- [ ] Benchmark `[metal]` vs CPU on a real Apple Silicon Mac against
|
||||
a short reference clip: measure (a) per-frame latency, (b) peak
|
||||
memory, and (c) `poses3d` divergence from the CPU baseline. If
|
||||
the speedup is meaningful and the numerics are within
|
||||
~1e-2 mm, move the `metal` extra from "experimental" to
|
||||
"supported" in the docs. If not, document the failure mode
|
||||
here and keep the extra where it is.
|
||||
|
||||
---
|
||||
|
||||
## MeTRAbs hosting and extensibility
|
||||
|
||||
### Current state (v0.1, commit 11)
|
||||
|
||||
The model loader in `neuropose._model.load_metrabs_model` will pin the
|
||||
canonical upstream URL:
|
||||
|
||||
```
|
||||
https://omnomnom.vision.rwth-aachen.de/data/metrabs/metrabs_eff2l_y4_384px_800k_28ds.tar.gz
|
||||
```
|
||||
|
||||
This is the RWTH Aachen "omnomnom" host — a raw HTTP file server run
|
||||
by the MeTRAbs authors' lab. There is no current HuggingFace mirror
|
||||
of the relevant MeTRAbs variant at the time of commit 11.
|
||||
|
||||
The URL encodes the model configuration:
|
||||
`metrabs_eff2l_y4_384px_800k_28ds` means the EfficientNetV2-L backbone,
|
||||
YOLOv4 detector head, 384-pixel input, 800k training steps, trained on
|
||||
28 datasets. This name pattern is worth preserving when we host the
|
||||
model ourselves so future variants stay self-describing.
|
||||
|
||||
### Supply-chain concerns
|
||||
|
||||
Pinning a single upstream URL to a third-party academic host is a
|
||||
real supply-chain risk, and the audit of the previous prototype called
|
||||
it out explicitly (the old code used `bit.ly/metrabs_1`, which was
|
||||
even worse). Concrete failure modes:
|
||||
|
||||
- The RWTH Aachen host goes down or is decommissioned.
|
||||
- The URL changes when Sárándi releases a new MeTRAbs version.
|
||||
- The tarball contents change under the same URL without a version bump.
|
||||
|
||||
**Minimum mitigation** (should land in or immediately after commit 11):
|
||||
|
||||
- **Pin a SHA-256 checksum** alongside the URL, and verify on download
|
||||
before unpacking. If the checksum doesn't match, fail hard with a
|
||||
clear error.
|
||||
- **Cache aggressively.** Once downloaded and verified, never hit the
|
||||
network again for the same configuration. `model_cache_dir` is
|
||||
already in `Settings`.
|
||||
- **Document the exact filename and checksum** in `RESEARCH.md` (or
|
||||
migrate to a `MODEL_ARTIFACTS.md` file) so operators have a way to
|
||||
manually download the model out-of-band if the primary URL is dead.
|
||||
|
||||
### Self-hosting options
|
||||
|
||||
We want to host the model ourselves, both for reliability and because
|
||||
it opens the door to future fine-tuning and redistribution of our own
|
||||
variants. Candidate hosting approaches:
|
||||
|
||||
#### Forgejo LFS
|
||||
|
||||
Pros:
|
||||
- Lives next to the code.
|
||||
- Version-controlled artifacts.
|
||||
- Access control mirrors repo access.
|
||||
|
||||
Cons:
|
||||
- LFS is designed for git-tracked binary assets, not for large
|
||||
infrequently-updated model weights — you pay LFS overhead on every
|
||||
clone unless you configure `lfs.fetchexclude`.
|
||||
- Model is ~2.2 GB; Forgejo LFS performance at that size is untested
|
||||
for our instance.
|
||||
- Pinning is by LFS pointer, which means the model is coupled to a
|
||||
particular repo revision. Messy if we want multiple code revisions
|
||||
to share the same model.
|
||||
|
||||
**Verdict:** Workable but not the best fit.
|
||||
|
||||
#### Forgejo generic package registry
|
||||
|
||||
Forgejo supports a [generic package
|
||||
registry](https://forgejo.org/docs/latest/user/packages/generic/) that
|
||||
can host arbitrary binary artifacts with versioned URLs. This is
|
||||
closer to what we want:
|
||||
|
||||
```
|
||||
https://git.levineuwirth.org/api/packages/neuwirth/generic/metrabs/eff2l_y4_384px_800k_28ds/metrabs.tar.gz
|
||||
```
|
||||
|
||||
Pros:
|
||||
- Versioned URLs decoupled from repo revisions.
|
||||
- Upload once, download many times, no clone coupling.
|
||||
- Integrated auth if we want to gate access.
|
||||
- Can be made public even if the repo is private.
|
||||
|
||||
Cons:
|
||||
- Requires uploading the file manually or via an API call.
|
||||
- Forgejo registry size / bandwidth limits depend on the instance.
|
||||
|
||||
**Verdict:** Probably the best fit for "we want it hosted alongside
|
||||
the project."
|
||||
|
||||
#### Plain HTTP server on a VPS subdomain
|
||||
|
||||
A dedicated subdomain like `models.levineuwirth.org` backed by a
|
||||
simple HTTP file server (nginx `autoindex`, or Caddy with a tidy
|
||||
directory layout). Example URL:
|
||||
|
||||
```
|
||||
https://models.levineuwirth.org/metrabs/metrabs_eff2l_y4_384px_800k_28ds.tar.gz
|
||||
```
|
||||
|
||||
Pros:
|
||||
- Simplest possible story. No API, no auth machinery.
|
||||
- Easy to mirror from — anyone can curl the URL.
|
||||
- Decoupled from the git forge, so we can share models publicly even
|
||||
when the repo itself is private.
|
||||
- Easy to put a CDN in front (Cloudflare) if bandwidth ever matters.
|
||||
|
||||
Cons:
|
||||
- Manual upload via scp/rsync.
|
||||
- No access control unless we add it.
|
||||
- No versioning beyond filename convention.
|
||||
|
||||
**Verdict:** Strong candidate. This is probably the right choice for
|
||||
v0.1 of self-hosted models.
|
||||
|
||||
#### S3-compatible object storage (MinIO self-hosted)
|
||||
|
||||
Run MinIO on the VPS, get S3-compatible API for free, and serve models
|
||||
via pre-signed URLs or a public bucket.
|
||||
|
||||
Pros:
|
||||
- Proper object storage with ETags, range requests, multipart uploads.
|
||||
- Integration story is straightforward if we ever move to cloud-hosted
|
||||
storage.
|
||||
- Industry-standard API.
|
||||
|
||||
Cons:
|
||||
- More operational complexity than a plain HTTP server for what might
|
||||
be a handful of files.
|
||||
|
||||
**Verdict:** Overkill for v0.1 but worth revisiting if model storage
|
||||
becomes a real operational concern.
|
||||
|
||||
### Integrity: SHA-256 pinning
|
||||
|
||||
Regardless of which hosting approach we pick, **the model loader should
|
||||
always verify a SHA-256 checksum** before trusting the downloaded
|
||||
artifact. This is the one piece of supply-chain hygiene that has to be
|
||||
in place before we ship commit 11 to any user outside the Shu lab.
|
||||
|
||||
Implementation sketch for `neuropose/_model.py`:
|
||||
|
||||
```python
|
||||
def load_metrabs_model(cache_dir: Path | None = None) -> Any:
|
||||
cache_dir = cache_dir or _default_model_cache_dir()
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
tarball = cache_dir / _MODEL_FILENAME
|
||||
if not tarball.exists():
|
||||
_download(_MODEL_URL, tarball)
|
||||
_verify_sha256(tarball, _MODEL_SHA256)
|
||||
extracted = _extract_if_needed(tarball, cache_dir)
|
||||
return tfhub.load(str(extracted)) # or tf.saved_model.load
|
||||
```
|
||||
|
||||
The `_MODEL_SHA256` constant is the source of truth; if it ever has
|
||||
to change, the constant change is visible in the git diff and a human
|
||||
reviews it.
|
||||
|
||||
### Fine-tuning
|
||||
|
||||
The next research direction after we have inference working is
|
||||
fine-tuning MeTRAbs on clinical-specific data. Open questions:
|
||||
|
||||
- **What data?** Any clinical data is IRB-gated. Even de-identified
|
||||
pose data may carry subject information if the recording conditions
|
||||
(lighting, room layout) are distinctive enough. Any training plan
|
||||
has to run through the data-handling policy that lives (will live)
|
||||
in `docs/data-policy.md`.
|
||||
- **Transfer learning strategy.**
|
||||
- *Head-only fine-tuning*: freeze the EfficientNetV2-L backbone and
|
||||
re-train the pose regression head on clinical data. Fast, low
|
||||
compute, unlikely to overfit, but also unlikely to capture
|
||||
clinical-pose idiosyncrasies.
|
||||
- *Low-LR full fine-tune*: unfreeze everything, use a learning rate
|
||||
1/100th of the original, train for a few epochs. Better
|
||||
adaptation, higher risk of catastrophic forgetting.
|
||||
- *Adapter layers*: insert small trainable adapters into the frozen
|
||||
backbone. Parameter-efficient, well-studied in NLP, less common
|
||||
for pose but should work.
|
||||
- **Compute requirements.** EfficientNetV2-L is roughly 120M parameters;
|
||||
fine-tuning on a single modern GPU (24 GB VRAM) is feasible at
|
||||
reduced batch size. A multi-GPU node is friendlier but not strictly
|
||||
required.
|
||||
- **Evaluation.** We need held-out clinical data with trusted ground
|
||||
truth. MoCap-derived poses are the gold standard; marker-based MoCap
|
||||
systems provide sub-millimeter accuracy at the cost of subject
|
||||
instrumentation. The Shu lab's access to MoCap is the gating factor.
|
||||
- **Sharing fine-tuned weights.** If we fine-tune on clinical data, the
|
||||
resulting weights may encode subject information in ways that are
|
||||
non-obvious and potentially IRB-relevant. Sharing fine-tuned weights
|
||||
externally has to be cleared through the same channels as sharing the
|
||||
training data.
|
||||
|
||||
### Training our own pose estimator
|
||||
|
||||
The long-range version of the research direction: train a pose
|
||||
estimator from scratch that extends MeTRAbs's methodology. MeTRAbs is
|
||||
a good starting point because the method is well-documented:
|
||||
|
||||
- Sárándi, I., et al. (2020). "MeTRAbs: Metric-Scale Truncation-Robust
|
||||
Heatmaps for Absolute 3D Human Pose Estimation."
|
||||
[arXiv 2007.07227](https://arxiv.org/abs/2007.07227),
|
||||
IEEE Transactions on Biometrics, Behavior, and Identity Science.
|
||||
|
||||
Core contributions (worth knowing if you modify any of this):
|
||||
|
||||
- **Truncation-robust heatmaps.** Instead of predicting a 2D heatmap
|
||||
bounded by the image, MeTRAbs predicts a heatmap that extends
|
||||
*outside* the image and can place a joint at coordinates the image
|
||||
alone could not disambiguate. Critical for crops where the subject
|
||||
is partially out of frame.
|
||||
- **Metric scale regression.** MeTRAbs predicts the absolute 3D
|
||||
positions of joints in millimetres by combining a 2D heatmap with a
|
||||
per-joint depth regressor. Most 3D pose methods produce only
|
||||
relative coordinates, which are useless for clinical measurement.
|
||||
- **Multi-dataset training with a common skeleton.** The 28-dataset
|
||||
training set unifies disparate skeleton topologies into a common
|
||||
43-joint Berkeley MHAD skeleton, which we carry forward in
|
||||
NeuroPose.
|
||||
|
||||
**Natural extensions worth considering:**
|
||||
|
||||
- **Temporal smoothing head.** MeTRAbs is a per-frame model. Clinical
|
||||
gait analysis wants temporally smooth trajectories. Adding a
|
||||
lightweight temporal head (1D CNN or small transformer over frame
|
||||
sequences) could produce smoother outputs without touching the
|
||||
backbone.
|
||||
- **Clinical-specific heatmap supervision.** If we have MoCap data for
|
||||
clinical poses, we can use it as ground-truth heatmap supervision to
|
||||
improve accuracy in the pose ranges the model sees least often in
|
||||
the 28-dataset training corpus (e.g., pathological gaits, walker-
|
||||
assisted ambulation).
|
||||
- **Multi-person identity tracking.** MeTRAbs produces detections per
|
||||
frame without continuity across frames. Adding a Hungarian-matched
|
||||
tracker (or a learned tracker) would solve the multi-person
|
||||
identity problem that `predictions_to_numpy` currently dodges with
|
||||
a `person_index` parameter.
|
||||
- **Alternative backbones.** EfficientNetV2-L is a 2020-era choice.
|
||||
Newer backbones (ConvNeXt, DINOv2-initialized ViTs) may give
|
||||
meaningful gains, especially for clinical poses that are
|
||||
under-represented in the original training set.
|
||||
- **Uncertainty estimation.** Clinical users want to know when the
|
||||
model is unsure. A Gaussian output head (mean + variance per joint)
|
||||
or an ensemble-based approach would let us propagate uncertainty
|
||||
into downstream analysis.
|
||||
|
||||
**Compute requirements:** training MeTRAbs from scratch was reported
|
||||
as "a few weeks" on 8x V100 in the original paper. A from-scratch
|
||||
re-training is a substantial undertaking. Fine-tuning is much more
|
||||
accessible.
|
||||
|
||||
### Collaboration opportunities
|
||||
|
||||
- **István Sárándi** (now at University of Tübingen, formerly RWTH
|
||||
Aachen) is the author of MeTRAbs. The code is MIT-licensed and he
|
||||
has historically been responsive to collaboration requests. If we
|
||||
end up publishing work that significantly extends MeTRAbs, at the
|
||||
very least we should reach out about co-authorship or
|
||||
acknowledgment; at best we might find an active collaborator.
|
||||
- **The Shu Lab's existing collaborators** on clinical gait research
|
||||
at Brown and partner institutions may have MoCap-validated datasets
|
||||
we can use for fine-tuning and evaluation. Worth asking Dr. Shu.
|
||||
|
||||
### Open questions
|
||||
|
||||
1. Does Forgejo's generic package registry actually handle a 2.2 GB
|
||||
upload cleanly, or do we need the plain HTTP server route?
|
||||
2. What's the right SHA-256 pin to commit alongside the URL? (Need to
|
||||
download the tarball first and run `sha256sum`.)
|
||||
3. Do we have access to MoCap-validated clinical gait data for
|
||||
fine-tuning evaluation? This gates every training-related
|
||||
experiment.
|
||||
4. Is fine-tuning even worth pursuing before we have inference results
|
||||
that are clearly *not* good enough on clinical data? (I.e.,
|
||||
motivate the work with concrete failure cases rather than assuming
|
||||
a delta we haven't measured.)
|
||||
5. Does it make sense to reach out to Sárándi now, or wait until we
|
||||
have something concrete to collaborate on?
|
||||
|
||||
### Reading list
|
||||
|
||||
- Sárándi, I. et al. (2020). "MeTRAbs: Metric-Scale Truncation-Robust
|
||||
Heatmaps for Absolute 3D Human Pose Estimation."
|
||||
[arXiv 2007.07227](https://arxiv.org/abs/2007.07227). **Essential
|
||||
reading** for anyone planning to extend the method.
|
||||
- Sárándi's personal site and the MeTRAbs GitHub repo
|
||||
(<https://github.com/isarandi/metrabs>) — the code, model zoo, and
|
||||
training scripts live here.
|
||||
- Zheng, C. et al. (2023). "Deep Learning-Based Human Pose Estimation: A
|
||||
Survey." Good survey paper for orienting on the state of the art.
|
||||
- The original 28-dataset training composition referenced in the
|
||||
MeTRAbs paper — worth tracing through to understand what poses are
|
||||
in- and out-of-distribution for the pretrained model.
|
||||
|
||||
### Next steps
|
||||
|
||||
- [ ] Download the pinned tarball and compute its SHA-256 for the
|
||||
commit-11 model loader.
|
||||
- [ ] Decide between Forgejo generic registry and plain HTTP subdomain
|
||||
for self-hosting. Prototype whichever one wins.
|
||||
- [ ] Mirror the pinned tarball to the chosen self-hosted location so
|
||||
we can fall over to it the moment the RWTH URL changes or goes
|
||||
down.
|
||||
- [ ] Write a one-page "MODEL_ARTIFACTS.md" that documents every model
|
||||
version we use, its checksum, and its canonical source URL.
|
||||
- [ ] Have the data-access conversation with Dr. Shu about clinical
|
||||
training data. Everything else is blocked on this.
|
||||
- [ ] (Much later) Reach out to Sárándi about potential collaboration
|
||||
once we have something concrete to talk about.
|
||||
|
|
@ -2,25 +2,49 @@
|
|||
|
||||
This page walks through installing NeuroPose, running your first pose
|
||||
estimation, and understanding the output. It targets researchers who are
|
||||
comfortable on a Linux command line but may not have used the package
|
||||
comfortable on a Unix command line but may not have used the package
|
||||
before.
|
||||
|
||||
!!! info "Model loader status"
|
||||
The MeTRAbs model loader is pending the commit-11 rewrite, during
|
||||
which the upstream model URL and TensorFlow version will be pinned.
|
||||
Until it lands, the `neuropose watch` and `neuropose process`
|
||||
commands will exit with a clear "pending commit 11" message. The
|
||||
Python API still works if you inject a model manually — see the
|
||||
*Python API* section below for the current workaround.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Linux (Ubuntu 22.04+ or equivalent)
|
||||
- Linux (Ubuntu 22.04+ or equivalent) **or** macOS on Apple Silicon
|
||||
(M1 / M2 / M3 / M4). Both are first-class targets — the same `uv`
|
||||
install command works on either.
|
||||
- Python 3.11
|
||||
- [`uv`](https://github.com/astral-sh/uv) for dependency management
|
||||
- CUDA-capable GPU (optional, recommended for long videos)
|
||||
- Internet access on first run (for the model download, once the loader
|
||||
lands)
|
||||
- CUDA-capable GPU (optional, recommended for long videos on Linux)
|
||||
- Internet access on first run (for the ~2 GB MeTRAbs model download)
|
||||
|
||||
!!! note "Apple Silicon"
|
||||
NeuroPose pins `tensorflow>=2.16`, which is the first TensorFlow
|
||||
release with native `darwin/arm64` wheels on PyPI. Mac users get a
|
||||
working CPU install from the same command Linux users run — no
|
||||
`tensorflow-macos`, no platform markers, no extra configuration.
|
||||
|
||||
Metal GPU acceleration is available as an **opt-in extra** for
|
||||
users who need it:
|
||||
|
||||
```bash
|
||||
uv sync --group dev --extra metal
|
||||
# or, for non-editable installs:
|
||||
pip install 'neuropose[metal]'
|
||||
```
|
||||
|
||||
This installs `tensorflow-metal`, Apple's PluggableDevice for TF,
|
||||
which registers a Metal-backed `/GPU:0` device. It is **not
|
||||
enabled by default** for two reasons:
|
||||
|
||||
1. **Untested in this codebase.** The default install (CPU) is
|
||||
the path we verify in CI. The Metal path has not been
|
||||
exercised against the MeTRAbs SavedModel; users are on their
|
||||
own for validation.
|
||||
2. **Numerical caveats.** `tensorflow-metal` has a documented
|
||||
history of producing silently-divergent fp32 results on some
|
||||
TF ops, especially under Keras 3. For clinical research where
|
||||
reproducibility matters more than inference latency, CPU
|
||||
inference is the safer default. If you enable Metal, spot-check
|
||||
a few `poses3d` outputs against the CPU equivalent before
|
||||
trusting results for any downstream measurement.
|
||||
|
||||
## Installation
|
||||
|
||||
|
|
@ -170,20 +194,15 @@ back into a validated `VideoPredictions` object.
|
|||
## Python API
|
||||
|
||||
For scripting, debugging, or integrating NeuroPose into a larger
|
||||
pipeline, you can use the `Estimator` class directly. This is also the
|
||||
current workaround for the pending model loader:
|
||||
pipeline, you can use the `Estimator` class directly:
|
||||
|
||||
```python
|
||||
from neuropose._model import load_metrabs_model
|
||||
from neuropose.estimator import Estimator
|
||||
from neuropose.io import save_video_predictions
|
||||
from pathlib import Path
|
||||
|
||||
# Load the MeTRAbs model however you like — e.g. via tensorflow_hub once
|
||||
# you know the canonical URL. Until commit 11 pins it, you'll need to
|
||||
# load it yourself here.
|
||||
import tensorflow_hub as tfhub
|
||||
model = tfhub.load("...") # TODO: pin upstream URL
|
||||
|
||||
model = load_metrabs_model() # uses the XDG cache dir; downloads on first call
|
||||
estimator = Estimator(model=model, device="/GPU:0")
|
||||
result = estimator.process_video(Path("trial_01.mp4"))
|
||||
|
||||
|
|
@ -228,8 +247,8 @@ function, so importing `neuropose.visualize` has no global side effects.
|
|||
|
||||
| Problem | Resolution |
|
||||
|---|---|
|
||||
| `error: pending commit 11` from `neuropose watch` or `process` | The model loader is not yet implemented. Use the Python API with a manually-loaded model. |
|
||||
| `AlreadyRunningError` from the daemon | Another NeuroPose daemon already holds the lock file. Check `data_dir/.neuropose.lock` for the PID. |
|
||||
| `VideoDecodeError` on valid-looking video | The file may be corrupted or in a codec OpenCV was built without. Try re-encoding with `ffmpeg -i in.mov -c:v libx264 out.mp4`. |
|
||||
| Jobs stuck in `processing` state on startup | The daemon now recovers these automatically — they'll be marked failed and quarantined to `data_dir/failed/` on the next run. |
|
||||
| Daemon not detecting a new job | Check that the job is inside a **subdirectory** of `data_dir/in/`, not directly in `data_dir/in/`. Empty subdirectories are silently skipped (the daemon assumes you are still copying files). |
|
||||
| `SHA-256 mismatch` from the model loader | The MeTRAbs tarball download was truncated or the upstream artifact has changed. The loader retries once automatically; if it still fails, delete `model_cache_dir/metrabs_eff2l_y4_384px_800k_28ds.tar.gz` and let it re-download, or check `RESEARCH.md` for the canonical pin. |
|
||||
|
|
|
|||
|
|
@ -29,7 +29,9 @@ classifiers = [
|
|||
"Development Status :: 2 - Pre-Alpha",
|
||||
"Intended Audience :: Science/Research",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Operating System :: POSIX",
|
||||
"Operating System :: POSIX :: Linux",
|
||||
"Operating System :: MacOS",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Topic :: Scientific/Engineering :: Image Recognition",
|
||||
|
|
@ -37,10 +39,14 @@ classifiers = [
|
|||
"Typing :: Typed",
|
||||
]
|
||||
|
||||
# Runtime dependencies. Versions intentionally unpinned at this stage: the
|
||||
# TensorFlow / MeTRAbs stack is pending a sanity check against the updated
|
||||
# upstream MeTRAbs release. Tightened pins will land alongside the MeTRAbs
|
||||
# model loader and smoke test.
|
||||
# Runtime dependencies. The TensorFlow floor is set to 2.16 because that is
|
||||
# the first release with native ``darwin/arm64`` wheels published under the
|
||||
# ``tensorflow`` package name on PyPI. Older versions would force a marker
|
||||
# split to ``tensorflow-macos`` on Apple Silicon. The MeTRAbs SavedModel was
|
||||
# serialized with TF 2.10 and uses only stock ops (no custom kernels), so
|
||||
# any TF >= 2.10 loads it; 2.16 is chosen for cross-platform install
|
||||
# uniformity, not for any MeTRAbs-side requirement. End-to-end verification
|
||||
# against TF 2.21 is logged in RESEARCH.md.
|
||||
dependencies = [
|
||||
"typer>=0.12",
|
||||
"pydantic>=2.6",
|
||||
|
|
@ -49,7 +55,7 @@ dependencies = [
|
|||
"numpy>=1.26",
|
||||
"opencv-python-headless>=4.9",
|
||||
"matplotlib>=3.8",
|
||||
"tensorflow>=2.15,<3.0",
|
||||
"tensorflow>=2.16,<3.0",
|
||||
"tensorflow-hub>=0.16",
|
||||
]
|
||||
|
||||
|
|
@ -60,11 +66,23 @@ analysis = [
|
|||
"scikit-learn>=1.4",
|
||||
"sktime>=0.28",
|
||||
]
|
||||
# Optional Apple Silicon GPU acceleration via Metal Performance Shaders.
|
||||
# This is a PluggableDevice maintained by Apple, NOT by the TensorFlow team,
|
||||
# with its own version-compatibility table (see
|
||||
# https://developer.apple.com/metal/tensorflow-plugin/) and a documented
|
||||
# history of producing numerically-divergent results on some ops under
|
||||
# Keras 3. It is deliberately opt-in. Install with
|
||||
# `pip install 'neuropose[metal]'` or `uv sync --extra metal`. The platform
|
||||
# markers ensure it is silently ignored on every non-Apple-Silicon install.
|
||||
metal = [
|
||||
"tensorflow-metal>=1.2,<2; sys_platform == 'darwin' and platform_machine == 'arm64'",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://git.levineuwirth.org/neuwirth/neuropose"
|
||||
Repository = "https://git.levineuwirth.org/neuwirth/neuropose"
|
||||
Issues = "https://git.levineuwirth.org/neuwirth/neuropose/issues"
|
||||
Changelog = "https://git.levineuwirth.org/neuwirth/neuropose/src/branch/main/CHANGELOG.md"
|
||||
|
||||
[project.scripts]
|
||||
neuropose = "neuropose.cli:run"
|
||||
|
|
@ -82,6 +100,13 @@ dev = [
|
|||
"pre-commit>=4.0",
|
||||
"mkdocs-material>=9.5",
|
||||
"mkdocstrings[python]>=0.26",
|
||||
# Analyzer subpackage runtime deps. Duplicated from the `analysis`
|
||||
# optional-dependencies extra so dev contributors can run the analyzer
|
||||
# tests without also having to install the extra. sktime is NOT
|
||||
# included here because we do not (yet) wrap it — users who want
|
||||
# classification install `pip install neuropose[analysis]` themselves.
|
||||
"fastdtw>=0.3.4",
|
||||
"scipy>=1.12",
|
||||
]
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -175,6 +200,7 @@ include = [
|
|||
"src/neuropose",
|
||||
"README.md",
|
||||
"LICENSE",
|
||||
"CHANGELOG.md",
|
||||
"AUTHORS.md",
|
||||
"CITATION.cff",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,69 @@
|
|||
#!/usr/bin/env python
|
||||
"""Pre-download the pinned MeTRAbs model into the NeuroPose cache.
|
||||
|
||||
Run this on a machine with network access before going offline, or to
|
||||
pre-warm a deployment target's cache so the first ``neuropose watch``
|
||||
or ``neuropose process`` invocation does not stall on a ~2 GB
|
||||
download.
|
||||
|
||||
Usage::
|
||||
|
||||
uv run python scripts/download_model.py [--cache-dir PATH]
|
||||
|
||||
If ``--cache-dir`` is omitted, the script uses
|
||||
``Settings().model_cache_dir`` (``$XDG_DATA_HOME/neuropose/models`` by
|
||||
default), which is the same location the daemon and the direct
|
||||
``Estimator`` entry points read from.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def _parse_args(argv: list[str]) -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--cache-dir",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Destination directory for the cached model. Defaults to "
|
||||
"the value of Settings().model_cache_dir.",
|
||||
)
|
||||
return parser.parse_args(argv)
|
||||
|
||||
|
||||
def main(argv: list[str] | None = None) -> int:
|
||||
args = _parse_args(sys.argv[1:] if argv is None else argv)
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)-8s %(name)s: %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
# Deferred imports so --help runs cheaply without importing TF.
|
||||
from neuropose._model import load_metrabs_model
|
||||
from neuropose.config import Settings
|
||||
|
||||
if args.cache_dir is None:
|
||||
settings = Settings()
|
||||
cache_dir = settings.model_cache_dir
|
||||
else:
|
||||
cache_dir = args.cache_dir
|
||||
|
||||
print(f"Fetching MeTRAbs model into {cache_dir}", file=sys.stderr)
|
||||
try:
|
||||
load_metrabs_model(cache_dir=cache_dir)
|
||||
except Exception as exc:
|
||||
print(f"error: model download failed: {exc}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
print("Download complete", file=sys.stderr)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
|
|
@ -1,51 +1,355 @@
|
|||
"""MeTRAbs model loading — stub pending commit 11.
|
||||
"""MeTRAbs model loading with download, checksum verification, and caching.
|
||||
|
||||
This module exists so that :mod:`neuropose.estimator` can import a single,
|
||||
well-typed loader function even before the upstream MeTRAbs URL is pinned
|
||||
and the TensorFlow version is settled.
|
||||
The :func:`load_metrabs_model` function is the single entry point through
|
||||
which :class:`neuropose.estimator.Estimator` acquires its TensorFlow
|
||||
model. It handles:
|
||||
|
||||
Commit 11 will replace :func:`load_metrabs_model` with an implementation
|
||||
that:
|
||||
1. First-call download from the pinned upstream URL.
|
||||
2. SHA-256 verification of the downloaded tarball against a known-good
|
||||
checksum. A mismatch triggers exactly one automatic retry (in case
|
||||
the download was truncated), after which the error surfaces.
|
||||
3. Atomic extraction to a staging directory and a single rename into
|
||||
the final cache location, so a crash mid-extraction cannot leave
|
||||
the cache in a half-extracted state.
|
||||
4. SavedModel load via ``tf.saved_model.load``.
|
||||
5. A post-load interface sanity check that verifies the loaded model
|
||||
exposes the ``detect_poses``, ``per_skeleton_joint_names``, and
|
||||
``per_skeleton_joint_edges`` attributes the estimator needs.
|
||||
|
||||
1. Pins the canonical MeTRAbs tfhub / Kaggle Models handle (replacing the
|
||||
``bit.ly/metrabs_1`` shortener from the previous prototype).
|
||||
2. Caches the downloaded model under ``Settings.model_cache_dir`` so the
|
||||
first run downloads it and subsequent runs are offline.
|
||||
3. Returns a typed handle that the estimator can invoke without hitting the
|
||||
network on each instantiation.
|
||||
Model artifact
|
||||
--------------
|
||||
The pinned model is MeTRAbs's EfficientNetV2-L variant
|
||||
(``metrabs_eff2l_y4_384px_800k_28ds``):
|
||||
|
||||
- **URL**: hosted on the RWTH Aachen "omnomnom" server, which is the
|
||||
canonical distribution point for the MeTRAbs authors' lab.
|
||||
- **SHA-256**: pinned below. Any change to the URL or the upstream
|
||||
tarball will surface as a verification failure, forcing a human
|
||||
review of the new checksum before downstream code trusts the new
|
||||
artifact.
|
||||
|
||||
See ``RESEARCH.md`` at the repo root for the ongoing discussion of
|
||||
self-hosting the model on our own infrastructure instead of relying on
|
||||
a single third-party URL.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import tarfile
|
||||
import urllib.request
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def load_metrabs_model(cache_dir: Path | None = None) -> Any: # noqa: ARG001
|
||||
# ---------------------------------------------------------------------------
|
||||
# Model artifact: pinned URL and checksum.
|
||||
# ---------------------------------------------------------------------------
|
||||
#
|
||||
# If the URL or checksum below changes, the diff should be reviewed by a
|
||||
# human. These are supply-chain constants.
|
||||
|
||||
_MODEL_URL = (
|
||||
"https://omnomnom.vision.rwth-aachen.de/data/metrabs/"
|
||||
"metrabs_eff2l_y4_384px_800k_28ds.tar.gz"
|
||||
)
|
||||
_MODEL_SHA256 = "fa31b5b043f227588c3d224e56db89307d021bfbbb52e36028919f90e1f96c89"
|
||||
_MODEL_ARCHIVE_NAME = "metrabs_eff2l_y4_384px_800k_28ds.tar.gz"
|
||||
_MODEL_DIR_NAME = "metrabs_eff2l_y4_384px_800k_28ds"
|
||||
|
||||
_DOWNLOAD_CHUNK_BYTES = 1024 * 1024 # 1 MB
|
||||
_DOWNLOAD_SOCKET_TIMEOUT = 120.0 # seconds between bytes, not total
|
||||
_REQUIRED_MODEL_ATTRS = (
|
||||
"detect_poses",
|
||||
"per_skeleton_joint_names",
|
||||
"per_skeleton_joint_edges",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def load_metrabs_model(cache_dir: Path | None = None) -> Any:
|
||||
"""Load the MeTRAbs model, downloading and caching on first use.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
cache_dir
|
||||
Directory where the model should be cached. Typically
|
||||
``Settings.model_cache_dir``. If ``None``, the implementation picks
|
||||
a default location.
|
||||
Directory where the model tarball and extracted SavedModel are
|
||||
cached. If ``None``, defaults to
|
||||
``$XDG_DATA_HOME/neuropose/models`` (matching
|
||||
:attr:`neuropose.config.Settings.model_cache_dir`).
|
||||
|
||||
Returns
|
||||
-------
|
||||
object
|
||||
An opaque model handle that exposes ``detect_poses`` and the
|
||||
per-skeleton joint metadata attributes used by
|
||||
:class:`neuropose.estimator.Estimator`.
|
||||
A TensorFlow SavedModel handle exposing ``detect_poses`` and
|
||||
the ``per_skeleton_joint_names`` / ``per_skeleton_joint_edges``
|
||||
attributes used by :class:`neuropose.estimator.Estimator`.
|
||||
|
||||
Raises
|
||||
------
|
||||
NotImplementedError
|
||||
Always, at this commit. Commit 11 provides the real implementation
|
||||
once the upstream MeTRAbs URL is pinned.
|
||||
RuntimeError
|
||||
If the download fails, the SHA-256 does not match (after one
|
||||
automatic retry), extraction fails, TensorFlow is not
|
||||
installed, or the loaded model does not expose the expected
|
||||
interface.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"load_metrabs_model is stubbed pending commit 11. "
|
||||
"Inject a model directly via Estimator(model=...) for now, "
|
||||
"or wait for the upstream MeTRAbs URL and TensorFlow version pin."
|
||||
resolved_cache = Path(cache_dir) if cache_dir is not None else _default_cache_dir()
|
||||
resolved_cache.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
model_dir = resolved_cache / _MODEL_DIR_NAME
|
||||
|
||||
if model_dir.exists():
|
||||
try:
|
||||
saved_model_dir = _find_saved_model(model_dir)
|
||||
except RuntimeError:
|
||||
logger.warning(
|
||||
"Cached model at %s appears incomplete; removing and re-downloading.",
|
||||
model_dir,
|
||||
)
|
||||
shutil.rmtree(model_dir, ignore_errors=True)
|
||||
else:
|
||||
return _tf_load(saved_model_dir)
|
||||
|
||||
tarball = resolved_cache / _MODEL_ARCHIVE_NAME
|
||||
|
||||
if not tarball.exists():
|
||||
_download_with_progress(_MODEL_URL, tarball)
|
||||
|
||||
try:
|
||||
_verify_sha256(tarball, _MODEL_SHA256)
|
||||
except RuntimeError as first_exc:
|
||||
logger.warning(
|
||||
"SHA-256 mismatch on cached tarball; re-downloading once: %s",
|
||||
first_exc,
|
||||
)
|
||||
tarball.unlink(missing_ok=True)
|
||||
_download_with_progress(_MODEL_URL, tarball)
|
||||
_verify_sha256(tarball, _MODEL_SHA256)
|
||||
|
||||
_extract_tarball(tarball, model_dir)
|
||||
saved_model_dir = _find_saved_model(model_dir)
|
||||
return _tf_load(saved_model_dir)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Cache directory resolution
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _default_cache_dir() -> Path:
|
||||
"""Return the default model cache directory under ``$XDG_DATA_HOME``.
|
||||
|
||||
Duplicates :func:`neuropose.config._default_model_cache_dir` rather
|
||||
than importing it, to keep this module free of a dependency on the
|
||||
config layer. The two must agree; a regression test in
|
||||
:mod:`tests.unit.test_config` verifies the Settings default and any
|
||||
future change here should be cross-checked there.
|
||||
"""
|
||||
xdg = os.environ.get("XDG_DATA_HOME")
|
||||
base = Path(xdg) if xdg else Path.home() / ".local" / "share"
|
||||
return base / "neuropose" / "models"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Download
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _download_with_progress(url: str, dest: Path) -> None:
|
||||
"""Download ``url`` to ``dest`` with progress reporting via the logger.
|
||||
|
||||
Streams the response in 1 MB chunks so memory usage stays flat
|
||||
regardless of the file size. Progress is logged at 10 % increments.
|
||||
On any exception, the partial file at ``dest`` is removed so the
|
||||
caller does not see a truncated file.
|
||||
"""
|
||||
logger.info("Downloading %s → %s", url, dest)
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
request = urllib.request.Request( # noqa: S310
|
||||
url,
|
||||
headers={"User-Agent": "neuropose/0.1"},
|
||||
)
|
||||
|
||||
try:
|
||||
with urllib.request.urlopen( # noqa: S310
|
||||
request,
|
||||
timeout=_DOWNLOAD_SOCKET_TIMEOUT,
|
||||
) as response:
|
||||
total_bytes_header = response.headers.get("Content-Length")
|
||||
total_bytes = int(total_bytes_header) if total_bytes_header else 0
|
||||
|
||||
downloaded = 0
|
||||
next_progress_log = 0.10 # log at 10 %, 20 %, ...
|
||||
|
||||
with dest.open("wb") as out_file:
|
||||
while True:
|
||||
chunk = response.read(_DOWNLOAD_CHUNK_BYTES)
|
||||
if not chunk:
|
||||
break
|
||||
out_file.write(chunk)
|
||||
downloaded += len(chunk)
|
||||
if total_bytes > 0:
|
||||
fraction = downloaded / total_bytes
|
||||
if fraction >= next_progress_log:
|
||||
logger.info(
|
||||
" %d / %d MB (%.0f%%)",
|
||||
downloaded // (1024 * 1024),
|
||||
total_bytes // (1024 * 1024),
|
||||
fraction * 100,
|
||||
)
|
||||
next_progress_log += 0.10
|
||||
except Exception as exc:
|
||||
# Clean up partial file so the next call re-downloads cleanly.
|
||||
dest.unlink(missing_ok=True)
|
||||
raise RuntimeError(
|
||||
f"Failed to download MeTRAbs model from {url}: {exc}"
|
||||
) from exc
|
||||
|
||||
if total_bytes > 0 and downloaded != total_bytes:
|
||||
dest.unlink(missing_ok=True)
|
||||
raise RuntimeError(
|
||||
f"Download from {url} was truncated: "
|
||||
f"got {downloaded} bytes, expected {total_bytes}."
|
||||
)
|
||||
logger.info("Download complete: %d bytes", downloaded)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Checksum verification
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _verify_sha256(path: Path, expected_hex: str) -> None:
|
||||
"""Verify that ``path``'s SHA-256 digest matches ``expected_hex``.
|
||||
|
||||
Streams the file through ``hashlib.sha256`` in 1 MB chunks so we do
|
||||
not pay for loading a 2 GB tarball into memory.
|
||||
"""
|
||||
logger.info("Verifying SHA-256 of %s", path)
|
||||
hasher = hashlib.sha256()
|
||||
with path.open("rb") as f:
|
||||
for chunk in iter(lambda: f.read(_DOWNLOAD_CHUNK_BYTES), b""):
|
||||
hasher.update(chunk)
|
||||
actual_hex = hasher.hexdigest()
|
||||
if actual_hex != expected_hex:
|
||||
raise RuntimeError(
|
||||
f"SHA-256 mismatch for {path}: "
|
||||
f"expected {expected_hex}, got {actual_hex}. "
|
||||
f"The downloaded file is corrupt, truncated, or has been tampered with."
|
||||
)
|
||||
logger.info("SHA-256 verified")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Extraction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _extract_tarball(tarball: Path, dest_dir: Path) -> None:
|
||||
"""Extract ``tarball`` atomically to ``dest_dir``.
|
||||
|
||||
Extracts first to a sibling ``<dest_dir>.staging`` directory, then
|
||||
replaces ``dest_dir`` with a single ``rename`` once extraction
|
||||
completes. A crash mid-extraction therefore cannot leave behind a
|
||||
half-populated ``dest_dir``.
|
||||
|
||||
Uses tarfile's ``data`` filter to block path traversal and other
|
||||
tar-bomb patterns.
|
||||
"""
|
||||
logger.info("Extracting %s → %s", tarball, dest_dir)
|
||||
staging = dest_dir.parent / (dest_dir.name + ".staging")
|
||||
if staging.exists():
|
||||
shutil.rmtree(staging)
|
||||
staging.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
with tarfile.open(tarball, "r:gz") as tf_archive:
|
||||
# ``filter="data"`` guards against path traversal and other
|
||||
# malicious tar contents. Available in Python 3.11.4+ and
|
||||
# required in 3.14+.
|
||||
tf_archive.extractall(staging, filter="data")
|
||||
except Exception as exc:
|
||||
shutil.rmtree(staging, ignore_errors=True)
|
||||
raise RuntimeError(f"Failed to extract {tarball}: {exc}") from exc
|
||||
|
||||
if dest_dir.exists():
|
||||
shutil.rmtree(dest_dir)
|
||||
staging.rename(dest_dir)
|
||||
logger.info("Extracted to %s", dest_dir)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SavedModel discovery and TF load
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _find_saved_model(root: Path) -> Path:
|
||||
"""Return the directory containing ``saved_model.pb`` under ``root``.
|
||||
|
||||
The MeTRAbs tarball extracts to a directory containing a SavedModel
|
||||
directory, which itself contains ``saved_model.pb``. The exact
|
||||
layout of intermediate directories is not contractually stable, so
|
||||
we search rather than hardcoding a path.
|
||||
|
||||
Raises
|
||||
------
|
||||
RuntimeError
|
||||
If no ``saved_model.pb`` is found, or if multiple candidates
|
||||
are found (which would make the choice ambiguous).
|
||||
"""
|
||||
candidates = list(root.rglob("saved_model.pb"))
|
||||
if not candidates:
|
||||
raise RuntimeError(
|
||||
f"no saved_model.pb found under {root}; tarball layout unexpected"
|
||||
)
|
||||
if len(candidates) > 1:
|
||||
raise RuntimeError(
|
||||
f"multiple saved_model.pb files found under {root}: "
|
||||
f"{[str(p) for p in candidates]}. "
|
||||
f"Cannot determine which one to load."
|
||||
)
|
||||
return candidates[0].parent
|
||||
|
||||
|
||||
def _tf_load(saved_model_dir: Path) -> Any:
|
||||
"""Load a SavedModel from ``saved_model_dir`` and sanity-check it.
|
||||
|
||||
TensorFlow is imported lazily here so that importing
|
||||
:mod:`neuropose._model` does not require TF for test or docs
|
||||
code paths that never reach the loader.
|
||||
"""
|
||||
try:
|
||||
import tensorflow as tf # noqa: PLC0415
|
||||
except ImportError as exc:
|
||||
raise RuntimeError(
|
||||
"TensorFlow is required to load the MeTRAbs model but is not installed. "
|
||||
"Install the NeuroPose runtime dependencies with: "
|
||||
"pip install neuropose (or uv sync from a dev checkout)."
|
||||
) from exc
|
||||
|
||||
logger.info("Loading SavedModel from %s", saved_model_dir)
|
||||
try:
|
||||
model = tf.saved_model.load(str(saved_model_dir))
|
||||
except Exception as exc:
|
||||
raise RuntimeError(
|
||||
f"Failed to load SavedModel from {saved_model_dir}: {exc}"
|
||||
) from exc
|
||||
|
||||
missing = [attr for attr in _REQUIRED_MODEL_ATTRS if not hasattr(model, attr)]
|
||||
if missing:
|
||||
raise RuntimeError(
|
||||
f"Loaded SavedModel at {saved_model_dir} is missing expected "
|
||||
f"attributes {missing}. The tarball may not be a MeTRAbs model."
|
||||
)
|
||||
|
||||
logger.info("MeTRAbs model loaded successfully from %s", saved_model_dir)
|
||||
return model
|
||||
|
|
|
|||
|
|
@ -0,0 +1,58 @@
|
|||
"""Post-processing and analysis utilities for NeuroPose predictions.
|
||||
|
||||
This subpackage operates on :class:`~neuropose.io.VideoPredictions`
|
||||
objects (and the numpy arrays derived from them) rather than on raw
|
||||
dicts or JSON files. The intent is a set of composable pure functions —
|
||||
feature extraction, normalization, joint-angle computation, and Dynamic
|
||||
Time Warping — that researchers can assemble into their own pipelines.
|
||||
|
||||
.. note::
|
||||
The analyzer's heavy dependencies (:mod:`fastdtw`, :mod:`scipy`) are
|
||||
declared under the ``analysis`` optional-dependencies extra in
|
||||
:file:`pyproject.toml`. Install them with::
|
||||
|
||||
pip install neuropose[analysis]
|
||||
|
||||
The imports inside :mod:`neuropose.analyzer.dtw` and the peak-finding
|
||||
helper in :mod:`neuropose.analyzer.features` are lazy, so importing
|
||||
this subpackage does not require those dependencies. You will only
|
||||
hit a clear :class:`ImportError` at call time if they are missing.
|
||||
|
||||
Public API
|
||||
----------
|
||||
See :mod:`neuropose.analyzer.dtw` and :mod:`neuropose.analyzer.features`
|
||||
for per-module details; the most commonly used names are re-exported
|
||||
here for ergonomic access.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from neuropose.analyzer.dtw import (
|
||||
DTWResult,
|
||||
dtw_all,
|
||||
dtw_per_joint,
|
||||
dtw_relation,
|
||||
)
|
||||
from neuropose.analyzer.features import (
|
||||
FeatureStatistics,
|
||||
extract_feature_statistics,
|
||||
extract_joint_angles,
|
||||
find_peaks,
|
||||
normalize_pose_sequence,
|
||||
pad_sequences,
|
||||
predictions_to_numpy,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DTWResult",
|
||||
"FeatureStatistics",
|
||||
"dtw_all",
|
||||
"dtw_per_joint",
|
||||
"dtw_relation",
|
||||
"extract_feature_statistics",
|
||||
"extract_joint_angles",
|
||||
"find_peaks",
|
||||
"normalize_pose_sequence",
|
||||
"pad_sequences",
|
||||
"predictions_to_numpy",
|
||||
]
|
||||
|
|
@ -0,0 +1,212 @@
|
|||
"""Dynamic Time Warping helpers for pose sequence comparison.
|
||||
|
||||
Three entry points, ordered by increasing precision (and increasing cost):
|
||||
|
||||
- :func:`dtw_all` — DTW on the flattened per-frame joint vector. Fast but
|
||||
coarse; collapses every joint axis into a single per-frame vector.
|
||||
- :func:`dtw_per_joint` — DTW on each joint independently. Preserves
|
||||
per-joint temporal alignment at the cost of one DTW call per joint.
|
||||
- :func:`dtw_relation` — DTW on the displacement vector between two
|
||||
specific joints. This is the right tool when the research question is
|
||||
about the *relative* motion of a specific pair of joints (e.g. the
|
||||
hand-to-hip vector during a reach-and-grasp trial).
|
||||
|
||||
All three return a :class:`DTWResult` dataclass with the DTW distance
|
||||
and the warping path. Inputs are expected to be ``(frames, joints, 3)``
|
||||
numpy arrays — the shape :func:`~neuropose.analyzer.features.predictions_to_numpy`
|
||||
produces.
|
||||
|
||||
Dependency note
|
||||
---------------
|
||||
This module requires :mod:`fastdtw` and :mod:`scipy`, which are part of
|
||||
the ``analysis`` optional extra. Imports are performed lazily inside
|
||||
:func:`_require_fastdtw` so that ``import neuropose.analyzer.dtw``
|
||||
succeeds even when the extra is not installed; the error surfaces with
|
||||
a clear installation hint the first time a DTW function is actually
|
||||
called.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DTWResult:
|
||||
"""Result of a single DTW computation.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
distance
|
||||
Scalar DTW distance between the two input sequences.
|
||||
path
|
||||
Warping path as a list of ``(i, j)`` index pairs, where ``i`` is
|
||||
an index into the first sequence and ``j`` is an index into the
|
||||
second.
|
||||
"""
|
||||
|
||||
distance: float
|
||||
path: list[tuple[int, int]]
|
||||
|
||||
|
||||
def _require_fastdtw() -> tuple[Callable, Callable]:
|
||||
"""Lazily import fastdtw and scipy.spatial.distance.euclidean.
|
||||
|
||||
Returns
|
||||
-------
|
||||
tuple
|
||||
``(fastdtw_callable, euclidean_callable)``.
|
||||
|
||||
Raises
|
||||
------
|
||||
ImportError
|
||||
If either ``fastdtw`` or ``scipy`` is unavailable. The message
|
||||
points the user at the ``analysis`` optional-dependencies extra.
|
||||
"""
|
||||
try:
|
||||
from fastdtw import fastdtw
|
||||
from scipy.spatial.distance import euclidean
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"neuropose.analyzer.dtw requires fastdtw and scipy. "
|
||||
"Install them with: pip install neuropose[analysis]"
|
||||
) from exc
|
||||
return fastdtw, euclidean
|
||||
|
||||
|
||||
def dtw_all(a: np.ndarray, b: np.ndarray) -> DTWResult:
|
||||
"""DTW on the flattened per-frame joint vector.
|
||||
|
||||
Each frame's joints are collapsed into a single vector before DTW
|
||||
is applied. This is fast — one DTW call regardless of the joint
|
||||
count — but loses per-joint temporal structure, so a small
|
||||
timing mismatch on one joint can dominate the distance metric.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
a, b
|
||||
Pose sequences as ``(frames, joints, 3)`` numpy arrays. The two
|
||||
sequences do not need to have the same number of frames, but
|
||||
they must have the same number of joints.
|
||||
|
||||
Returns
|
||||
-------
|
||||
DTWResult
|
||||
The DTW distance and warping path between the flattened
|
||||
sequences.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If ``a`` and ``b`` do not have the same joint count.
|
||||
"""
|
||||
_validate_same_joint_count(a, b)
|
||||
fastdtw, euclidean = _require_fastdtw()
|
||||
a_flat = a.reshape(a.shape[0], -1)
|
||||
b_flat = b.reshape(b.shape[0], -1)
|
||||
distance, path = fastdtw(a_flat, b_flat, dist=euclidean)
|
||||
return DTWResult(distance=float(distance), path=[tuple(p) for p in path])
|
||||
|
||||
|
||||
def dtw_per_joint(a: np.ndarray, b: np.ndarray) -> list[DTWResult]:
|
||||
"""DTW on each joint independently.
|
||||
|
||||
Performs one DTW computation per joint, yielding a list of
|
||||
:class:`DTWResult` objects in joint-index order. More precise than
|
||||
:func:`dtw_all` because each joint's temporal alignment is optimised
|
||||
separately, at the cost of J times more DTW calls for J joints.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
a, b
|
||||
Pose sequences as ``(frames, joints, 3)`` numpy arrays. The two
|
||||
sequences do not need to have the same number of frames but
|
||||
must have the same number of joints.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[DTWResult]
|
||||
One DTW result per joint, in index order.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If ``a`` and ``b`` do not have the same joint count.
|
||||
"""
|
||||
_validate_same_joint_count(a, b)
|
||||
fastdtw, euclidean = _require_fastdtw()
|
||||
results: list[DTWResult] = []
|
||||
for joint_idx in range(a.shape[1]):
|
||||
a_joint = a[:, joint_idx, :]
|
||||
b_joint = b[:, joint_idx, :]
|
||||
distance, path = fastdtw(a_joint, b_joint, dist=euclidean)
|
||||
results.append(
|
||||
DTWResult(distance=float(distance), path=[tuple(p) for p in path])
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
def dtw_relation(
|
||||
a: np.ndarray,
|
||||
b: np.ndarray,
|
||||
joint_i: int,
|
||||
joint_j: int,
|
||||
) -> DTWResult:
|
||||
"""DTW on the displacement vector between two specific joints.
|
||||
|
||||
For each frame, the input is reduced to the vector from ``joint_i``
|
||||
to ``joint_j``. DTW is then applied to the two sequences of
|
||||
displacement vectors. This is the right tool when the question is
|
||||
"how does the relationship between joint A and joint B change over
|
||||
time?" — for example, "does the subject's hand track a consistent
|
||||
distance from the hip during the reach trial?"
|
||||
|
||||
Parameters
|
||||
----------
|
||||
a, b
|
||||
Pose sequences as ``(frames, joints, 3)`` numpy arrays.
|
||||
joint_i, joint_j
|
||||
Indices of the two joints whose relative position should be
|
||||
compared. Must be valid indices into ``a`` and ``b``'s joint
|
||||
axis.
|
||||
|
||||
Returns
|
||||
-------
|
||||
DTWResult
|
||||
DTW distance and path between the two displacement sequences.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the sequences have different joint counts or if either joint
|
||||
index is out of range.
|
||||
"""
|
||||
_validate_same_joint_count(a, b)
|
||||
num_joints = a.shape[1]
|
||||
if not (0 <= joint_i < num_joints) or not (0 <= joint_j < num_joints):
|
||||
raise ValueError(
|
||||
f"joint indices must be in [0, {num_joints}); "
|
||||
f"got joint_i={joint_i}, joint_j={joint_j}"
|
||||
)
|
||||
fastdtw, euclidean = _require_fastdtw()
|
||||
disp_a = a[:, joint_j, :] - a[:, joint_i, :]
|
||||
disp_b = b[:, joint_j, :] - b[:, joint_i, :]
|
||||
distance, path = fastdtw(disp_a, disp_b, dist=euclidean)
|
||||
return DTWResult(distance=float(distance), path=[tuple(p) for p in path])
|
||||
|
||||
|
||||
def _validate_same_joint_count(a: np.ndarray, b: np.ndarray) -> None:
|
||||
"""Raise :class:`ValueError` if ``a`` and ``b`` disagree on joint count."""
|
||||
if a.ndim < 2 or b.ndim < 2:
|
||||
raise ValueError(
|
||||
f"expected 3D arrays of shape (frames, joints, 3); "
|
||||
f"got a.ndim={a.ndim}, b.ndim={b.ndim}"
|
||||
)
|
||||
if a.shape[1] != b.shape[1]:
|
||||
raise ValueError(
|
||||
f"input arrays disagree on joint count: "
|
||||
f"a has {a.shape[1]} joints, b has {b.shape[1]} joints"
|
||||
)
|
||||
|
|
@ -0,0 +1,386 @@
|
|||
"""Feature extraction helpers for pose sequences.
|
||||
|
||||
All functions in this module operate on numpy arrays of shape
|
||||
``(frames, joints, 3)`` — the output of
|
||||
:func:`predictions_to_numpy`. They are pure functions: none of them
|
||||
mutate their inputs, and none of them touch the filesystem or the
|
||||
model.
|
||||
|
||||
The following helpers are provided:
|
||||
|
||||
- :func:`predictions_to_numpy` — convert a validated
|
||||
:class:`~neuropose.io.VideoPredictions` into a numpy pose sequence.
|
||||
- :func:`normalize_pose_sequence` — scale a sequence so joint positions
|
||||
fit in the unit cube (either per-axis or uniform).
|
||||
- :func:`pad_sequences` — edge-pad a batch of sequences to a common
|
||||
length, suitable for downstream tensor-based analysis.
|
||||
- :func:`extract_joint_angles` — compute joint angles at specified
|
||||
triplet positions across a pose sequence.
|
||||
- :func:`extract_feature_statistics` — summary statistics
|
||||
(mean / std / min / max / range) for a 1D feature series.
|
||||
- :func:`find_peaks` — thin :mod:`scipy.signal` wrapper returning only
|
||||
the peak indices.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
|
||||
from neuropose.io import VideoPredictions
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# VideoPredictions → numpy
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def predictions_to_numpy(
|
||||
predictions: VideoPredictions,
|
||||
*,
|
||||
person_index: int = 0,
|
||||
) -> np.ndarray:
|
||||
"""Convert a :class:`VideoPredictions` to a 3D pose sequence.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
predictions
|
||||
The predictions to convert.
|
||||
person_index
|
||||
Which detected person to extract per frame. Defaults to ``0``
|
||||
(the first detected person) which matches the single-subject
|
||||
clinical case. Frames that do not have at least
|
||||
``person_index + 1`` detections raise :class:`ValueError`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
numpy.ndarray
|
||||
A ``(frames, joints, 3)`` array in the same physical units as
|
||||
the underlying predictions (millimetres for MeTRAbs output).
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If any frame lacks sufficient detections for ``person_index``,
|
||||
or if the predictions contain no frames.
|
||||
"""
|
||||
if len(predictions) == 0:
|
||||
raise ValueError("predictions contains zero frames")
|
||||
frames: list[list[list[float]]] = []
|
||||
for frame_name in predictions.frame_names():
|
||||
per_person = predictions[frame_name].poses3d
|
||||
if person_index >= len(per_person):
|
||||
raise ValueError(
|
||||
f"frame {frame_name} has {len(per_person)} detections; "
|
||||
f"person_index={person_index} is out of range"
|
||||
)
|
||||
frames.append(per_person[person_index])
|
||||
return np.asarray(frames, dtype=float)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Normalization
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def normalize_pose_sequence(
|
||||
sequence: np.ndarray,
|
||||
*,
|
||||
axis_wise: bool = False,
|
||||
) -> np.ndarray:
|
||||
"""Translate and scale a pose sequence so joints fit in the unit cube.
|
||||
|
||||
The minimum coordinate along each spatial axis is subtracted, and
|
||||
the result is divided by a single scalar (the range of the largest
|
||||
axis, when ``axis_wise=False``) or by the per-axis range (when
|
||||
``axis_wise=True``).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sequence
|
||||
Array of shape ``(frames, joints, 3)``.
|
||||
axis_wise
|
||||
If ``False`` (default), preserve the geometric aspect ratio by
|
||||
using a single scalar denominator (the maximum axis extent).
|
||||
If ``True``, scale each axis independently to ``[0, 1]``, which
|
||||
distorts the geometry but guarantees full-range normalization
|
||||
on every axis.
|
||||
|
||||
Returns
|
||||
-------
|
||||
numpy.ndarray
|
||||
A new array of the same shape as ``sequence``, with joint
|
||||
positions translated to start at the origin and scaled as
|
||||
described above. The input is not modified.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If ``sequence`` does not have a final axis of size 3, or if the
|
||||
sequence is degenerate (zero extent on every axis).
|
||||
"""
|
||||
if sequence.ndim != 3 or sequence.shape[-1] != 3:
|
||||
raise ValueError(
|
||||
f"expected (frames, joints, 3); got shape {sequence.shape}"
|
||||
)
|
||||
result = sequence.astype(float, copy=True)
|
||||
mins = result.reshape(-1, 3).min(axis=0)
|
||||
maxs = result.reshape(-1, 3).max(axis=0)
|
||||
ranges = maxs - mins
|
||||
|
||||
if np.all(ranges == 0):
|
||||
raise ValueError("cannot normalize a degenerate (zero-extent) sequence")
|
||||
|
||||
result -= mins # broadcasts over (frames, joints, 3)
|
||||
|
||||
if axis_wise:
|
||||
# Replace zero ranges with 1 to avoid division-by-zero on axes
|
||||
# where all joints share a coordinate; those axes will remain 0.
|
||||
safe_ranges = np.where(ranges == 0, 1.0, ranges)
|
||||
result = result / safe_ranges
|
||||
else:
|
||||
scale = float(ranges.max())
|
||||
result = result / scale
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Padding
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def pad_sequences(
|
||||
sequences: Sequence[np.ndarray],
|
||||
*,
|
||||
target_length: int | None = None,
|
||||
) -> list[np.ndarray]:
|
||||
"""Edge-pad a list of pose sequences to a common length.
|
||||
|
||||
Each input sequence is extended by repeating its last frame until
|
||||
it reaches ``target_length``. Sequences that are already longer
|
||||
than ``target_length`` are **truncated** to that length. The input
|
||||
list itself and the input arrays are never mutated.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sequences
|
||||
List of ``(frames_i, joints, 3)`` arrays. ``frames_i`` may
|
||||
differ per sequence, but all sequences must share the same
|
||||
joint count and spatial dimensionality.
|
||||
target_length
|
||||
Desired number of frames. If ``None``, the maximum
|
||||
``frames_i`` in the input is used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[numpy.ndarray]
|
||||
A new list of arrays, each with ``target_length`` frames.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If ``sequences`` is empty and ``target_length`` is ``None``, or
|
||||
if any sequence has mismatched trailing dimensions.
|
||||
"""
|
||||
if not sequences:
|
||||
if target_length is None:
|
||||
raise ValueError(
|
||||
"cannot infer target_length from an empty sequence list"
|
||||
)
|
||||
return []
|
||||
|
||||
first = sequences[0]
|
||||
trailing_shape = first.shape[1:]
|
||||
for idx, seq in enumerate(sequences):
|
||||
if seq.shape[1:] != trailing_shape:
|
||||
raise ValueError(
|
||||
f"sequence {idx} has trailing shape {seq.shape[1:]}; "
|
||||
f"expected {trailing_shape}"
|
||||
)
|
||||
|
||||
length = target_length if target_length is not None else max(
|
||||
s.shape[0] for s in sequences
|
||||
)
|
||||
|
||||
padded: list[np.ndarray] = []
|
||||
for seq in sequences:
|
||||
if seq.shape[0] == length:
|
||||
padded.append(seq.copy())
|
||||
elif seq.shape[0] > length:
|
||||
padded.append(seq[:length].copy())
|
||||
else:
|
||||
pad_amount = length - seq.shape[0]
|
||||
padding = [(0, pad_amount)] + [(0, 0)] * (seq.ndim - 1)
|
||||
padded.append(np.pad(seq, padding, mode="edge"))
|
||||
return padded
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Joint angles
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def extract_joint_angles(
|
||||
sequence: np.ndarray,
|
||||
triplets: Sequence[tuple[int, int, int]],
|
||||
) -> np.ndarray:
|
||||
"""Compute angles at specified joints across a pose sequence.
|
||||
|
||||
For each triplet ``(a, b, c)``, the angle at ``b`` is defined as
|
||||
the angle between the vectors ``(a - b)`` and ``(c - b)``, in
|
||||
radians. Frames where either vector has zero length (e.g. joint
|
||||
degeneracy) produce ``NaN`` in the output rather than raising.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sequence
|
||||
Array of shape ``(frames, joints, 3)``.
|
||||
triplets
|
||||
Iterable of ``(a, b, c)`` joint index tuples. Each index must
|
||||
be in ``[0, joints)``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
numpy.ndarray
|
||||
Array of shape ``(frames, len(triplets))``, where each column
|
||||
is the time-series of angles at the corresponding triplet's
|
||||
centre joint. Angles are in radians, in ``[0, pi]``.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If any joint index in ``triplets`` is out of range.
|
||||
"""
|
||||
if sequence.ndim != 3 or sequence.shape[-1] != 3:
|
||||
raise ValueError(
|
||||
f"expected (frames, joints, 3); got shape {sequence.shape}"
|
||||
)
|
||||
num_joints = sequence.shape[1]
|
||||
columns: list[np.ndarray] = []
|
||||
for a_idx, b_idx, c_idx in triplets:
|
||||
for idx in (a_idx, b_idx, c_idx):
|
||||
if not (0 <= idx < num_joints):
|
||||
raise ValueError(
|
||||
f"joint index {idx} out of range [0, {num_joints})"
|
||||
)
|
||||
v1 = sequence[:, a_idx, :] - sequence[:, b_idx, :]
|
||||
v2 = sequence[:, c_idx, :] - sequence[:, b_idx, :]
|
||||
n1 = np.linalg.norm(v1, axis=1)
|
||||
n2 = np.linalg.norm(v2, axis=1)
|
||||
with np.errstate(invalid="ignore", divide="ignore"):
|
||||
cosine = np.sum(v1 * v2, axis=1) / (n1 * n2)
|
||||
cosine = np.clip(cosine, -1.0, 1.0)
|
||||
angle = np.arccos(cosine)
|
||||
columns.append(angle)
|
||||
return np.stack(columns, axis=1)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Summary statistics
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FeatureStatistics:
|
||||
"""Summary statistics for a 1D feature series.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
mean, std, min, max
|
||||
Standard summary statistics of the input values.
|
||||
range
|
||||
``max - min``. Precomputed for convenience.
|
||||
"""
|
||||
|
||||
mean: float
|
||||
std: float
|
||||
min: float
|
||||
max: float
|
||||
range: float
|
||||
|
||||
|
||||
def extract_feature_statistics(values: np.ndarray) -> FeatureStatistics:
|
||||
"""Compute summary statistics for a 1D feature series.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
values
|
||||
A 1D numpy array. Higher-dimensional inputs are rejected to
|
||||
keep the semantics unambiguous — callers that want per-column
|
||||
statistics should reduce along their axis of interest first.
|
||||
|
||||
Returns
|
||||
-------
|
||||
FeatureStatistics
|
||||
Mean, standard deviation, minimum, maximum, and range of
|
||||
``values``.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If ``values`` is not 1D or is empty.
|
||||
"""
|
||||
if values.ndim != 1:
|
||||
raise ValueError(f"expected 1D array; got shape {values.shape}")
|
||||
if values.size == 0:
|
||||
raise ValueError("cannot compute statistics of an empty array")
|
||||
mn = float(values.min())
|
||||
mx = float(values.max())
|
||||
return FeatureStatistics(
|
||||
mean=float(values.mean()),
|
||||
std=float(values.std()),
|
||||
min=mn,
|
||||
max=mx,
|
||||
range=mx - mn,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Peak finding
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def find_peaks(values: np.ndarray, **kwargs: object) -> np.ndarray:
|
||||
"""Return indices of local maxima in a 1D series.
|
||||
|
||||
Thin wrapper around :func:`scipy.signal.find_peaks` that returns
|
||||
just the peak-index array (scipy's function also returns a
|
||||
properties dict, which callers rarely need).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
values
|
||||
1D numpy array of feature values (e.g. a joint's Y-coordinate
|
||||
across frames).
|
||||
**kwargs
|
||||
Forwarded to :func:`scipy.signal.find_peaks`. Common options
|
||||
include ``height``, ``threshold``, ``distance``, and
|
||||
``prominence``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
numpy.ndarray
|
||||
1D integer array of peak indices, in ascending order.
|
||||
|
||||
Raises
|
||||
------
|
||||
ImportError
|
||||
If scipy is not installed. The error message points at the
|
||||
``analysis`` optional extra.
|
||||
ValueError
|
||||
If ``values`` is not 1D.
|
||||
"""
|
||||
if values.ndim != 1:
|
||||
raise ValueError(f"expected 1D array; got shape {values.shape}")
|
||||
try:
|
||||
from scipy.signal import find_peaks as _sp_find_peaks
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"neuropose.analyzer.features.find_peaks requires scipy. "
|
||||
"Install it with: pip install neuropose[analysis]"
|
||||
) from exc
|
||||
indices, _properties = _sp_find_peaks(values, **kwargs)
|
||||
return indices
|
||||
|
|
@ -12,6 +12,42 @@ import numpy as np
|
|||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Slow test opt-in
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def pytest_addoption(parser: pytest.Parser) -> None:
|
||||
"""Register the ``--runslow`` command-line flag.
|
||||
|
||||
Tests marked ``@pytest.mark.slow`` (typically the integration tests
|
||||
under ``tests/integration/`` that download the MeTRAbs model) are
|
||||
skipped by default and run only when ``--runslow`` is passed. This
|
||||
keeps the default ``pytest`` invocation fast and offline-safe, and
|
||||
keeps CI's default test job from burning minutes on a 2 GB download
|
||||
on every push.
|
||||
"""
|
||||
parser.addoption(
|
||||
"--runslow",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="run tests marked @pytest.mark.slow (model download required)",
|
||||
)
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(
|
||||
config: pytest.Config,
|
||||
items: list[pytest.Item],
|
||||
) -> None:
|
||||
"""Skip ``@slow`` tests unless ``--runslow`` was given on the command line."""
|
||||
if config.getoption("--runslow"):
|
||||
return
|
||||
skip_slow = pytest.mark.skip(reason="need --runslow to run slow tests")
|
||||
for item in items:
|
||||
if "slow" in item.keywords:
|
||||
item.add_marker(skip_slow)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Environment isolation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -0,0 +1,152 @@
|
|||
"""End-to-end smoke test for the MeTRAbs model loader and estimator.
|
||||
|
||||
This module lives under ``tests/integration/`` and every test in it is
|
||||
marked ``@pytest.mark.slow``. That means:
|
||||
|
||||
- ``pytest`` with no flags **skips** these tests. The conftest hook
|
||||
in ``tests/conftest.py`` requires the ``--runslow`` flag to run
|
||||
anything marked ``slow``.
|
||||
- ``pytest --runslow`` runs them. On a cold cache this triggers a
|
||||
~2 GB download of the MeTRAbs model tarball from its upstream URL;
|
||||
on a warm cache (subsequent runs with the same ``cache_dir``) it
|
||||
completes in seconds.
|
||||
|
||||
The intent of this file is **plumbing verification**, not accuracy
|
||||
benchmarking:
|
||||
|
||||
- Does the loader download, verify, extract, and load the tarball?
|
||||
- Does the loader's second call hit the cache without re-downloading?
|
||||
- Does the estimator run end-to-end against a real MeTRAbs model on a
|
||||
synthetic video and produce a valid :class:`VideoPredictions` object?
|
||||
|
||||
Accuracy against reference pose data is out of scope here and will
|
||||
live in a separate benchmark harness once the project has data it is
|
||||
cleared to test against.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from neuropose._model import load_metrabs_model
|
||||
from neuropose.estimator import Estimator
|
||||
from neuropose.io import FramePrediction, VideoPredictions
|
||||
|
||||
pytestmark = pytest.mark.slow
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def integration_video(tmp_path: Path) -> Path:
|
||||
"""Generate a 384×288 synthetic video sized for MeTRAbs input.
|
||||
|
||||
The default ``synthetic_video`` fixture in ``tests/conftest.py``
|
||||
produces 32×32 frames, which is too small for MeTRAbs's 384 px
|
||||
input and may cause the detector pipeline to short-circuit
|
||||
unpredictably. This fixture produces a modestly-sized video so
|
||||
the smoke test's plumbing assertions are meaningful.
|
||||
"""
|
||||
path = tmp_path / "integration.avi"
|
||||
fourcc = cv2.VideoWriter_fourcc(*"MJPG")
|
||||
writer = cv2.VideoWriter(str(path), fourcc, 30.0, (384, 288))
|
||||
assert writer.isOpened(), "cv2.VideoWriter failed to open; MJPG codec missing?"
|
||||
for i in range(5):
|
||||
# Flat gray with a shifting offset per frame. There are no
|
||||
# humans in the frame; MeTRAbs should produce zero detections
|
||||
# per frame but the pipeline must still return valid structures.
|
||||
frame = np.full((288, 384, 3), 100 + i * 10, dtype=np.uint8)
|
||||
writer.write(frame)
|
||||
writer.release()
|
||||
assert path.exists() and path.stat().st_size > 0
|
||||
return path
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def shared_model_cache_dir(tmp_path_factory: pytest.TempPathFactory) -> Path:
|
||||
"""Session-scoped cache dir so the model is downloaded at most once per run.
|
||||
|
||||
Without this, each test in the file would trigger a fresh download
|
||||
because the default ``tmp_path`` is function-scoped. The session
|
||||
scope means the first test pays for the download and subsequent
|
||||
tests load from the cache.
|
||||
"""
|
||||
return tmp_path_factory.mktemp("neuropose_model_cache")
|
||||
|
||||
|
||||
class TestMetrabsLoader:
|
||||
"""Exercises the loader's download → verify → extract → load path."""
|
||||
|
||||
def test_download_and_load(self, shared_model_cache_dir: Path) -> None:
|
||||
model = load_metrabs_model(cache_dir=shared_model_cache_dir)
|
||||
assert model is not None
|
||||
for attr in ("detect_poses", "per_skeleton_joint_names", "per_skeleton_joint_edges"):
|
||||
assert hasattr(model, attr), f"loaded model is missing {attr}"
|
||||
|
||||
def test_second_call_uses_cache(self, shared_model_cache_dir: Path) -> None:
|
||||
"""Idempotent: second call should return the cached model cheaply."""
|
||||
model_a = load_metrabs_model(cache_dir=shared_model_cache_dir)
|
||||
model_b = load_metrabs_model(cache_dir=shared_model_cache_dir)
|
||||
# tf.saved_model.load returns a new Python object each call, so
|
||||
# identity comparison doesn't work — but both should still
|
||||
# expose the MeTRAbs interface.
|
||||
assert hasattr(model_a, "detect_poses")
|
||||
assert hasattr(model_b, "detect_poses")
|
||||
|
||||
def test_berkeley_mhad_skeleton_is_present(
|
||||
self, shared_model_cache_dir: Path
|
||||
) -> None:
|
||||
"""The estimator pins skeleton='berkeley_mhad_43'; verify it exists."""
|
||||
model = load_metrabs_model(cache_dir=shared_model_cache_dir)
|
||||
joint_names = model.per_skeleton_joint_names["berkeley_mhad_43"]
|
||||
joint_edges = model.per_skeleton_joint_edges["berkeley_mhad_43"]
|
||||
# MeTRAbs exposes these as tf.Tensor objects; just verify we
|
||||
# can pull a shape out.
|
||||
assert joint_names.shape[0] == 43
|
||||
assert joint_edges.shape[0] > 0
|
||||
|
||||
|
||||
class TestEndToEndInference:
|
||||
"""Runs the estimator against a real model on a synthetic video."""
|
||||
|
||||
def test_estimator_produces_valid_predictions(
|
||||
self,
|
||||
integration_video: Path,
|
||||
shared_model_cache_dir: Path,
|
||||
) -> None:
|
||||
model = load_metrabs_model(cache_dir=shared_model_cache_dir)
|
||||
estimator = Estimator(model=model)
|
||||
|
||||
result = estimator.process_video(integration_video)
|
||||
|
||||
assert isinstance(result.predictions, VideoPredictions)
|
||||
assert result.frame_count == 5
|
||||
assert result.predictions.metadata.width == 384
|
||||
assert result.predictions.metadata.height == 288
|
||||
|
||||
# Each frame's predictions must validate as a FramePrediction,
|
||||
# regardless of whether MeTRAbs detects any people in it.
|
||||
for frame_name in result.predictions.frame_names():
|
||||
frame = result.predictions[frame_name]
|
||||
assert isinstance(frame, FramePrediction)
|
||||
assert isinstance(frame.boxes, list)
|
||||
assert isinstance(frame.poses3d, list)
|
||||
assert isinstance(frame.poses2d, list)
|
||||
|
||||
def test_progress_callback_invoked_per_frame(
|
||||
self,
|
||||
integration_video: Path,
|
||||
shared_model_cache_dir: Path,
|
||||
) -> None:
|
||||
model = load_metrabs_model(cache_dir=shared_model_cache_dir)
|
||||
estimator = Estimator(model=model)
|
||||
|
||||
processed_counts: list[int] = []
|
||||
estimator.process_video(
|
||||
integration_video,
|
||||
progress=lambda processed, _total: processed_counts.append(processed),
|
||||
)
|
||||
|
||||
assert processed_counts == [1, 2, 3, 4, 5]
|
||||
|
|
@ -0,0 +1,142 @@
|
|||
"""Tests for :mod:`neuropose.analyzer.dtw`."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from neuropose.analyzer.dtw import (
|
||||
DTWResult,
|
||||
dtw_all,
|
||||
dtw_per_joint,
|
||||
dtw_relation,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def simple_sequence() -> np.ndarray:
|
||||
"""A 5-frame, 3-joint sequence of linearly-moving joints."""
|
||||
rng = np.random.default_rng(seed=42)
|
||||
return rng.standard_normal((5, 3, 3))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# dtw_all
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDtwAll:
|
||||
def test_identical_sequences_distance_zero(
|
||||
self, simple_sequence: np.ndarray
|
||||
) -> None:
|
||||
result = dtw_all(simple_sequence, simple_sequence)
|
||||
assert isinstance(result, DTWResult)
|
||||
assert result.distance == pytest.approx(0.0, abs=1e-9)
|
||||
# Identical sequences produce a diagonal warping path.
|
||||
assert all(i == j for i, j in result.path)
|
||||
|
||||
def test_shifted_sequences_distance_zero(
|
||||
self, simple_sequence: np.ndarray
|
||||
) -> None:
|
||||
"""DTW should absorb a pure time shift without penalty."""
|
||||
# Duplicate the first frame to create a one-frame shift.
|
||||
shifted = np.concatenate([simple_sequence[:1], simple_sequence], axis=0)
|
||||
result = dtw_all(simple_sequence, shifted)
|
||||
assert result.distance == pytest.approx(0.0, abs=1e-9)
|
||||
|
||||
def test_different_sequences_positive_distance(self) -> None:
|
||||
a = np.zeros((5, 3, 3))
|
||||
b = np.ones((5, 3, 3))
|
||||
result = dtw_all(a, b)
|
||||
assert result.distance > 0.0
|
||||
|
||||
def test_mismatched_joint_count_rejected(self) -> None:
|
||||
a = np.zeros((5, 3, 3))
|
||||
b = np.zeros((5, 4, 3))
|
||||
with pytest.raises(ValueError, match="joint count"):
|
||||
dtw_all(a, b)
|
||||
|
||||
def test_non_3d_input_rejected(self) -> None:
|
||||
a = np.zeros((5, 3)) # missing trailing axis
|
||||
b = np.zeros((5, 3))
|
||||
with pytest.raises(ValueError, match="expected 3D"):
|
||||
dtw_all(a, b)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# dtw_per_joint
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDtwPerJoint:
|
||||
def test_returns_one_result_per_joint(
|
||||
self, simple_sequence: np.ndarray
|
||||
) -> None:
|
||||
results = dtw_per_joint(simple_sequence, simple_sequence)
|
||||
assert len(results) == simple_sequence.shape[1]
|
||||
for result in results:
|
||||
assert isinstance(result, DTWResult)
|
||||
assert result.distance == pytest.approx(0.0, abs=1e-9)
|
||||
|
||||
def test_independent_joint_distances(self) -> None:
|
||||
# Construct two sequences where joint 0 matches exactly but
|
||||
# joint 1 is offset by a constant. Per-joint DTW should give
|
||||
# distance 0 for joint 0 and distance > 0 for joint 1.
|
||||
a = np.array(
|
||||
[
|
||||
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
|
||||
[[1.0, 1.0, 1.0], [0.0, 0.0, 0.0]],
|
||||
]
|
||||
)
|
||||
b = a.copy()
|
||||
b[:, 1, :] += 10.0
|
||||
results = dtw_per_joint(a, b)
|
||||
assert results[0].distance == pytest.approx(0.0, abs=1e-9)
|
||||
assert results[1].distance > 0.0
|
||||
|
||||
def test_mismatched_joint_count_rejected(self) -> None:
|
||||
a = np.zeros((5, 3, 3))
|
||||
b = np.zeros((5, 2, 3))
|
||||
with pytest.raises(ValueError, match="joint count"):
|
||||
dtw_per_joint(a, b)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# dtw_relation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDtwRelation:
|
||||
def test_identical_sequences_distance_zero(
|
||||
self, simple_sequence: np.ndarray
|
||||
) -> None:
|
||||
result = dtw_relation(simple_sequence, simple_sequence, joint_i=0, joint_j=1)
|
||||
assert result.distance == pytest.approx(0.0, abs=1e-9)
|
||||
|
||||
def test_same_relative_position_is_zero_even_under_translation(self) -> None:
|
||||
"""Translating the whole body does not change the
|
||||
joint-to-joint displacement, so dtw_relation should be 0."""
|
||||
a = np.zeros((4, 3, 3))
|
||||
a[:, 0, :] = [0.0, 0.0, 0.0]
|
||||
a[:, 1, :] = [1.0, 0.0, 0.0]
|
||||
a[:, 2, :] = [0.0, 1.0, 0.0]
|
||||
b = a + 50.0 # translate the whole body
|
||||
result = dtw_relation(a, b, joint_i=0, joint_j=1)
|
||||
assert result.distance == pytest.approx(0.0, abs=1e-9)
|
||||
|
||||
def test_joint_index_out_of_range_rejected(self) -> None:
|
||||
a = np.zeros((3, 2, 3))
|
||||
b = np.zeros((3, 2, 3))
|
||||
with pytest.raises(ValueError, match="joint indices"):
|
||||
dtw_relation(a, b, joint_i=0, joint_j=5)
|
||||
|
||||
def test_mismatched_joint_count_rejected(self) -> None:
|
||||
a = np.zeros((3, 3, 3))
|
||||
b = np.zeros((3, 2, 3))
|
||||
with pytest.raises(ValueError, match="joint count"):
|
||||
dtw_relation(a, b, joint_i=0, joint_j=1)
|
||||
|
|
@ -0,0 +1,303 @@
|
|||
"""Tests for :mod:`neuropose.analyzer.features`."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from neuropose.analyzer.features import (
|
||||
FeatureStatistics,
|
||||
extract_feature_statistics,
|
||||
extract_joint_angles,
|
||||
find_peaks,
|
||||
normalize_pose_sequence,
|
||||
pad_sequences,
|
||||
predictions_to_numpy,
|
||||
)
|
||||
from neuropose.io import VideoPredictions
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_predictions(num_frames: int, num_persons: int = 1) -> VideoPredictions:
|
||||
"""Build a minimal VideoPredictions object for tests."""
|
||||
frames = {}
|
||||
for i in range(num_frames):
|
||||
frames[f"frame_{i:06d}"] = {
|
||||
"boxes": [[0.0, 0.0, 1.0, 1.0, 0.9]] * num_persons,
|
||||
"poses3d": [
|
||||
[[float(i), float(i) * 2, float(i) * 3], [0.0, 0.0, 0.0]]
|
||||
]
|
||||
* num_persons,
|
||||
"poses2d": [[[0.0, 0.0], [1.0, 1.0]]] * num_persons,
|
||||
}
|
||||
return VideoPredictions.model_validate(
|
||||
{
|
||||
"metadata": {
|
||||
"frame_count": num_frames,
|
||||
"fps": 30.0,
|
||||
"width": 640,
|
||||
"height": 480,
|
||||
},
|
||||
"frames": frames,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# predictions_to_numpy
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPredictionsToNumpy:
|
||||
def test_single_person_shape(self) -> None:
|
||||
predictions = _make_predictions(num_frames=4)
|
||||
arr = predictions_to_numpy(predictions)
|
||||
assert arr.shape == (4, 2, 3)
|
||||
assert arr.dtype == np.float64
|
||||
|
||||
def test_values_preserved(self) -> None:
|
||||
predictions = _make_predictions(num_frames=3)
|
||||
arr = predictions_to_numpy(predictions)
|
||||
# Frame i has joint 0 at (i, 2i, 3i) per _make_predictions.
|
||||
for i in range(3):
|
||||
np.testing.assert_allclose(arr[i, 0], [i, 2 * i, 3 * i])
|
||||
np.testing.assert_allclose(arr[i, 1], [0, 0, 0])
|
||||
|
||||
def test_person_index_out_of_range(self) -> None:
|
||||
predictions = _make_predictions(num_frames=2, num_persons=1)
|
||||
with pytest.raises(ValueError, match="out of range"):
|
||||
predictions_to_numpy(predictions, person_index=1)
|
||||
|
||||
def test_multi_person_with_explicit_index(self) -> None:
|
||||
predictions = _make_predictions(num_frames=2, num_persons=2)
|
||||
arr = predictions_to_numpy(predictions, person_index=1)
|
||||
assert arr.shape == (2, 2, 3)
|
||||
|
||||
def test_empty_predictions_raises(self) -> None:
|
||||
predictions = _make_predictions(num_frames=0)
|
||||
with pytest.raises(ValueError, match="zero frames"):
|
||||
predictions_to_numpy(predictions)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# normalize_pose_sequence
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestNormalize:
|
||||
def test_uniform_preserves_ratio(self) -> None:
|
||||
# (frames, joints, 3) — one joint per frame, two frames.
|
||||
seq = np.array(
|
||||
[
|
||||
[[0.0, 0.0, 0.0]],
|
||||
[[3.0, 6.0, 9.0]],
|
||||
]
|
||||
)
|
||||
# Ranges: x=3, y=6, z=9. Uniform scale = 9. All values / 9.
|
||||
result = normalize_pose_sequence(seq, axis_wise=False)
|
||||
np.testing.assert_allclose(result, seq / 9.0)
|
||||
|
||||
def test_axis_wise_each_axis_to_unit_range(self) -> None:
|
||||
seq = np.array(
|
||||
[
|
||||
[[0.0, 0.0, 0.0]],
|
||||
[[3.0, 6.0, 9.0]],
|
||||
]
|
||||
)
|
||||
result = normalize_pose_sequence(seq, axis_wise=True)
|
||||
# Per-axis normalization → each axis's max becomes 1.
|
||||
np.testing.assert_allclose(result[0, 0], [0.0, 0.0, 0.0])
|
||||
np.testing.assert_allclose(result[1, 0], [1.0, 1.0, 1.0])
|
||||
|
||||
def test_does_not_mutate_input(self) -> None:
|
||||
seq = np.array([[[0.0, 0.0, 0.0]], [[1.0, 2.0, 3.0]]])
|
||||
before = seq.copy()
|
||||
normalize_pose_sequence(seq)
|
||||
np.testing.assert_array_equal(seq, before)
|
||||
|
||||
def test_degenerate_sequence_rejected(self) -> None:
|
||||
seq = np.zeros((3, 2, 3))
|
||||
with pytest.raises(ValueError, match="degenerate"):
|
||||
normalize_pose_sequence(seq)
|
||||
|
||||
def test_bad_shape_rejected(self) -> None:
|
||||
seq = np.zeros((3, 2)) # Missing the xyz axis.
|
||||
with pytest.raises(ValueError, match="expected"):
|
||||
normalize_pose_sequence(seq)
|
||||
|
||||
def test_axis_wise_with_zero_axis_keeps_it_zero(self) -> None:
|
||||
# Sequence where the Z axis never moves — axis_wise should not
|
||||
# divide by zero; the Z column should remain at 0.
|
||||
seq = np.array(
|
||||
[
|
||||
[[0.0, 0.0, 5.0]],
|
||||
[[4.0, 8.0, 5.0]],
|
||||
]
|
||||
)
|
||||
result = normalize_pose_sequence(seq, axis_wise=True)
|
||||
np.testing.assert_allclose(result[:, 0, 2], [0.0, 0.0])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# pad_sequences
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPadSequences:
|
||||
def test_pads_to_max_when_target_length_none(self) -> None:
|
||||
a = np.zeros((3, 2, 3))
|
||||
b = np.zeros((5, 2, 3))
|
||||
padded = pad_sequences([a, b])
|
||||
assert all(seq.shape[0] == 5 for seq in padded)
|
||||
|
||||
def test_pads_to_explicit_target_length(self) -> None:
|
||||
a = np.zeros((3, 2, 3))
|
||||
padded = pad_sequences([a], target_length=10)
|
||||
assert padded[0].shape == (10, 2, 3)
|
||||
|
||||
def test_edge_padding_repeats_last_frame(self) -> None:
|
||||
a = np.array([[[1.0, 2.0, 3.0]]]) # shape (1, 1, 3)
|
||||
padded = pad_sequences([a], target_length=4)
|
||||
# All 4 frames should equal the original single frame.
|
||||
for i in range(4):
|
||||
np.testing.assert_allclose(padded[0][i, 0], [1.0, 2.0, 3.0])
|
||||
|
||||
def test_truncates_longer_than_target(self) -> None:
|
||||
a = np.zeros((10, 2, 3))
|
||||
padded = pad_sequences([a], target_length=4)
|
||||
assert padded[0].shape == (4, 2, 3)
|
||||
|
||||
def test_does_not_mutate_input(self) -> None:
|
||||
a = np.zeros((3, 2, 3))
|
||||
pad_sequences([a], target_length=5)
|
||||
assert a.shape == (3, 2, 3)
|
||||
|
||||
def test_mismatched_trailing_shape_rejected(self) -> None:
|
||||
a = np.zeros((3, 2, 3))
|
||||
b = np.zeros((3, 4, 3)) # Different joint count.
|
||||
with pytest.raises(ValueError, match="trailing shape"):
|
||||
pad_sequences([a, b])
|
||||
|
||||
def test_empty_input_with_target(self) -> None:
|
||||
assert pad_sequences([], target_length=5) == []
|
||||
|
||||
def test_empty_input_without_target_raises(self) -> None:
|
||||
with pytest.raises(ValueError, match="empty"):
|
||||
pad_sequences([])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# extract_joint_angles
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExtractJointAngles:
|
||||
def test_right_angle(self) -> None:
|
||||
# Three joints forming a right angle at joint 1.
|
||||
# joint 0 at (1, 0, 0), joint 1 at origin, joint 2 at (0, 1, 0).
|
||||
sequence = np.array(
|
||||
[
|
||||
[[1.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 1.0, 0.0]],
|
||||
]
|
||||
)
|
||||
angles = extract_joint_angles(sequence, triplets=[(0, 1, 2)])
|
||||
assert angles.shape == (1, 1)
|
||||
assert angles[0, 0] == pytest.approx(math.pi / 2)
|
||||
|
||||
def test_collinear_gives_pi(self) -> None:
|
||||
sequence = np.array(
|
||||
[
|
||||
[[1.0, 0.0, 0.0], [0.0, 0.0, 0.0], [-1.0, 0.0, 0.0]],
|
||||
]
|
||||
)
|
||||
angles = extract_joint_angles(sequence, triplets=[(0, 1, 2)])
|
||||
assert angles[0, 0] == pytest.approx(math.pi)
|
||||
|
||||
def test_multiple_triplets(self) -> None:
|
||||
sequence = np.array(
|
||||
[
|
||||
[
|
||||
[1.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0],
|
||||
[0.0, 1.0, 0.0],
|
||||
[0.0, 0.0, 1.0],
|
||||
],
|
||||
]
|
||||
)
|
||||
# Right angle at 1 (first triplet) and right angle at 1 again
|
||||
# using joint 3 as the other arm — still 90°.
|
||||
angles = extract_joint_angles(sequence, triplets=[(0, 1, 2), (0, 1, 3)])
|
||||
assert angles.shape == (1, 2)
|
||||
assert angles[0, 0] == pytest.approx(math.pi / 2)
|
||||
assert angles[0, 1] == pytest.approx(math.pi / 2)
|
||||
|
||||
def test_zero_length_vector_yields_nan(self) -> None:
|
||||
# Joints 0 and 1 coincide → v1 is the zero vector → NaN angle.
|
||||
sequence = np.array(
|
||||
[
|
||||
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [1.0, 0.0, 0.0]],
|
||||
]
|
||||
)
|
||||
angles = extract_joint_angles(sequence, triplets=[(0, 1, 2)])
|
||||
assert math.isnan(angles[0, 0])
|
||||
|
||||
def test_out_of_range_index_rejected(self) -> None:
|
||||
sequence = np.zeros((1, 3, 3))
|
||||
with pytest.raises(ValueError, match="out of range"):
|
||||
extract_joint_angles(sequence, triplets=[(0, 1, 10)])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# extract_feature_statistics
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExtractFeatureStatistics:
|
||||
def test_basic_stats(self) -> None:
|
||||
values = np.array([1.0, 2.0, 3.0, 4.0, 5.0])
|
||||
stats = extract_feature_statistics(values)
|
||||
assert isinstance(stats, FeatureStatistics)
|
||||
assert stats.mean == pytest.approx(3.0)
|
||||
assert stats.min == pytest.approx(1.0)
|
||||
assert stats.max == pytest.approx(5.0)
|
||||
assert stats.range == pytest.approx(4.0)
|
||||
assert stats.std == pytest.approx(np.std(values))
|
||||
|
||||
def test_rejects_2d(self) -> None:
|
||||
values = np.zeros((3, 3))
|
||||
with pytest.raises(ValueError, match="1D"):
|
||||
extract_feature_statistics(values)
|
||||
|
||||
def test_rejects_empty(self) -> None:
|
||||
with pytest.raises(ValueError, match="empty"):
|
||||
extract_feature_statistics(np.array([]))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# find_peaks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFindPeaks:
|
||||
def test_sine_wave_peaks(self) -> None:
|
||||
# A sine wave over two full cycles has two peaks at quarter
|
||||
# cycles — roughly at t=pi/2 and t=5pi/2 given 4pi duration.
|
||||
t = np.linspace(0, 4 * np.pi, 401)
|
||||
values = np.sin(t)
|
||||
indices = find_peaks(values)
|
||||
assert indices.ndim == 1
|
||||
assert len(indices) == 2
|
||||
|
||||
def test_flat_signal_has_no_peaks(self) -> None:
|
||||
indices = find_peaks(np.zeros(100))
|
||||
assert indices.size == 0
|
||||
|
||||
def test_rejects_2d_input(self) -> None:
|
||||
with pytest.raises(ValueError, match="1D"):
|
||||
find_peaks(np.zeros((5, 5)))
|
||||
Loading…
Reference in New Issue