base
base
¶
Base self-attention module with hook-based API and KV cache support.
All Q/K/V hooks operate on (B, T, H, D) format. Subclasses override _project_inner/preprocess_qkv/build_mask/attn/postprocess_attn/output_proj.
KV cache is activated by calling model.apply(..., mutable=['cache']). Without mutable=['cache'], the cache is a no-op (training mode).
SelfAttention
¶
Bases: Module
project(x: jax.Array) -> Tuple[jax.Array, jax.Array, jax.Array]
¶
Project input to (q, k, v). Calls _project_inner.
preprocess_qkv(q: jax.Array, k: jax.Array, v: jax.Array, **kwargs: Any) -> Tuple[jax.Array, jax.Array, jax.Array]
¶
Hook for RoPE, KV repeat, etc. Input/output: (B, T, H, D).
build_mask(t: int, padding_mask: Optional[jax.Array], **kwargs: Any) -> Optional[jax.Array]
¶
Construct attention mask. Returns bool mask or None.
attn(q: jax.Array, k: jax.Array, v: jax.Array, mask: Optional[jax.Array] = None, **kwargs: Any) -> jax.Array
¶
Core attention. Input q/k/v: (B, T_q/T_kv, H, D). Output: (B, T_q, H, D).
postprocess_attn(y: jax.Array, padding_mask: Optional[jax.Array], deterministic: bool, **kwargs: Any) -> jax.Array
¶
Post-attention processing. Input/output: (B, T, H, D).
output_proj(y: jax.Array) -> jax.Array
¶
Output projection. (B, T, C) → (B, T, C).