training
training
¶
BaseTrainer(spec: ExecutionSpec)
¶
Bases: RestoreableJob[C], Generic[C, M]
Generic pretrainer for GPT-style models.
Build a basic trainer
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
spec
|
ExecutionSpec
|
execution specification |
required |
Raises:
| Type | Description |
|---|---|
AssertionError
|
if topology is not provided in spec |
done: bool
property
¶
check if training is done
sharded_init(model: Any, key: jax.Array, *args: Any, mesh: jax.sharding.Mesh, **kwargs: Any) -> Tuple[Any, Any]
staticmethod
¶
Initialize model params sharded across mesh from the start.
Uses eval_shape to compute sharding spec without materializing params, then JITs init with out_shardings so arrays land on correct devices directly.
Returns (variables, var_sharding).
evaluator() -> Optional[Evaluator[M]]
¶
define what evaluator to use
optimizer() -> str | optax.GradientTransformation
classmethod
¶
return either an optimizer from the optimizer library, or a custom optax optimizer
schedule() -> Optional[str | optax._src.base.Schedule]
classmethod
¶
return either a learning rate schedule, a schedule name from the library, or nothing to use a constant lr
batch(slice: str = 'train') -> PyTree[np.ndarray]
¶
get the next batch from the dataset strategy
train_step(state: train_state.TrainState, batch: PyTree[jax.Array], key: jax.Array, accumulate_steps: int) -> Tuple[train_state.TrainState, jax.Array, Any]
classmethod
¶
Compute gradients over S micro-batches and apply one optimizer step.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
TrainState
|
Current training state |
required |
batch
|
PyTree[Array]
|
(x, y, padding_mask) each with shape (S, B, T) S = accumulation steps, B = batch size, T = sequence length |
required |
key
|
Array
|
PRNG key for dropout |
required |
accumulate_steps
|
int
|
Number of micro-batches (S) |
required |
Returns:
| Type | Description |
|---|---|
Tuple[TrainState, Array, Any]
|
(updated_state, loss, meta) where meta is the last micro-batch's metadata |
val_step(state: train_state.TrainState, batch: PyTree[jax.Array]) -> Tuple[jax.Array, jax.Array, Any]
classmethod
¶
Compute validation loss over S micro-batches.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
TrainState
|
Current training state |
required |
batch
|
PyTree[Array]
|
(x, y, padding_mask) each with shape (S, B, T) S = accumulation size, B = batch size, T = sequence length |
required |
Returns:
| Type | Description |
|---|---|
Tuple[Array, Array, Any]
|
(loss_sum, token_count, last_meta) |
save(suffix: Path) -> None
¶
final save at the end of training
load(suffix: Path) -> None
¶
load from a checkpoint, if available
run() -> None
¶
main entry point to run training, called on all nodes
KLDivergenceTrainer(spec: ExecutionSpec)
¶
Bases: BaseTrainer[C, M], Generic[C, M]
Two-stage trainer with KL-divergence penalty.
- Stage 1 – standard language-model pretraining (cross-entropy only).
- Stage 2 – pretraining loss plus
beta * KL(policy || reference)where the reference policy is a frozen snapshot taken at the stage boundary.
The KL penalty is approximated as the difference in per-token NLL
between the current model and the reference model on the same batch:
kl_penalty = policy_loss - sg(reference_loss).
KLDivergenceTrainerConfig(batch_size: int = field('training/batch_size', default=512), per_device_batch_size: int = field('training/per_device_batch_size', default=(-1)), total_tokens: List[int] = field('training/tokens', default_factory=(lambda: [1000000000, 100000000])), lr: float = field('optimization/lr', default=0.0003), warmup_pct: float = field('training/warmup_pct', default=0.01), decay_pct: float = field('training/decay_pct', default=0.1), validate: bool = field('training/validation', default=True), evaluate: bool = field('training/evaluate', default=True), datasets: List[List[Sampling]] = field('training/dataset', default_factory=(lambda: [[Sampling(name='fineweb', rate=1, style=(DatasetStyle.PMD))], [Sampling(name='fineweb', rate=1, style=(DatasetStyle.PMD))]])), evaluations: List[str] = field('eval/evaluations', default_factory=(lambda: [])), block_size: int = field('architecture/block_size', default=512), report_interval: int = field('logging/report_interval', default=32), checkpoint_interval: int = field('logging/checkpoint_interval', default=1024), validation_interval: int = field('logging/validation_interval', default=512), validation_steps: int = field('training/validation_steps', default=2048), wandb: bool = field('logging/wandb', default=False))
dataclass
¶
Bases: BaseTrainerConfig
Config for two-stage KL-divergence trainer.
total_tokens is a two-element list: [stage1_tokens, stage2_tokens].
datasets is a two-element list of sampling lists, one per stage.