Skip to content

benchmark

benchmark

Benchmark continual learning trainers.

Provides trainers for the benchmark paper with configurable: - Architecture (Transformer/Mamba/Hybrid/MoE via separate job registrations) - Schedule (WSD, cosine rewarm, WSD+reset) - Optimization (Full or LoRA via separate job registrations)

Non-LoRA jobs: continual/train/benchmark{,_mamba,_hybrid,_moe} LoRA jobs: continual/train/benchmark{,_mamba,_hybrid,_moe}_lora

BenchmarkConfig(batch_size: int = field('training/batch_size', default=512), per_device_batch_size: int = field('training/per_device_batch_size', default=(-1)), total_tokens: List[int] = field('training/tokens', default_factory=(lambda: [1000000000, 100000000, 100000000, 100000000, 100000000])), lr: float = field('optimization/lr', default=0.0003), warmup_pct: float = field('optimization/warmup_pct', default=0.01), decay_pct: float = field('optimization/decay_pct', default=0.01), validate: bool = field('training/validation', default=True), evaluate: bool = field('training/evaluate', default=True), datasets: List[List[Sampling]] = field('training/dataset', default_factory=(lambda: [[Sampling(name='fineweb', rate=1, style=(DatasetStyle.PMD))], [Sampling(name='mnli', rate=1, style=(DatasetStyle.PADDED))], [Sampling(name='qqp', rate=1, style=(DatasetStyle.PADDED))], [Sampling(name='sst2', rate=1, style=(DatasetStyle.PADDED))], [Sampling(name='siqa', rate=1, style=(DatasetStyle.PADDED))]])), evaluations: List[str] = field('eval/evaluations', default_factory=(lambda: ['mnli', 'qqp', 'sst2', 'siqa'])), block_size: int = field('architecture/block_size', default=512), report_interval: int = field('logging/report_interval', default=32), checkpoint_interval: int = field('logging/checkpoint_interval', default=1024), validation_interval: int = field('logging/validation_interval', default=512), validation_steps: int = field('training/validation_steps', default=2048), wandb: bool = field('logging/wandb', default=False), constant_pct: float = field('optimization/constant_pct', default=0.3), skip_first_dataset_validation: bool = field('training/skip_first_dataset_validation', default=False), fade: FadeConfig = FadeConfig(), schedule_type: str = field('optimization/schedule', default='wsd'), reset_optimizer_at_boundaries: bool = field('optimization/reset_optimizer', default=False)) dataclass

Bases: ABCDConfig

Config for non-LoRA benchmark runs.

BenchmarkLoRAConfig(batch_size: int = field('training/batch_size', default=512), per_device_batch_size: int = field('training/per_device_batch_size', default=(-1)), total_tokens: List[int] = field('training/tokens', default_factory=(lambda: [1000000000, 100000000, 100000000, 100000000, 100000000])), lr: float = field('optimization/lr', default=0.0003), warmup_pct: float = field('optimization/warmup_pct', default=0.01), decay_pct: float = field('optimization/decay_pct', default=0.01), validate: bool = field('training/validation', default=True), evaluate: bool = field('training/evaluate', default=True), datasets: List[List[Sampling]] = field('training/dataset', default_factory=(lambda: [[Sampling(name='fineweb', rate=1, style=(DatasetStyle.PMD))], [Sampling(name='mnli', rate=1, style=(DatasetStyle.PADDED))], [Sampling(name='qqp', rate=1, style=(DatasetStyle.PADDED))], [Sampling(name='sst2', rate=1, style=(DatasetStyle.PADDED))], [Sampling(name='siqa', rate=1, style=(DatasetStyle.PADDED))]])), evaluations: List[str] = field('eval/evaluations', default_factory=(lambda: ['mnli', 'qqp', 'sst2', 'siqa'])), block_size: int = field('architecture/block_size', default=512), report_interval: int = field('logging/report_interval', default=32), checkpoint_interval: int = field('logging/checkpoint_interval', default=1024), validation_interval: int = field('logging/validation_interval', default=512), validation_steps: int = field('training/validation_steps', default=2048), wandb: bool = field('logging/wandb', default=False), pre_lora_tokens: List[int] = field('training/pre_lora_tokens', default_factory=(lambda: [1000000000])), pre_lora_datasets: List[List[Sampling]] = field('training/pre_lora_dataset', default_factory=(lambda: [[Sampling(name='fineweb', rate=1, style=(DatasetStyle.PMD))]])), post_lora_tokens: List[int] = field('training/post_lora_tokens', default_factory=(lambda: [100000000])), post_lora_datasets: List[List[Sampling]] = field('training/post_lora_dataset', default_factory=(lambda: [[Sampling(name='fineweb', rate=1, style=(DatasetStyle.PMD))]])), constant_pct: float = field('optimization/constant_pct', default=0.3), skip_first_dataset_validation: bool = field('training/skip_first_dataset_validation', default=False), fade: FadeConfig = FadeConfig(), schedule_type: str = field('optimization/schedule', default='wsd'), reset_optimizer_at_boundaries: bool = field('optimization/reset_optimizer', default=False)) dataclass

Bases: BenchmarkConfig, LoRATrainerConfig

Config for LoRA benchmark runs.

Inherits BenchmarkConfig (ABCD multi-stage + schedule/reset) and LoRATrainerConfig (pre/post LoRA token budgets + datasets).

The pre-LoRA phase uses ABCDConfig's total_tokens/datasets for multi-stage full-param training. The post-LoRA phase uses post_lora_tokens/post_lora_datasets from LoRATrainerConfig for adapter-only training.

LoRA hyperparameters (rank, alpha, target_modules) live under optimization/lora/ and are read via configure(LoRAConfig).

BenchmarkBaseTrainer(spec: ExecutionSpec)

Bases: ABCDBaseTrainer[BC, M], Generic[BC, M]

Benchmark trainer with schedule-variant boundary handling.

Extends ABCDBaseTrainer with: - Cosine rewarm: rebuilds schedule at each boundary - WSD+Reset: resets optimizer state at each boundary

BenchmarkLoRABaseTrainer(spec: ExecutionSpec)

Bases: BenchmarkBaseTrainer[BLC, M], Generic[BLC, M]

Benchmark trainer with LoRA support.

Three-level data loading: 1. Pre-LoRA ABCD stages (full params) — uses ABCDConfig datasets/tokens 2. LoRA injection at boundary 3. Post-LoRA ABCD stages (LoRA params only) — uses post_lora_datasets/tokens