block
block
¶
ThoughtBlock
¶
Bases: Block
Transformer block for Thoughtbubbles with cumulative score support. Extends Block by weighting attention and MLP outputs by cumulative scores.
__call__(x: jax.Array, **kwargs: Any) -> Union[jax.Array, ForkingOutput]
¶
Forward pass with cumulative score weighting.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input tensor of shape (B, T, C). |
required |
**kwargs
|
Any
|
Must include cumulative_scores, token_index for forking mode. |
{}
|
Returns:
| Type | Description |
|---|---|
Union[Array, ForkingOutput]
|
If forking: Tuple of (output, cumulative_scores, token_index). |
Union[Array, ForkingOutput]
|
If not forking: Just the output tensor. |
ForkingBlock
¶
Bases: ThoughtBlock
Transformer block with forking capability. Extends ThoughtBlock by adding token forking based on learned scores.
clipped_logsigmoid(x: jax.Array, min_val: float = -20.0) -> jax.Array
staticmethod
¶
Compute log-sigmoid with clipping for numerical stability.
fork(x: jax.Array, cumulative_scores: jax.Array, token_index: jax.Array, padding_mask: Optional[jax.Array], input_seq_len: Optional[int] = None) -> Tuple[jax.Array, jax.Array, jax.Array]
¶
Top-k forking: doubles tokens then selects top-k to keep.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input tensor of shape (B, T, C). |
required |
cumulative_scores
|
Array
|
Log-probability scores of shape (B, T). |
required |
token_index
|
Array
|
Original token indices of shape (B, T). |
required |
padding_mask
|
Optional[Array]
|
Boolean tensor of shape (B, T). True for valid. |
required |
input_seq_len
|
Optional[int]
|
Original input sequence length for ratio calc. |
None
|
Returns:
| Type | Description |
|---|---|
Tuple[Array, Array, Array]
|
Tuple of (forked_x, new_scores, new_token_indices). |
__call__(x: jax.Array, **kwargs: Any) -> Union[jax.Array, ForkingOutput]
¶
Forward pass: fork first, then apply standard block operations.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input tensor of shape (B, T, C). |
required |
**kwargs
|
Any
|
cumulative_scores, token_index, padding_mask, deterministic, input_seq_len. |
{}
|
Returns:
| Type | Description |
|---|---|
Union[Array, ForkingOutput]
|
Tuple of (output, cumulative_scores, token_index). |