scratching
scratching
¶
ScratchingBlock
¶
Bases: ForkingBlock
fork(x: jax.Array, cumulative_scores: jax.Array, token_index: jax.Array, padding_mask: Optional[jax.Array], input_seq_len: Optional[int] = None) -> Tuple[jax.Array, jax.Array, jax.Array]
¶
Top-k forking: doubles tokens then selects top-k to keep.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input tensor of shape (B, T, C). |
required |
cumulative_scores
|
Array
|
Log-probability scores of shape (B, T). |
required |
token_index
|
Array
|
Original token indices of shape (B, T). |
required |
padding_mask
|
Optional[Array]
|
Boolean tensor of shape (B, T). True for valid. |
required |
input_seq_len
|
Optional[int]
|
Original input sequence length for ratio calc. |
None
|
Mocks
x = jnp.ones((4, config.architecture.block_size, config.architecture.n_embd)) # (B=4, T=block_size, C=n_embd) cumulative_scores = jnp.zeros((4, config.architecture.block_size)) # (B=4, T=block_size) token_index = jnp.tile(jnp.arange(config.architecture.block_size), (4, 1)) # (B=4, T=block_size) padding_mask = jnp.ones((4, config.architecture.block_size), dtype=bool) # (B=4, T=block_size)
Returns:
| Type | Description |
|---|---|
Tuple[Array, Array, Array]
|
Tuple of (forked_x, new_scores, new_token_indices). |