utils
utils
¶
Small utilities useful during training.
estimate_per_device_batch_size(chip_memory: int, total_params_millions: float, shards: int, block_size: int, vram_calib_factor: float) -> int
¶
Estimate max per-device batch size based on VRAM.
Memory breakdown (per param, bf16 training with AdamW): - Params: 2 bytes (bf16) - Gradients: 4 bytes (fp32 for accumulation) - Master weights: 4 bytes (fp32 copy for optimizer) - Optimizer m: 4 bytes (fp32 first moment) - Optimizer v: 4 bytes (fp32 second moment) Total: ~18 bytes/param (all sharded by tensor parallelism)
Activations: scales with batch * seq * sqrt(params)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
chip_memory
|
int
|
bytes of VRAM per device |
required |
total_params_millions
|
float
|
model parameters in millions |
required |
shards
|
int
|
tensor parallel shards |
required |
block_size
|
int
|
sequence length |
required |
vram_calib_factor
|
float
|
user-tuned calibration for activation memory |
required |
Returns:
| Type | Description |
|---|---|
int
|
Estimated batch size (at least 1) |
find_accumulation_steps(batch_size: int, per_device_batch_size: int, topology: Topology) -> tuple[int, int]
¶
Finds the largest per-device batch size and corresponding number of gradient accumulation steps
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch_size
|
int
|
Global batch size |
required |
per_device_batch_size
|
int
|
Maximum per-device batch size |
required |
topology
|
Topology
|
Topology object containing replica information |
required |
Returns:
| Type | Description |
|---|---|
tuple[int, int]
|
Tuple[int, int]: per-device batch size and number of gradient accumulation steps |