base
base
¶
GPT
¶
Bases: Module
embed(idx: jax.Array, deterministic: bool = False, **kwargs: Any) -> Any
¶
Compute token and positional embeddings given inputs.
decode(x: jax.Array, padding_mask: Optional[jax.Array] = None, deterministic: bool = False, **kwargs: Any) -> Any
¶
Compute decoded residual channels given embeddings.
unembed(x: jax.Array) -> Any
¶
Compute output distribution.
loss(logits: jax.Array, targets: jax.Array) -> jax.Array
¶
Compute cross-entropy loss given logits and targets.
__call__(idx: jax.Array, targets: Optional[jax.Array] = None, padding_mask: Optional[jax.Array] = None, deterministic: bool = False, **kwargs: Any) -> Tuple[jax.Array, Optional[jax.Array]]
¶
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
idx
|
Array
|
Input token indices of shape (B, T). |
required |
targets
|
Optional[Array]
|
Target token indices of shape (B, T). Use -1 to ignore positions. |
None
|
padding_mask
|
Optional[Array]
|
Boolean tensor of shape (B, T). True for valid tokens, False for padding tokens. |
None
|
deterministic
|
bool
|
If False, applies dropout. |
False
|
Returns:
| Name | Type | Description |
|---|---|---|
logits |
Array
|
Output logits of shape (B, T, vocab_size). |
loss |
Optional[Array]
|
Cross-entropy loss if targets provided, else None. |