Skip to content

ppo

ppo

On-policy PPO trainer.

Custom train state caches a frozen reference policy (snapshotted at init) so the loss can apply a KL penalty against it. The training step is unchanged relative to BaseTrainer — the loop calls self.batch() per step, which we override to (a) roll out the current policy via the existing Evaluator machinery for every RL component, (b) build an (N, k) per-rollout per-component reward matrix where each row's source-component column holds that component's score and other columns are zero, (c) optionally transform per-component via self.reward_postprocess(R), (d) gather the per-rollout scalar from each row's source column and smear it per-token via reward-to-go, and (e) cache old_log_probs for the importance ratio.

Subclasses (e.g. GRPOTrainer) override advantage computation by overriding _advantages_from_rewards.

PPOTrainer(spec: ExecutionSpec)

Bases: BaseTrainer[BaseTrainerConfig, M], Generic[M]

PPO trainer scaffold.

reward_postprocess(scores: np.ndarray) -> np.ndarray

Optional per-component reward transformation. Shape-preserving (N, k) -> (N, k). Most subclasses do NOT need to override this — the default identity is correct whenever each component's raw score is already the desired training signal.

scores[n, i] is component i's score on rollout n. Components have disjoint datasets, so each rollout is produced and scored by exactly one component: only the source column is populated per row, other columns are zero (structural padding so the matrix has uniform shape, not a signal). Column order matches RLEvaluatorConfig.components order — i.e. the order of self.rl_inference.evaluations.

Override only to scale, clip, or otherwise transform individual channels (e.g. scores[:, 0] *= 2.0 to upweight the first component, or to clip a noisy reward to [0, 1]). Must preserve the zero-padding for non-source columns if you want per-component logging means to remain meaningful — see the source-mask trick in forward().

Called from _refill_buffer outside the gradient path; subclasses can freely read instance state here. After this returns, the per- rollout scalar fed to PPO is just R[n, source[n]] — there is no cross-channel aggregation.

log_prob_step(state: train_state.TrainState, batch: PyTree[jax.Array]) -> jax.Array staticmethod

Scan log π(y_t | x_<=t) over S micro-batches; mirrors val_step shape.

batch(slice: str = 'train') -> PyTree[np.ndarray]

Generate one PPO training batch by reusing the Evaluator dynamics.

Returns a numpy dict shaped to feed straight into the parent train()'s _to_global(_reshape_batch(...)).

forward(state: train_state.TrainState, params: PyTree[jax.Array], batch: PyTree[jax.Array], key: Optional[jax.Array] = None, deterministic: bool = False, intermediates: bool = False) -> Any staticmethod

PPO clipped surrogate loss + KL penalty against the frozen reference.

BackbonedPPOTrainer(spec: ExecutionSpec)

Bases: BackbonedTrainer, PPOTrainer[Module]

PPO trainer that initializes from a pretrained HuggingFace backbone.

Mirrors BackbonedContrastiveTrainer: pulls in BackbonedTrainer's _init_model (loads HF weights, no cls.MODEL.gather()) plus PPOTrainer's state/data/forward overrides. The MRO resolves cleanly because both parents only override disjoint methods.