masks
masks
¶
causal_mask(seq_len: int) -> jnp.ndarray
¶
Boolean causal mask (1,1,T,T), True=keep.
sliding_window_mask(seq_len: int, window: int) -> jnp.ndarray
¶
Sliding window causal mask (1,1,T,T).
Query at row i attends to keys j with i - window < j <= i
(i.e. the last window past tokens, inclusive of self). The previous
sign of dist flipped this into a future-only mask, which silently
leaked target tokens into LaCT's SWA sublayer and collapsed the loss
near zero.
cache_mask(max_length: int, cache_index: jnp.ndarray) -> jnp.ndarray
¶
Boolean mask for KV cache: attend only to positions < cache_index.
Returns (1, 1, 1, max_length) bool mask, True=keep.