Config System¶
This document covers theseus/config.py — the lightweight, dataclass-driven configuration layer that ties together model architecture, training hyperparameters, and runtime settings into a single OmegaConf YAML.
Overview¶
| Primitive | What it does |
|---|---|
field("path/key", default=...) |
Marks a dataclass field as config-driven |
build(*classes) |
Collects all marked fields across a class tree and produces an OmegaConf schema |
configuration(cfg) |
Context manager that activates a loaded config |
configure(cls) |
Hydrates a dataclass instance from the active config |
patch() |
Context manager for temporarily mutating the active config |
field() — Declaring Config-Driven Fields¶
Any dataclass field annotated with field("path/key") participates in the config system:
from dataclasses import dataclass
from theseus.config import field
@dataclass
class SelfAttention(Module):
n_embd: int = field("architecture/n_embd", default=2048)
n_head: int = field("architecture/n_head", default=16)
block_size: int = field("architecture/block_size", default=512)
dropout: float = field("architecture/dropout", default=0.0)
Under the hood field() is a thin wrapper around dataclasses.field() that stashes the path string in the field's metadata dict under the key "th_config_field". That metadata is the only marker build() and configure() look for.
The path string uses / as a separator and maps directly to a nested YAML structure:
"architecture/n_embd" → architecture: { n_embd: ... }.
build() — Generating the OmegaConf Schema¶
build(*classes) collects config fields from every class you pass in (plus any dataclass-typed sub-fields) and returns a fully-structured OmegaConf config with defaults filled in:
Internally this runs through three stages.
Stage 1 — DFS class expansion (generate_canonical_config)¶
build() calls generate_canonical_config(*classes). Before collecting fields it recursively expands any field whose type is itself a dataclass:
BaseTrainerConfig
└─ has field `optimizer_cfg: AdamWConfig` ← dataclass type → recurse
AdamWConfig
└─ lr: float = field("optimization/lr", ...)
This means you can compose arbitrarily nested config dataclasses and build() sees the full flat set of annotated fields.
Stage 2 — Key union (LUB)¶
Multiple classes can declare the same path key. This is intentional: SelfAttention, GroupedSelfAttention, and GPT all declare "architecture/n_embd" because they all need it at construction time. generate_canonical_config deduplicates by computing the least upper bound type:
- One class → use that type directly
- Multiple classes, same type → same type
- Multiple classes, different types →
Union[t1, t2, ...]
For default values, the first non-None value wins. If two classes declare the same key with conflicting defaults, None is used and the field becomes ??? (mandatory-missing) in the generated YAML.
Stage 3 — Nested dict construction¶
nest_slash_keys() converts the flat {"architecture/n_embd": 2048, "training/lr": 3e-4} dict into a properly nested Python dict:
This is then handed to OmegaConf.create() with set_struct=True (no unknown keys allowed), producing the final config object.
Generated YAML¶
The CLI command theseus configure <job> run.yaml calls build() on the job's registered config classes and writes the result:
architecture:
n_embd: 2048
n_head: 16
block_size: 512
dropout: 0.0
training:
batch_size: 512
per_device_batch_size: -1
tokens: 1000000000
optimization:
lr: 0.0003
Users edit this file (or pass key=value overrides on the CLI) before running.
configuration() / configure() — The Context Guard¶
Config values are activated via a context manager, never via global state that leaks across threads:
from omegaconf import OmegaConf
from theseus.config import configuration, configure
cfg = OmegaConf.load("run.yaml")
with configuration(cfg):
attention = configure(SelfAttention) # reads architecture.n_embd etc.
trainer_args = configure(BaseTrainerConfig)
configuration(cfg) uses a ContextVar (_current_config) — Python's per-async-task / per-thread context isolation primitive. Setting the config in one coroutine or thread doesn't affect others.
configure(cls) calls hydrate(cls, config) which:
- Flattens the OmegaConf object back to
"a/b/c"→ value pairs. - Iterates over
fields(cls), matching eachth_config_fieldkey to the flat config. - For fields whose type is itself a dataclass, recurses via
hydrate(sub_cls, config). - For plain
Modulesubclasses (Flax), callscls(**init_kwargs)directly. - For pure-Python dataclasses, uses
OmegaConf.structured+OmegaConf.mergeto get typed coercion.
Inside a Flax setup()¶
configure() is frequently called inside nn.Module.setup():
class GPT(Module):
def setup(self):
self.blocks = [configure(TransformerBlock) for _ in range(self.n_layers)]
self.ln_f = configure(LayerNorm)
This works because setup() is called inside model.init(...) which is always inside a with configuration(cfg): block established by the trainer.
patch() — Temporary Config Mutations¶
patch() is a context manager that temporarily disables struct-mode, allowing free-form additions:
from theseus.config import patch
with patch() as cfg:
cfg.architecture.n_layers = 32
cfg.architecture.attention_bias = False
model.init(jax.random.PRNGKey(0), dummy_input)
If a configuration() context is already active, patch() mutates it directly (and re-seals it on exit). If there's no active context, it creates a fresh empty config scoped to the block. This is used by from_pretrained() style loaders to inject architecture fields read from a checkpoint into the active config.
End-to-End Flow¶
theseus configure gpt/pretrain run.yaml
│
▼
job_obj.config() # returns [BaseTrainerConfig, GPT, ...]
│
build(*classes)
│ DFS expand dataclass-typed fields
│ collect all th_config_field metadata
│ union types / pick defaults (LUB)
│ nest_slash_keys()
│ OmegaConf.create(...) + set_struct=True
▼
run.yaml written
─────────────────────────────────────────────────
theseus run my-run run.yaml ./output
│
▼
cfg = OmegaConf.load("run.yaml")
with configuration(cfg):
job_cls(spec)()
│
▼
trainer.__init__()
│ configure(BaseTrainerConfig) → hydrates training args
│ configure(GPT) → creates the model
└ configure(SelfAttention) → called inside GPT.setup()