lact
lact
¶
Pure-functional primitives for LaCT fast weights.
LaCT (arXiv:2505.23884) maintains a small SwiGLU MLP f_W(x) = W2 (silu(W1 x) * W3 x)
whose weights W = (W1, W2, W3) are gradient-stepped once per chunk of b tokens
on the inner loss L(W) = - sum_i eta_i * <f_W(k_i), v_i>. These functions are
deliberately Flax-free: they operate on plain JAX arrays / NamedTuples so the
training path can call jax.grad(inner_loss) directly and the inference path can
plug into a Flax self.variable("fast_weights", ...) cell.
Per-sequence W is the unit of computation — every batch element runs its own
chunked scan. We expose single-sequence functions and a jax.vmap-batched
wrapper to keep both shapes legible.
FastWeights
¶
Bases: NamedTuple
Per-sequence fast-weight matrices for a single LaCT block.
Shapes (no batch axis, single sequence):
W1: (h, d) — input-side gate projection
W2: (d, h) — output projection
W3: (h, d) — input-side up projection
where d is the model hidden dim and h is the fast-weight intermediate dim.
Batched usage adds a leading (B, ...) axis via jax.vmap.
FastMomentum
¶
Bases: NamedTuple
Optimizer momentum buffers, shape-matched to FastWeights.
fw_zeros_like(W: FastWeights) -> FastMomentum
¶
Zero momentum buffers with the same shape/dtype as W.
apply_fw(W: FastWeights, x: jax.Array) -> jax.Array
¶
SwiGLU forward through fast weights.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
W
|
FastWeights
|
FastWeights for a single sequence (W1/W3: (h, d), W2: (d, h)). |
required |
x
|
Array
|
(..., d) input. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
(..., d) output |
inner_loss(W: FastWeights, k: jax.Array, v: jax.Array, eta: jax.Array) -> jax.Array
¶
Negative-dot-product inner loss for one chunk of one sequence.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
W
|
FastWeights
|
FastWeights for this sequence. |
required |
k
|
Array
|
(b, d) keys for |
required |
v
|
Array
|
(b, d) values for |
required |
eta
|
Array
|
(b,) per-token inner learning rates (zero on padded positions). |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Scalar |
Array
|
the L2 row-norm in |
Array
|
chunks. |
l2_row_norm(W: FastWeights, eps: float = 1e-06) -> FastWeights
¶
Renormalize each row to unit L2 along the input axis.
Acts as weight decay and bounds the iterate after every chunk update. "Row" here means the input-axis vector: for W1 (h, d) we normalize axis=-1 (so each of the h hidden neurons keeps a unit-norm input-side weight); same convention for W2 (d, h) and W3 (h, d).
muon_newton_schulz(M: jax.Array, n_iters: int = 5, eps: float = 1e-07) -> jax.Array
¶
5-iteration Newton-Schulz polynomial that maps M ≈ U S V^T → U V^T.
Implements the quintic from the Muon paper (arXiv:2502.16982): per iteration
X ← a·X + b·(XX^T)X + c·(XX^T)^2 X with (a, b, c) = (3.4445, -4.7750,
2.0315) and the input first rescaled to Frobenius norm 1. Cheap because it
only needs matmuls in the smaller of the two matrix dims (we transpose tall
matrices so XX^T is the smaller side).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
M
|
Array
|
2D matrix. |
required |
n_iters
|
int
|
number of Newton-Schulz iterations. |
5
|
Returns:
| Type | Description |
|---|---|
Array
|
Matrix of the same shape, with singular values all ≈ 1. |
update_step(W: FastWeights, M: FastMomentum, g: FastWeights, optimizer: str, beta: float) -> Tuple[FastWeights, FastMomentum]
¶
One inner optimizer step for a single sequence.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
W
|
FastWeights
|
current fast weights. |
required |
M
|
FastMomentum
|
momentum (zeros when |
required |
g
|
FastWeights
|
gradient |
required |
optimizer
|
str
|
one of |
required |
beta
|
float
|
momentum coefficient (ignored for |
required |
Returns:
| Type | Description |
|---|---|
FastWeights
|
|
FastMomentum
|
|
chunked_update_and_apply_single(W0: FastWeights, M0: FastMomentum, K: jax.Array, V: jax.Array, eta: jax.Array, Q: jax.Array, chunk_size: int, optimizer: str, beta: float, apply_then_update: bool) -> Tuple[jax.Array, FastWeights, FastMomentum]
¶
Run the chunked TTT inner loop for one sequence.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
W0
|
FastWeights
|
initial fast weights (typically the slow params). |
required |
M0
|
FastMomentum
|
initial momentum (zeros when |
required |
K
|
Array
|
(T, d) keys the chunk's inner loss is computed against. |
required |
V
|
Array
|
(T, d) values the chunk's inner loss is computed against. |
required |
Q
|
Array
|
(T, d) queries. |
required |
eta
|
Array
|
(T,) per-token learning rates. Pre-zero padded positions before calling. |
required |
chunk_size
|
int
|
tokens per inner-loop chunk ( |
required |
optimizer
|
str
|
|
required |
beta
|
float
|
momentum coefficient. |
required |
apply_then_update
|
bool
|
when True, the chunk's queries see the pre-update W (shifted block-causal order, LM default); when False, queries see the post-update W (full block-causal order). |
required |
Returns:
| Type | Description |
|---|---|
Array
|
|
FastWeights
|
residual / projection. |
FastMomentum
|
chunk's update, used by the inference path to persist across forward calls. |
chunked_update_and_apply(W0: FastWeights, M0: FastMomentum, K: jax.Array, V: jax.Array, eta: jax.Array, Q: jax.Array, chunk_size: int, optimizer: str, beta: float, apply_then_update: bool) -> Tuple[jax.Array, FastWeights, FastMomentum]
¶
Batched wrapper around chunked_update_and_apply_single.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
W0
|
FastWeights
|
batched initial FastWeights with leading batch axis. |
required |
M0
|
FastMomentum
|
batched initial FastMomentum with leading batch axis. |
required |
K
|
Array
|
(B, T, d) keys. |
required |
V
|
Array
|
(B, T, d) values. |
required |
Q
|
Array
|
(B, T, d) queries. |
required |
eta
|
Array
|
(B, T) per-token learning rates. |
required |
chunk_size
|
int
|
see |
required |
optimizer
|
str
|
see |
required |
beta
|
float
|
see |
required |
apply_then_update
|
bool
|
see |
required |
Returns:
| Type | Description |
|---|---|
Tuple[Array, FastWeights, FastMomentum]
|
|
batch_broadcast(W: FastWeights, batch_size: int) -> FastWeights
¶
Broadcast a single-sequence FastWeights to (B, ...).
batch_zeros_momentum(W_batched: FastWeights) -> FastMomentum
¶
Zero momentum matching a batched FastWeights pytree.