Skip to content

base

base

Inference job for running inference on trained models.

Provides abstract base class InferenceJob that can be created from a trainer or loaded from a checkpoint.

InferenceJob(spec: ExecutionSpec)

Bases: RestoreableJob[C], Generic[C, M]

Abstract base for inference jobs. Must be subclassed with custom forward().

Subclasses define MODEL class attribute and forward() method. Users create instances via from_trainer() or from_checkpoint(), not init directly.

Example subclass

class GPTInference(InferenceJob): MODEL = GPT

@staticmethod
def forward(state, params, batch, key, deterministic):
    # Custom forward implementation
    ...

Attributes (set by from_trainer/from_checkpoint): state: TrainState with params mesh: JAX device mesh state_sharding: NamedSharding replicas, local_replicas, per_device_batch_size, block_size key: PRNG key

done: bool property

InferenceJob doesn't track completion state.

log(values: Dict[str, Any]) -> None

Log metric values through the attached plotter (if any).

Mirrors BaseTrainer.log so eval components can surface side metrics without knowing whether they were instantiated from a trainer or a bare checkpoint. No-op when plotter is None.

Step is taken from state.step (the optax optimizer-step counter, incremented once per state.apply_gradients call). This matches BaseTrainer.log, which uses global_step_counter_ // accumulate_steps — one global-step bump (= accumulate_steps micro-batches) corresponds to exactly one apply_gradients call, so the two counters are always equal during training. Reading state.step does a device→host sync; evals already run after a rollout barrier so the cost is negligible.

forward(state: train_state.TrainState, params: Any, batch: Tuple[jax.Array, Optional[jax.Array], jax.Array], key: Optional[jax.Array] = None, deterministic: bool = False, mutable: Optional[list[str] | tuple[str, ...]] = None, extra_variables: Optional[dict[str, Any]] = None, cache_max_len: Optional[int] = None) -> Any staticmethod

Forward pass with optional mutable variable collections (e.g. KV cache).

Parameters:

Name Type Description Default
mutable Optional[list[str] | tuple[str, ...]]

List of mutable variable collections (e.g. ['cache']). When provided, returns ((logits, loss), mutated_variables).

None
extra_variables Optional[dict[str, Any]]

Additional variable collections to pass alongside params (e.g. {'cache': cache_state} for decode steps).

None
cache_max_len Optional[int]

Forwarded to model __call__ so attention layers size their KV cache to actual decode need rather than the model's full block_size. Only forwarded when not None, so models that don't accept it are unaffected.

None

Returns:

Type Description
Any

(logits, loss, meta) when mutable is None.

Any

((logits, loss, meta), mutated_variables) when mutable is provided.

from_trainer(trainer: BaseTrainer[Any, Any]) -> Self classmethod

Create InferenceJob sharing trainer's state.

The InferenceJob references (not copies) trainer's state, so changes to trainer.state are reflected in the InferenceJob.

restore_from_path(rel_path: str | Path) -> None

Restore inference state from rel_path under checkpoints_dir.

Must be called within a configuration(cfg) context (as done by RestoreableJob.from_checkpoint_path). Initializes model, mesh, sharding, and loads checkpoint.

pad(seqs: List[List[int]], pad_token: int = 0, pad_to: Optional[int] = None) -> Tuple[np.ndarray, np.ndarray] staticmethod

Left-pad sequences to uniform length.

Parameters:

Name Type Description Default
seqs List[List[int]]

List of token id lists

required
pad_token int

Token to use for padding (default 0)

0
pad_to Optional[int]

Minimum length to pad to (default None, uses max seq length)

None

Returns:

Name Type Description
padded ndarray

(batch_size, max_len) host int32 array

mask ndarray

(batch_size, max_len) host bool array, True for real tokens

rollout(inputs: List[Union[str, ChatTemplate, jax.Array, List[int]]], encoding: Optional[Tokenizer] = None, max_new_tokens: Optional[int] = None, max_prompt_length: Optional[int] = None, temperature: float = 0.0, top_p: float = 1.0, chunk_size: int = 200, return_type: Literal['decoded', 'indices', 'output_decoded', 'output_indices', 'raw_indices'] = 'decoded') -> Union[List[Union[str, ChatTemplate]], List[str], List[List[int]]]

Autoregressive rollout of the language model.

Parameters:

Name Type Description Default
inputs List[Union[str, ChatTemplate, Array, List[int]]]

List of raw strings, ChatTemplates, or pre-tokenized 1D jax arrays of token ids.

required
encoding Optional[Tokenizer]

Tokenizer for encoding/decoding. Required when any input is a string/ChatTemplate or when return_type is one of the decoded variants. May be omitted when all inputs are pre-tokenized arrays AND return_type is an indices variant.

None
max_new_tokens Optional[int]

Maximum number of new tokens to generate. Defaults to block_size - max_prompt_length (or block_size // 2 if max_prompt_length is also unset).

None
max_prompt_length Optional[int]

Length to which prompts are padded. Defaults to block_size - max_new_tokens. Inputs longer than this raise. Holding this fixed across calls keeps the JIT trace stable — varying it triggers recompiles.

None
temperature float

Sampling temperature (0.0 for greedy).

0.0
top_p float

Nucleus sampling threshold.

1.0
chunk_size int

Number of batches per JIT chunk.

200
return_type
  • "decoded": full prompt + generated tokens, decoded, with left-pad stripped.
  • "indices": full prompt + generated tokens as ids, with left-pad stripped.
  • "output_decoded": generated portion only, decoded.
  • "output_indices": generated portion only as ids.
  • "raw_indices": full fixed-shape rows, left padding preserved. Shape is logically (N, max_prompt_length + max_new_tokens).