Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix: improve robustness and consistency in mismatch_helper
- Add input validation assertions for eos_mask and response_mask
- Fix variance and ESS to use consistent mean from clamped weights
- Remove redundant eos_mask.any() checks after initial assertion
- Update docstring to accurately describe mixed clamping strategy
  • Loading branch information
szrlee committed Oct 9, 2025
commit 3461684c4d36fa9ce61dedf8df8e33602403985e
72 changes: 34 additions & 38 deletions verl/trainer/ppo/mismatch_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,9 +239,13 @@ def compute_is_metrics(
Reference:
When Speed Kills Stability: https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda

This function computes metrics that reflect the TRUE distribution (before clamping)
while avoiding numerical overflow by working in log-space when possible.
This function computes metrics using a mix of true unclamped values (for max/min/fractions
in sequence/geometric mode via log-space) and safety-clamped values (for mean/std/ESS)
to balance accuracy with numerical stability and avoid overflow.
"""
# Validate that we have at least one valid sample
assert eos_mask.any(), "Expected at least one valid sample in eos_mask"

metrics = {}
device = rollout_is_weights.device

Expand Down Expand Up @@ -291,39 +295,30 @@ def compute_is_metrics(
metrics["rollout_is_ratio_fraction_low"] = verl_F.masked_mean(rollout_is_below_threshold.float(), eos_mask)

# Max/min for token level
if eos_mask.any():
mask_bool = eos_mask.bool()
metrics["rollout_is_max"] = rollout_is_weights.masked_fill(~mask_bool, float("-inf")).max()
metrics["rollout_is_min"] = rollout_is_weights.masked_fill(~mask_bool, float("inf")).min()
else:
metrics["rollout_is_max"] = torch.tensor(0.0, device=device)
metrics["rollout_is_min"] = torch.tensor(0.0, device=device)
mask_bool = eos_mask.bool()
metrics["rollout_is_max"] = rollout_is_weights.masked_fill(~mask_bool, float("-inf")).max()
metrics["rollout_is_min"] = rollout_is_weights.masked_fill(~mask_bool, float("inf")).min()

# Compute standard deviation using clamped weights to avoid overflow
if eos_mask.any():
mask_count = eos_mask.sum()
if mask_count > 1:
# Use clamped weights for variance to avoid squaring huge values
weights_for_std = rollout_is_weights.clamp(min=rollout_is_threshold_lower, max=rollout_is_threshold)
rollout_is_var = (
verl_F.masked_mean(weights_for_std.square(), eos_mask) - metrics["rollout_is_mean"].square()
)
metrics["rollout_is_std"] = torch.sqrt(torch.clamp(rollout_is_var, min=0.0))
else:
metrics["rollout_is_std"] = torch.tensor(0.0, device=device)
mask_count = eos_mask.sum()
if mask_count > 1:
# Use clamped weights for variance to avoid squaring huge values
weights_for_std = rollout_is_weights.clamp(min=rollout_is_threshold_lower, max=rollout_is_threshold)
# Use mean from clamped weights for consistency
mean_clamped = verl_F.masked_mean(weights_for_std, eos_mask)
rollout_is_var = verl_F.masked_mean(weights_for_std.square(), eos_mask) - mean_clamped.square()
metrics["rollout_is_std"] = torch.sqrt(torch.clamp(rollout_is_var, min=0.0))
else:
metrics["rollout_is_std"] = torch.tensor(0.0, device=device)

# Effective sample size (use clamped weights to avoid overflow)
if eos_mask.any():
weights_for_ess = rollout_is_weights.clamp(min=rollout_is_threshold_lower, max=rollout_is_threshold)
is_weights_normalized = weights_for_ess / (metrics["rollout_is_mean"] + 1e-8)
metrics["rollout_is_eff_sample_size"] = 1.0 / verl_F.masked_mean(is_weights_normalized.square(), eos_mask)
else:
metrics["rollout_is_eff_sample_size"] = torch.tensor(1.0, device=device)
weights_for_ess = rollout_is_weights.clamp(min=rollout_is_threshold_lower, max=rollout_is_threshold)
mean_for_ess = verl_F.masked_mean(weights_for_ess, eos_mask)
is_weights_normalized = weights_for_ess / (mean_for_ess + 1e-8)
metrics["rollout_is_eff_sample_size"] = 1.0 / verl_F.masked_mean(is_weights_normalized.square(), eos_mask)

# Per-sequence breakdown metrics
if rollout_is_weights.dim() > 1 and eos_mask.any():
if rollout_is_weights.dim() > 1:
# Compute mean IS weight per sequence
seq_mean_weights = verl_F.masked_mean(rollout_is_weights, eos_mask, axis=-1)

Expand All @@ -344,17 +339,15 @@ def compute_is_metrics(
metrics["rollout_is_seq_fraction_low"] = (seq_mean_weights < rollout_is_threshold_lower).float().mean()

# Percentile metrics for better distribution understanding
if eos_mask.any():
# Get all valid IS weights
flat_weights = rollout_is_weights[eos_mask.bool()]

if flat_weights.numel() > 0:
# Compute key percentiles
metrics["rollout_is_p25"] = torch.quantile(flat_weights, 0.25)
metrics["rollout_is_p50"] = torch.quantile(flat_weights, 0.50) # median
metrics["rollout_is_p75"] = torch.quantile(flat_weights, 0.75)
metrics["rollout_is_p95"] = torch.quantile(flat_weights, 0.95)
metrics["rollout_is_p99"] = torch.quantile(flat_weights, 0.99)
# Get all valid IS weights
flat_weights = rollout_is_weights[eos_mask.bool()]
# Compute key percentiles (guaranteed to have elements due to assertion at function start)
assert flat_weights.numel() > 0, "flat_weights should not be empty"
metrics["rollout_is_p25"] = torch.quantile(flat_weights, 0.25)
metrics["rollout_is_p50"] = torch.quantile(flat_weights, 0.50) # median
metrics["rollout_is_p75"] = torch.quantile(flat_weights, 0.75)
metrics["rollout_is_p95"] = torch.quantile(flat_weights, 0.95)
metrics["rollout_is_p99"] = torch.quantile(flat_weights, 0.99)

return metrics

Expand Down Expand Up @@ -396,6 +389,9 @@ def compute_mismatch_metrics(
Reference:
- When Speed Kills Stability: https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda
"""
# Validate that we have at least one valid token
assert response_mask.any(), "Expected at least one valid token in response_mask"

metrics = {}

# 1. Training policy perplexity (always available)
Expand Down
Loading