benchmark
benchmark
¶
Benchmark continual learning trainers.
Provides trainers for the benchmark paper with configurable: - Architecture (Transformer/Mamba/Hybrid/MoE via separate job registrations) - Schedule (WSD, cosine rewarm, WSD+reset) - Optimization (Full or LoRA via separate job registrations)
Non-LoRA jobs: continual/train/benchmark{,_mamba,_hybrid,_moe} LoRA jobs: continual/train/benchmark{,_mamba,_hybrid,_moe}_lora
BenchmarkConfig(batch_size: int = field('training/batch_size', default=512), per_device_batch_size: int = field('training/per_device_batch_size', default=(-1)), total_tokens: List[int] = field('training/tokens', default_factory=(lambda: [1000000000, 100000000, 100000000, 100000000, 100000000])), lr: float = field('optimization/lr', default=0.0003), warmup_pct: float = field('optimization/warmup_pct', default=0.01), decay_pct: float = field('optimization/decay_pct', default=0.01), validate: bool = field('training/validation', default=True), evaluate: bool = field('training/evaluate', default=True), datasets: List[List[Sampling]] = field('training/dataset', default_factory=(lambda: [[Sampling(name='fineweb', rate=1, style=(DatasetStyle.PMD))], [Sampling(name='mnli', rate=1, style=(DatasetStyle.PADDED))], [Sampling(name='qqp', rate=1, style=(DatasetStyle.PADDED))], [Sampling(name='sst2', rate=1, style=(DatasetStyle.PADDED))], [Sampling(name='siqa', rate=1, style=(DatasetStyle.PADDED))]])), evaluations: List[str] = field('eval/evaluations', default_factory=(lambda: ['mnli', 'qqp', 'sst2', 'siqa'])), block_size: int = field('architecture/block_size', default=512), report_interval: int = field('logging/report_interval', default=32), checkpoint_interval: int = field('logging/checkpoint_interval', default=1024), validation_interval: int = field('logging/validation_interval', default=512), validation_steps: int = field('training/validation_steps', default=2048), wandb: bool = field('logging/wandb', default=False), constant_pct: float = field('optimization/constant_pct', default=0.3), skip_first_dataset_validation: bool = field('training/skip_first_dataset_validation', default=False), fade: FadeConfig = FadeConfig(), schedule_type: str = field('optimization/schedule', default='wsd'), reset_optimizer_at_boundaries: bool = field('optimization/reset_optimizer', default=False))
dataclass
¶
Bases: ABCDConfig
Config for non-LoRA benchmark runs.
BenchmarkLoRAConfig(batch_size: int = field('training/batch_size', default=512), per_device_batch_size: int = field('training/per_device_batch_size', default=(-1)), total_tokens: List[int] = field('training/tokens', default_factory=(lambda: [1000000000, 100000000, 100000000, 100000000, 100000000])), lr: float = field('optimization/lr', default=0.0003), warmup_pct: float = field('optimization/warmup_pct', default=0.01), decay_pct: float = field('optimization/decay_pct', default=0.01), validate: bool = field('training/validation', default=True), evaluate: bool = field('training/evaluate', default=True), datasets: List[List[Sampling]] = field('training/dataset', default_factory=(lambda: [[Sampling(name='fineweb', rate=1, style=(DatasetStyle.PMD))], [Sampling(name='mnli', rate=1, style=(DatasetStyle.PADDED))], [Sampling(name='qqp', rate=1, style=(DatasetStyle.PADDED))], [Sampling(name='sst2', rate=1, style=(DatasetStyle.PADDED))], [Sampling(name='siqa', rate=1, style=(DatasetStyle.PADDED))]])), evaluations: List[str] = field('eval/evaluations', default_factory=(lambda: ['mnli', 'qqp', 'sst2', 'siqa'])), block_size: int = field('architecture/block_size', default=512), report_interval: int = field('logging/report_interval', default=32), checkpoint_interval: int = field('logging/checkpoint_interval', default=1024), validation_interval: int = field('logging/validation_interval', default=512), validation_steps: int = field('training/validation_steps', default=2048), wandb: bool = field('logging/wandb', default=False), pre_lora_tokens: List[int] = field('training/pre_lora_tokens', default_factory=(lambda: [1000000000])), pre_lora_datasets: List[List[Sampling]] = field('training/pre_lora_dataset', default_factory=(lambda: [[Sampling(name='fineweb', rate=1, style=(DatasetStyle.PMD))]])), post_lora_tokens: List[int] = field('training/post_lora_tokens', default_factory=(lambda: [100000000])), post_lora_datasets: List[List[Sampling]] = field('training/post_lora_dataset', default_factory=(lambda: [[Sampling(name='fineweb', rate=1, style=(DatasetStyle.PMD))]])), constant_pct: float = field('optimization/constant_pct', default=0.3), skip_first_dataset_validation: bool = field('training/skip_first_dataset_validation', default=False), fade: FadeConfig = FadeConfig(), schedule_type: str = field('optimization/schedule', default='wsd'), reset_optimizer_at_boundaries: bool = field('optimization/reset_optimizer', default=False))
dataclass
¶
Bases: BenchmarkConfig, LoRATrainerConfig
Config for LoRA benchmark runs.
Inherits BenchmarkConfig (ABCD multi-stage + schedule/reset) and LoRATrainerConfig (pre/post LoRA token budgets + datasets).
The pre-LoRA phase uses ABCDConfig's total_tokens/datasets
for multi-stage full-param training. The post-LoRA phase uses
post_lora_tokens/post_lora_datasets from LoRATrainerConfig
for adapter-only training.
LoRA hyperparameters (rank, alpha, target_modules) live under
optimization/lora/ and are read via configure(LoRAConfig).
BenchmarkBaseTrainer(spec: ExecutionSpec)
¶
Bases: ABCDBaseTrainer[BC, M], Generic[BC, M]
Benchmark trainer with schedule-variant boundary handling.
Extends ABCDBaseTrainer with: - Cosine rewarm: rebuilds schedule at each boundary - WSD+Reset: resets optimizer state at each boundary
BenchmarkLoRABaseTrainer(spec: ExecutionSpec)
¶
Bases: BenchmarkBaseTrainer[BLC, M], Generic[BLC, M]
Benchmark trainer with LoRA support.
Three-level data loading: 1. Pre-LoRA ABCD stages (full params) — uses ABCDConfig datasets/tokens 2. LoRA injection at boundary 3. Post-LoRA ABCD stages (LoRA params only) — uses post_lora_datasets/tokens