Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
7c9e41d
fix(rollout_corr): compute metrics in actor for bypass mode and fix t…
szrlee Nov 10, 2025
96ae2be
docs(rollout_corr): move to algo/ and add pure_rs preset
szrlee Nov 10, 2025
c0ea9bd
feat(rollout_corr): add batch normalization option for IS weights
szrlee Nov 10, 2025
7de6c5f
docs(rollout_corr_math): use REINFORCE in aggregation loss examples f…
szrlee Nov 10, 2025
2b34cfe
refactor(rollout_corr): simplify metrics computation by removing unus…
szrlee Nov 10, 2025
0c42f85
docs(rollout_corr): add prominent cross-references between usage and …
szrlee Nov 10, 2025
fef8a48
docs(rollout_corr_math): add dedicated section for batch normalization
szrlee Nov 10, 2025
08cc9c7
fix: docstring of compute_policy_loss_with_rollout_correction
tongyx361 Nov 11, 2025
437a4ab
feat: reuse need_recomputation instead of bypass_mode
tongyx361 Nov 11, 2025
5f9a53b
feat: improve comments
tongyx361 Nov 11, 2025
b2f6370
feat: improve comments
tongyx361 Nov 11, 2025
79cdbf2
feat: refactor bypass_recomputing_logprobs
tongyx361 Nov 11, 2025
62e3270
feat(rollout_corr): align batch normalization with IS aggregation level
szrlee Nov 11, 2025
b5c19ff
docs(rollout_corr): rename decoupled mode presets for clarity and upd…
szrlee Nov 11, 2025
11f9aa0
fix(rollout_corr): correct metrics computation to run in decoupled mo…
szrlee Nov 11, 2025
58565cb
docs(rollout_corr): rename presets for clarity and consistency
szrlee Nov 11, 2025
8bb1a0e
refactor(rollout_corr): rename config vars for semantic clarity
szrlee Nov 11, 2025
6002c00
refactor(rollout_corr): update implementation to use renamed config v…
szrlee Nov 11, 2025
7f9ba9c
Merge branch 'main' into pr/szrlee/4070
tongyx361 Nov 11, 2025
56f69bf
fix: ppo_trainer config format
tongyx361 Nov 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
refactor(rollout_corr): update implementation to use renamed config v…
…ariables

Complete variable renaming initiated in commit 8bb1a0e:
  • bypass_old_logprob_for_rollout → bypass_mode
  • use_pure_rollout_correction → use_policy_gradient

Changes:
- Update implementation files (rollout_corr_helper, ray_trainer, core_algos)
- Update example script (run_with_rollout_corr.sh)
- Improve docstring classifications ("Bypass + PPO/PG loss")
- Clarify documentation terminology (IS/RS independence)
- Update mode names ("Bypass + Policy Gradient mode")

All references to old variable names removed from codebase.
  • Loading branch information
szrlee committed Nov 11, 2025
commit 6002c00c6f39d3bee05495af2e8e08ba4047c558
14 changes: 7 additions & 7 deletions docs/algo/rollout_corr.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ This critical implementation mistake that leads to RL training collapse was iden
**Mathematically correct approaches:**
- **Decoupled mode**: Three policies (π_rollout, π_old, π_θ) with IS correction from π_rollout to π_old
- **Bypass mode**: Two policies (π_rollout = π_old, π_θ) using actual rollout policy as PPO anchor
- **Pure IS mode**: Two policies (π_rollout, π_θ) with IS correction and no PPO clipping
- **Bypass + Policy Gradient mode**: Two policies (π_rollout, π_θ) with IS/RS correction and no PPO clipping

