Skip to content

masks

masks

causal_mask(seq_len: int) -> jnp.ndarray

Boolean causal mask (1,1,T,T), True=keep.

sliding_window_mask(seq_len: int, window: int) -> jnp.ndarray

Sliding window causal mask (1,1,T,T).

cache_mask(max_length: int, cache_index: jnp.ndarray) -> jnp.ndarray

Boolean mask for KV cache: attend only to positions < cache_index.

Returns (1, 1, 1, max_length) bool mask, True=keep.