Skip to content

training

training

BaseTrainer(spec: ExecutionSpec)

Bases: RestoreableJob[C], Generic[C, M]

Generic pretrainer for GPT-style models.

Build a basic trainer

Parameters:

Name Type Description Default
spec ExecutionSpec

execution specification

required

Raises:

Type Description
AssertionError

if topology is not provided in spec

done: bool property

check if training is done

sharded_init(model: Any, key: jax.Array, *args: Any, mesh: jax.sharding.Mesh, **kwargs: Any) -> Tuple[Any, Any] staticmethod

Initialize model params sharded across mesh from the start.

Uses eval_shape to compute sharding spec without materializing params, then JITs init with out_shardings so arrays land on correct devices directly.

Returns (variables, var_sharding).

evaluator() -> Optional[Evaluator[M]]

define what evaluator to use

optimizer() -> str | optax.GradientTransformation classmethod

return either an optimizer from the optimizer library, or a custom optax optimizer

schedule() -> Optional[str | optax._src.base.Schedule] classmethod

return either a learning rate schedule, a schedule name from the library, or nothing to use a constant lr

batch(slice: str = 'train') -> PyTree[np.ndarray]

get the next batch from the dataset strategy

train_step(state: train_state.TrainState, batch: PyTree[jax.Array], key: jax.Array, accumulate_steps: int) -> Tuple[train_state.TrainState, jax.Array, Any] classmethod

Compute gradients over S micro-batches and apply one optimizer step.

Parameters:

Name Type Description Default
state TrainState

Current training state

required
batch PyTree[Array]

(x, y, padding_mask) each with shape (S, B, T) S = accumulation steps, B = batch size, T = sequence length

required
key Array

PRNG key for dropout

required
accumulate_steps int

Number of micro-batches (S)

required

Returns:

Type Description
Tuple[TrainState, Array, Any]

(updated_state, loss, meta) where meta is the last micro-batch's metadata

val_step(state: train_state.TrainState, batch: PyTree[jax.Array]) -> Tuple[jax.Array, jax.Array, Any] classmethod

Compute validation loss over S micro-batches.

Parameters:

Name Type Description Default
state TrainState

Current training state

required
batch PyTree[Array]

(x, y, padding_mask) each with shape (S, B, T) S = accumulation size, B = batch size, T = sequence length

required

Returns:

Type Description
Tuple[Array, Array, Any]

(loss_sum, token_count, last_meta)

save(suffix: Path) -> None

final save at the end of training

load(suffix: Path) -> None

load from a checkpoint, if available

run() -> None

main entry point to run training, called on all nodes

KLDivergenceTrainer(spec: ExecutionSpec)

Bases: BaseTrainer[C, M], Generic[C, M]

Two-stage trainer with KL-divergence penalty.

  • Stage 1 – standard language-model pretraining (cross-entropy only).
  • Stage 2 – pretraining loss plus beta * KL(policy || reference) where the reference policy is a frozen snapshot taken at the stage boundary.

The KL penalty is approximated as the difference in per-token NLL between the current model and the reference model on the same batch: kl_penalty = policy_loss - sg(reference_loss).

KLDivergenceTrainerConfig(batch_size: int = field('training/batch_size', default=512), per_device_batch_size: int = field('training/per_device_batch_size', default=(-1)), total_tokens: List[int] = field('training/tokens', default_factory=(lambda: [1000000000, 100000000])), lr: float = field('optimization/lr', default=0.0003), warmup_pct: float = field('training/warmup_pct', default=0.01), decay_pct: float = field('training/decay_pct', default=0.1), validate: bool = field('training/validation', default=True), evaluate: bool = field('training/evaluate', default=True), datasets: List[List[Sampling]] = field('training/dataset', default_factory=(lambda: [[Sampling(name='fineweb', rate=1, style=(DatasetStyle.PMD))], [Sampling(name='fineweb', rate=1, style=(DatasetStyle.PMD))]])), evaluations: List[str] = field('eval/evaluations', default_factory=(lambda: [])), block_size: int = field('architecture/block_size', default=512), report_interval: int = field('logging/report_interval', default=32), checkpoint_interval: int = field('logging/checkpoint_interval', default=1024), validation_interval: int = field('logging/validation_interval', default=512), validation_steps: int = field('training/validation_steps', default=2048), wandb: bool = field('logging/wandb', default=False)) dataclass

Bases: BaseTrainerConfig

Config for two-stage KL-divergence trainer.

total_tokens is a two-element list: [stage1_tokens, stage2_tokens]. datasets is a two-element list of sampling lists, one per stage.