Skip to content

inference

inference

InferenceJob(spec: ExecutionSpec)

Bases: CheckpointedJob[C], Generic[C, M]

Abstract base for inference jobs. Must be subclassed with custom forward().

Subclasses define MODEL class attribute and forward() method. Users create instances via from_trainer() or from_checkpoint(), not init directly.

Example subclass

class GPTInference(InferenceJob): MODEL = GPT

@staticmethod
def forward(state, params, batch, key, deterministic):
    # Custom forward implementation
    ...

Attributes (set by from_trainer/from_checkpoint): state: TrainState with params mesh: JAX device mesh state_sharding: NamedSharding replicas, local_replicas, per_device_batch_size, block_size key: PRNG key

Direct init not supported - use from_trainer() or from_checkpoint().

done: bool property

InferenceJob doesn't track completion state.

forward(state: train_state.TrainState, params: Any, batch: Tuple[jax.Array, Optional[jax.Array], jax.Array], key: Optional[jax.Array] = None, deterministic: bool = False, mutable: Optional[list[str] | tuple[str, ...]] = None, extra_variables: Optional[dict[str, Any]] = None) -> Any staticmethod

Forward pass with optional mutable variable collections (e.g. KV cache).

Parameters:

Name Type Description Default
mutable Optional[list[str] | tuple[str, ...]]

List of mutable variable collections (e.g. ['cache']). When provided, returns ((logits, loss), mutated_variables).

None
extra_variables Optional[dict[str, Any]]

Additional variable collections to pass alongside params (e.g. {'cache': cache_state} for decode steps).

None

Returns:

Type Description
Any

(logits, loss, meta) when mutable is None.

Any

((logits, loss, meta), mutated_variables) when mutable is provided.

from_trainer(trainer: BaseTrainer[Any, Any]) -> Self classmethod

Create InferenceJob sharing trainer's state.

The InferenceJob references (not copies) trainer's state, so changes to trainer.state are reflected in the InferenceJob.

from_checkpoint(suffix: str | Path, spec: ExecutionSpec) -> Tuple[Self, Any] classmethod

Load InferenceJob from checkpoint using CheckpointedJob infrastructure.

Uses cls.MODEL for model initialization and sharding. Calls get_tree_and_metadata() for checkpoint restoration.

Parameters:

Name Type Description Default
suffix str | Path

Checkpoint suffix

required
spec ExecutionSpec

ExecutionSpec with topology

required

Returns:

Type Description
Tuple[Self, Any]

(job, config) tuple

pad(seqs: List[List[int]], pad_token: int = 0, pad_to: Optional[int] = None) -> Tuple[jax.Array, jax.Array] staticmethod

Left-pad sequences to uniform length.

Parameters:

Name Type Description Default
seqs List[List[int]]

List of token id lists

required
pad_token int

Token to use for padding (default 0)

0
pad_to Optional[int]

Minimum length to pad to (default None, uses max seq length)

None

Returns:

Name Type Description
padded Array

(batch_size, max_len) jnp array

mask Array

(batch_size, max_len) jnp bool array, True for real tokens

rollout(inputs: List[Union[str, ChatTemplate]], encoding: Tokenizer, max_new_tokens: Optional[int] = None, temperature: float = 0.0, top_p: float = 1.0, chunk_size: int = 200) -> List[Union[str, ChatTemplate]]

Autoregressive rollout of the language model.

Parameters:

Name Type Description Default
inputs List[Union[str, ChatTemplate]]

List of raw strings or ChatTemplates to complete.

required
encoding Tokenizer

Tokenizer for encoding/decoding.

required
max_new_tokens Optional[int]

Maximum number of new tokens to generate. Defaults to block_size.

None
temperature float

Sampling temperature (0.0 for greedy).

0.0
top_p float

Nucleus sampling threshold.

1.0
chunk_size int

Number of batches per JIT chunk.

200

Returns:

Type Description
List[Union[str, ChatTemplate]]

List of completed strings or ChatTemplates matching input types.