ppo
ppo
¶
On-policy PPO trainer.
Custom train state caches a frozen reference policy (snapshotted at init) so
the loss can apply a KL penalty against it. The training step is unchanged
relative to BaseTrainer — the loop calls self.batch() per step, which we
override to (a) roll out the current policy via the existing Evaluator
machinery for every RL component, (b) build an (N, k) per-rollout per-component
reward matrix where each row's source-component column holds that component's
score and other columns are zero, (c) optionally transform per-component via
self.reward_postprocess(R), (d) gather the per-rollout scalar from each row's
source column and smear it per-token via reward-to-go, and (e) cache
old_log_probs for the importance ratio.
Subclasses (e.g. GRPOTrainer) override advantage computation by overriding
_advantages_from_rewards.
PPOTrainer(spec: ExecutionSpec)
¶
Bases: BaseTrainer[BaseTrainerConfig, M], Generic[M]
PPO trainer scaffold.
reward_postprocess(scores: np.ndarray) -> np.ndarray
¶
Optional per-component reward transformation. Shape-preserving (N, k) -> (N, k). Most subclasses do NOT need to override this — the default identity is correct whenever each component's raw score is already the desired training signal.
scores[n, i] is component i's score on rollout n. Components
have disjoint datasets, so each rollout is produced and scored by exactly
one component: only the source column is populated per row, other columns
are zero (structural padding so the matrix has uniform shape, not a
signal). Column order matches RLEvaluatorConfig.components order —
i.e. the order of self.rl_inference.evaluations.
Override only to scale, clip, or otherwise transform individual channels
(e.g. scores[:, 0] *= 2.0 to upweight the first component, or to
clip a noisy reward to [0, 1]). Must preserve the zero-padding for
non-source columns if you want per-component logging means to remain
meaningful — see the source-mask trick in forward().
Called from _refill_buffer outside the gradient path; subclasses
can freely read instance state here. After this returns, the per-
rollout scalar fed to PPO is just R[n, source[n]] — there is no
cross-channel aggregation.
log_prob_step(state: train_state.TrainState, batch: PyTree[jax.Array]) -> jax.Array
staticmethod
¶
Scan log π(y_t | x_<=t) over S micro-batches; mirrors val_step shape.
batch(slice: str = 'train') -> PyTree[np.ndarray]
¶
Generate one PPO training batch by reusing the Evaluator dynamics.
Returns a numpy dict shaped to feed straight into the parent train()'s
_to_global(_reshape_batch(...)).
forward(state: train_state.TrainState, params: PyTree[jax.Array], batch: PyTree[jax.Array], key: Optional[jax.Array] = None, deterministic: bool = False, intermediates: bool = False) -> Any
staticmethod
¶
PPO clipped surrogate loss + KL penalty against the frozen reference.
BackbonedPPOTrainer(spec: ExecutionSpec)
¶
Bases: BackbonedTrainer, PPOTrainer[Module]
PPO trainer that initializes from a pretrained HuggingFace backbone.
Mirrors BackbonedContrastiveTrainer: pulls in BackbonedTrainer's
_init_model (loads HF weights, no cls.MODEL.gather()) plus PPOTrainer's
state/data/forward overrides. The MRO resolves cleanly because both
parents only override disjoint methods.