Skip to content

Mock System

This document covers theseus/mock.py — a lightweight debugging tool that lets you run a Flax module's forward pass logic (including plot() implementations) without materializing real parameters or touching a GPU.


The Problem

Flax modules are frozen dataclasses. Their weights live in an external params pytree. Writing analysis or visualization code that calls sub-modules requires:

  1. A full set of initialized parameters — which may be gigabytes.
  2. A live JAX device — which may not be available on a dev machine.

The mock system sidesteps both by replacing every sub-module with a shape-inference stub that returns random arrays of the correct shape.


Mocker — Drop-in Self Replacement

Mocker is designed to be used as a drop-in replacement for self inside a Flax module method. Assign any nn.Module to it and the module is automatically wrapped into a mock:

from theseus.mock import Mocker

class MyModel(Module):
    @staticmethod
    def plot(intermediates):
        self = Mocker()           # ← replace self entirely
        self.attn  = SelfAttention()
        self.mlp   = MLP()
        ...

Mocker.__setattr__ intercepts assignments of nn.Module instances and wraps each one in a shape-inference stub. Plain Python values (strings, ints, etc.) pass through unchanged. Each wrapped module gets its own RNG key split from the Mocker's internal key so successive calls produce different values.

Usage

Calling a mocked module

Call a mocked module like you would the real thing. It runs shape inference through the real module code via jax.eval_shape — no parameters are allocated, no GPU is needed — then materializes random arrays of the correct shape and dtype:

self = Mocker()
self.attn = SelfAttention()
self.mlp  = MLP()

dummy = jnp.zeros((1, 16, 512))
attn_out = self.attn(dummy)    # random array, shape (1, 16, 512)
mlp_out  = self.mlp(attn_out)  # random array, correct output shape

The internal RNG key advances on each call so successive calls to the same stub return different values, which matters for code that checks for variation.

Random values are generated by dtype:

  • floatjax.random.normal
  • complex → independent normal for real and imaginary parts
  • signed intjax.random.randint clamped to [-100, 100]
  • unsigned intjax.random.randint clamped to [0, 100]
  • booljax.random.bernoulli

nn.Partitioned wrappers (from nn.with_partitioning) are automatically unwrapped before shape inspection.

.intermediates() — capturing sowed values

Call .intermediates(*args, **kwargs) to get the shapes of everything the module sows into the "plots" and "intermediates" Flax collections, rather than the forward-pass output. This mirrors what the trainer captures during the val step.

self = Mocker()
self.block = ThoughtBlock()

fake_meta = self.block.intermediates(jnp.zeros((1, 16, 512)))
# fake_meta["plots"]["new_cumulative_scores"] → random array, correct shape

The primary use case is developing and testing plot() implementations on a laptop without a checkpoint. The same fake_meta dict can be passed directly to the model's plot() method:

self = Mocker()
self.block = ThoughtBlock()

fake_intermediates = self.block.intermediates(jnp.zeros((1, 16, 512)))
figs = Thoughtbubbles.plot(fake_intermediates)
figs["analysis/cum_scores"].savefig("debug.pdf")

Why eval_shape

jax.eval_shape runs the XLA shape inference pass without dispatch — it's ~10 ms, works on CPU, needs no GPU, and handles all JAX transformations (jit, vmap, pmap, etc.). Because the shapes are traced through the real module code, the mocked outputs are guaranteed to have correct shapes even for modules with dynamic output shapes.