Skip to content

cosine_rewarm

cosine_rewarm

Cosine rewarming schedule (Gupta et al., 2023).

At each dataset boundary the learning rate decays to min_lr, then linearly warms back up to rewarm_lr over warmup_steps steps, then cosine-decays to min_lr again at the next boundary.

The first segment uses max_lr as its peak (the initial warmup target). Subsequent segments use rewarm_lr which is typically lower than max_lr.

Boundary positions are derived from the per-stage token budgets in stage_tokens together with batch_size and block_size.

cosine_rewarm(total_steps: int, cfg: CosineRewarmConfig) -> optax._src.base.Schedule

Build a piecewise cosine-rewarm schedule across all stages.

Stage 0: warmup(min_lr → max_lr) + cosine(max_lr → min_lr) Stage i>0: warmup(min_lr → rewarm_lr) + cosine(rewarm_lr → min_lr)

Boundaries are computed from stage_tokens / (batch_size * block_size).