Skip to content

abcd

abcd

FadeConfig(overlap: float = field('training/fade/overlap', default=0.0), curve: str = field('training/fade/curve', default='linear'), steepness: float = field('training/fade/steepness', default=10.0), per_boundary_overlap: List[float] = field('training/fade/per_boundary_overlap', default_factory=list)) dataclass

Configuration for gradual dataset fade transitions.

Controls how datasets blend during boundary transitions instead of hard-switching.

Parameters:

Name Type Description Default
overlap float

Size of the blending region as a fraction of the smaller adjacent segment. 0 = hard switch (original behavior), 1 = maximum overlap (fade spans the entire smaller segment).

field('training/fade/overlap', default=0.0)
curve str

Shape of the blending function. One of: - "linear": constant-rate crossfade - "cosine": smooth S-curve (slow start/end, fast middle) - "sigmoid": steep S-curve controlled by steepness

field('training/fade/curve', default='linear')
steepness float

Slope of the sigmoid curve at the midpoint. Only used when curve="sigmoid". Higher values produce a sharper transition. Default 10.0.

field('training/fade/steepness', default=10.0)
per_boundary_overlap List[float]

Per-boundary overlap overrides. When non-empty, must have length len(datasets) - 1. Each entry replaces overlap for the corresponding boundary.

field('training/fade/per_boundary_overlap', default_factory=list)

ABCDBaseTrainer(spec: ExecutionSpec)

Bases: BaseTrainer[C, M], Generic[C, M]

Standard continual learning with optional gradual fade between datasets.

When fade.overlap > 0, adjacent dataset segments overlap so that the model transitions gradually rather than hard-switching at boundaries. All strategies are created and their async batch workers started at init time; during the fade region batches are drawn from both the outgoing and incoming dataloaders and combined proportionally.

batch(slice: str = 'train') -> PyTree[np.ndarray]

Return the next training or validation batch.

When fade.overlap == 0 this behaves identically to the original hard-switch logic. With overlap > 0 it draws rows from multiple dataloaders in proportion to their fade weights and shuffles them together.

ABCDKLConfig(batch_size: int = field('training/batch_size', default=512), per_device_batch_size: int = field('training/per_device_batch_size', default=(-1)), total_tokens: List[int] = field('training/tokens', default_factory=(lambda: [1000000000, 100000000, 100000000, 100000000, 100000000])), lr: float = field('optimization/lr', default=0.0003), warmup_pct: float = field('optimization/warmup_pct', default=0.01), decay_pct: float = field('optimization/decay_pct', default=0.01), validate: bool = field('training/validation', default=True), evaluate: bool = field('training/evaluate', default=True), datasets: List[List[Sampling]] = field('training/dataset', default_factory=(lambda: [[Sampling(name='fineweb', rate=1, style=(DatasetStyle.PMD))], [Sampling(name='mnli', rate=1, style=(DatasetStyle.PADDED))], [Sampling(name='qqp', rate=1, style=(DatasetStyle.PADDED))], [Sampling(name='sst2', rate=1, style=(DatasetStyle.PADDED))], [Sampling(name='siqa', rate=1, style=(DatasetStyle.PADDED))]])), evaluations: List[str] = field('eval/evaluations', default_factory=(lambda: ['mnli', 'qqp', 'sst2', 'siqa'])), block_size: int = field('architecture/block_size', default=512), report_interval: int = field('logging/report_interval', default=32), checkpoint_interval: int = field('logging/checkpoint_interval', default=1024), validation_interval: int = field('logging/validation_interval', default=512), validation_steps: int = field('training/validation_steps', default=2048), wandb: bool = field('logging/wandb', default=False), constant_pct: float = field('optimization/constant_pct', default=0.3), betas: List[float] = field('optimization/kl/betas', default_factory=(lambda: [0.0, 0.1, 0.1, 0.1, 0.1])), reference_update_stages: List[int] = field('optimization/kl/reference_update_stages', default_factory=(lambda: [1])), skip_first_dataset_validation: bool = field('training/skip_first_dataset_validation', default=False), fade: FadeConfig = FadeConfig()) dataclass

Bases: KLDivergenceTrainerConfig

Config for multi-stage ABCD training with per-stage KL penalties.

Extends :class:KLDivergenceTrainerConfig to support an arbitrary number of stages (not just two), per-stage KL penalty weights, and configurable stages at which the reference policy is updated.

Parameters

betas: Per-stage KL penalty weight. Length must equal len(total_tokens). A value of 0 disables the KL penalty for that stage (i.e. pure pretraining). reference_update_stages: Stage indices at which the reference policy is updated (snapshot of current params taken as the new reference). Typically the first stage that uses a KL penalty.

ABCDKLDivergenceTrainer(spec: ExecutionSpec)

Bases: KLDivergenceTrainer[CKL, M], Generic[CKL, M]

Multi-stage continual-learning trainer with per-stage KL penalties.

Combines :class:KLDivergenceTrainer's KL-penalised forward pass with :class:ABCDBaseTrainer's fade-based multi-dataset scheduling.

  • Any number of stages (≥ 2).
  • Per-stage beta values control KL penalty strength (0 = off).
  • reference_update_stages specifies at which stage boundaries the reference policy snapshot is refreshed.
  • Optional gradual fade between adjacent datasets via :class:FadeConfig.

batch(slice: str = 'train') -> PyTree[np.ndarray]

Return the next batch using fade-weighted multi-dataset sampling.