Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
7c9e41d
fix(rollout_corr): compute metrics in actor for bypass mode and fix t…
szrlee Nov 10, 2025
96ae2be
docs(rollout_corr): move to algo/ and add pure_rs preset
szrlee Nov 10, 2025
c0ea9bd
feat(rollout_corr): add batch normalization option for IS weights
szrlee Nov 10, 2025
7de6c5f
docs(rollout_corr_math): use REINFORCE in aggregation loss examples f…
szrlee Nov 10, 2025
2b34cfe
refactor(rollout_corr): simplify metrics computation by removing unus…
szrlee Nov 10, 2025
0c42f85
docs(rollout_corr): add prominent cross-references between usage and …
szrlee Nov 10, 2025
fef8a48
docs(rollout_corr_math): add dedicated section for batch normalization
szrlee Nov 10, 2025
08cc9c7
fix: docstring of compute_policy_loss_with_rollout_correction
tongyx361 Nov 11, 2025
437a4ab
feat: reuse need_recomputation instead of bypass_mode
tongyx361 Nov 11, 2025
5f9a53b
feat: improve comments
tongyx361 Nov 11, 2025
b2f6370
feat: improve comments
tongyx361 Nov 11, 2025
79cdbf2
feat: refactor bypass_recomputing_logprobs
tongyx361 Nov 11, 2025
62e3270
feat(rollout_corr): align batch normalization with IS aggregation level
szrlee Nov 11, 2025
b5c19ff
docs(rollout_corr): rename decoupled mode presets for clarity and upd…
szrlee Nov 11, 2025
11f9aa0
fix(rollout_corr): correct metrics computation to run in decoupled mo…
szrlee Nov 11, 2025
58565cb
docs(rollout_corr): rename presets for clarity and consistency
szrlee Nov 11, 2025
8bb1a0e
refactor(rollout_corr): rename config vars for semantic clarity
szrlee Nov 11, 2025
6002c00
refactor(rollout_corr): update implementation to use renamed config v…
szrlee Nov 11, 2025
7f9ba9c
Merge branch 'main' into pr/szrlee/4070
tongyx361 Nov 11, 2025
56f69bf
fix: ppo_trainer config format
tongyx361 Nov 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
docs(rollout_corr): rename presets for clarity and consistency
Standardize preset naming to clearly indicate mode and algorithm:

Decoupled mode (3 policies):
  token_is → decoupled_token_is
  seq_is → decoupled_seq_is
  seq_is_rs → decoupled_seq_is_rs
  geo_rs → decoupled_geo_rs

Policy gradient (bypass mode):
  pure_is → pg_is
  pure_rs → pg_rs

Changes:
- Remove aliases (token_tis, seq_mis, geo_mis)
- Add disabled() to tables
- Update docstrings and examples
- Improve YAML comments
- Add missing rollout_is_batch_normalize parameter

All 8 presets now consistent across code and docs.
  • Loading branch information
szrlee committed Nov 11, 2025
commit 58565cb0ec5a0ca5d0660cbc89a1e9d875a1faf7
23 changes: 12 additions & 11 deletions docs/algo/rollout_corr.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,10 @@ config = RolloutCorrectionConfig.decoupled_geo_rs()
config = RolloutCorrectionConfig.ppo_is_bypass()

# Advanced: Pure policy gradient with IS
config = RolloutCorrectionConfig.pure_is()
config = RolloutCorrectionConfig.pg_is()

# Advanced: Pure policy gradient with rejection sampling (bypass + pure + geometric RS)
config = RolloutCorrectionConfig.pure_rs()
config = RolloutCorrectionConfig.pg_rs()

# Metrics only (no correction)
config = RolloutCorrectionConfig.disabled()
Expand Down Expand Up @@ -290,10 +290,11 @@ This section provides detailed guidance on choosing and using the verified prese
| `decoupled_seq_is_rs()` | Decoupled | sequence | sequence | Sequence IS + sequence RS |
| `decoupled_geo_rs()` | Decoupled | - | geometric + veto | Geometric RS + veto, no IS weights |
| `ppo_is_bypass()` | Bypass | - | - | Bypass mode, skips old_log_prob |
| `pure_rs()` | Bypass | - | geometric + veto | Pure policy gradient with RS (no IS weights) |
| `pure_is()` | Bypass | sequence | - | Pure policy gradient with IS |
| `pg_rs()` | Bypass | - | geometric + veto | Policy gradient with RS (no IS weights) |
| `pg_is()` | Bypass | sequence | - | Policy gradient with IS |
| `disabled()` | - | - | - | Metrics only, no correction |

