Skip to content

utils

utils

Small utilities useful during training.

estimate_per_device_batch_size(chip_memory: int, total_params_millions: float, shards: int, block_size: int, vram_calib_factor: float) -> int

Estimate max per-device batch size based on VRAM.

Memory breakdown (per param, bf16 training with AdamW): - Params: 2 bytes (bf16) - Gradients: 4 bytes (fp32 for accumulation) - Master weights: 4 bytes (fp32 copy for optimizer) - Optimizer m: 4 bytes (fp32 first moment) - Optimizer v: 4 bytes (fp32 second moment) Total: ~18 bytes/param (all sharded by tensor parallelism)

Activations: scales with batch * seq * sqrt(params)

Parameters:

Name Type Description Default
chip_memory int

bytes of VRAM per device

required
total_params_millions float

model parameters in millions

required
shards int

tensor parallel shards

required
block_size int

sequence length

required
vram_calib_factor float

user-tuned calibration for activation memory

required

Returns:

Type Description
int

Estimated batch size (at least 1)

find_accumulation_steps(batch_size: int, per_device_batch_size: int, topology: Topology) -> tuple[int, int]

Finds the largest per-device batch size and corresponding number of gradient accumulation steps

Parameters:

Name Type Description Default
batch_size int

Global batch size

required
per_device_batch_size int

Maximum per-device batch size

required
topology Topology

Topology object containing replica information

required

Returns:

Type Description
tuple[int, int]

Tuple[int, int]: per-device batch size and number of gradient accumulation steps