Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
7c9e41d
fix(rollout_corr): compute metrics in actor for bypass mode and fix t…
szrlee Nov 10, 2025
96ae2be
docs(rollout_corr): move to algo/ and add pure_rs preset
szrlee Nov 10, 2025
c0ea9bd
feat(rollout_corr): add batch normalization option for IS weights
szrlee Nov 10, 2025
7de6c5f
docs(rollout_corr_math): use REINFORCE in aggregation loss examples f…
szrlee Nov 10, 2025
2b34cfe
refactor(rollout_corr): simplify metrics computation by removing unus…
szrlee Nov 10, 2025
0c42f85
docs(rollout_corr): add prominent cross-references between usage and …
szrlee Nov 10, 2025
fef8a48
docs(rollout_corr_math): add dedicated section for batch normalization
szrlee Nov 10, 2025
08cc9c7
fix: docstring of compute_policy_loss_with_rollout_correction
tongyx361 Nov 11, 2025
437a4ab
feat: reuse need_recomputation instead of bypass_mode
tongyx361 Nov 11, 2025
5f9a53b
feat: improve comments
tongyx361 Nov 11, 2025
b2f6370
feat: improve comments
tongyx361 Nov 11, 2025
79cdbf2
feat: refactor bypass_recomputing_logprobs
tongyx361 Nov 11, 2025
62e3270
feat(rollout_corr): align batch normalization with IS aggregation level
szrlee Nov 11, 2025
b5c19ff
docs(rollout_corr): rename decoupled mode presets for clarity and upd…
szrlee Nov 11, 2025
11f9aa0
fix(rollout_corr): correct metrics computation to run in decoupled mo…
szrlee Nov 11, 2025
58565cb
docs(rollout_corr): rename presets for clarity and consistency
szrlee Nov 11, 2025
8bb1a0e
refactor(rollout_corr): rename config vars for semantic clarity
szrlee Nov 11, 2025
6002c00
refactor(rollout_corr): update implementation to use renamed config v…
szrlee Nov 11, 2025
7f9ba9c
Merge branch 'main' into pr/szrlee/4070
tongyx361 Nov 11, 2025
56f69bf
fix: ppo_trainer config format
tongyx361 Nov 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
fix(rollout_corr): compute metrics in actor for bypass mode and fix t…
…rainer bugs

Fix three critical issues in rollout correction metrics computation:

1. Missing rollout_corr_config parameter in ray_trainer.py line 1178
   compute_rollout_correction_and_add_to_batch() call

2. Trainer computes meaningless metrics in bypass mode since old=rollout
   Results in KL≈0, weights≈1.0 that don't reflect actual drift

3. No metrics computed for bypass+non-pure mode during actor training
   Bypass+pure already computes metrics in pure loss function

Solution:
- Add compute_rollout_corr_metrics_from_logprobs() helper function to compute
  metrics using current policy vs rollout policy log probabilities
- Always pass rollout_correction config to actor in bypass mode for metrics
- Skip trainer metrics in bypass mode, compute meaningful metrics in actor
- Actor computes per-microbatch metrics showing drift as training progresses

Behavior by mode:
- Bypass+non-pure: Actor computes metrics (π_current vs π_rollout)
- Bypass+pure: Pure loss function computes metrics internally
- Decoupled: Trainer computes metrics (π_old vs π_rollout)

Files changed:
- verl/trainer/ppo/rollout_corr_helper.py: Add metrics helper, always pass config
- verl/trainer/ppo/ray_trainer.py: Fix missing param, skip bypass metrics
- verl/workers/actor/dp_actor.py: Add rollout_log_probs selection, compute metrics
- verl/workers/actor/megatron_actor.py: Add rollout_log_probs selection, compute metrics
- verl/trainer/ppo/core_algos.py: Remove outdated documentation
  • Loading branch information
