Skip to content

gated_delta

gated_delta

Gated delta-rule linear attention (Qwen 3.5 / DeltaNet-style).

The token mixer is a depthwise causal conv1d feeding into a chunked gated-delta-rule kernel, then an RMSNorm gated by a per-token z gate.

Token shapes throughout (JAX-side): Q, K: (B, T, num_k_heads, head_k_dim) V: (B, T, num_v_heads, head_v_dim) g, beta: (B, T, num_v_heads)

Parity reference: transformers.models.qwen3_5.modeling_qwen3_5 .torch_chunk_gated_delta_rule. We port the chunked variant (no recurrent single-token path) — sufficient for full-sequence forward parity.

GatedDeltaNet

Bases: Module

Gated delta-rule linear attention token mixer.

chunk_gated_delta_rule(q: jax.Array, k: jax.Array, v: jax.Array, g: jax.Array, beta: jax.Array, chunk_size: int = 64, use_qk_l2norm: bool = True, return_state: bool = False) -> Any

Chunked gated delta-rule attention. See module docstring for shapes.

Returns (B, T, num_v_heads, head_v_dim). When return_state is set, also returns the final recurrent state (B, H, head_k_dim, head_v_dim) so a decode loop can continue token-by-token via :func:recurrent_gated_delta_step. Padding positions (zero q/k/v/beta/g) leave the state unchanged, so the returned state reflects only real tokens.

recurrent_gated_delta_step(q: jax.Array, k: jax.Array, v: jax.Array, g: jax.Array, beta: jax.Array, state: jax.Array, use_qk_l2norm: bool = True) -> Tuple[jax.Array, jax.Array]

Single-token gated delta-rule recurrence (the decode counterpart of :func:chunk_gated_delta_rule).

Shapes: q, k (B, 1, H, D_k), v (B, 1, H, D_v), g, beta (B, 1, H), state (B, H, D_k, D_v). Returns (out (B, 1, H, D_v), new_state (B, H, D_k, D_v)).

Derived as the chunk_size == 1 specialization of scan_body above, so token-by-token decoding reproduces the chunked path's logits exactly (up to fp accumulation order).

causal_conv1d_step(x_t: jax.Array, conv_state: jax.Array, weight: jax.Array) -> Tuple[jax.Array, jax.Array]

Single-step causal depthwise conv1d using a rolling input cache.

x_t (B, 1, C); conv_state (B, K-1, C) holds the previous K-1 inputs; weight (C, K). Returns (out (B, 1, C), new_conv_state (B, K-1, C)). Matches :func:causal_depthwise_conv1d at the appended position.

causal_depthwise_conv1d(x: jax.Array, weight: jax.Array, kernel_size: int) -> jax.Array

Causal depthwise conv1d: x (B, T, C); weight (C, kernel_size).

Matches HF Conv1d(groups=C, padding=K-1) followed by [:, :, :T].