Skip to content

base

base

Base self-attention module with hook-based API and KV cache support.

All Q/K/V hooks operate on (B, T, H, D) format. Subclasses override _project_inner/preprocess_qkv/build_mask/attn/postprocess_attn/output_proj.

KV cache is activated by calling model.apply(..., mutable=['cache']). Without mutable=['cache'], the cache is a no-op (training mode).

SelfAttention

Bases: Module

project(x: jax.Array) -> Tuple[jax.Array, jax.Array, jax.Array]

Project input to (q, k, v). Calls _project_inner.

preprocess_qkv(q: jax.Array, k: jax.Array, v: jax.Array, **kwargs: Any) -> Tuple[jax.Array, jax.Array, jax.Array]

Hook for RoPE, KV repeat, etc. Input/output: (B, T, H, D).

build_mask(t: int, padding_mask: Optional[jax.Array], **kwargs: Any) -> Optional[jax.Array]

Construct attention mask. Returns bool mask or None.

attn(q: jax.Array, k: jax.Array, v: jax.Array, mask: Optional[jax.Array] = None, **kwargs: Any) -> jax.Array

Core attention. Input q/k/v: (B, T_q/T_kv, H, D). Output: (B, T_q, H, D).

postprocess_attn(y: jax.Array, padding_mask: Optional[jax.Array], deterministic: bool, **kwargs: Any) -> jax.Array

Post-attention processing. Input/output: (B, T, H, D).

output_proj(y: jax.Array) -> jax.Array

Output projection. (B, T, C) → (B, T, C).