experiments
experiments
¶
Evaluator(spec: ExecutionSpec)
¶
Bases: InferenceJob[EvaluatorConfig, M], Generic[M]
InferenceJob that runs evaluations and saves results.
Created from a trainer or checkpoint, holds a list of evaluations, and runs them when run() is called.
Example
evaluator = Evaluator.from_trainer(trainer, evaluations, encoding) evaluator() # Runs evaluations and saves results
done: bool
property
¶
Check if evaluation results already exist.
from_trainer(trainer: BaseTrainer[Any, Any], config: Optional[Any] = None) -> Evaluator[M]
classmethod
¶
Create Evaluator from trainer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
trainer
|
BaseTrainer[Any, Any]
|
BaseTrainer instance to get inference state from |
required |
config
|
Optional[Any]
|
Optional config object whose |
None
|
Returns:
| Type | Description |
|---|---|
Evaluator[M]
|
Evaluator instance ready to run evaluations |
from_checkpoint(suffix: str | Path, spec: ExecutionSpec, runtime_cfg: Any | None = None, resume: bool = False) -> Tuple[Evaluator[M], Any]
classmethod
¶
Create Evaluator from checkpoint.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
suffix
|
str | Path
|
Checkpoint suffix |
required |
spec
|
ExecutionSpec
|
ExecutionSpec with topology |
required |
runtime_cfg
|
Any | None
|
Optional runtime config overlay |
None
|
resume
|
bool
|
If |
False
|
Returns:
| Type | Description |
|---|---|
Tuple[Evaluator[M], Any]
|
(evaluator, config) tuple |
evaluate(reduce: str = 'mean', return_intermediates: bool = False, **kwargs: Any) -> Any
¶
Run all evaluations.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
reduce
|
str
|
passed through to each evaluation. "mean"/"sum" → float per evaluation; "none" → np.ndarray of per-sample scores. |
'mean'
|
return_intermediates
|
bool
|
when True, also return the per-evaluation list of (x, mask) rollouts (one inner list per evaluation). |
False
|
**kwargs
|
Any
|
forwarded to each evaluation's call (e.g. temperature, top_p, chunk_size). |
{}
|
run() -> None
¶
Run all evaluations and save results to disk.
GPT
¶
Bases: Module
embed(idx: jax.Array, deterministic: bool = False, **kwargs: Any) -> Any
¶
Compute token and positional embeddings given inputs.
decode(x: jax.Array, padding_mask: Optional[jax.Array] = None, deterministic: bool = False, **kwargs: Any) -> Any
¶
Compute decoded residual channels given embeddings.
unembed(x: jax.Array) -> Any
¶
Compute output distribution.
loss(logits: jax.Array, targets: jax.Array) -> jax.Array
¶
Compute cross-entropy loss given logits and targets.
__call__(idx: jax.Array, targets: Optional[jax.Array] = None, padding_mask: Optional[jax.Array] = None, deterministic: bool = False, **kwargs: Any) -> Tuple[jax.Array, Optional[jax.Array]]
¶
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
idx
|
Array
|
Input token indices of shape (B, T). |
required |
targets
|
Optional[Array]
|
Target token indices of shape (B, T). Use -1 to ignore positions. |
None
|
padding_mask
|
Optional[Array]
|
Boolean tensor of shape (B, T). True for valid tokens, False for padding tokens. |
None
|
deterministic
|
bool
|
If False, applies dropout. |
False
|
Returns:
| Name | Type | Description |
|---|---|---|
logits |
Array
|
Output logits of shape (B, T, vocab_size). |
loss |
Optional[Array]
|
Cross-entropy loss if targets provided, else None. |
Module
¶
Bases: Module
plot(intermediates: Any) -> Dict[str, Any]
staticmethod
¶
intermediates -> [figure]
sharding() -> List[Tuple[str, Optional[Any]]]
¶
Return the sharding configuration for this module.
Returns:
| Type | Description |
|---|---|
List[Tuple[str, Optional[Any]]]
|
Tuple[Tuple[str, Optional[Axis]]]: A dictionary mapping Axes to sharding dimensions. |
components() -> List[Type[Any]]
classmethod
¶
Return the types of constituent parts of this module.
Returns:
| Name | Type | Description |
|---|---|---|
Type |
List[Type[Any]]
|
A type or tuple of types representing the constituent parts. |
gather() -> List[Type[Any]]
classmethod
¶
Depth-first search of all constituent parts of this module.
Returns:
| Type | Description |
|---|---|
List[Type[Any]]
|
List[Type]: A list of all constituent part types. |
Evaluate(spec: ExecutionSpec)
¶
BackboneEvaluate(spec: ExecutionSpec)
¶
configure(cls: Any, **overrides: Any) -> Any
¶
Hydrates a dataclass from the current Configurate context.
Must be called within a Configurate context manager.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
cls
|
Any
|
Dataclass type to instantiate. |
required |
**overrides
|
Any
|
Per-instance overrides applied after config-derived values. |
{}
|
Returns:
| Type | Description |
|---|---|
Any
|
Instantiated dataclass with values from the current config. |
job(key: str) -> Callable[[T], T]
¶
Register a job class under the given key.