Skip to content

scratching

scratching

ScratchSparseCrossAttention

Bases: ForkingAttention

attn(q: jax.Array, k: jax.Array, v: jax.Array, mask: Optional[jax.Array] = None, **kwargs: Any) -> jax.Array

Cross-attention with causal mask based on token positions, not array indices.

The inherited ForkingAttention.attn uses a standard lower-triangular mask on array positions, which is incorrect for cross-attention where Q (post top-k selection) and K (original sequence) have different position mappings. This override builds the causal mask from actual token indices so that a query at original position p can only attend to keys at positions <= p.