gated_delta
gated_delta
¶
Gated delta-rule linear attention (Qwen 3.5 / DeltaNet-style).
The token mixer is a depthwise causal conv1d feeding into a chunked
gated-delta-rule kernel, then an RMSNorm gated by a per-token z gate.
Token shapes throughout (JAX-side):
Q, K: (B, T, num_k_heads, head_k_dim)
V: (B, T, num_v_heads, head_v_dim)
g, beta: (B, T, num_v_heads)
Parity reference: transformers.models.qwen3_5.modeling_qwen3_5
.torch_chunk_gated_delta_rule. We port the chunked variant (no recurrent
single-token path) — sufficient for full-sequence forward parity.
chunk_gated_delta_rule(q: jax.Array, k: jax.Array, v: jax.Array, g: jax.Array, beta: jax.Array, chunk_size: int = 64, use_qk_l2norm: bool = True, return_state: bool = False) -> Any
¶
Chunked gated delta-rule attention. See module docstring for shapes.
Returns (B, T, num_v_heads, head_v_dim). When return_state is set,
also returns the final recurrent state (B, H, head_k_dim, head_v_dim)
so a decode loop can continue token-by-token via
:func:recurrent_gated_delta_step. Padding positions (zero q/k/v/beta/g)
leave the state unchanged, so the returned state reflects only real tokens.
recurrent_gated_delta_step(q: jax.Array, k: jax.Array, v: jax.Array, g: jax.Array, beta: jax.Array, state: jax.Array, use_qk_l2norm: bool = True) -> Tuple[jax.Array, jax.Array]
¶
Single-token gated delta-rule recurrence (the decode counterpart of
:func:chunk_gated_delta_rule).
Shapes: q, k (B, 1, H, D_k), v (B, 1, H, D_v),
g, beta (B, 1, H), state (B, H, D_k, D_v).
Returns (out (B, 1, H, D_v), new_state (B, H, D_k, D_v)).
Derived as the chunk_size == 1 specialization of scan_body above,
so token-by-token decoding reproduces the chunked path's logits exactly
(up to fp accumulation order).
causal_conv1d_step(x_t: jax.Array, conv_state: jax.Array, weight: jax.Array) -> Tuple[jax.Array, jax.Array]
¶
Single-step causal depthwise conv1d using a rolling input cache.
x_t (B, 1, C); conv_state (B, K-1, C) holds the previous K-1 inputs;
weight (C, K). Returns (out (B, 1, C), new_conv_state (B, K-1, C)).
Matches :func:causal_depthwise_conv1d at the appended position.
causal_depthwise_conv1d(x: jax.Array, weight: jax.Array, kernel_size: int) -> jax.Array
¶
Causal depthwise conv1d: x (B, T, C); weight (C, kernel_size).
Matches HF Conv1d(groups=C, padding=K-1) followed by [:, :, :T].