Skip to content

lact

lact

Pure-functional primitives for LaCT fast weights.

LaCT (arXiv:2505.23884) maintains a small SwiGLU MLP f_W(x) = W2 (silu(W1 x) * W3 x) whose weights W = (W1, W2, W3) are gradient-stepped once per chunk of b tokens on the inner loss L(W) = - sum_i eta_i * <f_W(k_i), v_i>. These functions are deliberately Flax-free: they operate on plain JAX arrays / NamedTuples so the training path can call jax.grad(inner_loss) directly and the inference path can plug into a Flax self.variable("fast_weights", ...) cell.

Per-sequence W is the unit of computation — every batch element runs its own chunked scan. We expose single-sequence functions and a jax.vmap-batched wrapper to keep both shapes legible.

FastWeights

Bases: NamedTuple

Per-sequence fast-weight matrices for a single LaCT block.

Shapes (no batch axis, single sequence): W1: (h, d) — input-side gate projection W2: (d, h) — output projection W3: (h, d) — input-side up projection where d is the model hidden dim and h is the fast-weight intermediate dim. Batched usage adds a leading (B, ...) axis via jax.vmap.

FastMomentum

Bases: NamedTuple

Optimizer momentum buffers, shape-matched to FastWeights.

fw_zeros_like(W: FastWeights) -> FastMomentum

Zero momentum buffers with the same shape/dtype as W.

apply_fw(W: FastWeights, x: jax.Array) -> jax.Array

SwiGLU forward through fast weights.

Parameters:

Name Type Description Default
W FastWeights

FastWeights for a single sequence (W1/W3: (h, d), W2: (d, h)).

required
x Array

(..., d) input.

required

Returns:

Type Description
Array

(..., d) output W2 @ (silu(W1 @ x) * (W3 @ x)).

inner_loss(W: FastWeights, k: jax.Array, v: jax.Array, eta: jax.Array) -> jax.Array

Negative-dot-product inner loss for one chunk of one sequence.

Parameters:

Name Type Description Default
W FastWeights

FastWeights for this sequence.

required
k Array

(b, d) keys for b tokens.

required
v Array

(b, d) values for b tokens.

required
eta Array

(b,) per-token inner learning rates (zero on padded positions).

required

Returns:

Type Description
Array

Scalar - sum_i eta_i * <f_W(k_i), v_i>. No lower bound by construction;

Array

the L2 row-norm in update_step is what keeps the iterate stable across

Array

chunks.

l2_row_norm(W: FastWeights, eps: float = 1e-06) -> FastWeights

Renormalize each row to unit L2 along the input axis.

Acts as weight decay and bounds the iterate after every chunk update. "Row" here means the input-axis vector: for W1 (h, d) we normalize axis=-1 (so each of the h hidden neurons keeps a unit-norm input-side weight); same convention for W2 (d, h) and W3 (h, d).

muon_newton_schulz(M: jax.Array, n_iters: int = 5, eps: float = 1e-07) -> jax.Array

5-iteration Newton-Schulz polynomial that maps M ≈ U S V^T → U V^T.

Implements the quintic from the Muon paper (arXiv:2502.16982): per iteration X ← a·X + b·(XX^T)X + c·(XX^T)^2 X with (a, b, c) = (3.4445, -4.7750, 2.0315) and the input first rescaled to Frobenius norm 1. Cheap because it only needs matmuls in the smaller of the two matrix dims (we transpose tall matrices so XX^T is the smaller side).

Parameters:

Name Type Description Default
M Array

2D matrix.

required
n_iters int

number of Newton-Schulz iterations.

5

Returns:

Type Description
Array

Matrix of the same shape, with singular values all ≈ 1.

update_step(W: FastWeights, M: FastMomentum, g: FastWeights, optimizer: str, beta: float) -> Tuple[FastWeights, FastMomentum]

One inner optimizer step for a single sequence.

Parameters:

Name Type Description Default
W FastWeights

current fast weights.

required
M FastMomentum

momentum (zeros when optimizer == "gd").

required
g FastWeights

gradient ∂inner_loss/∂W for this chunk.

required
optimizer str

one of "gd", "momentum", "muon".

required
beta float

momentum coefficient (ignored for "gd").

required

Returns:

Type Description
FastWeights

(W_new, M_new). W is renormalized after the step via

FastMomentum

l2_row_norm regardless of optimizer.

chunked_update_and_apply_single(W0: FastWeights, M0: FastMomentum, K: jax.Array, V: jax.Array, eta: jax.Array, Q: jax.Array, chunk_size: int, optimizer: str, beta: float, apply_then_update: bool) -> Tuple[jax.Array, FastWeights, FastMomentum]

Run the chunked TTT inner loop for one sequence.

Parameters:

Name Type Description Default
W0 FastWeights

initial fast weights (typically the slow params).

required
M0 FastMomentum

initial momentum (zeros when optimizer == "gd").

required
K Array

(T, d) keys the chunk's inner loss is computed against.

required
V Array

(T, d) values the chunk's inner loss is computed against.

required
Q Array

(T, d) queries.

required
eta Array

(T,) per-token learning rates. Pre-zero padded positions before calling.

required
chunk_size int

tokens per inner-loop chunk (b in the paper).

required
optimizer str

"gd" | "momentum" | "muon".

required
beta float

momentum coefficient.

required
apply_then_update bool

when True, the chunk's queries see the pre-update W (shifted block-causal order, LM default); when False, queries see the post-update W (full block-causal order).

required

Returns:

Type Description
Array

(out, W_final, M_final): out is (T, d) — the TTT-layer output before any

FastWeights

residual / projection. W_final/M_final are the state after the last

FastMomentum

chunk's update, used by the inference path to persist across forward calls.

chunked_update_and_apply(W0: FastWeights, M0: FastMomentum, K: jax.Array, V: jax.Array, eta: jax.Array, Q: jax.Array, chunk_size: int, optimizer: str, beta: float, apply_then_update: bool) -> Tuple[jax.Array, FastWeights, FastMomentum]

Batched wrapper around chunked_update_and_apply_single.

Parameters:

Name Type Description Default
W0 FastWeights

batched initial FastWeights with leading batch axis.

required
M0 FastMomentum

batched initial FastMomentum with leading batch axis.

required
K Array

(B, T, d) keys.

required
V Array

(B, T, d) values.

required
Q Array

(B, T, d) queries.

required
eta Array

(B, T) per-token learning rates.

required
chunk_size int

see chunked_update_and_apply_single.

required
optimizer str

see chunked_update_and_apply_single.

required
beta float

see chunked_update_and_apply_single.

required
apply_then_update bool

see chunked_update_and_apply_single.

required

Returns:

Type Description
Tuple[Array, FastWeights, FastMomentum]

(out (B, T, d), W_final (B, ...), M_final (B, ...)).

batch_broadcast(W: FastWeights, batch_size: int) -> FastWeights

Broadcast a single-sequence FastWeights to (B, ...).

batch_zeros_momentum(W_batched: FastWeights) -> FastMomentum

Zero momentum matching a batched FastWeights pytree.