abcd
abcd
¶
FadeConfig(overlap: float = field('training/fade/overlap', default=0.0), curve: str = field('training/fade/curve', default='linear'), steepness: float = field('training/fade/steepness', default=10.0), per_boundary_overlap: List[float] = field('training/fade/per_boundary_overlap', default_factory=list))
dataclass
¶
Configuration for gradual dataset fade transitions.
Controls how datasets blend during boundary transitions instead of hard-switching.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
overlap
|
float
|
Size of the blending region as a fraction of the smaller adjacent segment. 0 = hard switch (original behavior), 1 = maximum overlap (fade spans the entire smaller segment). |
field('training/fade/overlap', default=0.0)
|
curve
|
str
|
Shape of the blending function. One of:
- "linear": constant-rate crossfade
- "cosine": smooth S-curve (slow start/end, fast middle)
- "sigmoid": steep S-curve controlled by |
field('training/fade/curve', default='linear')
|
steepness
|
float
|
Slope of the sigmoid curve at the midpoint.
Only used when |
field('training/fade/steepness', default=10.0)
|
per_boundary_overlap
|
List[float]
|
Per-boundary overlap overrides. When
non-empty, must have length |
field('training/fade/per_boundary_overlap', default_factory=list)
|
ABCDBaseTrainer(spec: ExecutionSpec)
¶
Bases: BaseTrainer[C, M], Generic[C, M]
Standard continual learning with optional gradual fade between datasets.
When fade.overlap > 0, adjacent dataset segments overlap so that
the model transitions gradually rather than hard-switching at
boundaries. All strategies are created and their async batch workers
started at init time; during the fade region batches are drawn from
both the outgoing and incoming dataloaders and combined proportionally.
batch(slice: str = 'train') -> PyTree[np.ndarray]
¶
Return the next training or validation batch.
When fade.overlap == 0 this behaves identically to the
original hard-switch logic. With overlap > 0 it draws rows
from multiple dataloaders in proportion to their fade weights and
shuffles them together.
ABCDKLConfig(batch_size: int = field('training/batch_size', default=512), per_device_batch_size: int = field('training/per_device_batch_size', default=(-1)), total_tokens: List[int] = field('training/tokens', default_factory=(lambda: [1000000000, 100000000, 100000000, 100000000, 100000000])), lr: float = field('optimization/lr', default=0.0003), warmup_pct: float = field('optimization/warmup_pct', default=0.01), decay_pct: float = field('optimization/decay_pct', default=0.01), validate: bool = field('training/validation', default=True), evaluate: bool = field('training/evaluate', default=True), datasets: List[List[Sampling]] = field('training/dataset', default_factory=(lambda: [[Sampling(name='fineweb', rate=1, style=(DatasetStyle.PMD))], [Sampling(name='mnli', rate=1, style=(DatasetStyle.PADDED))], [Sampling(name='qqp', rate=1, style=(DatasetStyle.PADDED))], [Sampling(name='sst2', rate=1, style=(DatasetStyle.PADDED))], [Sampling(name='siqa', rate=1, style=(DatasetStyle.PADDED))]])), evaluations: List[str] = field('eval/evaluations', default_factory=(lambda: ['mnli', 'qqp', 'sst2', 'siqa'])), block_size: int = field('architecture/block_size', default=512), report_interval: int = field('logging/report_interval', default=32), checkpoint_interval: int = field('logging/checkpoint_interval', default=1024), validation_interval: int = field('logging/validation_interval', default=512), validation_steps: int = field('training/validation_steps', default=2048), wandb: bool = field('logging/wandb', default=False), constant_pct: float = field('optimization/constant_pct', default=0.3), betas: List[float] = field('optimization/kl/betas', default_factory=(lambda: [0.0, 0.1, 0.1, 0.1, 0.1])), reference_update_stages: List[int] = field('optimization/kl/reference_update_stages', default_factory=(lambda: [1])), skip_first_dataset_validation: bool = field('training/skip_first_dataset_validation', default=False), fade: FadeConfig = FadeConfig())
dataclass
¶
Bases: KLDivergenceTrainerConfig
Config for multi-stage ABCD training with per-stage KL penalties.
Extends :class:KLDivergenceTrainerConfig to support an arbitrary
number of stages (not just two), per-stage KL penalty weights, and
configurable stages at which the reference policy is updated.
Parameters¶
betas:
Per-stage KL penalty weight. Length must equal
len(total_tokens). A value of 0 disables the KL penalty
for that stage (i.e. pure pretraining).
reference_update_stages:
Stage indices at which the reference policy is updated
(snapshot of current params taken as the new reference).
Typically the first stage that uses a KL penalty.
ABCDKLDivergenceTrainer(spec: ExecutionSpec)
¶
Bases: KLDivergenceTrainer[CKL, M], Generic[CKL, M]
Multi-stage continual-learning trainer with per-stage KL penalties.
Combines :class:KLDivergenceTrainer's KL-penalised forward pass
with :class:ABCDBaseTrainer's fade-based multi-dataset scheduling.
- Any number of stages (≥ 2).
- Per-stage
betavalues control KL penalty strength (0 = off). reference_update_stagesspecifies at which stage boundaries the reference policy snapshot is refreshed.- Optional gradual fade between adjacent datasets via
:class:
FadeConfig.
batch(slice: str = 'train') -> PyTree[np.ndarray]
¶
Return the next batch using fade-weighted multi-dataset sampling.