Skip to content

base

base

a very basic trainer

BaseTrainer(spec: ExecutionSpec)

Bases: RestoreableJob[C], Generic[C, M]

Generic pretrainer for GPT-style models.

Build a basic trainer

Parameters:

Name Type Description Default
spec ExecutionSpec

execution specification

required

Raises:

Type Description
AssertionError

if topology is not provided in spec

done: bool property

check if training is done

sharded_init(model: Any, key: jax.Array, *args: Any, mesh: jax.sharding.Mesh, **kwargs: Any) -> Tuple[Any, Any] staticmethod

Initialize model params sharded across mesh from the start.

Uses eval_shape to compute sharding spec without materializing params, then JITs init with out_shardings so arrays land on correct devices directly.

Returns (variables, var_sharding).

evaluator() -> Optional[Evaluator[M]]

define what evaluator to use

optimizer() -> str | optax.GradientTransformation classmethod

return either an optimizer from the optimizer library, or a custom optax optimizer

schedule() -> Optional[str | optax._src.base.Schedule] classmethod

return either a learning rate schedule, a schedule name from the library, or nothing to use a constant lr

batch(slice: str = 'train') -> PyTree[np.ndarray]

get the next batch from the dataset strategy

train_step(state: train_state.TrainState, batch: PyTree[jax.Array], key: jax.Array, accumulate_steps: int) -> Tuple[train_state.TrainState, jax.Array, Any] classmethod

Compute gradients over S micro-batches and apply one optimizer step.

Parameters:

Name Type Description Default
state TrainState

Current training state

required
batch PyTree[Array]

(x, y, padding_mask) each with shape (S, B, T) S = accumulation steps, B = batch size, T = sequence length

required
key Array

PRNG key for dropout

required
accumulate_steps int

Number of micro-batches (S)

required

Returns:

Type Description
Tuple[TrainState, Array, Any]

(updated_state, loss, meta) where meta is the last micro-batch's metadata

val_step(state: train_state.TrainState, batch: PyTree[jax.Array]) -> Tuple[jax.Array, jax.Array, Any] classmethod

Compute validation loss over S micro-batches.

Parameters:

Name Type Description Default
state TrainState

Current training state

required
batch PyTree[Array]

(x, y, padding_mask) each with shape (S, B, T) S = accumulation size, B = batch size, T = sequence length

required

Returns:

Type Description
Tuple[Array, Array, Any]

(loss_sum, token_count, last_meta)

save(suffix: Path) -> None

final save at the end of training

load(suffix: Path) -> None

load from a checkpoint, if available

run() -> None

main entry point to run training, called on all nodes