kl_divergence
kl_divergence
¶
KL Divergence Trainer
Two-stage trainer: stage 1 is standard pretraining, stage 2 enforces a customizable KL penalty against the stage-1 reference policy.
Stage switching follows the same token-counting approach used by the ABCD continual-learning trainers.
KLConfig(beta: float = field('optimization/kl/beta', default=0.1))
dataclass
¶
KL divergence penalty configuration.
KLDivergenceTrainerConfig(batch_size: int = field('training/batch_size', default=512), per_device_batch_size: int = field('training/per_device_batch_size', default=(-1)), total_tokens: List[int] = field('training/tokens', default_factory=(lambda: [1000000000, 100000000])), lr: float = field('optimization/lr', default=0.0003), warmup_pct: float = field('training/warmup_pct', default=0.01), decay_pct: float = field('training/decay_pct', default=0.1), validate: bool = field('training/validation', default=True), evaluate: bool = field('training/evaluate', default=True), datasets: List[List[Sampling]] = field('training/dataset', default_factory=(lambda: [[Sampling(name='fineweb', rate=1, style=(DatasetStyle.PMD))], [Sampling(name='fineweb', rate=1, style=(DatasetStyle.PMD))]])), evaluations: List[str] = field('eval/evaluations', default_factory=(lambda: [])), block_size: int = field('architecture/block_size', default=512), report_interval: int = field('logging/report_interval', default=32), checkpoint_interval: int = field('logging/checkpoint_interval', default=1024), validation_interval: int = field('logging/validation_interval', default=512), validation_steps: int = field('training/validation_steps', default=2048), wandb: bool = field('logging/wandb', default=False))
dataclass
¶
Bases: BaseTrainerConfig
Config for two-stage KL-divergence trainer.
total_tokens is a two-element list: [stage1_tokens, stage2_tokens].
datasets is a two-element list of sampling lists, one per stage.
KLDivergenceTrainState
¶
Bases: TrainState
Train state carrying a frozen reference-policy snapshot and KL weight.
KLDivergenceTrainer(spec: ExecutionSpec)
¶
Bases: BaseTrainer[C, M], Generic[C, M]
Two-stage trainer with KL-divergence penalty.
- Stage 1 – standard language-model pretraining (cross-entropy only).
- Stage 2 – pretraining loss plus
beta * KL(policy || reference)where the reference policy is a frozen snapshot taken at the stage boundary.
The KL penalty is approximated as the difference in per-token NLL
between the current model and the reference model on the same batch:
kl_penalty = policy_loss - sg(reference_loss).