Skip to content

base

base

Topology

Bases: BaseModel

new(chip: Chip, shard_into: int | None = None) -> Topology classmethod

Create a Topology instance based on the current JAX device configuration.

Parameters:

Name Type Description Default
chip Chip

The chip type being used.

required
shard_into int | None

Number of shards to divide the devices into for tensor parallelism. If None, defaults to local device count (shard evenly within each host). The SPMD/data parallel axis is determined automatically.

None

JobSpec

Bases: BaseModel

user provided specification for a job

ExecutionSpec

Bases: JobSpec

actually allocated specification for a job

local(root_dir: str, work_dir: str) -> HardwareResult

Detect hardware using JAX device information. Creates one ClusterMachine per host, with host index matching jax.process_index(). Assumes paths are identical across all hosts.