Skip to content

Commit 9e60f01

Browse files
[doc,algo] feat: Rollout Correction - Fix Metrics, Add Documentation, and Add Batch Normalization (verl-project#4070)
## Overview This PR fixes bugs, refactors configuration for semantic clarity, and adds batch normalization support to the rollout correction implementation introduced in PR verl-project#3984. --- ## Bug Fixes ### 1. Metrics Computation Running in Wrong Mode ⚠️ **Problem**: Rollout correction metrics were computed in **bypass mode** instead of **decoupled mode**, making them meaningless. **Root Cause**: Incorrect condition at [ray_trainer.py:1177-1180](verl/trainer/ppo/ray_trainer.py#L1177) ```python # BEFORE (incorrect - runs in bypass mode) if rollout_corr_config is not None and "rollout_log_probs" in batch.batch: batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch) ``` ```python # AFTER (correct - runs in decoupled mode only) if (rollout_corr_config is not None and "rollout_log_probs" in batch.batch and not bypass_recomputing_logprobs): # Only in decoupled mode batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch, rollout_corr_config) ``` **Impact**: - IS weights and rejection sampling metrics are now computed only when meaningful (decoupled mode with 3 policies) - In bypass mode (2 policies), actor now correctly computes metrics from evolving π_θ vs π_rollout **Related Changes**: - Added clarifying comments in [ray_trainer.py:1104-1107](verl/trainer/ppo/ray_trainer.py#L1104) (operating mode selection) - Added clarifying comments in [ray_trainer.py:1175-1177](verl/trainer/ppo/ray_trainer.py#L1175) (metrics behavior) - Fixed actor metrics computation in [dp_actor.py](verl/workers/actor/dp_actor.py), [megatron_actor.py](verl/workers/actor/megatron_actor.py) --- ## Configuration Refactor (Semantic Clarity) ### 2. Variable Renaming Renamed config variables to accurately reflect their semantics: | Old Name | New Name | Rationale | |----------|----------|-----------| | `bypass_old_logprob_for_rollout` | `bypass_mode` | Directly describes the operating mode (2-policy vs 3-policy) | | `use_pure_rollout_correction` | `use_policy_gradient` | Reflects actual choice: policy gradient loss vs Q-function loss | **Before** ([algorithm.py @ e8ad3cd](https://github.com/volcengine/verl/blob/e8ad3cdb/verl/trainer/config/algorithm.py)): ```python bypass_old_logprob_for_rollout: bool = False # Unclear what "bypass" means use_pure_rollout_correction: bool = False # "Pure" is vague ``` **After** ([algorithm.py @ HEAD](verl/trainer/config/algorithm.py#L157)): ```python bypass_mode: bool = False # Clear: bypass or decoupled mode use_policy_gradient: bool = False # Clear: PG or Q-function loss ``` **Files Updated**: - Core config: [algorithm.py](verl/trainer/config/algorithm.py), [rollout_correction.yaml](verl/trainer/config/algorithm/rollout_correction.yaml) - Implementation: [ray_trainer.py](verl/trainer/ppo/ray_trainer.py), [rollout_corr_helper.py](verl/trainer/ppo/rollout_corr_helper.py), [core_algos.py](verl/trainer/ppo/core_algos.py) - Examples: [run_with_rollout_corr.sh](examples/rollout_correction/run_with_rollout_corr.sh), [run_dapo_qwen2.5_32b_rollout_corr.sh](recipe/dapo/run_dapo_qwen2.5_32b_rollout_corr.sh) - Generated configs: [_generated_ppo_trainer.yaml](verl/trainer/config/_generated_ppo_trainer.yaml), [_generated_ppo_megatron_trainer.yaml](verl/trainer/config/_generated_ppo_megatron_trainer.yaml) --- ## New Feature: Batch Normalization ### 3. IS Weight Batch Normalization **Added**: `rollout_is_batch_normalize` config parameter ([algorithm.py:159](verl/trainer/config/algorithm.py#L159)) ```python rollout_is_batch_normalize: bool = False ``` **Purpose**: - Normalizes importance sampling weights to have mean=1.0 within each batch - Aligns normalization scope with IS aggregation level (token/sequence/geometric) - Helps stabilize training when policy drift is large **Behavior**: - `True`: IS weights normalized so mean=1.0 per batch (reduces variance) - `False`: Raw truncated IS weights used (standard behavior, default) **Documentation**: - Mathematical formulation: [rollout_corr_math.md §3.4](docs/algo/rollout_corr_math.md) - Usage guide: [rollout_corr.md](docs/algo/rollout_corr.md) --- ## Documentation Overhaul ### 4. File Reorganization **Moved documentation to `docs/algo/`**: - `docs/advance/rollout_corr.md` → `docs/algo/rollout_corr.md` (+439 additions) - `docs/advance/rollout_corr_math.md` → `docs/algo/rollout_corr_math.md` (+459 additions) **Deleted redundant file**: - `examples/rollout_correction/README.md` (-253 lines) **Updated references**: - [docs/index.rst](docs/index.rst): Updated paths - [docs/advance/fully_async.md](docs/advance/fully_async.md): Updated cross-references ### 5. Preset Renaming for Clarity Renamed presets to clearly indicate operating mode: | Old Name | New Name | Operating Mode | Description | |----------|----------|----------------|-------------| | `token_is` | `decoupled_token_is` | Decoupled (3-policy) | Token-level IS weighting | | `seq_is` | `decoupled_seq_is` | Decoupled (3-policy) | Sequence-level IS weighting | | `geo_rs` | `decoupled_geo_rs` | Decoupled (3-policy) | Geometric rejection sampling | | `ppo_is_bypass` | `ppo_is_bypass` | Bypass (2-policy) | PPO with IS (unchanged) | | `pure_is` | `pg_is` | Bypass (2-policy) | Policy gradient + sequence IS | | N/A | `pg_rs` | Bypass (2-policy) | Policy gradient + geometric RS (new) | **Naming Convention**: - **Decoupled mode** presets: `decoupled_*` (requires old_log_prob computation) - **Bypass mode** presets: `pg_*` or `ppo_*` (skips old_log_prob computation) ### 6. Content Improvements **Cross-References**: - Added prominent links between [rollout_corr.md](docs/algo/rollout_corr.md) (usage guide) and [rollout_corr_math.md](docs/algo/rollout_corr_math.md) (mathematical foundations) **Clarified Loss Formulations**: - Changed examples from PPO to REINFORCE in [rollout_corr_math.md §3.3](docs/algo/rollout_corr_math.md) - **Rationale**: Separates IS weight mechanics from PPO clipping for clarity - Added note that REINFORCE examples can be combined with PPO clipping **New Sections**: - Dedicated batch normalization section: [rollout_corr_math.md §3.4](docs/algo/rollout_corr_math.md) - Improved operating mode explanations throughout --- ## Code Quality Improvements ### 7. Enhanced Comments and Documentation **Trainer Logic** ([ray_trainer.py](verl/trainer/ppo/ray_trainer.py)): - Lines 1104-1107: Operating mode selection logic - Lines 1175-1177: Metrics computation behavior explanation **Policy Loss** ([core_algos.py](verl/trainer/ppo/core_algos.py)): - Enhanced docstrings for `compute_policy_loss_with_rollout_correction` - Clarified when to use policy gradient vs Q-function loss **Actor Workers** ([dp_actor.py](verl/workers/actor/dp_actor.py), [megatron_actor.py](verl/workers/actor/megatron_actor.py)): - Added comments explaining bypass mode metrics computation ### 8. Code Simplification **Removed Unused Logic** ([rollout_corr_helper.py](verl/trainer/ppo/rollout_corr_helper.py)): - Removed unnecessary config parameters from metrics computation - Removed unused IS weight processing logic - Simplified metrics calculation flow **Improved Variable Reuse**: - Reused `need_recomputation` variable instead of redundant bypass mode checks - Reduced code duplication --- ## Commit History <details> <summary>18 commits (click to expand)</summary> 1. `7c9e41da` - fix(rollout_corr): compute metrics in actor for bypass mode and fix trainer bugs 2. `96ae2be1` - docs(rollout_corr): move to algo/ and add pure_rs preset 3. `c0ea9bdc` - feat(rollout_corr): add batch normalization option for IS weights 4. `7de6c5f9` - docs(rollout_corr_math): use REINFORCE in aggregation loss examples for clarity 5. `2b34cfee` - refactor(rollout_corr): simplify metrics computation by removing unused config and IS weight logic 6. `0c42f85a` - docs(rollout_corr): add prominent cross-references between usage and math docs 7. `fef8a48f` - docs(rollout_corr_math): add dedicated section for batch normalization 8. `08cc9c7d` - fix: docstring of compute_policy_loss_with_rollout_correction 9. `437a4aba` - feat: reuse need_recomputation instead of bypass_mode 10. `5f9a53bf` - feat: improve comments 11. `b2f63709` - feat: improve comments 12. `79cdbf2f` - feat: refactor bypass_recomputing_logprobs 13. `62e32701` - feat(rollout_corr): align batch normalization with IS aggregation level 14. `b5c19ff7` - docs(rollout_corr): rename decoupled mode presets for clarity and update examples 15. `11f9aa05` - fix(rollout_corr): correct metrics computation to run in decoupled mode only 16. `58565cb0` - docs(rollout_corr): rename presets for clarity and consistency 17. `8bb1a0e0` - refactor(rollout_corr): rename config vars for semantic clarity 18. `6002c00c` - refactor(rollout_corr): update implementation to use renamed config variables </details> --- ## Summary This PR systematically improves the rollout correction implementation through three key areas: 1. **Bug Fixes**: Corrected metrics computation to run in the appropriate mode 2. **Semantic Clarity**: Renamed variables to accurately reflect their purpose (`bypass_mode`, `use_policy_gradient`) 3. **Feature Addition**: Added batch normalization option for IS weights with comprehensive documentation All changes maintain backward compatibility while significantly improving code clarity, correctness, and maintainability. --------- Co-authored-by: Shawn/Yuxuan Tong <tongyuxuan361@gmail.com>
1 parent 4ac02e6 commit 9e60f01

File tree

17 files changed

+988
-700
lines changed

17 files changed

+988
-700
lines changed

docs/advance/fully_async.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,11 +166,11 @@ https://github.com/ArronHZG/verl-community/blob/recipe/async_policy/docs/fully_a
166166

167167
During the training process, we observed that metrics and response lengths may become unstable in the later
168168
stages of training. To mitigate this issue, we can use
169-
the [Rollout Correction](https://verl.readthedocs.io/en/latest/advance/rollout_corr.html)
169+
the [Rollout Correction](https://verl.readthedocs.io/en/latest/algo/rollout_corr.html)
170170
technique for importance sampling and rejection sampling. To utilize Rollout Correction, we need to compute log_prob using
171171
the training engine, which requires enabling this switch.
172172
Additionally, when compute_prox_log_prob and Rollout Correction are enabled under mode d
173-
(async stream pipeline with partial rollout), our implementation approximates `Areal's Decoupled PPO`.
173+
(async stream pipeline with partial rollout), our implementation follows `Decoupled PPO` that is described in [Mathmatics of Rollout Correction](https://verl.readthedocs.io/en/latest/algo/rollout_corr_math.html).
174174

175175
### Supported Modes
176176

0 commit comments

Comments
 (0)