**Note:** All presets use PPO loss except `pure_is()` and `pure_rs()` which use pure policy gradient (both require `use_pure_rollout_correction=True`).
**Note:** All presets use PPO loss except `pg_is()` and `pg_rs()` which use policy gradient (both require `use_pure_rollout_correction=True`).

#### Other Supported Combinations (Manual Configuration Required)

Expand All @@ -308,7 +309,7 @@ See [detailed configuration examples below](#additional-useful-configurations-no
- Any aggregation level (token/sequence/geometric) works in either decoupled or bypass mode
- All combinations are fully supported by the implementation
- Rejection sampling is independent of IS weighting
- Pure RS (`pure_rs`) uses bypass + geometric RS with `use_pure_rollout_correction=True` (no IS weights)
- Pure RS (`pg_rs`) uses bypass + geometric RS with `use_pure_rollout_correction=True` (no IS weights)

---

Expand Down Expand Up @@ -490,11 +491,11 @@ algorithm:

---

### 6. Pure IS (Off-Policy REINFORCE) (`pure_is`)
### 6. Policy Gradient with IS (`pg_is`)

**Configuration:**
```python
config = RolloutCorrectionConfig.pure_is(threshold=2.0)
config = RolloutCorrectionConfig.pg_is(threshold=2.0)
```

**Components:**
Expand Down Expand Up @@ -523,11 +524,11 @@ algorithm:

---

### 7. Pure Policy Gradient with Rejection Sampling (`pure_rs`)
### 7. Policy Gradient with Rejection Sampling (`pg_rs`)

**Configuration:**
```python
config = RolloutCorrectionConfig.pure_rs(
config = RolloutCorrectionConfig.pg_rs(
rs_threshold=1.001,
veto_threshold=1e-4
)
Expand Down Expand Up @@ -1309,7 +1310,7 @@ Rollout Correction provides a unified framework for handling general off-policy
- ✅ Supports diverse scenarios: policy mismatch, staleness, replay buffers, off-policy algorithms
- ✅ Numerical stability with safety bounds and rejection mechanisms
- ✅ Comprehensive diagnostics: KL, perplexity, χ² divergence
- ✅ Flexible methods from token-level (token_is) to sequence-level (seq_is_rs)
- ✅ Flexible methods from token-level to sequence-level aggregation
- ✅ Memory-efficient implementation

## References
Expand Down
17 changes: 9 additions & 8 deletions docs/algo/rollout_corr_math.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ The transition dynamics $p(s_{t+1}|s_t, a_t)$ and initial state $p(s_0)$ cancel
- **Off-policy capable**: Can learn from any behavior policy via importance sampling
- **No trust region**: Policy updates not constrained

**Implementation in verl:** The `pure_is` method implements off-policy REINFORCE with truncated importance sampling.
**Implementation in verl:** The `pg_is` method implements off-policy REINFORCE with truncated importance sampling.

### 1.2 PPO: Adding Trust Region Control

Expand Down Expand Up @@ -501,10 +501,11 @@ where $\bar{w}_j = \frac{1}{T_j}\sum_{t=1}^{T_j} w_{j,t} \cdot m_{j,t}$ is the p
| `decoupled_seq_is_rs()` | Decoupled | sequence | sequence | Sequence IS + sequence RS |
| `decoupled_geo_rs()` | Decoupled | - | geometric + veto | Geometric RS + veto, no IS weights |
| `ppo_is_bypass()` | Bypass | - | - | Bypass mode, skips old_log_prob |
| `pure_rs()` | Bypass | - | geometric + veto | Pure policy gradient with RS (no IS weights) |
| `pure_is()` | Bypass | sequence | - | Pure policy gradient with IS |
| `pg_rs()` | Bypass | - | geometric + veto | Policy gradient with RS (no IS weights) |
| `pg_is()` | Bypass | sequence | - | Policy gradient with IS |
| `disabled()` | - | - | - | Metrics only, no correction |

**Note:** All presets use PPO loss except `pure_is()` and `pure_rs()` which use pure policy gradient (both require `use_pure_rollout_correction=True`).
**Note:** All presets use PPO loss except `pg_is()` and `pg_rs()` which use policy gradient (both require `use_pure_rollout_correction=True`).

#### Additional Supported Combinations (Manual Configuration)

Expand Down Expand Up @@ -546,7 +547,7 @@ config = RolloutCorrectionConfig(
- Rejection sampling can be added to any combination
- Veto is independent and can be added to any combination
- Geometric aggregation is typically used for RS only (not IS weighting)
- Pure RS (`pure_rs`) uses bypass + geometric RS with `use_pure_rollout_correction=True` for pure policy gradient (no IS weights)
- Pure RS (`pg_rs`) uses bypass + geometric RS with `use_pure_rollout_correction=True` for pure policy gradient (no IS weights)
- All combinations in the table above are valid and supported by the implementation

---
Expand Down Expand Up @@ -655,8 +656,8 @@ $$

| Method | Theory | Policies | PPO Clip | IS Correction | Correctness | Speed |
|--------|--------|----------|----------|---------------|-------------|-------|
| `pure_is` | Off-policy REINFORCE | 2 (rollout, θ) | ❌ | ✅ Seq-level | ✅ Correct | **Fast** |
| `pure_rs` | Pure PG + Geo RS | 2 (rollout, θ) | ❌ | Rejection only | ✅ Correct | **Fast** |
| `pg_is` | Off-policy REINFORCE | 2 (rollout, θ) | ❌ | ✅ Seq-level | ✅ Correct | **Fast** |
| `pg_rs` | Pure PG + Geo RS | 2 (rollout, θ) | ❌ | Rejection only | ✅ Correct | **Fast** |
| Naive LLM-RL | Incorrect PPO usage | 2 (old, θ) | ✅ | ❌ | ⚠️ Incorrect | Standard |
| `ppo_is_bypass` | PPO (rollout as prox) | 2 (rollout, θ) | ✅ | ❌ | ✅ Correct | **Fast** |
| `decoupled_token_is` | Decoupled PPO | 3 (rollout, old, θ) | ✅ | ✅ Token-level | ✅ Correct | Standard |
Expand All @@ -674,7 +675,7 @@ $$
**Algorithm properties:**
- **Batch size invariance**: Decoupled mode with three policies (`decoupled_token_is`, `decoupled_seq_is`) achieves batch size invariance
- **Computational efficiency**: Bypass mode (`ppo_is_bypass`) skips `old_log_prob` computation
- **Pure policy gradient**: `pure_is` implements off-policy REINFORCE without PPO clipping
- **Pure policy gradient**: `pg_is` implements off-policy REINFORCE without PPO clipping

### 5.3 Decoupled Mode vs Bypass Mode

Expand Down
1 change: 1 addition & 0 deletions verl/trainer/config/_generated_ppo_megatron_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,7 @@ algorithm:
rollout_token_veto_threshold: null
bypass_old_logprob_for_rollout: false
use_pure_rollout_correction: false
rollout_is_batch_normalize: false
_target_: verl.trainer.config.AlgoConfig
gamma: 1.0
lam: 1.0
Expand Down
1 change: 1 addition & 0 deletions verl/trainer/config/_generated_ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,7 @@ reward_model:
rollout_token_veto_threshold: null
bypass_old_logprob_for_rollout: false
use_pure_rollout_correction: false
rollout_is_batch_normalize: false
custom_reward_function:
path: null
name: compute_score
Expand Down
94 changes: 37 additions & 57 deletions verl/trainer/config/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,12 @@ class RolloutCorrectionConfig(BaseConfig):
config = RolloutCorrectionConfig()

# Use presets
config = RolloutCorrectionConfig.token_is() # Token-level IS
config = RolloutCorrectionConfig.seq_is_rs() # Sequence-level IS + rejection sampling
config = RolloutCorrectionConfig.seq_is() # Sequence-level IS
config = RolloutCorrectionConfig.decoupled_token_is() # Decoupled mode with token-level IS
config = RolloutCorrectionConfig.decoupled_seq_is_rs() # Decoupled mode with sequence IS + RS
config = RolloutCorrectionConfig.decoupled_seq_is() # Decoupled mode with sequence-level IS
config = RolloutCorrectionConfig.ppo_is_bypass() # Bypass mode
config = RolloutCorrectionConfig.pg_is() # Policy gradient with IS
config = RolloutCorrectionConfig.pg_rs() # Policy gradient with RS

Reference:
Liu, Li, Fu, Wang, Liu, Shen (2025)
Expand All @@ -156,46 +159,43 @@ class RolloutCorrectionConfig(BaseConfig):
rollout_is_batch_normalize: bool = False

@classmethod
def token_is(cls, threshold: float = 2.0) -> "RolloutCorrectionConfig":
"""Token-level Truncated Importance Sampling.
def decoupled_token_is(cls, threshold: float = 2.0) -> "RolloutCorrectionConfig":
"""Decoupled Mode with Token-level Importance Sampling.

IS weight correction at token level.
IS weight correction at token level in decoupled mode (three policies).

Args:
threshold (float): Upper threshold for IS weights. Default: 2.0

Returns:
RolloutCorrectionConfig configured for token-level IS
RolloutCorrectionConfig configured for decoupled mode with token-level IS
"""
return cls(rollout_is="token", rollout_is_threshold=threshold, rollout_rs=None)

@classmethod
def token_tis(cls, threshold: float = 2.0) -> "RolloutCorrectionConfig":
"""Alias for token_is()."""
return cls.token_is(threshold=threshold)
def decoupled_seq_is(cls, threshold: float = 2.0) -> "RolloutCorrectionConfig":
"""Decoupled Mode with Sequence-level Importance Sampling.

@classmethod
def seq_is(cls, threshold: float = 2.0) -> "RolloutCorrectionConfig":
"""Sequence-level Truncated Importance Sampling.
IS weight correction at sequence level in decoupled mode (three policies).

Args:
threshold (float): Upper threshold for IS weights. Default: 2.0

Returns:
RolloutCorrectionConfig configured for sequence-level IS
RolloutCorrectionConfig configured for decoupled mode with sequence-level IS
"""
return cls(rollout_is="sequence", rollout_is_threshold=threshold, rollout_rs=None)

@classmethod
def seq_is_rs(
def decoupled_seq_is_rs(
cls,
is_threshold: float = 2.0,
rs_threshold: float = 2.0,
rs_threshold_lower: Optional[float] = None,
) -> "RolloutCorrectionConfig":
"""Sequence-level IS with Rejection Sampling (MIS).
"""Decoupled Mode with Sequence-level IS + Rejection Sampling.

Sequence-level IS with sequence-level rejection sampling.
Sequence-level IS with sequence-level rejection sampling in decoupled mode.
Rejects entire sequences based on sequence-level IS weight.

Args:
Expand All @@ -205,7 +205,7 @@ def seq_is_rs(
If None, auto-computed as reciprocal of rs_threshold. Default: None

Returns:
RolloutCorrectionConfig configured for sequence IS + RS
RolloutCorrectionConfig configured for decoupled mode with sequence IS + RS
"""
return cls(
rollout_is="sequence",
Expand All @@ -216,24 +216,15 @@ def seq_is_rs(
)

@classmethod
def seq_mis(cls, threshold: float = 2.0) -> "RolloutCorrectionConfig":
"""Alias for seq_is_rs()."""
return cls.seq_is_rs(
is_threshold=threshold,
rs_threshold=threshold,
rs_threshold_lower=0,
)

@classmethod
def geo_rs(
def decoupled_geo_rs(
cls,
rs_threshold: float = 1.001,
rs_threshold_lower: Optional[float] = None,
veto_threshold: float = 1e-4,
) -> "RolloutCorrectionConfig":
"""Geometric Rejection Sampling with Veto.
"""Decoupled Mode with Geometric Rejection Sampling.

Uses geometric mean for rejection sampling at sequence level,
Uses geometric mean for rejection sampling at sequence level in decoupled mode,
with additional veto mechanism. Geometric mean is extremely sensitive to outliers,
requiring very tight thresholds close to 1.0.

Expand All @@ -244,7 +235,7 @@ def geo_rs(
veto_threshold (float): Per-token veto threshold. Default: 1e-4

Returns:
RolloutCorrectionConfig configured for geometric RS with veto
RolloutCorrectionConfig configured for decoupled mode with geometric RS + veto
"""
return cls(
rollout_is=None,
Expand All @@ -254,20 +245,6 @@ def geo_rs(
rollout_token_veto_threshold=veto_threshold,
)

@classmethod
def geo_mis(
cls,
rs_threshold: float = 1.001,
rs_threshold_lower: float = 0.999,
veto_threshold: float = 1e-4,
) -> "RolloutCorrectionConfig":
"""Alias for geo_rs()."""
return cls.geo_rs(
rs_threshold=rs_threshold,
rs_threshold_lower=rs_threshold_lower,
veto_threshold=veto_threshold,
)

@classmethod
def ppo_is_bypass(cls, threshold: float = 2.0) -> "RolloutCorrectionConfig":
"""PPO with IS Correction in Bypass Mode.
Expand All @@ -290,17 +267,17 @@ def ppo_is_bypass(cls, threshold: float = 2.0) -> "RolloutCorrectionConfig":
)

@classmethod
def pure_is(cls, threshold: float = 2.0) -> "RolloutCorrectionConfig":
"""Pure Policy Gradient with IS Correction.
def pg_is(cls, threshold: float = 2.0) -> "RolloutCorrectionConfig":
"""Policy Gradient with IS Correction.

Uses pure policy gradient loss with explicit IS correction.
Uses policy gradient loss with explicit IS correction.
No PPO clipping.

Args:
threshold (float): Upper threshold for IS weights. Default: 2.0

Returns:
RolloutCorrectionConfig configured for pure IS mode
RolloutCorrectionConfig configured for PG with IS
"""
return cls(
rollout_is="sequence",
Expand All @@ -311,15 +288,15 @@ def pure_is(cls, threshold: float = 2.0) -> "RolloutCorrectionConfig":
)

@classmethod
def pure_rs(
def pg_rs(
cls,
rs_threshold: float = 1.001,
rs_threshold_lower: Optional[float] = None,
veto_threshold: float = 1e-4,
) -> "RolloutCorrectionConfig":
"""Pure Rejection Sampling.
"""Policy Gradient with Rejection Sampling.

Pure rejection sampling (no IS weights) using geometric mean in bypass mode.
Policy gradient with rejection sampling (no IS weights) using geometric mean in bypass mode.
Skips old_log_prob computation for faster execution.

Args:
Expand All @@ -329,7 +306,7 @@ def pure_rs(
veto_threshold (float): Per-token veto threshold. Default: 1e-4

Returns:
RolloutCorrectionConfig configured for pure RS
RolloutCorrectionConfig configured for PG with RS
"""
return cls(
rollout_is=None,
Expand Down Expand Up @@ -374,10 +351,13 @@ class AlgoConfig(BaseConfig):
Addresses off-policy issues from policy mismatch, model staleness, and general distribution shifts.

Set to None to disable entirely. Use factory methods for common presets:
- RolloutCorrectionConfig.token_is() - Token-level IS
- RolloutCorrectionConfig.seq_is_rs() - Sequence-level IS + rejection sampling
- RolloutCorrectionConfig.seq_is() - Sequence-level IS (unbiased estimator)
- RolloutCorrectionConfig.geo_rs() - Geometric RS with veto
- RolloutCorrectionConfig.decoupled_token_is() - Decoupled mode with token-level IS
- RolloutCorrectionConfig.decoupled_seq_is() - Decoupled mode with sequence-level IS
- RolloutCorrectionConfig.decoupled_seq_is_rs() - Decoupled mode with sequence IS + RS
- RolloutCorrectionConfig.decoupled_geo_rs() - Decoupled mode with geometric RS + veto
- RolloutCorrectionConfig.ppo_is_bypass() - Bypass mode (skips old_log_prob)
- RolloutCorrectionConfig.pg_is() - Policy gradient with IS
- RolloutCorrectionConfig.pg_rs() - Policy gradient with RS

For backward compatibility, you can still pass a dict, which will be converted to
RolloutCorrectionConfig automatically.
Expand Down
24 changes: 22 additions & 2 deletions verl/trainer/config/algorithm/rollout_correction.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,30 @@
# Rollout Correction: corrects distribution mismatch between rollout and training policies
# Override via CLI: algorithm.rollout_correction.rollout_is="token_is"
# Rollout Correction: corrects off-policy distribution shifts
# See documentation: docs/algo/rollout_corr.md
# Use presets: RolloutCorrectionConfig.decoupled_seq_is(), .pg_is(), etc.

# IS aggregation level: null (disabled), "token" (per-token), "sequence" (per-sequence)
rollout_is: null

# Upper threshold for IS weight truncation (typical: 2.0-5.0)
rollout_is_threshold: 2.0

# RS aggregation level: null (disabled), "token", "sequence", "geometric"
rollout_rs: null

# Upper threshold for rejection sampling (null = use rollout_is_threshold)
rollout_rs_threshold: null

# Lower threshold for rejection sampling (null = auto-compute as 1/upper)
rollout_rs_threshold_lower: null

# Per-token veto threshold for catastrophic outliers (null = disabled)
rollout_token_veto_threshold: null

# Operating mode: false = Decoupled (3 policies), true = Bypass (2 policies)
bypass_old_logprob_for_rollout: false

# Loss function: false = PPO with clipping, true = Policy gradient (no clipping)
use_pure_rollout_correction: false

# Batch normalize IS weights: false = raw weights, true = normalize to mean=1.0
rollout_is_batch_normalize: false