Skip to content

attention

attention

SelfAttention

Bases: Module

project(x: jax.Array) -> Tuple[jax.Array, jax.Array, jax.Array]

Project input to (q, k, v). Calls _project_inner.

preprocess_qkv(q: jax.Array, k: jax.Array, v: jax.Array, **kwargs: Any) -> Tuple[jax.Array, jax.Array, jax.Array]

Hook for RoPE, KV repeat, etc. Input/output: (B, T, H, D).

build_mask(t: int, padding_mask: Optional[jax.Array], **kwargs: Any) -> Optional[jax.Array]

Construct attention mask. Returns bool mask or None.

attn(q: jax.Array, k: jax.Array, v: jax.Array, mask: Optional[jax.Array] = None, **kwargs: Any) -> jax.Array

Core attention. Input q/k/v: (B, T_q/T_kv, H, D). Output: (B, T_q, H, D).

postprocess_attn(y: jax.Array, padding_mask: Optional[jax.Array], deterministic: bool, **kwargs: Any) -> jax.Array

Post-attention processing. Input/output: (B, T, H, D).

output_proj(y: jax.Array) -> jax.Array

Output projection. (B, T, C) → (B, T, C).

ForkingAttention

Bases: RopeAttention

Self-attention with forking support, extending RopeAttention.