base
base
¶
Topology
¶
Bases: BaseModel
new(chip: Chip, shard_into: int | None = None) -> Topology
classmethod
¶
Create a Topology instance based on the current JAX device configuration.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
chip
|
Chip
|
The chip type being used. |
required |
shard_into
|
int | None
|
Number of shards to divide the devices into for tensor parallelism. If None, defaults to local device count (shard evenly within each host). The SPMD/data parallel axis is determined automatically. |
None
|
JobSpec
¶
Bases: BaseModel
user provided specification for a job
local(root_dir: str, work_dir: str) -> HardwareResult
¶
Detect hardware using JAX device information. Creates one ClusterMachine per host, with host index matching jax.process_index(). Assumes paths are identical across all hosts.