Mock System¶
This document covers theseus/mock.py — a lightweight debugging tool that lets you run a Flax module's forward pass logic (including plot() implementations) without materializing real parameters or touching a GPU.
The Problem¶
Flax modules are frozen dataclasses. Their weights live in an external params pytree. Writing analysis or visualization code that calls sub-modules requires:
- A full set of initialized parameters — which may be gigabytes.
- A live JAX device — which may not be available on a dev machine.
The mock system sidesteps both by replacing every sub-module with a shape-inference stub that returns random arrays of the correct shape.
Mocker — Drop-in Self Replacement¶
Mocker is designed to be used as a drop-in replacement for self inside a Flax module method. Assign any nn.Module to it and the module is automatically wrapped into a mock:
from theseus.mock import Mocker
class MyModel(Module):
@staticmethod
def plot(intermediates):
self = Mocker() # ← replace self entirely
self.attn = SelfAttention()
self.mlp = MLP()
...
Mocker.__setattr__ intercepts assignments of nn.Module instances and wraps each one in a shape-inference stub. Plain Python values (strings, ints, etc.) pass through unchanged. Each wrapped module gets its own RNG key split from the Mocker's internal key so successive calls produce different values.
Usage¶
Calling a mocked module¶
Call a mocked module like you would the real thing. It runs shape inference through the real module code via jax.eval_shape — no parameters are allocated, no GPU is needed — then materializes random arrays of the correct shape and dtype:
self = Mocker()
self.attn = SelfAttention()
self.mlp = MLP()
dummy = jnp.zeros((1, 16, 512))
attn_out = self.attn(dummy) # random array, shape (1, 16, 512)
mlp_out = self.mlp(attn_out) # random array, correct output shape
The internal RNG key advances on each call so successive calls to the same stub return different values, which matters for code that checks for variation.
Random values are generated by dtype:
float→jax.random.normalcomplex→ independent normal for real and imaginary partssigned int→jax.random.randintclamped to[-100, 100]unsigned int→jax.random.randintclamped to[0, 100]bool→jax.random.bernoulli
nn.Partitioned wrappers (from nn.with_partitioning) are automatically unwrapped before shape inspection.
.intermediates() — capturing sowed values¶
Call .intermediates(*args, **kwargs) to get the shapes of everything the module sows into the "plots" and "intermediates" Flax collections, rather than the forward-pass output. This mirrors what the trainer captures during the val step.
self = Mocker()
self.block = ThoughtBlock()
fake_meta = self.block.intermediates(jnp.zeros((1, 16, 512)))
# fake_meta["plots"]["new_cumulative_scores"] → random array, correct shape
The primary use case is developing and testing plot() implementations on a laptop without a checkpoint. The same fake_meta dict can be passed directly to the model's plot() method:
self = Mocker()
self.block = ThoughtBlock()
fake_intermediates = self.block.intermediates(jnp.zeros((1, 16, 512)))
figs = Thoughtbubbles.plot(fake_intermediates)
figs["analysis/cum_scores"].savefig("debug.pdf")
Why eval_shape¶
jax.eval_shape runs the XLA shape inference pass without dispatch — it's ~10 ms, works on CPU, needs no GPU, and handles all JAX transformations (jit, vmap, pmap, etc.). Because the shapes are traced through the real module code, the mocked outputs are guaranteed to have correct shapes even for modules with dynamic output shapes.