attention
attention
¶
SelfAttention
¶
Bases: Module
project(x: jax.Array) -> Tuple[jax.Array, jax.Array, jax.Array]
¶
Project input to (q, k, v). Calls _project_inner.
preprocess_qkv(q: jax.Array, k: jax.Array, v: jax.Array, **kwargs: Any) -> Tuple[jax.Array, jax.Array, jax.Array]
¶
Hook for RoPE, KV repeat, etc. Input/output: (B, T, H, D).
build_mask(t: int, padding_mask: Optional[jax.Array], **kwargs: Any) -> Optional[jax.Array]
¶
Construct attention mask. Returns bool mask or None.
attn(q: jax.Array, k: jax.Array, v: jax.Array, mask: Optional[jax.Array] = None, **kwargs: Any) -> jax.Array
¶
Core attention. Input q/k/v: (B, T_q/T_kv, H, D). Output: (B, T_q, H, D).
postprocess_attn(y: jax.Array, padding_mask: Optional[jax.Array], deterministic: bool, **kwargs: Any) -> jax.Array
¶
Post-attention processing. Input/output: (B, T, H, D).
output_proj(y: jax.Array) -> jax.Array
¶
Output projection. (B, T, C) → (B, T, C).
ForkingAttention
¶
Bases: RopeAttention
Self-attention with forking support, extending RopeAttention.