Skip to content

base

base

Evaluation framework for theseus trainers.

Provides abstract base classes for different evaluation types: - RolloutEvaluation: Autoregressive generation tasks - EncodingEvaluation: Next-token prediction accuracy - PerplexityEvaluation: Dataset perplexity (returns 1/ppl, higher is better) - PerplexityComparisonEvaluation: Multiple-choice via perplexity comparison

Also provides: - Evaluator: InferenceJob subclass that runs multiple evaluations

Evaluation

Bases: ABC

Abstract base class for all evaluations.

name: str abstractmethod property

Name of this evaluation.

prefix() -> str

Prefix for metrics from this evaluation.

__len__() -> int abstractmethod

Number of samples in this evaluation.

__call__(inference: InferenceJob[Any, M], encoding: Any, **kwargs: Any) -> float abstractmethod

Run the evaluation and return a score.

find_accumulation_steps(dataset_size: int, max_batch_size: int, dp_replicate: int) -> Tuple[int, int] | Tuple[None, None] staticmethod

Find batch size and accumulation steps that evenly divide dataset.

Parameters:

Name Type Description Default
dataset_size int

Total number of samples

required
max_batch_size int

Maximum per-device batch size

required
dp_replicate int

Data parallel replication factor

required

Returns:

Type Description
Tuple[int, int] | Tuple[None, None]

(batch_size, accumulation_steps) or (None, None) if no valid size found

RolloutEvaluation

Bases: Evaluation

Evaluation using autoregressive generation.

score(ys: list[str], y_hats: list[str]) -> float

Compute score from generated results.

Parameters:

Name Type Description Default
ys list[str]

Ground truth strings

required
y_hats list[str]

Generated results

required

Returns:

Type Description
float

Score (higher is better)

check(y: str, y_hat: str) -> bool

Check if y_hat matches y.

Parameters:

Name Type Description Default
y str

Ground truth

required
y_hat str

Generated result

required

Returns:

Type Description
bool

Whether y_hat matches y

clean(y_hat: str) -> str abstractmethod

Clean generated result before checking.

Parameters:

Name Type Description Default
y_hat str

Generated result, which can include the prompt

required

Returns:

Type Description
str

Cleaned/normalized result available for comparison

get(indx: int) -> Tuple[str, str] abstractmethod

Get sample at index.

Returns:

Type Description
Tuple[str, str]

(input_string, expected_output_string)

max_new_tokens(inference: InferenceJob[Any, M]) -> int

Maximum tokens to generate. Override in subclasses for shorter rollouts.

Default is full block_size, but most evaluations only need ~10-100 tokens.

__call__(inference: InferenceJob[Any, M], encoding: Any, temperature: float = 0.0, top_p: float = 1.0, chunk_size: int = 200, **kwargs: Any) -> float

Run evaluation.

Parameters:

Name Type Description Default
inference InferenceJob[Any, M]

InferenceJob instance for running inference

required
encoding Any

Tokenizer with encode_batch/decode_batch methods

required
temperature float

Sampling temperature (0.0 for greedy)

0.0
top_p float

Nucleus sampling threshold

1.0
chunk_size int

Number of batches per JIT chunk (default 200)

200

Returns:

Type Description
float

Evaluation score

EncodingEvaluation

Bases: Evaluation

Evaluation using next-token prediction accuracy.

score(xs: list[str], y_hats: list[str]) -> float

Compute score from input and model predictions.

Parameters:

Name Type Description Default
xs list[str]

Input strings

required
y_hats list[str]

Model predictions (argmax of logits, shifted by 1)

required

Returns:

Type Description
float

Score (higher is better)

check(x: str, y_hat: str) -> bool

Check if prediction is correct given input.

Parameters:

Name Type Description Default
x str

Input string

required
y_hat str

Model prediction (cleaned, decoded argmax)

required

Returns:

Type Description
bool

Whether prediction is correct

clean(y_hat: str) -> str abstractmethod

Clean model prediction before checking.

Parameters:

Name Type Description Default
y_hat str

Raw decoded model prediction

required

Returns:

Type Description
str

Cleaned/normalized result available for comparison

get(indx: int) -> str abstractmethod

Get input string at index.

__call__(inference: InferenceJob[Any, M], encoding: Any, chunk_size: int = 200, **kwargs: Any) -> float

Run evaluation.

Parameters:

Name Type Description Default
inference InferenceJob[Any, M]

InferenceJob instance for running inference

required
encoding Any

Tokenizer with encode_batch/decode_batch methods

required
chunk_size int

Number of batches per JIT chunk (default 200)

200

Returns:

Type Description
float

Evaluation score

PerplexityEvaluation

Bases: Evaluation

Evaluation that computes dataset perplexity and returns 1/ppl (higher is better).

Runs a blockwise forward pass like EncodingEvaluation, computes the mean negative log-likelihood over all non-padding tokens, and returns 1/perplexity.

get(indx: int) -> str abstractmethod

Get input string at index.

__call__(inference: InferenceJob[Any, M], encoding: Any, chunk_size: int = 200, **kwargs: Any) -> float

Run evaluation.

Parameters:

Name Type Description Default
inference InferenceJob[Any, M]

InferenceJob instance for running inference

required
encoding Any

Tokenizer with encode_batch methods

required
chunk_size int

Number of batches per JIT chunk (default 200)

200

Returns:

Type Description
float

1/perplexity (higher is better)

PerplexityComparisonEvaluation

Bases: Evaluation

Evaluation using perplexity comparison for multiple-choice tasks.

get(indx: int) -> Tuple[str, list[str], int] abstractmethod

Get sample at index.

Returns:

Type Description
Tuple[str, list[str], int]

(prefix, list_of_continuations, correct_index)

__call__(inference: InferenceJob[Any, M], encoding: Any, chunk_size: int = 200, **kwargs: Any) -> float

Run evaluation.

Parameters:

Name Type Description Default
inference InferenceJob[Any, M]

InferenceJob instance for running inference

required
encoding Any

Tokenizer with encode/encode_batch methods

required
chunk_size int

Number of batches per JIT chunk (default 200)

200

Returns:

Type Description
float

Accuracy score

EvaluatorConfig(evaluations: List[str] = field('eval/evaluations')) dataclass

Configuration for Evaluator.

Evaluator(spec: ExecutionSpec)

Bases: InferenceJob[EvaluatorConfig, M], Generic[M]

InferenceJob that runs evaluations and saves results.

Created from a trainer or checkpoint, holds a list of evaluations, and runs them when run() is called.

Example

evaluator = Evaluator.from_trainer(trainer, evaluations, encoding) evaluator() # Runs evaluations and saves results

Direct init not supported - use from_trainer() or from_checkpoint().

done: bool property

Check if evaluation results already exist.

from_trainer(trainer: BaseTrainer[Any, Any]) -> Evaluator[M] classmethod

Create Evaluator from trainer.

Parameters:

Name Type Description Default
trainer BaseTrainer[Any, Any]

BaseTrainer instance to get inference state from

required

Returns:

Type Description
Evaluator[M]

Evaluator instance ready to run evaluations

from_checkpoint(suffix: str | Path, spec: ExecutionSpec) -> Tuple[Evaluator[M], Any] classmethod

Create Evaluator from checkpoint.

Parameters:

Name Type Description Default
suffix str | Path

Checkpoint suffix

required
spec ExecutionSpec

ExecutionSpec with topology

required

Returns:

Type Description
Tuple[Evaluator[M], Any]

(evaluator, config) tuple

run() -> None

Run all evaluations and save results to disk.