models
models
¶
GPT
¶
Bases: Module
embed(idx: jax.Array, deterministic: bool = False, **kwargs: Any) -> Any
¶
Compute token and positional embeddings given inputs.
decode(x: jax.Array, padding_mask: Optional[jax.Array] = None, deterministic: bool = False, **kwargs: Any) -> Any
¶
Compute decoded residual channels given embeddings.
unembed(x: jax.Array) -> Any
¶
Compute output distribution.
loss(logits: jax.Array, targets: jax.Array) -> jax.Array
¶
Compute cross-entropy loss given logits and targets.
__call__(idx: jax.Array, targets: Optional[jax.Array] = None, padding_mask: Optional[jax.Array] = None, deterministic: bool = False, **kwargs: Any) -> Tuple[jax.Array, Optional[jax.Array]]
¶
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
idx
|
Array
|
Input token indices of shape (B, T). |
required |
targets
|
Optional[Array]
|
Target token indices of shape (B, T). Use -1 to ignore positions. |
None
|
padding_mask
|
Optional[Array]
|
Boolean tensor of shape (B, T). True for valid tokens, False for padding tokens. |
None
|
deterministic
|
bool
|
If False, applies dropout. |
False
|
Returns:
| Name | Type | Description |
|---|---|---|
logits |
Array
|
Output logits of shape (B, T, vocab_size). |
loss |
Optional[Array]
|
Cross-entropy loss if targets provided, else None. |
Hybrid
¶
Bases: GPT
Hybrid Transformer + Mamba language model.
mamba_layers controls which layers use Mamba blocks:
- "even": even-indexed layers (0, 2, 4, ...) are Mamba
- "odd": odd-indexed layers (1, 3, 5, ...) are Mamba
- comma-separated indices: e.g. "0,2,4,6" for explicit control
LaCT
¶
Mamba
¶
Bases: GPT
Mamba-2 language model — SSM-only, no attention.
Overrides GPT's setup/embed/decode to use MambaBlock layers
and skip positional embeddings. loss and unembed are
inherited unchanged.
embed(idx: jax.Array, deterministic: bool = False, **kwargs: Any) -> Any
¶
Token embeddings only — no positional encoding needed for SSMs.
Scratchbubbles
¶
Bases: Thoughtbubbles
left_attn_average(input_seq_len: int, x: jax.Array, cumulative_scores: jax.Array, token_index: jax.Array) -> jax.Array
¶
Averaging where rightmost token of each idx acts as a query over a small masked attention which is only its left children.
unembed_forked(x: jax.Array, cumulative_scores: jax.Array, token_index: jax.Array, input_seq_len: int) -> jax.Array
¶
Compute output logits for forked tokens using residual averaging.
key = jax.random.PRNGKey(0) x = jax.random.uniform(key, (job.args.per_device_batch_size, job.args.block_size2, job.model.n_embd)) cumulative_scores = jax.random.uniform(key, (job.args.per_device_batch_size, job.args.block_size2)) token_index = jnp.arange(job.args.block_size)[None, :].repeat(2, axis=1).repeat(job.args.per_device_batch_size, axis=0) input_seq_len = job.args.block_size
plot(intermediates: Any) -> Dict[str, Any]
staticmethod
¶
intermediates -> [figure]
Thoughtbubbles
¶
Bases: GPT
Thoughtbubbles model with token forking support. Extends GPT with the ability to fork tokens during processing.
plot(intermediates: Any) -> Dict[str, Any]
staticmethod
¶
intermediates -> [figure]
embed(idx: jax.Array, deterministic: bool = False, **kwargs: Any) -> Tuple[jax.Array, jax.Array, jax.Array]
¶
Compute token embeddings and initialize forking state. Overrides GPT.embed to return forking state.
decode(x: jax.Array, padding_mask: Optional[jax.Array] = None, deterministic: bool = False, **kwargs: Any) -> Any
¶
Process through transformer blocks with forking. Overrides GPT.decode to handle forking state.
residual_average(input_size: int, residuals: jax.Array, cumulative_scores: jax.Array, token_index: jax.Array) -> jax.Array
¶
Average residuals weighted by cumulative scores.
unembed(x: jax.Array) -> jax.Array
¶
Standard unembed - used when no forking state provided.
unembed_forked(x: jax.Array, cumulative_scores: jax.Array, token_index: jax.Array, input_seq_len: int) -> jax.Array
¶
Compute output logits for forked tokens using residual averaging.
__call__(idx: jax.Array, targets: Optional[jax.Array] = None, padding_mask: Optional[jax.Array] = None, deterministic: bool = False, **kwargs: Any) -> Tuple[jax.Array, Optional[jax.Array]]
¶
Forward pass through Thoughtbubbles.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
idx
|
Array
|
Input token indices of shape (B, T). |
required |
targets
|
Optional[Array]
|
Target token indices of shape (B, T). Use -1 to ignore. |
None
|
padding_mask
|
Optional[Array]
|
Boolean tensor of shape (B, T). True=valid, False=padding. |
None
|
deterministic
|
bool
|
If False, applies dropout. |
False
|
Returns:
| Name | Type | Description |
|---|---|---|
logits |
Array
|
Output logits of shape (B, T, vocab_size). |
loss |
Optional[Array]
|
Cross-entropy loss if targets provided, else None. |