base
base
¶
Inference job for running inference on trained models.
Provides abstract base class InferenceJob that can be created from a trainer or loaded from a checkpoint.
InferenceJob(spec: ExecutionSpec)
¶
Bases: CheckpointedJob[C], Generic[C, M]
Abstract base for inference jobs. Must be subclassed with custom forward().
Subclasses define MODEL class attribute and forward() method. Users create instances via from_trainer() or from_checkpoint(), not init directly.
Example subclass
class GPTInference(InferenceJob): MODEL = GPT
@staticmethod
def forward(state, params, batch, key, deterministic):
# Custom forward implementation
...
Attributes (set by from_trainer/from_checkpoint): state: TrainState with params mesh: JAX device mesh state_sharding: NamedSharding replicas, local_replicas, per_device_batch_size, block_size key: PRNG key
Direct init not supported - use from_trainer() or from_checkpoint().
done: bool
property
¶
InferenceJob doesn't track completion state.
forward(state: train_state.TrainState, params: Any, batch: Tuple[jax.Array, Optional[jax.Array], jax.Array], key: Optional[jax.Array] = None, deterministic: bool = False, mutable: Optional[list[str] | tuple[str, ...]] = None, extra_variables: Optional[dict[str, Any]] = None) -> Any
staticmethod
¶
Forward pass with optional mutable variable collections (e.g. KV cache).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mutable
|
Optional[list[str] | tuple[str, ...]]
|
List of mutable variable collections (e.g. ['cache']). When provided, returns ((logits, loss), mutated_variables). |
None
|
extra_variables
|
Optional[dict[str, Any]]
|
Additional variable collections to pass alongside params (e.g. {'cache': cache_state} for decode steps). |
None
|
Returns:
| Type | Description |
|---|---|
Any
|
(logits, loss, meta) when mutable is None. |
Any
|
((logits, loss, meta), mutated_variables) when mutable is provided. |
from_trainer(trainer: BaseTrainer[Any, Any]) -> Self
classmethod
¶
Create InferenceJob sharing trainer's state.
The InferenceJob references (not copies) trainer's state, so changes to trainer.state are reflected in the InferenceJob.
from_checkpoint(suffix: str | Path, spec: ExecutionSpec) -> Tuple[Self, Any]
classmethod
¶
Load InferenceJob from checkpoint using CheckpointedJob infrastructure.
Uses cls.MODEL for model initialization and sharding. Calls get_tree_and_metadata() for checkpoint restoration.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
suffix
|
str | Path
|
Checkpoint suffix |
required |
spec
|
ExecutionSpec
|
ExecutionSpec with topology |
required |
Returns:
| Type | Description |
|---|---|
Tuple[Self, Any]
|
(job, config) tuple |
pad(seqs: List[List[int]], pad_token: int = 0, pad_to: Optional[int] = None) -> Tuple[jax.Array, jax.Array]
staticmethod
¶
Left-pad sequences to uniform length.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
seqs
|
List[List[int]]
|
List of token id lists |
required |
pad_token
|
int
|
Token to use for padding (default 0) |
0
|
pad_to
|
Optional[int]
|
Minimum length to pad to (default None, uses max seq length) |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
padded |
Array
|
(batch_size, max_len) jnp array |
mask |
Array
|
(batch_size, max_len) jnp bool array, True for real tokens |
rollout(inputs: List[Union[str, ChatTemplate]], encoding: Tokenizer, max_new_tokens: Optional[int] = None, temperature: float = 0.0, top_p: float = 1.0, chunk_size: int = 200) -> List[Union[str, ChatTemplate]]
¶
Autoregressive rollout of the language model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
inputs
|
List[Union[str, ChatTemplate]]
|
List of raw strings or ChatTemplates to complete. |
required |
encoding
|
Tokenizer
|
Tokenizer for encoding/decoding. |
required |
max_new_tokens
|
Optional[int]
|
Maximum number of new tokens to generate. Defaults to block_size. |
None
|
temperature
|
float
|
Sampling temperature (0.0 for greedy). |
0.0
|
top_p
|
float
|
Nucleus sampling threshold. |
1.0
|
chunk_size
|
int
|
Number of batches per JIT chunk. |
200
|
Returns:
| Type | Description |
|---|---|
List[Union[str, ChatTemplate]]
|
List of completed strings or ChatTemplates matching input types. |