Skip to content

base

base

GPT

Bases: Module

embed(idx: jax.Array, deterministic: bool = False, **kwargs: Any) -> Any

Compute token and positional embeddings given inputs.

decode(x: jax.Array, padding_mask: Optional[jax.Array] = None, deterministic: bool = False, **kwargs: Any) -> Any

Compute decoded residual channels given embeddings.

unembed(x: jax.Array) -> Any

Compute output distribution.

loss(logits: jax.Array, targets: jax.Array) -> jax.Array

Compute cross-entropy loss given logits and targets.

__call__(idx: jax.Array, targets: Optional[jax.Array] = None, padding_mask: Optional[jax.Array] = None, deterministic: bool = False, **kwargs: Any) -> Tuple[jax.Array, Optional[jax.Array]]

Parameters:

Name Type Description Default
idx Array

Input token indices of shape (B, T).

required
targets Optional[Array]

Target token indices of shape (B, T). Use -1 to ignore positions.

None
padding_mask Optional[Array]

Boolean tensor of shape (B, T). True for valid tokens, False for padding tokens.

None
deterministic bool

If False, applies dropout.

False

Returns:

Name Type Description
logits Array

Output logits of shape (B, T, vocab_size).

loss Optional[Array]

Cross-entropy loss if targets provided, else None.