masks
masks
¶
causal_mask(seq_len: int) -> jnp.ndarray
¶
Boolean causal mask (1,1,T,T), True=keep.
sliding_window_mask(seq_len: int, window: int) -> jnp.ndarray
¶
Sliding window causal mask (1,1,T,T).
cache_mask(max_length: int, cache_index: jnp.ndarray) -> jnp.ndarray
¶
Boolean mask for KV cache: attend only to positions < cache_index.
Returns (1, 1, 1, max_length) bool mask, True=keep.