base
base
¶
a very basic trainer
BaseTrainer(spec: ExecutionSpec)
¶
Bases: RestoreableJob[C], Generic[C, M]
Generic pretrainer for GPT-style models.
Build a basic trainer
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
spec
|
ExecutionSpec
|
execution specification |
required |
Raises:
| Type | Description |
|---|---|
AssertionError
|
if topology is not provided in spec |
done: bool
property
¶
check if training is done
sharded_init(model: Any, key: jax.Array, *args: Any, mesh: jax.sharding.Mesh, **kwargs: Any) -> Tuple[Any, Any]
staticmethod
¶
Initialize model params sharded across mesh from the start.
Uses eval_shape to compute sharding spec without materializing params, then JITs init with out_shardings so arrays land on correct devices directly.
Returns (variables, var_sharding).
evaluator() -> Optional[Evaluator[M]]
¶
define what evaluator to use
optimizer() -> str | optax.GradientTransformation
classmethod
¶
return either an optimizer from the optimizer library, or a custom optax optimizer
schedule() -> Optional[str | optax._src.base.Schedule]
classmethod
¶
return either a learning rate schedule, a schedule name from the library, or nothing to use a constant lr
batch(slice: str = 'train') -> PyTree[np.ndarray]
¶
get the next batch from the dataset strategy
train_step(state: train_state.TrainState, batch: PyTree[jax.Array], key: jax.Array, accumulate_steps: int) -> Tuple[train_state.TrainState, jax.Array, Any]
classmethod
¶
Compute gradients over S micro-batches and apply one optimizer step.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
TrainState
|
Current training state |
required |
batch
|
PyTree[Array]
|
(x, y, padding_mask) each with shape (S, B, T) S = accumulation steps, B = batch size, T = sequence length |
required |
key
|
Array
|
PRNG key for dropout |
required |
accumulate_steps
|
int
|
Number of micro-batches (S) |
required |
Returns:
| Type | Description |
|---|---|
Tuple[TrainState, Array, Any]
|
(updated_state, loss, meta) where meta is the last micro-batch's metadata |
val_step(state: train_state.TrainState, batch: PyTree[jax.Array]) -> Tuple[jax.Array, jax.Array, Any]
classmethod
¶
Compute validation loss over S micro-batches.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
TrainState
|
Current training state |
required |
batch
|
PyTree[Array]
|
(x, y, padding_mask) each with shape (S, B, T) S = accumulation size, B = batch size, T = sequence length |
required |
Returns:
| Type | Description |
|---|---|
Tuple[Array, Array, Any]
|
(loss_sum, token_count, last_meta) |
save(suffix: Path) -> None
¶
final save at the end of training
load(suffix: Path) -> None
¶
load from a checkpoint, if available
run() -> None
¶
main entry point to run training, called on all nodes