See [Mathematical Formulations](rollout_corr_math.md#38-common-implementation-mistake) for detailed explanation.

Expand Down Expand Up @@ -152,7 +152,7 @@ algorithm:
rollout_rs_threshold_lower: null # RS lower threshold (auto-reciprocal if null)
rollout_token_veto_threshold: null # Per-token veto threshold (null = disabled)
bypass_mode: false # Skip old_log_prob computation
use_policy_gradient: false # Pure policy gradient with IS
use_policy_gradient: false # Use policy gradient loss (vs PPO loss)

# REQUIRED: Enable log prob calculation
actor_rollout_ref:
Expand Down Expand Up @@ -512,7 +512,7 @@ algorithm:
rollout_is_threshold: 2.0
rollout_rs: null
bypass_mode: true # Required
use_policy_gradient: true # Pure IS loss
use_policy_gradient: true # Use policy gradient loss (no PPO clipping)
```

**Properties:**
Expand Down Expand Up @@ -646,7 +646,7 @@ The framework provides **two operating modes** for computing π_old, which can b
|---------------|----------------------------------|------------------------------|----------------|---------------|-------------|
| **Decoupled** | `false` | `false` | Decoupled | PPO | Computes `old_log_prob` separately via `actor.compute_log_prob()` |
| **Bypass** | `true` | `false` | Bypass | PPO | Sets `old_log_prob = rollout_log_prob`, PPO clips against rollout policy |
| **Pure IS** | `true` | `true` | Bypass | Pure Policy Gradient | Bypass mode with pure IS loss (no PPO clipping) |
| **Bypass + PG** | `true` | `true` | Bypass | Policy Gradient | Bypass mode with policy gradient loss (no PPO clipping) |

### Operating Mode Details

Expand Down Expand Up @@ -752,7 +752,7 @@ The aggregation level can be chosen **independently** of the operating mode. Any

**Trade-off**: PPO clips against rollout policy instead of true old policy

**Alternative**: Set `use_policy_gradient: true` for pure policy gradient with IS (no clipping)
**Alternative**: Set `use_policy_gradient: true` for policy gradient loss with IS/RS (no clipping)

## Usage

Expand Down Expand Up @@ -1144,8 +1144,8 @@ algorithm:
rollout_is: token # Explicit IS correction in loss
rollout_is_threshold: 2.0
rollout_rs: null # Optional: can add rejection sampling
bypass_mode: true # Required for pure mode
use_policy_gradient: true # Use pure policy gradient with IS
bypass_mode: true # Required for policy gradient mode
use_policy_gradient: true # Use policy gradient loss (no PPO clipping)
```
**No PPO clipping, pure policy gradient with IS correction**

Expand Down
20 changes: 11 additions & 9 deletions docs/algo/rollout_corr_math.md
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ The operating mode determines how the proximal policy $\pi_{\text{old}}$ is comp

---

### 3.2 Loss Functions: PPO vs Pure IS
### 3.2 Loss Functions: PPO vs Policy Gradient

#### 3.2.1 PPO Loss (with Clipping)

Expand All @@ -284,28 +284,30 @@ where:
- Limits policy update magnitude
- Standard in RL training

#### 3.2.2 Pure IS Loss (Policy Gradient)
#### 3.2.2 Policy Gradient Loss (with IS/RS Correction)

**Configuration:** `use_policy_gradient = true` (requires `bypass_mode = true`)

**Loss function:**
**Loss function** (example with sequence-level IS):

$$
L_{\text{PureIS}}(\theta) = -\mathbb{E}_{(s,a) \sim \pi_{\text{rollout}}} \left[ w_{\text{seq}}(\theta) \cdot \sum_{t \in T} \log \pi_{\theta}(a_t|s_t) \cdot A_t \right]
L_{\text{PG}}(\theta) = -\mathbb{E}_{(s,a) \sim \pi_{\text{rollout}}} \left[ w_{\text{seq}}(\theta) \cdot \sum_{t \in T} \log \pi_{\theta}(a_t|s_t) \cdot A_t \right]
$$

where:
- $w_{\text{seq}}(\theta) = \min\left( \prod_{t \in T} \frac{\pi_{\theta}(a_t|s_t)}{\pi_{\text{rollout}}(a_t|s_t)}, C_{\text{IS}} \right)$: Sequence-level IS weight
- IS weight is **detached from gradient** (treated as constant)
- $w_{\text{seq}}(\theta)$: Sample weight (IS or RS, see §3.3-3.4 for details)
- For IS: $w_{\text{seq}}(\theta) = \min\left( \prod_{t \in T} \frac{\pi_{\theta}(a_t|s_t)}{\pi_{\text{rollout}}(a_t|s_t)}, C_{\text{IS}} \right)$
- For RS: $w_{\text{seq}}(\theta) \in \{0, 1\}$ (binary rejection mask)
- Weight is **detached from gradient** (treated as constant)

**Effective gradient:**

$$
\nabla_\theta L_{\text{PureIS}} = -\mathbb{E}_{(s,a) \sim \pi_{\text{rollout}}} \left[ \text{stopgrad}(w_{\text{seq}}(\theta)) \cdot \sum_{t \in T} \nabla_\theta \log \pi_{\theta}(a_t|s_t) \cdot A_t \right]
\nabla_\theta L_{\text{PG}} = -\mathbb{E}_{(s,a) \sim \pi_{\text{rollout}}} \left[ \text{stopgrad}(w_{\text{seq}}(\theta)) \cdot \sum_{t \in T} \nabla_\theta \log \pi_{\theta}(a_t|s_t) \cdot A_t \right]
$$

**Properties:**
- **Algorithm**: Off-policy REINFORCE + IS
- **Algorithm**: Off-policy REINFORCE + IS/RS correction
- **No PPO clipping**: Pure policy gradient
- **Always uses bypass mode**: Direct $\pi_\theta$ to $\pi_{\text{rollout}}$ comparison
- **Fast**: Single forward pass
Expand Down Expand Up @@ -576,7 +578,7 @@ where $r_t(\theta) = \frac{\pi_{\theta}(a_t|s_t)}{\pi_{\text{old}}(a_t|s_t)}$ (i
**Correct alternatives:**
1. **Decoupled mode**: Three policies with IS correction from $\pi_{\text{rollout}}$ to $\pi_{\text{old}}$
2. **Bypass mode**: Two policies using $\pi_{\text{rollout}}$ as both behavior policy and proximal policy
3. **Pure IS mode**: Two policies with IS correction and no PPO clipping
3. **Bypass + Policy Gradient mode**: Two policies with IS/RS correction and no PPO clipping

**Implementation:** `compute_policy_loss()` in [core_algos.py](../../verl/trainer/ppo/core_algos.py#L812-L884)

Expand Down
12 changes: 6 additions & 6 deletions examples/rollout_correction/run_with_rollout_corr.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ rollout_rs_threshold_lower="null" # RS lower threshold
# Veto mechanism (optional, independent of IS/RS)
rollout_token_veto_threshold="null" # Per-token veto threshold (null to disable)

# Pure IS mode (pure policy gradient, no PPO clipping)
bypass_old_logprob_for_rollout="true" # Required for pure_is mode
use_pure_rollout_correction="true" # Use pure policy gradient with IS
# Policy Gradient loss mode (bypass mode with policy gradient loss, no PPO clipping)
bypass_mode="true" # Required for policy gradient mode
use_policy_gradient="true" # Use policy gradient loss (works with IS/RS/both)

# ==============================================================================
# Model and Data Configuration
Expand Down Expand Up @@ -75,8 +75,8 @@ python3 -m verl.trainer.main_ppo \
algorithm.rollout_correction.rollout_rs_threshold=${rollout_rs_threshold} \
algorithm.rollout_correction.rollout_rs_threshold_lower=${rollout_rs_threshold_lower} \
algorithm.rollout_correction.rollout_token_veto_threshold=${rollout_token_veto_threshold} \
algorithm.rollout_correction.bypass_old_logprob_for_rollout=${bypass_old_logprob_for_rollout} \
algorithm.rollout_correction.use_pure_rollout_correction=${use_pure_rollout_correction} \
algorithm.rollout_correction.bypass_mode=${bypass_mode} \
algorithm.rollout_correction.use_policy_gradient=${use_policy_gradient} \
actor_rollout_ref.model.path="${MODEL_PATH}" \
actor_rollout_ref.actor.optim.lr=${learning_rate} \
actor_rollout_ref.actor.ppo_mini_batch_size=${ppo_mini_batch_size} \
Expand All @@ -95,7 +95,7 @@ echo " - Algorithm: RLOO (REINFORCE Leave-One-Out)"
echo " - Advantage estimator: ${adv_estimator}"
echo " - IS mode: ${rollout_is} (self-normalized: ${rollout_is_batch_normalize})"
echo " - IS threshold: ${rollout_is_threshold}"
echo " - Pure IS mode: ${use_pure_rollout_correction} (bypass: ${bypass_old_logprob_for_rollout})"
echo " - Policy gradient mode: ${use_policy_gradient} (bypass: ${bypass_mode})"
echo ""
echo "Monitor these key metrics in wandb:"
echo " - rollout_corr/rollout_is_mean (should be ~1.0 before batch norm)"
Expand Down
12 changes: 6 additions & 6 deletions verl/trainer/ppo/core_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -1609,14 +1609,14 @@ def compute_policy_loss_with_rollout_correction(

Usage:
This function is called by the actor when:
- bypass_old_logprob_for_rollout=True (trainer uses rollout_log_prob as old_log_prob)
- use_pure_rollout_correction=True (actor uses this function instead of compute_policy_loss)
- bypass_mode=True (trainer uses rollout_log_prob as old_log_prob)
- use_policy_gradient=True (actor uses this function instead of compute_policy_loss)

Example config:
algorithm:
rollout_correction:
bypass_old_logprob_for_rollout: true
use_pure_rollout_correction: true
bypass_mode: true
use_policy_gradient: true
rollout_is: "token"
rollout_is_threshold: 2.0
rollout_rs: "token"
Expand Down Expand Up @@ -1702,7 +1702,7 @@ def compute_policy_loss_rollout_correction_wrapper(
) -> tuple[torch.Tensor, dict[str, Any]]:
"""Wrapper for compute_policy_loss_with_rollout_correction to match PolicyLossFn interface.

This function is used when algorithm.rollout_correction.use_pure_rollout_correction=True.
This function is used when algorithm.rollout_correction.use_policy_gradient=True.
In this mode, the trainer has already set old_log_prob=rollout_log_prob (bypass mode).

Args:
Expand All @@ -1717,7 +1717,7 @@ def compute_policy_loss_rollout_correction_wrapper(
assert config is not None, "config is required for rollout_correction loss mode"

# Extract rollout_correction config
# In ray_trainer, when use_pure_rollout_correction=True, the rollout_correction config
# In ray_trainer, when use_policy_gradient=True, the rollout_correction config
# is embedded in actor config's policy_loss field
rollout_corr_config = config.policy_loss.get("rollout_correction", None) if hasattr(config, "policy_loss") else None

Expand Down
4 changes: 1 addition & 3 deletions verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1106,9 +1106,7 @@ def fit(self):
# - Decoupled mode: Recomputes old_log_probs as proximal anchor (3 policies: π_rollout, π_old, π_θ)
# Note: π_old computed once per data batch, serves as stable reference during mini-batch updates
rollout_corr_config = self.config.algorithm.get("rollout_correction", None)
bypass_recomputing_logprobs = rollout_corr_config and rollout_corr_config.get(
"bypass_old_logprob_for_rollout", False
)
bypass_recomputing_logprobs = rollout_corr_config and rollout_corr_config.get("bypass_mode", False)
if bypass_recomputing_logprobs: # Use `rollout_log_probs`
from verl.trainer.ppo.rollout_corr_helper import apply_rollout_correction

Expand Down
25 changes: 13 additions & 12 deletions verl/trainer/ppo/rollout_corr_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,21 +910,22 @@ def apply_rollout_correction(
BYPASS MODE: Use rollout_log_probs as old_log_probs
Skips expensive actor forward pass for old_log_prob computation

Two sub-modes (controlled by use_pure_rollout_correction in actor):
1. PPO_IS mode (use_pure_rollout_correction=False, default):
- Actor uses standard PPO with old_log_prob=rollout_log_prob
- PPO clips ratio = π_current / π_rollout (not π_current / π_old)
Two sub-modes (controlled by use_policy_gradient):
1. Bypass + PPO loss (use_policy_gradient=False, default):
- Uses standard PPO loss function with old_log_prob=rollout_log_prob
- PPO clips ratio π_θ/π_rollout instead of π_θ/π_old

2. Pure rollout correction mode (use_pure_rollout_correction=True):
- Actor uses compute_policy_loss_with_rollout_correction()
- Pure policy gradient with IS correction (no PPO clipping)
2. Bypass + Policy Gradient loss (use_policy_gradient=True):
- Uses compute_policy_loss_with_rollout_correction()
- Policy gradient (REINFORCE-style) with IS/RS correction applied
- No PPO clipping

Note:
The implementation is copied from szrlee <szrlee@gmail.com>.
"""
if "rollout_log_probs" not in batch.batch:
raise ValueError(
"bypass_old_logprob_for_rollout=True requires rollout_log_probs in batch. "
"bypass_mode=True requires rollout_log_probs in batch. "
"Ensure rollout worker is configured to calculate_log_probs=true."
)

Expand All @@ -934,10 +935,10 @@ def apply_rollout_correction(
# Always pass rollout_correction config to actor for metrics computation
policy_loss_config["rollout_correction"] = rollout_corr_config

# Check if pure rollout correction mode is enabled
use_pure_rollout_correction = rollout_corr_config.get("use_pure_rollout_correction", False)
# Check if policy gradient loss mode is enabled
use_policy_gradient = rollout_corr_config.get("use_policy_gradient", False)

if use_pure_rollout_correction:
# Pure IS mode: Configure actor to use rollout_correction loss function
if use_policy_gradient:
# Policy gradient mode: Configure actor to use rollout_correction loss function
# This will use compute_policy_loss_with_rollout_correction (no PPO clipping)
policy_loss_config["loss_mode"] = "rollout_correction"