Skip to content

muon

muon

scale_by_muon(momentum: float = 0.95, ns_steps: int = 5, beta2: float = 0.95) -> base.GradientTransformation

Core Muon gradient transformation.

Applies three sequential operations to each update
  1. Nesterov momentum (EMA of gradients with look-ahead blend)
  2. Polar Express orthogonalization for 2-D+ parameters
  3. NorMuon per-row/col adaptive variance reduction

Parameters with fewer than 2 dimensions receive only the momentum step.

Parameters:

Name Type Description Default
momentum float

EMA coefficient for the first-moment buffer.

0.95
ns_steps int

Number of Polar Express iterations (1–5; 5 recommended).

5
beta2 float

EMA coefficient for the factored second-moment buffer.

0.95

Returns:

Name Type Description
An GradientTransformation

class:optax.GradientTransformation.

muon(lr: optax._src.base.Schedule | float, cfg: MuonConfig, param_labels: Callable[[Any], Any] | None = None) -> optax.GradientTransformation

Muon + AdamW mixed optimizer, mirroring the :func:adamw API.

Applies the Muon update (momentum → Polar-Express orthogonalization → NorMuon variance reduction) to matrix-shaped parameters, and standard AdamW to embedding, unembedding, and scalar parameters. Each group is scaled by its own LR multiplier relative to the base schedule.

Parameters:

Name Type Description Default
lr Schedule | float

Base LR schedule (step → float) or constant float. Matrix params receive lr * cfg.matrix_lr_multiplier.

required
cfg MuonConfig

:class:MuonConfig with all hyperparameters.

required
param_labels Callable[[Any], Any] | None

Optional callable params → labels_pytree. If None, :func:_label_params is used (path-keyword heuristic). Pass a custom callable for architectures with non-standard naming.

None

Returns:

Name Type Description
An GradientTransformation

class:optax.GradientTransformation.