Skip to content

lora

lora

LoRA (Low-Rank Adaptation) Trainer.

Two-phase trainer: phase 1 trains full parameters on pre-LoRA datasets, then freezes and injects low-rank adapters, phase 2 trains only LoRA parameters on post-LoRA datasets.

Since this extends RestoreableJob, users can restore from a checkpoint and set pre_lora_datasets/tokens to empty to skip straight to LoRA fine-tuning from a pretrained checkpoint.

LoRAConfig(rank: int = field('optimization/lora/rank', default=16), alpha: float = field('optimization/lora/alpha', default=16.0), target_modules: List[str] = field('optimization/lora/target_modules', default_factory=(lambda: ['kernel']))) dataclass

Low-rank adaptation parameters.

LoRATrainerConfig(batch_size: int = field('training/batch_size', default=512), per_device_batch_size: int = field('training/per_device_batch_size', default=(-1)), total_tokens: int = field('training/tokens', default=1000000000), lr: float = field('optimization/lr', default=0.0003), warmup_pct: float = field('training/warmup_pct', default=0.01), decay_pct: float = field('training/decay_pct', default=0.1), validate: bool = field('training/validation', default=True), evaluate: bool = field('training/evaluate', default=True), datasets: List[Sampling] = field('training/dataset', default_factory=(lambda: [Sampling(name='fineweb', rate=1, style=(DatasetStyle.PMD))])), evaluations: List[str] = field('eval/evaluations', default_factory=(lambda: [])), block_size: int = field('architecture/block_size', default=512), report_interval: int = field('logging/report_interval', default=32), checkpoint_interval: int = field('logging/checkpoint_interval', default=1024), validation_interval: int = field('logging/validation_interval', default=512), validation_steps: int = field('training/validation_steps', default=2048), wandb: bool = field('logging/wandb', default=False), pre_lora_tokens: List[int] = field('training/pre_lora_tokens', default_factory=(lambda: [1000000000])), pre_lora_datasets: List[List[Sampling]] = field('training/pre_lora_dataset', default_factory=(lambda: [[Sampling(name='fineweb', rate=1, style=(DatasetStyle.PMD))]])), post_lora_tokens: List[int] = field('training/post_lora_tokens', default_factory=(lambda: [100000000])), post_lora_datasets: List[List[Sampling]] = field('training/post_lora_dataset', default_factory=(lambda: [[Sampling(name='fineweb', rate=1, style=(DatasetStyle.PMD))]]))) dataclass

Bases: BaseTrainerConfig

Config for two-phase LoRA trainer.

Pre-LoRA phase trains full parameters, then LoRA adapters are injected and only they are trained in the post-LoRA phase.

LoRA hyperparameters (rank, alpha, target_modules) are read from the config namespace via configure(LoRAConfig) at init time — they live under optimization/lora/ in the YAML.

LoRATrainState

Bases: TrainState

Train state for LoRA phase.

params holds the trainable LoRA parameters as {"lora_A": pytree, "lora_B": pytree}. The optimizer in tx acts on params, so only the adapters are updated.

base_params holds the frozen pretrained weights (bfloat16).

LoRATrainer(spec: ExecutionSpec)

Bases: BaseTrainer[C, M], Generic[C, M]

Two-phase LoRA trainer.

Phase 1 (pre-LoRA): Full-parameter training on pre_lora_datasets. Phase 2 (post-LoRA): Freeze base, inject LoRA adapters, train only adapters on post_lora_datasets.

The transition happens automatically at the token boundary between pre-LoRA and post-LoRA phases.

restore_from_path(rel_path: str | Path) -> None

Restore from checkpoint, handling LoRA state type mismatch.

If the checkpoint was saved during the LoRA phase, the state is a LoRATrainState (params = {"lora_A", "lora_B"}, plus base_params). But _init_state creates a plain TrainState so the template tree won't match.

We peek at the checkpoint metadata to read the step count. If it's past the pre-LoRA boundary we transition first (creating a LoRATrainState template), then let the parent restore overwrite it with the checkpoint contents.

param_filter(params: PyTree[jax.Array], target_modules: List[str]) -> PyTree[bool]

Create a boolean mask over the param tree.

Returns True for leaves whose path contains any of target_modules. This is the filter that determines which parameters get LoRA adapters.

inject_lora_params(params: PyTree[jax.Array], mask: PyTree[bool], rank: int, key: jax.Array) -> Tuple[PyTree[Optional[jax.Array]], PyTree[Optional[jax.Array]]]

Create LoRA A/B matrices for each targeted parameter.

For a parameter of shape (in_features, out_features): - A: (in_features, rank) — initialized from normal(0, 1/rank) - B: (rank, out_features) — initialized to zeros

Returns two pytrees (lora_A, lora_B) with same structure as params. Non-targeted leaves are zero arrays (same shape as base param) so the tree structure is compatible with jax.tree_util operations.

merge_lora_params(base_params: PyTree[jax.Array], lora_A: PyTree[Any], lora_B: PyTree[Any], alpha: float, rank: int) -> PyTree[jax.Array]

Merge LoRA into base: W_eff = W + (alpha/rank) * A @ B.

No stop_gradient on base_params is needed: in the training loop value_and_grad differentiates only w.r.t. its argument (the LoRA params dict), so base_params — accessed from the closed-over state — is already treated as a constant by JAX.

count_lora_params(lora_A: PyTree[Any], lora_B: PyTree[Any]) -> int

Count total trainable LoRA parameters.

transition_to_lora(trainer: Any) -> None

Freeze base params and inject LoRA adapters on trainer.

Shared by LoRATrainer and BenchmarkLoRABaseTrainer. Expects the trainer to have state, lora_config, model, scheduler, mesh, and main_process() already set.

The new LoRATrainState is built through the same eval_shapelogical_to_mesh_shardingjit pipeline that _init_state uses, so that LoRA params and optimizer buffers are correctly sharded across devices.

After this call: - trainer.state is a LoRATrainState whose params field holds {"lora_A": ..., "lora_B": ...} (the only thing the optimizer updates). - trainer.state_sharding is updated to match the new state shape. - trainer._in_lora_phase is True.