Skip to content

experiments

experiments

ExecutionSpec

Bases: JobSpec

actually allocated specification for a job

Evaluator(spec: ExecutionSpec)

Bases: InferenceJob[EvaluatorConfig, M], Generic[M]

InferenceJob that runs evaluations and saves results.

Created from a trainer or checkpoint, holds a list of evaluations, and runs them when run() is called.

Example

evaluator = Evaluator.from_trainer(trainer, evaluations, encoding) evaluator() # Runs evaluations and saves results

done: bool property

Check if evaluation results already exist.

from_trainer(trainer: BaseTrainer[Any, Any], config: Optional[Any] = None) -> Evaluator[M] classmethod

Create Evaluator from trainer.

Parameters:

Name Type Description Default
trainer BaseTrainer[Any, Any]

BaseTrainer instance to get inference state from

required
config Optional[Any]

Optional config object whose .components field names the evaluations to run. If None, hydrates EvaluatorConfig from the global config. Pass an RLEvaluatorConfig to build a separate evaluator for RL rollouts.

None

Returns:

Type Description
Evaluator[M]

Evaluator instance ready to run evaluations

from_checkpoint(suffix: str | Path, spec: ExecutionSpec, runtime_cfg: Any | None = None, resume: bool = False) -> Tuple[Evaluator[M], Any] classmethod

Create Evaluator from checkpoint.

Parameters:

Name Type Description Default
suffix str | Path

Checkpoint suffix

required
spec ExecutionSpec

ExecutionSpec with topology

required
runtime_cfg Any | None

Optional runtime config overlay

None
resume bool

If True, restore spec identity (including wandb id) from the checkpoint for idempotent job resumption.

False

Returns:

Type Description
Tuple[Evaluator[M], Any]

(evaluator, config) tuple

evaluate(reduce: str = 'mean', return_intermediates: bool = False, **kwargs: Any) -> Any

Run all evaluations.

Parameters:

Name Type Description Default
reduce str

passed through to each evaluation. "mean"/"sum" → float per evaluation; "none" → np.ndarray of per-sample scores.

'mean'
return_intermediates bool

when True, also return the per-evaluation list of (x, mask) rollouts (one inner list per evaluation).

False
**kwargs Any

forwarded to each evaluation's call (e.g. temperature, top_p, chunk_size).

{}

run() -> None

Run all evaluations and save results to disk.

GPT

Bases: Module

embed(idx: jax.Array, deterministic: bool = False, **kwargs: Any) -> Any

Compute token and positional embeddings given inputs.

decode(x: jax.Array, padding_mask: Optional[jax.Array] = None, deterministic: bool = False, **kwargs: Any) -> Any

Compute decoded residual channels given embeddings.

unembed(x: jax.Array) -> Any

Compute output distribution.

loss(logits: jax.Array, targets: jax.Array) -> jax.Array

Compute cross-entropy loss given logits and targets.

__call__(idx: jax.Array, targets: Optional[jax.Array] = None, padding_mask: Optional[jax.Array] = None, deterministic: bool = False, **kwargs: Any) -> Tuple[jax.Array, Optional[jax.Array]]

Parameters:

Name Type Description Default
idx Array

Input token indices of shape (B, T).

required
targets Optional[Array]

Target token indices of shape (B, T). Use -1 to ignore positions.

None
padding_mask Optional[Array]

Boolean tensor of shape (B, T). True for valid tokens, False for padding tokens.

None
deterministic bool

If False, applies dropout.

False

Returns:

Name Type Description
logits Array

Output logits of shape (B, T, vocab_size).

loss Optional[Array]

Cross-entropy loss if targets provided, else None.

Module

Bases: Module

plot(intermediates: Any) -> Dict[str, Any] staticmethod

intermediates -> [figure]

sharding() -> List[Tuple[str, Optional[Any]]]

Return the sharding configuration for this module.

Returns:

Type Description
List[Tuple[str, Optional[Any]]]

Tuple[Tuple[str, Optional[Axis]]]: A dictionary mapping Axes to sharding dimensions.

components() -> List[Type[Any]] classmethod

Return the types of constituent parts of this module.

Returns:

Name Type Description
Type List[Type[Any]]

A type or tuple of types representing the constituent parts.

gather() -> List[Type[Any]] classmethod

Depth-first search of all constituent parts of this module.

Returns:

Type Description
List[Type[Any]]

List[Type]: A list of all constituent part types.

Evaluate(spec: ExecutionSpec)

Bases: Evaluator[GPT]

Run the configured evaluation suite against a GPT checkpoint.

Invoke with --restore <ckpt>; the checkpoint's state is loaded via InferenceJob.restore_from_path and the evaluator-specific fields (encoding / evaluations / sampling rng) are wired up on top.

BackboneEvaluate(spec: ExecutionSpec)

Bases: Evaluate

Run the configured evaluation suite against a HuggingFace backbone.

configure(cls: Any, **overrides: Any) -> Any

Hydrates a dataclass from the current Configurate context.

Must be called within a Configurate context manager.

Parameters:

Name Type Description Default
cls Any

Dataclass type to instantiate.

required
**overrides Any

Per-instance overrides applied after config-derived values.

{}

Returns:

Type Description
Any

Instantiated dataclass with values from the current config.

job(key: str) -> Callable[[T], T]

Register a job class under the given key.