Skip to content

models

models

GPT

Bases: Module

embed(idx: jax.Array, deterministic: bool = False, **kwargs: Any) -> Any

Compute token and positional embeddings given inputs.

decode(x: jax.Array, padding_mask: Optional[jax.Array] = None, deterministic: bool = False, **kwargs: Any) -> Any

Compute decoded residual channels given embeddings.

unembed(x: jax.Array) -> Any

Compute output distribution.

loss(logits: jax.Array, targets: jax.Array) -> jax.Array

Compute cross-entropy loss given logits and targets.

__call__(idx: jax.Array, targets: Optional[jax.Array] = None, padding_mask: Optional[jax.Array] = None, deterministic: bool = False, **kwargs: Any) -> Tuple[jax.Array, Optional[jax.Array]]

Parameters:

Name Type Description Default
idx Array

Input token indices of shape (B, T).

required
targets Optional[Array]

Target token indices of shape (B, T). Use -1 to ignore positions.

None
padding_mask Optional[Array]

Boolean tensor of shape (B, T). True for valid tokens, False for padding tokens.

None
deterministic bool

If False, applies dropout.

False

Returns:

Name Type Description
logits Array

Output logits of shape (B, T, vocab_size).

loss Optional[Array]

Cross-entropy loss if targets provided, else None.

Hybrid

Bases: GPT

Hybrid Transformer + Mamba language model.

mamba_layers controls which layers use Mamba blocks: - "even": even-indexed layers (0, 2, 4, ...) are Mamba - "odd": odd-indexed layers (1, 3, 5, ...) are Mamba - comma-separated indices: e.g. "0,2,4,6" for explicit control

LaCT

Bases: GPT

LaCT language model — SWA + chunked-TTT layers throughout.

embed(idx: jax.Array, deterministic: bool = False, **kwargs: Any) -> Any

Token embeddings only — RoPE handles positions inside each block.

Mamba

Bases: GPT

Mamba-2 language model — SSM-only, no attention.

Overrides GPT's setup/embed/decode to use MambaBlock layers and skip positional embeddings. loss and unembed are inherited unchanged.

embed(idx: jax.Array, deterministic: bool = False, **kwargs: Any) -> Any

Token embeddings only — no positional encoding needed for SSMs.

Scratchbubbles

Bases: Thoughtbubbles

left_attn_average(input_seq_len: int, x: jax.Array, cumulative_scores: jax.Array, token_index: jax.Array) -> jax.Array

Averaging where rightmost token of each idx acts as a query over a small masked attention which is only its left children.

unembed_forked(x: jax.Array, cumulative_scores: jax.Array, token_index: jax.Array, input_seq_len: int) -> jax.Array

Compute output logits for forked tokens using residual averaging.

key = jax.random.PRNGKey(0) x = jax.random.uniform(key, (job.args.per_device_batch_size, job.args.block_size2, job.model.n_embd)) cumulative_scores = jax.random.uniform(key, (job.args.per_device_batch_size, job.args.block_size2)) token_index = jnp.arange(job.args.block_size)[None, :].repeat(2, axis=1).repeat(job.args.per_device_batch_size, axis=0) input_seq_len = job.args.block_size

plot(intermediates: Any) -> Dict[str, Any] staticmethod

intermediates -> [figure]

Thoughtbubbles

Bases: GPT

Thoughtbubbles model with token forking support. Extends GPT with the ability to fork tokens during processing.

plot(intermediates: Any) -> Dict[str, Any] staticmethod

intermediates -> [figure]

embed(idx: jax.Array, deterministic: bool = False, **kwargs: Any) -> Tuple[jax.Array, jax.Array, jax.Array]

Compute token embeddings and initialize forking state. Overrides GPT.embed to return forking state.

decode(x: jax.Array, padding_mask: Optional[jax.Array] = None, deterministic: bool = False, **kwargs: Any) -> Any

Process through transformer blocks with forking. Overrides GPT.decode to handle forking state.

residual_average(input_size: int, residuals: jax.Array, cumulative_scores: jax.Array, token_index: jax.Array) -> jax.Array

Average residuals weighted by cumulative scores.

unembed(x: jax.Array) -> jax.Array

Standard unembed - used when no forking state provided.

unembed_forked(x: jax.Array, cumulative_scores: jax.Array, token_index: jax.Array, input_seq_len: int) -> jax.Array

Compute output logits for forked tokens using residual averaging.

__call__(idx: jax.Array, targets: Optional[jax.Array] = None, padding_mask: Optional[jax.Array] = None, deterministic: bool = False, **kwargs: Any) -> Tuple[jax.Array, Optional[jax.Array]]

Forward pass through Thoughtbubbles.

Parameters:

Name Type Description Default
idx Array

Input token indices of shape (B, T).

required
targets Optional[Array]

Target token indices of shape (B, T). Use -1 to ignore.

None
padding_mask Optional[Array]

Boolean tensor of shape (B, T). True=valid, False=padding.

None
deterministic bool

If False, applies dropout.

False

Returns:

Name Type Description
logits Array

Output logits of shape (B, T, vocab_size).

loss Optional[Array]

Cross-entropy loss if targets provided, else None.

Marin

Bases: Module

Marin model (marin-community/marin-8b-base).

Architecturally identical to Llama 3 -- same decoder blocks, GQA, SiLU MLP, RMSNorm, and RoPE. Defaults reflect the Marin-8B configuration.