szrlee committed Nov 10, 2025
commit 7c9e41daa02f14c96facf52651cb944d581c9219
4 changes: 0 additions & 4 deletions verl/trainer/ppo/core_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -1630,10 +1630,6 @@ def compute_policy_loss_with_rollout_correction(
rollout_rs_threshold: 2.0
rollout_rs_threshold_lower: 0.5

Performance:
- Memory: Saves ~1MB per batch (no old_log_prob storage)
- Speed: ~15-20% faster (skips actor.compute_log_prob())
- Variance: Higher than PPO (no clipping safety net)
"""
# Import rollout correction helper
from verl.trainer.ppo.rollout_corr_helper import compute_rollout_correction_and_rejection_mask
Expand Down
12 changes: 9 additions & 3 deletions verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1171,10 +1171,16 @@ def fit(self):
# Compute rollout correction weights centrally (once per batch)
# This corrects for off-policy issues (policy mismatch, model staleness, etc.)
# Also computes off-policy diagnostic metrics (KL, PPL, etc.)
# Skip in bypass mode since old=rollout (IS weights=1.0, KL=0, rejection does nothing)
if rollout_corr_config is not None and "rollout_log_probs" in batch.batch:
batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch)
# IS and off-policy metrics already have rollout_corr/ prefix
metrics.update(is_metrics)
bypass_mode = rollout_corr_config.get("bypass_old_logprob_for_rollout", False)
if not bypass_mode:
# Compute IS weights, apply rejection sampling, compute metrics
batch, is_metrics = compute_rollout_correction_and_add_to_batch(
batch, rollout_corr_config
)
# IS and off-policy metrics already have rollout_corr/ prefix
metrics.update(is_metrics)

# compute advantages, executed on the driver process
norm_adv_by_std_in_grpo = self.config.algorithm.get(
Expand Down
70 changes: 69 additions & 1 deletion verl/trainer/ppo/rollout_corr_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,71 @@ def compute_rollout_correction_and_add_to_batch(
return batch, rollout_corr_metrics


def compute_rollout_corr_metrics_from_logprobs(
log_prob: torch.Tensor,
rollout_log_prob: torch.Tensor,
response_mask: torch.Tensor,
rollout_corr_config: Optional[RolloutCorrectionConfig | dict] = None,
) -> dict[str, float]:
"""Compute rollout correction metrics from log probabilities during training.

This function is used in the actor to compute metrics using the CURRENT policy
log probabilities versus rollout log probabilities, allowing tracking of the
off-policy gap as training progresses.

It computes two categories of metrics:
1. Off-policy diagnostics (KL, PPL, χ²) from log probabilities
2. IS weight statistics (mean, std, ESS, etc.) computed fresh from log_probs

Args:
log_prob: Current policy log probabilities, shape (batch_size, seq_length)
rollout_log_prob: Rollout policy log probabilities, shape (batch_size, seq_length)
response_mask: Valid token mask, shape (batch_size, seq_length)
rollout_corr_config: Rollout correction config containing rollout_is mode and threshold

Returns:
Dictionary of metrics with "rollout_corr/" prefix
"""
# Compute off-policy diagnostic metrics
offpolicy_metrics = compute_offpolicy_metrics(
old_log_prob=log_prob,
rollout_log_prob=rollout_log_prob,
response_mask=response_mask,
)

# Determine rollout_is mode and threshold from config
if rollout_corr_config is not None:
if isinstance(rollout_corr_config, dict):
rollout_is_mode = rollout_corr_config.get("rollout_is", "token")
rollout_is_threshold = rollout_corr_config.get("rollout_is_threshold", 2.0)
else:
rollout_is_mode = getattr(rollout_corr_config, "rollout_is", "token")
rollout_is_threshold = getattr(rollout_corr_config, "rollout_is_threshold", 2.0)
else:
rollout_is_mode = "token"
rollout_is_threshold = 2.0

# Compute IS weights fresh from log probabilities
log_ratio = log_prob - rollout_log_prob
_, is_metrics = compute_rollout_correction_weights(
log_ratio=log_ratio,
response_mask=response_mask,
rollout_is=rollout_is_mode,
rollout_is_threshold=rollout_is_threshold,
)

# Merge all metrics with rollout_corr/ prefix
all_metrics = {**offpolicy_metrics, **is_metrics}
metrics_with_prefix = {}
for key, value in all_metrics.items():
if isinstance(value, torch.Tensor):
metrics_with_prefix[f"rollout_corr/{key}"] = value.item()
else:
metrics_with_prefix[f"rollout_corr/{key}"] = value

return metrics_with_prefix


def maybe_apply_rollout_correction(
batch: DataProto,
rollout_corr_config: Optional[RolloutCorrectionConfig] = None,
Expand Down Expand Up @@ -865,14 +930,17 @@ def maybe_apply_rollout_correction(

# Use rollout log probs as old log probs (zero-cost substitution)
batch.batch["old_log_probs"] = batch.batch["rollout_log_probs"]

# Always pass rollout_correction config to actor for metrics computation
policy_loss_config["rollout_correction"] = rollout_corr_config

# Check if pure rollout correction mode is enabled
use_pure_rollout_correction = rollout_corr_config.get("use_pure_rollout_correction", False)

if use_pure_rollout_correction:
# Pure IS mode: Configure actor to use rollout_correction loss function
# This will use compute_policy_loss_with_rollout_correction (no PPO clipping)
policy_loss_config["loss_mode"] = "rollout_correction"
policy_loss_config["rollout_correction"] = rollout_corr_config

return False

Expand Down
32 changes: 27 additions & 5 deletions verl/workers/actor/dp_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,9 @@ def update_policy(self, data: DataProto):
# Weights are computed centrally in trainer and added to batch when algorithm.rollout_is=True
if "rollout_is_weights" in data.batch.keys():
select_keys.append("rollout_is_weights")
# Include rollout_log_probs for computing rollout_corr metrics in bypass mode
if "rollout_log_probs" in data.batch.keys():
select_keys.append("rollout_log_probs")

has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys()
non_tensor_select_keys = ["multi_modal_inputs"] if has_multi_modal_inputs else []
Expand Down Expand Up @@ -443,11 +446,6 @@ def update_policy(self, data: DataProto):
# Weights are computed centrally in trainer and added when algorithm.rollout_is=True
rollout_is_weights = model_inputs.get("rollout_is_weights", None)

# NOTE: Both mismatch diagnostic metrics (PPL, KL, etc.) and IS weight metrics
# are computed centrally in ray_trainer.py for consistency and efficiency.
# This ensures metrics are computed uniformly across all batches at the trainer level
# and avoids redundant computation across workers and micro-batches.

# gpg -> verl.trainer.ppo.core_algos.compute_policy_loss_gpg
# clip_cov -> verl.trainer.ppo.core_algos.compute_policy_loss_clip_cov
policy_loss_fn = get_policy_loss_fn(loss_mode)
Expand All @@ -464,6 +462,30 @@ def update_policy(self, data: DataProto):
)
micro_batch_metrics.update(pg_metrics)

# Compute rollout_corr metrics during training (for monitoring drift)
# This computes metrics using CURRENT policy log_prob vs rollout_log_prob
# to track off-policy gap as training progresses (different from trainer metrics
# which use old_log_prob and only show gap at start of training)
# Skip if using pure rollout correction mode (metrics already in pg_metrics)
# Only computed in bypass mode where config is passed to actor
if loss_mode != "rollout_correction":
rollout_log_prob = model_inputs.get("rollout_log_probs", None)
rollout_corr_config = (
self.config.policy_loss.get("rollout_correction", None)
if hasattr(self.config, "policy_loss")
else None
)
if rollout_log_prob is not None and rollout_corr_config is not None:
from verl.trainer.ppo.rollout_corr_helper import compute_rollout_corr_metrics_from_logprobs

rollout_corr_metrics = compute_rollout_corr_metrics_from_logprobs(
log_prob=log_prob,
rollout_log_prob=rollout_log_prob,
response_mask=response_mask,
rollout_corr_config=rollout_corr_config,
)
micro_batch_metrics.update(rollout_corr_metrics)

if entropy_coeff != 0:
entropy_loss = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)

Expand Down
28 changes: 28 additions & 0 deletions verl/workers/actor/megatron_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,9 @@ def make_minibatch_iterator(self, data: DataProto) -> Iterable[DataProto]:
# Weights are computed centrally in trainer and added to batch when algorithm.rollout_is=True
if "rollout_is_weights" in data.batch.keys():
select_keys.append("rollout_is_weights")
# Include rollout_log_probs for computing rollout_corr metrics in bypass mode
if "rollout_log_probs" in data.batch.keys():
select_keys.append("rollout_log_probs")
self.has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys()
if self.has_multi_modal_inputs:
data = data.select(select_keys, ["multi_modal_inputs"])
Expand Down Expand Up @@ -461,6 +464,31 @@ def loss_func(output, data, meta_info):
rollout_is_weights=rollout_is_weights,
)
stats.update(pg_metrics)

# Compute rollout_corr metrics during training (for monitoring drift)
# This computes metrics using CURRENT policy log_prob vs rollout_log_prob
# to track off-policy gap as training progresses (different from trainer metrics
# which use old_log_prob and only show gap at start of training)
# Skip if using pure rollout correction mode (metrics already in pg_metrics)
# Only computed in bypass mode where config is passed to actor
if loss_mode != "rollout_correction":
rollout_log_prob = data.get("rollout_log_probs", None)
rollout_corr_config = (
self.config.policy_loss.get("rollout_correction", None)
if hasattr(self.config, "policy_loss")
else None
)
if rollout_log_prob is not None and rollout_corr_config is not None:
from verl.trainer.ppo.rollout_corr_helper import compute_rollout_corr_metrics_from_logprobs

rollout_corr_metrics = compute_rollout_corr_metrics_from_logprobs(
log_prob=log_prob,
rollout_log_prob=rollout_log_prob,
response_mask=response_mask,
rollout_corr_config=rollout_corr_config,
)
stats.update(rollout_corr_metrics)

stats["actor/pg_loss"] = pg_loss.detach().item()
policy_loss = pg_loss

Expand Down