Skip to content

optimizers

optimizers

scale_by_muon(momentum: float = 0.95, ns_steps: int = 5, beta2: float = 0.95) -> base.GradientTransformation

Core Muon gradient transformation.

Applies three sequential operations to each update
  1. Nesterov momentum (EMA of gradients with look-ahead blend)
  2. Polar Express orthogonalization for 2-D+ parameters
  3. NorMuon per-row/col adaptive variance reduction

Parameters with fewer than 2 dimensions receive only the momentum step.

Parameters:

Name Type Description Default
momentum float

EMA coefficient for the first-moment buffer.

0.95
ns_steps int

Number of Polar Express iterations (1–5; 5 recommended).

5
beta2 float

EMA coefficient for the factored second-moment buffer.

0.95

Returns:

Name Type Description
An GradientTransformation

class:optax.GradientTransformation.