muon
muon
¶
scale_by_muon(momentum: float = 0.95, ns_steps: int = 5, beta2: float = 0.95) -> base.GradientTransformation
¶
Core Muon gradient transformation.
Applies three sequential operations to each update
- Nesterov momentum (EMA of gradients with look-ahead blend)
- Polar Express orthogonalization for 2-D+ parameters
- NorMuon per-row/col adaptive variance reduction
Parameters with fewer than 2 dimensions receive only the momentum step.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
momentum
|
float
|
EMA coefficient for the first-moment buffer. |
0.95
|
ns_steps
|
int
|
Number of Polar Express iterations (1–5; 5 recommended). |
5
|
beta2
|
float
|
EMA coefficient for the factored second-moment buffer. |
0.95
|
Returns:
| Name | Type | Description |
|---|---|---|
An |
GradientTransformation
|
class: |
muon(lr: optax._src.base.Schedule | float, cfg: MuonConfig, param_labels: Callable[[Any], Any] | None = None) -> optax.GradientTransformation
¶
Muon + AdamW mixed optimizer, mirroring the :func:adamw API.
Applies the Muon update (momentum → Polar-Express orthogonalization → NorMuon variance reduction) to matrix-shaped parameters, and standard AdamW to embedding, unembedding, and scalar parameters. Each group is scaled by its own LR multiplier relative to the base schedule.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
lr
|
Schedule | float
|
Base LR schedule ( |
required |
cfg
|
MuonConfig
|
:class: |
required |
param_labels
|
Callable[[Any], Any] | None
|
Optional callable |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
An |
GradientTransformation
|
class: |