lora
lora
¶
LoRA (Low-Rank Adaptation) Trainer.
Two-phase trainer: phase 1 trains full parameters on pre-LoRA datasets, then freezes and injects low-rank adapters, phase 2 trains only LoRA parameters on post-LoRA datasets.
Since this extends RestoreableJob, users can restore from a checkpoint and set pre_lora_datasets/tokens to empty to skip straight to LoRA fine-tuning from a pretrained checkpoint.
LoRAConfig(rank: int = field('optimization/lora/rank', default=16), alpha: float = field('optimization/lora/alpha', default=16.0), target_modules: List[str] = field('optimization/lora/target_modules', default_factory=(lambda: ['kernel'])))
dataclass
¶
Low-rank adaptation parameters.
LoRATrainerConfig(batch_size: int = field('training/batch_size', default=512), per_device_batch_size: int = field('training/per_device_batch_size', default=(-1)), total_tokens: int = field('training/tokens', default=1000000000), lr: float = field('optimization/lr', default=0.0003), warmup_pct: float = field('training/warmup_pct', default=0.01), decay_pct: float = field('training/decay_pct', default=0.1), validate: bool = field('training/validation', default=True), evaluate: bool = field('training/evaluate', default=True), datasets: List[Sampling] = field('training/dataset', default_factory=(lambda: [Sampling(name='fineweb', rate=1, style=(DatasetStyle.PMD))])), evaluations: List[str] = field('eval/evaluations', default_factory=(lambda: [])), block_size: int = field('architecture/block_size', default=512), report_interval: int = field('logging/report_interval', default=32), checkpoint_interval: int = field('logging/checkpoint_interval', default=1024), validation_interval: int = field('logging/validation_interval', default=512), validation_steps: int = field('training/validation_steps', default=2048), wandb: bool = field('logging/wandb', default=False), pre_lora_tokens: List[int] = field('training/pre_lora_tokens', default_factory=(lambda: [1000000000])), pre_lora_datasets: List[List[Sampling]] = field('training/pre_lora_dataset', default_factory=(lambda: [[Sampling(name='fineweb', rate=1, style=(DatasetStyle.PMD))]])), post_lora_tokens: List[int] = field('training/post_lora_tokens', default_factory=(lambda: [100000000])), post_lora_datasets: List[List[Sampling]] = field('training/post_lora_dataset', default_factory=(lambda: [[Sampling(name='fineweb', rate=1, style=(DatasetStyle.PMD))]])))
dataclass
¶
Bases: BaseTrainerConfig
Config for two-phase LoRA trainer.
Pre-LoRA phase trains full parameters, then LoRA adapters are injected and only they are trained in the post-LoRA phase.
LoRA hyperparameters (rank, alpha, target_modules) are read from the
config namespace via configure(LoRAConfig) at init time — they
live under optimization/lora/ in the YAML.
LoRATrainState
¶
Bases: TrainState
Train state for LoRA phase.
params holds the trainable LoRA parameters as
{"lora_A": pytree, "lora_B": pytree}. The optimizer in
tx acts on params, so only the adapters are updated.
base_params holds the frozen pretrained weights (bfloat16).
LoRATrainer(spec: ExecutionSpec)
¶
Bases: BaseTrainer[C, M], Generic[C, M]
Two-phase LoRA trainer.
Phase 1 (pre-LoRA): Full-parameter training on pre_lora_datasets.
Phase 2 (post-LoRA): Freeze base, inject LoRA adapters, train only
adapters on post_lora_datasets.
The transition happens automatically at the token boundary between pre-LoRA and post-LoRA phases.
restore_from_path(rel_path: str | Path) -> None
¶
Restore from checkpoint, handling LoRA state type mismatch.
If the checkpoint was saved during the LoRA phase, the state is
a LoRATrainState (params = {"lora_A", "lora_B"}, plus
base_params). But _init_state creates a plain TrainState
so the template tree won't match.
We peek at the checkpoint metadata to read the step count. If
it's past the pre-LoRA boundary we transition first (creating
a LoRATrainState template), then let the parent restore
overwrite it with the checkpoint contents.
param_filter(params: PyTree[jax.Array], target_modules: List[str]) -> PyTree[bool]
¶
Create a boolean mask over the param tree.
Returns True for leaves whose path contains any of target_modules.
This is the filter that determines which parameters get LoRA adapters.
inject_lora_params(params: PyTree[jax.Array], mask: PyTree[bool], rank: int, key: jax.Array) -> Tuple[PyTree[Optional[jax.Array]], PyTree[Optional[jax.Array]]]
¶
Create LoRA A/B matrices for each targeted parameter.
For a parameter of shape (in_features, out_features): - A: (in_features, rank) — initialized from normal(0, 1/rank) - B: (rank, out_features) — initialized to zeros
Returns two pytrees (lora_A, lora_B) with same structure as params. Non-targeted leaves are zero arrays (same shape as base param) so the tree structure is compatible with jax.tree_util operations.
merge_lora_params(base_params: PyTree[jax.Array], lora_A: PyTree[Any], lora_B: PyTree[Any], alpha: float, rank: int) -> PyTree[jax.Array]
¶
Merge LoRA into base: W_eff = W + (alpha/rank) * A @ B.
No stop_gradient on base_params is needed: in the training
loop value_and_grad differentiates only w.r.t. its argument
(the LoRA params dict), so base_params — accessed from the
closed-over state — is already treated as a constant by JAX.
count_lora_params(lora_A: PyTree[Any], lora_B: PyTree[Any]) -> int
¶
Count total trainable LoRA parameters.
transition_to_lora(trainer: Any) -> None
¶
Freeze base params and inject LoRA adapters on trainer.
Shared by LoRATrainer and BenchmarkLoRABaseTrainer.
Expects the trainer to have state, lora_config, model,
scheduler, mesh, and main_process() already set.
The new LoRATrainState is built through the same
eval_shape → logical_to_mesh_sharding → jit pipeline
that _init_state uses, so that LoRA params and optimizer
buffers are correctly sharded across devices.
After this call:
- trainer.state is a LoRATrainState whose params field
holds {"lora_A": ..., "lora_B": ...} (the only thing the
optimizer updates).
- trainer.state_sharding is updated to match the new state shape.
- trainer._in_lora_phase is True.