- Python 100%
JAX-native active inference agents wired into cadcad-jax's lax.scan-wrapped
simulate runner. Companion to the numpy blockference package.
What's here:
- aif.py: build_jax_agent factory (handles pymdp 1.x batch-dim ceremony)
+ make_aif_step (closure-style, returns step_fn for cadcad-jax)
- grid.py: JAX-friendly grid step (jnp.where, no Python control flow)
+ build_B for deterministic transition tensor
- examples/jax_gridworld.py: single-agent end-to-end via simulate
- examples/vmap_sweep_preference.py: preference-strength sweep scaffold
- tests/: 6 passing
End-to-end working: 4x4 grid, agent reaches corner at t=9 via cadcad-jax
simulate seeded run. Grad-optimize over preference C is the next payoff.
MIT licensed. Requires pymdp 1.x (JAX rewrite) — does NOT mix with the
numpy blockference's pymdp 0.0.7.x in the same venv.
|
||
|---|---|---|
| examples | ||
| src/blockference_jax | ||
| tests | ||
| .gitignore | ||
| pyproject.toml | ||
| README.md | ||
blockference-jax
JAX-native active inference agents wired into cadcad-jax.
Companion to the numpy blockference package. Uses
inferactively-pymdp >= 1.0 (JAX-native rewrite) and cadcad_jax.simulate's
lax.scan-wrapped runner.
Why this exists separately
pymdp 1.x is a JAX-native rewrite — Distribution objects, JAX arrays, batched
ops, equinox modules. The numpy 0.0.x API it replaced is what blockference
targets. Mixing the two in one venv is painful (different array types, batch
conventions, RNG semantics), so the JAX path lives in a parallel package
sharing the same intellectual lineage.
What you get
build_jax_agent(width, height, target_idx, ...)— constructs a pymdp 1.xAgentwith the implicit batch dim (=1) baked in correctly. Handles the(1, n_obs, n_states),(1, n_states, n_states, n_actions)shape ceremony.make_aif_step(agent, width, height)— returns a purestep_fn(state, params, key) → new_statethat drops directly intocadcad_jax.simulate. The pymdp Agent is captured by closure; carry holds position + prior + last action.jax_grid_step(idx, action, w, h)—jit-friendly grid transition built withjnp.where(no Python control flow).build_B(w, h)— deterministic transition tensor for the grid env.
Install
pip install -e .
# pulls inferactively-pymdp 1.x + cadcad-jax + jax/jaxlib
Quick start
python examples/jax_gridworld.py
grid: 4x4, target_idx: 15
trajectory (position → action):
t= 0 pos= 4 (y=1, x=0) action=1
...
t= 9 pos=15 (y=3, x=3) action=3
reached target at t=9
Why JAX matters here
cadCAD physically can't differentiate; cadcad-jax can. With pymdp 1.x's
JAX-native inference, the entire AIF planning loop is differentiable — meaning
you can grad_optimize over preference C, transition priors, or
hyperparameters. This package is the bridge.
examples/vmap_sweep_preference.py shows the seeded sweep pattern; converting
it to a true vmap_sweep requires hoisting the Agent's static config out of
the closure (planned).
Status
- 6/6 tests passing
- Single-agent gridworld via cadcad-jax
simulate: works - vmap_sweep over preference strength: scaffolded (loops sequentially today)
grad_optimizeover goal embedding: roadmap
Layout
src/blockference_jax/
aif.py # build_jax_agent + make_aif_step (closure for cadcad-jax)
grid.py # JAX-friendly grid step + B builder
examples/
jax_gridworld.py
vmap_sweep_preference.py
tests/ # 6 passing
Known gotchas
- pymdp 1.x emits
UserWarning: A JAX array is being set as static!from insideequinoxwhen the Agent is constructed. Cosmetic; safe to ignore. - pymdp 1.x requires every input (A, B, C, D, observations, rng_key) to carry
the implicit batch dimension. Forgetting it produces opaque
vmaperrors about inconsistent axis sizes. agent.infer_statesreturnsqs[modality]with shape(batch, time, n_states). The "current belief" isqs[0][:, -1, :].
License
MIT.