Skip to content

scratching

scratching

ScratchingBlock

Bases: ForkingBlock

fork(x: jax.Array, cumulative_scores: jax.Array, token_index: jax.Array, padding_mask: Optional[jax.Array], input_seq_len: Optional[int] = None) -> Tuple[jax.Array, jax.Array, jax.Array]

Top-k forking: doubles tokens then selects top-k to keep.

Parameters:

Name Type Description Default
x Array

Input tensor of shape (B, T, C).

required
cumulative_scores Array

Log-probability scores of shape (B, T).

required
token_index Array

Original token indices of shape (B, T).

required
padding_mask Optional[Array]

Boolean tensor of shape (B, T). True for valid.

required
input_seq_len Optional[int]

Original input sequence length for ratio calc.

None
Mocks

x = jnp.ones((4, config.architecture.block_size, config.architecture.n_embd)) # (B=4, T=block_size, C=n_embd) cumulative_scores = jnp.zeros((4, config.architecture.block_size)) # (B=4, T=block_size) token_index = jnp.tile(jnp.arange(config.architecture.block_size), (4, 1)) # (B=4, T=block_size) padding_mask = jnp.ones((4, config.architecture.block_size), dtype=bool) # (B=4, T=block_size)

Returns:

Type Description
Tuple[Array, Array, Array]

Tuple of (forked_x, new_scores, new_token_indices).