JAX-native active inference agents wired into cadcad-jax (pymdp 1.x companion to blockference)
Find a file
Jeff Emmett abb4b4f591 feat: blockference-jax v0.1 — pymdp 1.x + cadcad-jax integration
JAX-native active inference agents wired into cadcad-jax's lax.scan-wrapped
simulate runner. Companion to the numpy blockference package.

What's here:
  - aif.py:   build_jax_agent factory (handles pymdp 1.x batch-dim ceremony)
              + make_aif_step (closure-style, returns step_fn for cadcad-jax)
  - grid.py:  JAX-friendly grid step (jnp.where, no Python control flow)
              + build_B for deterministic transition tensor
  - examples/jax_gridworld.py:           single-agent end-to-end via simulate
  - examples/vmap_sweep_preference.py:   preference-strength sweep scaffold
  - tests/:                              6 passing

End-to-end working: 4x4 grid, agent reaches corner at t=9 via cadcad-jax
simulate seeded run. Grad-optimize over preference C is the next payoff.

MIT licensed. Requires pymdp 1.x (JAX rewrite) — does NOT mix with the
numpy blockference's pymdp 0.0.7.x in the same venv.
2026-05-09 01:44:50 -04:00
examples feat: blockference-jax v0.1 — pymdp 1.x + cadcad-jax integration 2026-05-09 01:44:50 -04:00
src/blockference_jax feat: blockference-jax v0.1 — pymdp 1.x + cadcad-jax integration 2026-05-09 01:44:50 -04:00
tests feat: blockference-jax v0.1 — pymdp 1.x + cadcad-jax integration 2026-05-09 01:44:50 -04:00
.gitignore feat: blockference-jax v0.1 — pymdp 1.x + cadcad-jax integration 2026-05-09 01:44:50 -04:00
pyproject.toml feat: blockference-jax v0.1 — pymdp 1.x + cadcad-jax integration 2026-05-09 01:44:50 -04:00
README.md feat: blockference-jax v0.1 — pymdp 1.x + cadcad-jax integration 2026-05-09 01:44:50 -04:00

blockference-jax

JAX-native active inference agents wired into cadcad-jax.

Companion to the numpy blockference package. Uses inferactively-pymdp >= 1.0 (JAX-native rewrite) and cadcad_jax.simulate's lax.scan-wrapped runner.

Why this exists separately

pymdp 1.x is a JAX-native rewrite — Distribution objects, JAX arrays, batched ops, equinox modules. The numpy 0.0.x API it replaced is what blockference targets. Mixing the two in one venv is painful (different array types, batch conventions, RNG semantics), so the JAX path lives in a parallel package sharing the same intellectual lineage.

What you get

  • build_jax_agent(width, height, target_idx, ...) — constructs a pymdp 1.x Agent with the implicit batch dim (=1) baked in correctly. Handles the (1, n_obs, n_states), (1, n_states, n_states, n_actions) shape ceremony.
  • make_aif_step(agent, width, height) — returns a pure step_fn(state, params, key) → new_state that drops directly into cadcad_jax.simulate. The pymdp Agent is captured by closure; carry holds position + prior + last action.
  • jax_grid_step(idx, action, w, h)jit-friendly grid transition built with jnp.where (no Python control flow).
  • build_B(w, h) — deterministic transition tensor for the grid env.

Install

pip install -e .
# pulls inferactively-pymdp 1.x + cadcad-jax + jax/jaxlib

Quick start

python examples/jax_gridworld.py
grid: 4x4, target_idx: 15
trajectory (position → action):
  t= 0 pos= 4 (y=1, x=0) action=1
  ...
  t= 9 pos=15 (y=3, x=3) action=3

reached target at t=9

Why JAX matters here

cadCAD physically can't differentiate; cadcad-jax can. With pymdp 1.x's JAX-native inference, the entire AIF planning loop is differentiable — meaning you can grad_optimize over preference C, transition priors, or hyperparameters. This package is the bridge.

examples/vmap_sweep_preference.py shows the seeded sweep pattern; converting it to a true vmap_sweep requires hoisting the Agent's static config out of the closure (planned).

Status

  • 6/6 tests passing
  • Single-agent gridworld via cadcad-jax simulate: works
  • vmap_sweep over preference strength: scaffolded (loops sequentially today)
  • grad_optimize over goal embedding: roadmap

Layout

src/blockference_jax/
  aif.py               # build_jax_agent + make_aif_step (closure for cadcad-jax)
  grid.py              # JAX-friendly grid step + B builder
examples/
  jax_gridworld.py
  vmap_sweep_preference.py
tests/                 # 6 passing

Known gotchas

  • pymdp 1.x emits UserWarning: A JAX array is being set as static! from inside equinox when the Agent is constructed. Cosmetic; safe to ignore.
  • pymdp 1.x requires every input (A, B, C, D, observations, rng_key) to carry the implicit batch dimension. Forgetting it produces opaque vmap errors about inconsistent axis sizes.
  • agent.infer_states returns qs[modality] with shape (batch, time, n_states). The "current belief" is qs[0][:, -1, :].

License

MIT.