inference
inference
¶
InferenceJob(spec: ExecutionSpec)
¶
Bases: RestoreableJob[C], Generic[C, M]
Abstract base for inference jobs. Must be subclassed with custom forward().
Subclasses define MODEL class attribute and forward() method. Users create instances via from_trainer() or from_checkpoint(), not init directly.
Example subclass
class GPTInference(InferenceJob): MODEL = GPT
@staticmethod
def forward(state, params, batch, key, deterministic):
# Custom forward implementation
...
Attributes (set by from_trainer/from_checkpoint): state: TrainState with params mesh: JAX device mesh state_sharding: NamedSharding replicas, local_replicas, per_device_batch_size, block_size key: PRNG key
done: bool
property
¶
InferenceJob doesn't track completion state.
log(values: Dict[str, Any]) -> None
¶
Log metric values through the attached plotter (if any).
Mirrors BaseTrainer.log so eval components can surface side metrics
without knowing whether they were instantiated from a trainer or a bare
checkpoint. No-op when plotter is None.
Step is taken from state.step (the optax optimizer-step counter,
incremented once per state.apply_gradients call). This matches
BaseTrainer.log, which uses global_step_counter_ // accumulate_steps
— one global-step bump (= accumulate_steps micro-batches) corresponds
to exactly one apply_gradients call, so the two counters are always
equal during training. Reading state.step does a device→host sync;
evals already run after a rollout barrier so the cost is negligible.
forward(state: train_state.TrainState, params: Any, batch: Tuple[jax.Array, Optional[jax.Array], jax.Array], key: Optional[jax.Array] = None, deterministic: bool = False, mutable: Optional[list[str] | tuple[str, ...]] = None, extra_variables: Optional[dict[str, Any]] = None, cache_max_len: Optional[int] = None) -> Any
staticmethod
¶
Forward pass with optional mutable variable collections (e.g. KV cache).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mutable
|
Optional[list[str] | tuple[str, ...]]
|
List of mutable variable collections (e.g. ['cache']). When provided, returns ((logits, loss), mutated_variables). |
None
|
extra_variables
|
Optional[dict[str, Any]]
|
Additional variable collections to pass alongside params (e.g. {'cache': cache_state} for decode steps). |
None
|
cache_max_len
|
Optional[int]
|
Forwarded to model |
None
|
Returns:
| Type | Description |
|---|---|
Any
|
(logits, loss, meta) when mutable is None. |
Any
|
((logits, loss, meta), mutated_variables) when mutable is provided. |
from_trainer(trainer: BaseTrainer[Any, Any]) -> Self
classmethod
¶
Create InferenceJob sharing trainer's state.
The InferenceJob references (not copies) trainer's state, so changes to trainer.state are reflected in the InferenceJob.
restore_from_path(rel_path: str | Path) -> None
¶
Restore inference state from rel_path under checkpoints_dir.
Must be called within a configuration(cfg) context (as done by
RestoreableJob.from_checkpoint_path). Initializes model, mesh,
sharding, and loads checkpoint.
pad(seqs: List[List[int]], pad_token: int = 0, pad_to: Optional[int] = None) -> Tuple[np.ndarray, np.ndarray]
staticmethod
¶
Left-pad sequences to uniform length.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
seqs
|
List[List[int]]
|
List of token id lists |
required |
pad_token
|
int
|
Token to use for padding (default 0) |
0
|
pad_to
|
Optional[int]
|
Minimum length to pad to (default None, uses max seq length) |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
padded |
ndarray
|
(batch_size, max_len) host int32 array |
mask |
ndarray
|
(batch_size, max_len) host bool array, True for real tokens |
rollout(inputs: List[Union[str, ChatTemplate, jax.Array, List[int]]], encoding: Optional[Tokenizer] = None, max_new_tokens: Optional[int] = None, max_prompt_length: Optional[int] = None, temperature: float = 0.0, top_p: float = 1.0, chunk_size: int = 200, return_type: Literal['decoded', 'indices', 'output_decoded', 'output_indices', 'raw_indices'] = 'decoded') -> Union[List[Union[str, ChatTemplate]], List[str], List[List[int]]]
¶
Autoregressive rollout of the language model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
inputs
|
List[Union[str, ChatTemplate, Array, List[int]]]
|
List of raw strings, ChatTemplates, or pre-tokenized 1D jax arrays of token ids. |
required |
encoding
|
Optional[Tokenizer]
|
Tokenizer for encoding/decoding. Required when any input
is a string/ChatTemplate or when |
None
|
max_new_tokens
|
Optional[int]
|
Maximum number of new tokens to generate. Defaults
to |
None
|
max_prompt_length
|
Optional[int]
|
Length to which prompts are padded. Defaults to
|
None
|
temperature
|
float
|
Sampling temperature (0.0 for greedy). |
0.0
|
top_p
|
float
|
Nucleus sampling threshold. |
1.0
|
chunk_size
|
int
|
Number of batches per JIT chunk. |
200
|
return_type
- "decoded": full prompt + generated tokens, decoded, with left-pad stripped.
- "indices": full prompt + generated tokens as ids, with left-pad stripped.
- "output_decoded": generated portion only, decoded.
- "output_indices": generated portion only as ids.
- "raw_indices": full fixed-shape rows, left padding preserved. Shape is logically (N, max_prompt_length + max_new_tokens).
TTTInferenceJob(spec: ExecutionSpec)
¶
Bases: InferenceJob[Any, Any]
InferenceJob variant that also mutates the "fast_weights" collection.
All other behavior — sharding, rollout, autoregressive decoding, padding —
is inherited unchanged. The only override is forward, which auto-adds
"fast_weights" to mutable whenever the caller requests any mutation
(typically the KV cache during prefill/decode).
forward(state: Any, params: Any, batch: Tuple[Any, Optional[jax.Array], jax.Array], key: Optional[jax.Array] = None, deterministic: bool = False, mutable: Optional[list[str] | tuple[str, ...]] = None, extra_variables: Optional[dict[str, Any]] = None, cache_max_len: Optional[int] = None) -> Any
staticmethod
¶
Forward with auto-paired cache + fast_weights mutation.
We always pair the two collections — there is no LaCT inference path
that wants KV-cache persistence but not fast-weight persistence (or
vice versa). When mutable is None (e.g. teacher-forced
perplexity eval), neither collection is mutated and the model's
_ttt branch takes the pure-functional path with W = W_0.