Skip to content

topology

topology

Hardware information and topology representation for distributed JAX setups. Defines a Topology class that encapsulates device and process information, as well as JAX Mesh configuration.

Topology

Bases: BaseModel

new(chip: Chip, shard_into: int | None = None) -> Topology classmethod

Create a Topology instance based on the current JAX device configuration.

Parameters:

Name Type Description Default
chip Chip

The chip type being used.

required
shard_into int | None

Number of shards to divide the devices into for tensor parallelism. If None, defaults to local device count (shard evenly within each host). The SPMD/data parallel axis is determined automatically.

None