Skip to content

kl_divergence

kl_divergence

KL Divergence Trainer

Two-stage trainer: stage 1 is standard pretraining, stage 2 enforces a customizable KL penalty against the stage-1 reference policy.

Stage switching follows the same token-counting approach used by the ABCD continual-learning trainers.

KLConfig(beta: float = field('optimization/kl/beta', default=0.1)) dataclass

KL divergence penalty configuration.

KLDivergenceTrainerConfig(batch_size: int = field('training/batch_size', default=512), per_device_batch_size: int = field('training/per_device_batch_size', default=(-1)), total_tokens: List[int] = field('training/tokens', default_factory=(lambda: [1000000000, 100000000])), lr: float = field('optimization/lr', default=0.0003), warmup_pct: float = field('training/warmup_pct', default=0.01), decay_pct: float = field('training/decay_pct', default=0.1), validate: bool = field('training/validation', default=True), evaluate: bool = field('training/evaluate', default=True), datasets: List[List[Sampling]] = field('training/dataset', default_factory=(lambda: [[Sampling(name='fineweb', rate=1, style=(DatasetStyle.PMD))], [Sampling(name='fineweb', rate=1, style=(DatasetStyle.PMD))]])), evaluations: List[str] = field('eval/evaluations', default_factory=(lambda: [])), block_size: int = field('architecture/block_size', default=512), report_interval: int = field('logging/report_interval', default=32), checkpoint_interval: int = field('logging/checkpoint_interval', default=1024), validation_interval: int = field('logging/validation_interval', default=512), validation_steps: int = field('training/validation_steps', default=2048), wandb: bool = field('logging/wandb', default=False)) dataclass

Bases: BaseTrainerConfig

Config for two-stage KL-divergence trainer.

total_tokens is a two-element list: [stage1_tokens, stage2_tokens]. datasets is a two-element list of sampling lists, one per stage.

KLDivergenceTrainState

Bases: TrainState

Train state carrying a frozen reference-policy snapshot and KL weight.

KLDivergenceTrainer(spec: ExecutionSpec)

Bases: BaseTrainer[C, M], Generic[C, M]

Two-stage trainer with KL-divergence penalty.

  • Stage 1 – standard language-model pretraining (cross-entropy only).
  • Stage 2 – pretraining loss plus beta * KL(policy || reference) where the reference policy is a frozen snapshot taken at the stage boundary.

The KL penalty is approximated as the difference in per-token NLL between the current model and the reference model on the same batch: kl_penalty = policy_loss - sg(reference_loss).