Skip to content

block

block

ThoughtBlock

Bases: Block

Transformer block for Thoughtbubbles with cumulative score support. Extends Block by weighting attention and MLP outputs by cumulative scores.

__call__(x: jax.Array, **kwargs: Any) -> Union[jax.Array, ForkingOutput]

Forward pass with cumulative score weighting.

Parameters:

Name Type Description Default
x Array

Input tensor of shape (B, T, C).

required
**kwargs Any

Must include cumulative_scores, token_index for forking mode.

{}

Returns:

Type Description
Union[Array, ForkingOutput]

If forking: Tuple of (output, cumulative_scores, token_index).

Union[Array, ForkingOutput]

If not forking: Just the output tensor.

ForkingBlock

Bases: ThoughtBlock

Transformer block with forking capability. Extends ThoughtBlock by adding token forking based on learned scores.

clipped_logsigmoid(x: jax.Array, min_val: float = -20.0) -> jax.Array staticmethod

Compute log-sigmoid with clipping for numerical stability.

fork(x: jax.Array, cumulative_scores: jax.Array, token_index: jax.Array, padding_mask: Optional[jax.Array], input_seq_len: Optional[int] = None) -> Tuple[jax.Array, jax.Array, jax.Array]

Top-k forking: doubles tokens then selects top-k to keep.

Parameters:

Name Type Description Default
x Array

Input tensor of shape (B, T, C).

required
cumulative_scores Array

Log-probability scores of shape (B, T).

required
token_index Array

Original token indices of shape (B, T).

required
padding_mask Optional[Array]

Boolean tensor of shape (B, T). True for valid.

required
input_seq_len Optional[int]

Original input sequence length for ratio calc.

None

Returns:

Type Description
Tuple[Array, Array, Array]

Tuple of (forked_x, new_scores, new_token_indices).

__call__(x: jax.Array, **kwargs: Any) -> Union[jax.Array, ForkingOutput]

Forward pass: fork first, then apply standard block operations.

Parameters:

Name Type Description Default
x Array

Input tensor of shape (B, T, C).

required
**kwargs Any

cumulative_scores, token_index, padding_mask, deterministic, input_seq_len.

{}

Returns:

Type Description
Union[Array, ForkingOutput]

Tuple of (output, cumulative_scores, token_index).