[BREAKING][rollout, trainer, algo] feat: comprehensive rollout importance sampling implementation#3694
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a comprehensive Rollout Importance Sampling (IS) system, which is a significant improvement over the old TIS implementation. The changes are extensive, covering core logic, configuration, documentation, and testing. The new implementation is well-structured, particularly the core logic in mismatch_helper.py which demonstrates careful consideration for numerical stability and memory efficiency. The addition of extensive metrics is also a great enhancement for monitoring and debugging.
I have two main points of feedback. The first is a critical concern regarding configuration management, where parameters for Rollout IS are defined in multiple places, potentially leading to silent failures if not synchronized correctly. The second is a high-severity issue related to input validation for the veto threshold, which could cause training to fail with NaN losses if misconfigured. Addressing these points will improve the robustness and usability of this new feature.
706a0e9 to
4a88913
Compare
Replace legacy tis_imp_ratio_cap with a complete rollout importance sampling (IS) implementation to correct distribution mismatch between rollout policy and training policy. **Breaking Change**: Removes tis_imp_ratio_cap parameter. Users must migrate to new rollout_is_* parameters. Backward compatibility achieved via configuration equivalence (see migration guide). **New Features:** - Three aggregation levels: token (biased, stable), sequence (unbiased, high variance), geometric (balanced) - Two bounding modes: truncate (TIS - upper only), clip (CIS - dual bounds) - Veto mechanism to reject sequences with catastrophic outliers - 37+ comprehensive metrics (26 rollout IS + 11 mismatch diagnostics) - Log-space computation for numerical stability - Memory-efficient implementation to prevent CUDA OOM **Core Changes:** - NEW: verl/trainer/ppo/mismatch_helper.py (445 lines) - compute_rollout_importance_weights(): main IS computation - compute_mismatch_metrics(): PPL, KL divergence diagnostics - All functions reference "When Speed Kills Stability" in docstrings - MODIFIED: verl/trainer/ppo/core_algos.py - Replaced 5-line TIS (lines 962-967) with 30-line rollout IS - Added metrics collection and return path - MODIFIED: verl/workers/actor/dp_actor.py - Integrated rollout IS metrics collection into training loop - Added mismatch metrics computation **Configuration:** - Added 6 new algorithm parameters: rollout_is, rollout_is_threshold, rollout_is_threshold_lower, rollout_is_level, rollout_is_mode, rollout_is_veto_threshold - Updated AlgoConfig, ActorConfig, actor.yaml, ppo_trainer.yaml - Migrated recipe/dapo/run_dapo_qwen2.5_32b_tis.sh to demonstrate equivalence **Documentation:** - NEW: ROLLOUT_IS_MIGRATION.md (324 lines) - comprehensive migration guide - NEW: examples/rollout_importance_sampling/ - example script + README - UPDATED: docs/examples/config.rst - parameter documentation **Testing:** - NEW: test_rollout_is.py (262 lines) - standalone unit tests - NEW: tests/trainer/ppo/test_rollout_is_integration.py (291 lines) - All tests passing ✓ **Migration (Backward Compatible Behavior):** Before: ```yaml actor: tis_imp_ratio_cap: 2.0 After (equivalent behavior): algorithm: rollout_is: true rollout_is_threshold: 2.0 rollout_is_threshold_lower: null # No lower bound rollout_is_level: token rollout_is_mode: truncate rollout_is_veto_threshold: null # No veto The new implementation with these settings produces identical behavior to the old TIS, while enabling advanced features when needed. References: When Speed Kills Stability: https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda Off-policy RL: https://fengyao.notion.site/off-policy-rl
4a88913 to
9999052
Compare
66a5759 to
eff1b95
Compare
Address PR feedback and improve rollout IS implementation efficiency. **File Organization:** - Move ROLLOUT_IS_MIGRATION.md to docs/advance/ - Move test_rollout_is.py to tests/trainer/ppo/ - Change metrics prefix from "actor/" to "mismatch/" **Architecture Improvements:** - Keep all policy loss functions returning 4 values (no API changes) - Move rollout IS weight computation from inside policy loss to dp_actor.py - Compute IS weights and metrics once before calling policy loss - Metrics computation separated to dp_actor.py (separation of concerns) **API Cleanup:** - Remove unnecessary rollout_log_probs parameter from all 6 policy loss functions (not needed since IS weights are pre-computed in dp_actor.py) - Add rollout_is_weights parameter to all policy loss functions instead **Bug Fixes:** - Fix geo_mean shape handling: rollout_is_weights is always 2D (batch_size, seq_length) regardless of rollout_is_level (expanded via .expand_as() for sequence/geometric levels) - Simplify geo_mean IS weight aggregation logic **Efficiency Gains:** - Eliminate duplicate IS weight computation (~50% reduction in overhead) - Apply IS weights to pg_losses before aggregation in all loss functions - All policy loss functions now support IS correction - Simplified integration tests (302 → 175 lines) Net: -101 lines, cleaner architecture, better performance.
eff1b95 to
87759e5
Compare
Move mismatch diagnostic metrics (KL, PPL, etc.) from worker to trainer level for consistency and to avoid duplication. Keep IS weight metrics in worker as they are tightly coupled with loss computation. Changes: - Add compute_mismatch_metrics_batch() to metric_utils.py - Centralized function matching existing metrics pattern - Imports compute_mismatch_metrics() from mismatch_helper.py - Adds 'mismatch/' prefix to all metrics - Handles missing data gracefully - Update ray_trainer.py to call centralized function - Import compute_mismatch_metrics_batch - Call after other metrics collection - Remove duplicate mismatch diagnostics from dp_actor.py - Keep rollout IS metrics (coupled with IS weights used in loss) - Remove redundant mismatch diagnostic computation - Add clear documentation comment - Update mismatch_helper.py documentation - Clarify usage patterns (helper vs centralized) - Keep compute_mismatch_metrics() for internal use and tests Architecture: mismatch_helper.py: Contains logic (single source of truth) metric_utils.py: Centralized wrapper for trainer ray_trainer.py: Calls wrapper (centralized entry point) dp_actor.py: Computes IS metrics (operational, coupled with loss) Tests: Import from mismatch_helper (no changes needed) Benefits: - Mismatch diagnostics computed once at trainer level - IS metrics stay in worker (operationally necessary) - No duplicate computation - Consistent with existing metrics pattern - Tests require no changes
Pr 3694 megatron
- Add input validation assertions for eos_mask and response_mask - Fix variance and ESS to use consistent mean from clamped weights - Remove redundant eos_mask.any() checks after initial assertion - Update docstring to accurately describe mixed clamping strategy
Move importance sampling weight computation from workers to trainer for efficiency and consistency, following open_verl architecture. - ray_trainer: compute IS weights centrally after compute_advantage(), add to batch via union(), collect metrics with "mismatch/" prefix - dp_actor/megatron_actor: remove local IS computation (~75 lines), assert and extract pre-computed weights from batch - test_rollout_is_integration: update test to use pre-computed weights Benefits: single computation point (CPU, not N×M on GPU), all workers use identical weights, cleaner data flow enforced via assertions.
Move rollout_is configuration from actor to algorithm for better separation of concerns. Add rollout_is flag to enable metrics-only mode (default: false) for monitoring before enabling weight application. Changes: - Move 6 rollout_is params from ActorConfig to AlgoConfig - Update ray_trainer to read from algorithm config - Simplify dp_actor and megatron_actor (data checks only) - Update docs, examples, and tests - Net -44 lines of code Migration: actor_rollout_ref.actor.rollout_is_* → algorithm.rollout_is_*
798345c to
d651fce
Compare
|
|
||
| elif rollout_is_mode == "clip": | ||
| # Clipped IS (CIS): zero out weights outside [lower_threshold, upper_threshold] | ||
| clip_mask = (rollout_is_weights >= lower_threshold) & (rollout_is_weights <= upper_threshold) |
There was a problem hiding this comment.
Why MIS is implemented under clip mode? People may consider clip as clamp function, not a mask.
# Rollout Importance Sampling Framework related to #3694 ## Summary This PR introduces a comprehensive **Rollout Importance Sampling (IS)** framework to correct distribution mismatch between data-collecting (rollout) and training policies, a critical factor for ensuring stable and efficient model training in RL fine-tuning. This work is motivated by the analysis in our blog post, [When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda). If you find this implementation useful in your research, please consider citing: ```bibtex @misc{liu-li-2025, title = {When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch}, url = {https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Inference-Training-Mismatch-271211a558b7808d8b12d403fd15edda}, author = {Jiacai Liu and Yingru Li and Yuqian Fu and Jiawei Wang and Qian Liu and Yu Shen}, year = {2025}, month = {September}, } ``` --- ## Problem Statement When using different policies for rollout generation (e.g., vLLM with BFloat16) and training (e.g., FSDP with FP32), distribution mismatch occurs, leading to: - Biased gradient estimates - Training instability and collapse - Reduced sample efficiency - Poor convergence properties This framework addresses these issues through principled importance sampling correction. --- ## Key Features & Improvements ### 1. **Flexible Aggregation Levels** Three methods for calculating IS weights: - **`token`**: Per-token importance ratios - **`sequence`**: Product of per-token ratios - **`geometric`**: Geometric mean of ratios ### 2. **Advanced Bounding Modes** Two strategies to control weight variance: - **`truncate`** (TIS): Caps weights at upper threshold only, preserving gradients - **`mask`** (MIS): Zeros out weights outside bounds, more aggressive filtering ### 3. **Comprehensive Diagnostics** Detailed metrics to monitor distribution mismatch and training health: **Rollout IS Metrics** (automatically prefixed with `mismatch/`): - Health indicators: `rollout_is_eff_sample_size`, `rollout_is_mean` - Distribution statistics: `rollout_is_p25`, `rollout_is_p50`, `rollout_is_p75`, `rollout_is_p95`, `rollout_is_p99`, `rollout_is_max`, `rollout_is_min`, `rollout_is_std` - Diagnostics: `rollout_is_veto_fraction`, `rollout_is_catastrophic_token_fraction`, `rollout_is_masked_fraction` (mask mode) - Sequence-level statistics (for sequence/geometric modes): `rollout_is_seq_mean`, `rollout_is_seq_std`, `rollout_is_seq_max`, `rollout_is_seq_min`, etc. **Mismatch Metrics** (computed efficiently within IS weight computation): - KL Divergence: `mismatch_kl` (forward KL), `mismatch_k3_kl` (K3 estimator for stability) - Perplexity: `mismatch_training_ppl`, `mismatch_rollout_ppl`, `mismatch_ppl_ratio` - Log perplexity statistics: `mismatch_log_ppl_diff`, `mismatch_log_ppl_abs_diff`, `mismatch_log_ppl_diff_max`, `mismatch_log_ppl_diff_min` ### 4. **Outlier Mitigation** - **Veto mechanism**: Automatically discards samples with catastrophic importance weights (per-token ratios below threshold) - Prevents gradient corruption from extreme outliers - Configurable threshold (default: 1e-4) ### 5. **Numerical Stability** - All core computations in **log-space** to prevent underflow/overflow - Carefully designed clamping and bounding to maintain numerical precision - Safe handling of edge cases (zero probabilities, extreme ratios) ### 6. **Memory Efficiency** - Optimized computation to minimize CUDA memory usage - Efficient metric aggregation without large intermediate tensors - Suitable for large-scale distributed training ### 7. **Metrics-Only Mode** - Compute and monitor mismatch metrics **without** applying IS weights - Useful for: - Understanding distribution mismatch before intervention - Deciding whether IS correction is needed - A/B testing IS impact - Controlled by `algorithm.rollout_is` flag (independent of weight computation) ### 8. **Universal PPO Support** - Integrated with **all PPO variants**: vanilla, GSPO, GPG, Clip-Cov, KL-Cov, geo_mean - Consistent interface across different policy loss functions - Automatic weight application when enabled --- ## API and Configuration Changes ### Migration from Legacy TIS #### ❌ **Before (REMOVED)** ```yaml # Old TIS configuration - NO LONGER SUPPORTED actor_rollout_ref: actor: tis_imp_ratio_cap: 2.0 # Removed from actor config ``` The legacy implementation: - Only supported token-level truncation - No metrics tracking - Lacked numerical stability - Limited configurability #### ✅ **After (New Framework)** Configuration moved to `algorithm` section for better organization: ```yaml algorithm: # Main on/off switch: null = disabled, float = enabled rollout_is_threshold: 2.0 # Control weight application (independent of metrics computation) rollout_is: true # true = apply weights, false = metrics only # Optional: lower threshold (defaults to 1/upper if null) rollout_is_threshold_lower: null # Aggregation level: "token", "sequence", or "geometric" rollout_is_level: token # Bounding mode: "truncate" or "mask" rollout_is_mode: truncate # Veto threshold for catastrophic outliers (null = disabled) rollout_is_veto_threshold: 1e-4 # REQUIRED: Enable log probability calculation actor_rollout_ref: rollout: calculate_log_probs: true ``` ### Configuration Examples **1. Token-level truncation (recommended starting point)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: token rollout_is_mode: truncate ``` **2. Sequence-level masking (more aggressive)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: sequence rollout_is_mode: mask ``` **3. Metrics-only mode (monitoring without correction)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: false # Compute metrics but don't apply weights rollout_is_level: token rollout_is_mode: truncate ``` **Example script:** `bash examples/rollout_importance_sampling/run_with_rollout_is.sh` --- ## Code Changes Overview ### New Files (4 files, 1,442 lines) 1. **`verl/trainer/ppo/mismatch_helper.py`** (459 lines) - Core implementation of IS weight computation - Three aggregation levels: token, sequence, geometric - Two bounding modes: truncate, mask - Veto mechanism for outlier detection - Comprehensive metrics computation (IS + mismatch) - All computations in log-space for numerical stability - Memory-efficient design 2. **`docs/advance/rollout_is_migration.md`** (642 lines) - Comprehensive migration guide from legacy TIS - Detailed explanation of all configuration options - Recommended threshold ranges for each aggregation level - Troubleshooting guide and best practices - Metrics interpretation guide 3. **`examples/rollout_importance_sampling/README.md`** (242 lines) - Quick start guide with working examples - Configuration templates for common scenarios - Threshold tuning guidelines - Metrics monitoring instructions 4. **`examples/rollout_importance_sampling/run_with_rollout_is.sh`** (99 lines) - Complete working example script - Demonstrates token-level and sequence-level configurations - Ready to run with minimal modifications ### Modified Core Files (9 files) 1. **`verl/trainer/ppo/core_algos.py`** (~50 lines changed) - Removed legacy TIS logic (`tis_imp_ratio_cap`) - Added `rollout_is_weights` parameter to all policy loss functions - Unified IS weight application interface across all PPO variants: - `compute_policy_loss_vanilla` - `compute_policy_loss_gspo` - `compute_policy_loss_gpg` - `compute_policy_loss_clip_cov` - `compute_policy_loss_kl_cov` - `compute_policy_loss_geo_mean` - Special handling for `geo_mean` (sequence-level aggregation) 2. **`verl/trainer/ppo/ray_trainer.py`** (~52 lines added) - New method: `compute_rollout_importance_weights_and_add_to_batch()` - Centralized IS computation (once per batch, on driver) - Conditional weight distribution to workers based on `algorithm.rollout_is` - Metrics collection and aggregation - Integration with existing training loop 3. **`verl/trainer/config/algorithm.py`** (+18 lines) - Added 6 new Rollout IS parameters: - `rollout_is_threshold` (main on/off switch) - `rollout_is` (weight application control) - `rollout_is_threshold_lower` - `rollout_is_level` - `rollout_is_mode` - `rollout_is_veto_threshold` - Comprehensive docstrings explaining each parameter 4. **`verl/workers/config/actor.py`** (-1 line) - Removed deprecated `tis_imp_ratio_cap` parameter 5. **`verl/workers/actor/dp_actor.py`** (~26 lines changed) - Updated to use new `rollout_is_weights` parameter - Removed legacy TIS logic 6. **`verl/workers/actor/megatron_actor.py`** (~15 lines changed) - Updated to use new `rollout_is_weights` parameter - Removed legacy TIS logic 7. **Configuration Files** (4 files updated) - `verl/trainer/config/ppo_trainer.yaml` - `verl/trainer/config/ppo_megatron_trainer.yaml` - `verl/trainer/config/_generated_ppo_trainer.yaml` - `verl/trainer/config/_generated_ppo_megatron_trainer.yaml` - Added default Rollout IS configuration section with explanatory comments ### Testing (2 files, 530 lines) 1. **`tests/trainer/ppo/test_rollout_is.py`** (289 lines) - Unit tests for `mismatch_helper.py` - Coverage for all aggregation levels (token, sequence, geometric) - Coverage for all bounding modes (truncate, mask) - Veto mechanism tests - Edge case handling (zeros, extremes, empty sequences) - Numerical stability verification - Metrics correctness validation 2. **`tests/trainer/ppo/test_rollout_is_integration.py`** (241 lines) - Integration tests with PPO training loop - End-to-end workflow validation - Batch processing tests - Configuration validation - Metrics collection verification - Compatibility with distributed training ### Updated Recipes (2 files) 1. **`recipe/dapo/dapo_ray_trainer.py`** (+5 lines) - Updated imports to use new framework 2. **`recipe/dapo/run_dapo_qwen2.5_32b_tis.sh`** (~42 lines changed) - Migrated from legacy TIS to new Rollout IS configuration - Updated documentation and comments ### Documentation Updates (2 files) 1. **`docs/examples/config.rst`** (~22 lines changed) - Updated configuration examples - Added Rollout IS section 2. **`docs/index.rst`** (+1 line) - Added link to Rollout IS migration guide --- ## Implementation Highlights ### Centralized Architecture The new design follows a clean separation of concerns: ``` ray_trainer.py (driver) └─> compute_rollout_importance_weights_and_add_to_batch() └─> mismatch_helper.compute_rollout_importance_weights() ├─> Computes IS weights (token/sequence/geometric) ├─> Applies bounding (truncate/mask) ├─> Veto mechanism for outliers ├─> Computes IS metrics └─> Computes mismatch metrics (KL, PPL) └─> Conditionally adds weights to batch (if rollout_is=True) └─> Distributes batch to workers actor workers (dp_actor, megatron_actor) └─> Receive batch with rollout_is_weights (if enabled) └─> Pass weights to policy loss function core_algos.py └─> All policy loss functions accept rollout_is_weights └─> Apply weights if provided: pg_losses *= rollout_is_weights ``` ### Key Design Decisions 1. **Centralized Computation**: IS weights computed once on driver, not per worker - Reduces redundant computation - Ensures consistency across workers - Simplifies debugging and metrics collection 2. **Configuration in Algorithm**: Moved from actor config to algorithm config - Better conceptual organization (algorithm-level concern, not worker-level) - Easier to manage and validate - Consistent with other algorithm parameters 3. **Two-Level Control**: - `rollout_is_threshold`: Enables/disables entire system (null = off) - `rollout_is`: Controls weight application (true = apply, false = metrics only) - Allows flexible monitoring and gradual rollout 4. **Metrics Consolidation**: Mismatch metrics computed within IS weight computation - Eliminates duplicate computation - Reduces memory overhead - Maintains metric accuracy 5. **Universal PPO Support**: Single interface for all PPO variants - Minimal code changes required - Consistent behavior across algorithms - Easy to add new variants --- ## Migration Guide ### For Users of Legacy TIS **Step 1: Update your configuration file** ```yaml # OLD (remove this) actor_rollout_ref: actor: tis_imp_ratio_cap: 2.0 # NEW (add this) algorithm: rollout_is_threshold: 2.0 # Use same value as old tis_imp_ratio_cap rollout_is: true rollout_is_level: token rollout_is_mode: truncate # REQUIRED (add if not present) actor_rollout_ref: rollout: calculate_log_probs: true ``` **Step 2: Monitor metrics** The first time you run with the new configuration, check these metrics: - `mismatch/rollout_is_eff_sample_size`: Should be > 80% of batch size - `mismatch/rollout_is_veto_fraction`: Should be < 5% - `mismatch/rollout_is_mean`: Should be close to 1.0 **Step 3: Tune if needed** If effective sample size is too low: - Increase `rollout_is_threshold` - Try `rollout_is_mode: mask` with appropriate lower bound - Consider `rollout_is_level: sequence` for more aggressive correction For detailed guidance, see `docs/advance/rollout_is_migration.md`. ### For New Users Start with recommended defaults: ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: token rollout_is_mode: truncate actor_rollout_ref: rollout: calculate_log_probs: true ``` Run the example script to see it in action: ```bash bash examples/rollout_importance_sampling/run_with_rollout_is.sh ``` --- ## Testing ### Unit Tests - **289 lines** of comprehensive unit tests in `test_rollout_is.py` - Covers all aggregation levels, bounding modes, and edge cases - Validates numerical stability and correctness - Fast execution (~1-2 seconds) ### Integration Tests - **241 lines** of integration tests in `test_rollout_is_integration.py` - End-to-end workflow with PPO training loop - Distributed training compatibility - Metrics collection validation - Moderate execution time (~10-20 seconds) ### Running Tests ```bash # Run all Rollout IS tests pytest tests/trainer/ppo/test_rollout_is.py -v pytest tests/trainer/ppo/test_rollout_is_integration.py -v # Run specific test pytest tests/trainer/ppo/test_rollout_is.py::test_token_level_truncate -v ``` --- ## Metrics Reference ### Rollout IS Metrics (all prefixed with `mismatch/`) | Metric | Description | Ideal Range | |--------|-------------|-------------| | `rollout_is_eff_sample_size` | Effective number of samples after IS | > 80% of batch | | `rollout_is_mean` | Mean IS weight | ~1.0 | | `rollout_is_std` | Standard deviation of IS weights | Low variance | | `rollout_is_p25` | 25th percentile | ~0.8-1.0 | | `rollout_is_p50` | Median IS weight | ~1.0 | | `rollout_is_p75` | 75th percentile | ~1.0-1.2 | | `rollout_is_p95` | 95th percentile | < threshold | | `rollout_is_p99` | 99th percentile | < threshold | | `rollout_is_max` | Maximum weight | ≤ threshold | | `rollout_is_min` | Minimum weight | ≥ lower threshold (mask mode) | | `rollout_is_veto_fraction` | % sequences vetoed | < 5% | | `rollout_is_catastrophic_token_fraction` | % catastrophic tokens | < 1% | | `rollout_is_masked_fraction` | % tokens masked (mask mode) | Variable | ### Mismatch Metrics (all prefixed with `mismatch/`) | Metric | Description | What It Means | |--------|-------------|---------------| | `mismatch_kl` | Forward KL divergence | Distribution difference (rollout vs training) | | `mismatch_k3_kl` | K3 KL estimator | Stable KL estimate for small divergences | | `mismatch_training_ppl` | Training policy perplexity | Prediction difficulty of training policy | | `mismatch_rollout_ppl` | Rollout policy perplexity | Prediction difficulty of rollout policy | | `mismatch_ppl_ratio` | Ratio of training to rollout PPL | Relative prediction difficulty | | `mismatch_log_ppl_diff` | Log perplexity difference | Sequence-level PPL mismatch | | `mismatch_log_ppl_abs_diff` | Absolute log PPL difference | Magnitude of mismatch | | `mismatch_log_ppl_diff_max` | Max log PPL difference | Worst-case mismatch | | `mismatch_log_ppl_diff_min` | Min log PPL difference | Best-case mismatch | | `mismatch_training_log_ppl` | Log of training PPL | Log-scale training perplexity | | `mismatch_rollout_log_ppl` | Log of rollout PPL | Log-scale rollout perplexity | --- ## Performance Impact ### Memory - Minimal overhead: ~1-2% increase in peak memory usage - Efficient log-space computation - No large intermediate tensors ### Computation - Negligible impact on training speed: < 1% overhead - Centralized computation on driver (no per-worker redundancy) - Optimized tensor operations ### Training Stability - Significant improvement in stability when distribution mismatch exists - Faster convergence in many scenarios - Reduced risk of training collapse --- ## Breaking Changes > [!IMPORTANT] > This PR contains **BREAKING CHANGES** to the configuration API. ### Removed - `actor_rollout_ref.actor.tis_imp_ratio_cap`: No longer supported ### Migration Required All users of the legacy TIS implementation must update their configuration files. See the migration guide above or `docs/advance/rollout_is_migration.md` for detailed instructions. ### Backward Compatibility - No backward compatibility with legacy TIS - Configuration files with `tis_imp_ratio_cap` will raise validation errors - Affected recipes have been updated in this PR --- ## Pre-Submission Checklist - [x] Search for similar PRs: [https://github.com/volcengine/verl/pulls?q=is%3Apr+importance+sampling](https://github.com/volcengine/verl/pulls?q=is%3Apr+importance+sampling) - [x] Format PR title as `[{modules}] {type}: {description}` (checked by CI) - **Suggested title:** `[BREAKING][rollout, trainer, algo] feat: implement comprehensive Rollout Importance Sampling framework` - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md) - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting) - [x] Add/update [documentation](https://github.com/volcengine/verl/tree/main/docs) (3 new docs, 2 updated) - [x] Add unit and integration tests (530 lines of tests) - [x] Once PR is ready for CI, send message in `ci-request` channel --- ## References - **Blog post:** [When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda) - **Migration guide:** `docs/advance/rollout_is_migration.md` - **Examples:** `examples/rollout_importance_sampling/` - **Tests:** `tests/trainer/ppo/test_rollout_is*.py`
|
Hello, does anyone know why I'm getting this error? Traceback (most recent call last): max_response_length 32k, rollout.n 16, train_batch_size 16 |
…ance sampling implementation (verl-project#3694) This PR introduces a comprehensive **Rollout Importance Sampling (IS)** framework to correct distribution mismatch between data-collecting (rollout) and training policies, a critical factor for ensuring stable and efficient model training in RL fine-tuning. This work is motivated by the analysis in our blog post, [When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda). If you find this implementation useful in your research, please consider citing: ```bibtex @misc{liu-li-2025, title = {When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch}, url = {https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Inference-Training-Mismatch-271211a558b7808d8b12d403fd15edda}, author = {Jiacai Liu and Yingru Li and Yuqian Fu and Jiawei Wang and Qian Liu and Yu Shen}, year = {2025}, month = {September}, } ``` --- When using different policies for rollout generation (e.g., vLLM with BFloat16) and training (e.g., FSDP with FP32), distribution mismatch occurs, leading to: - Biased gradient estimates - Training instability and collapse - Reduced sample efficiency - Poor convergence properties This framework addresses these issues through principled importance sampling correction. --- Three methods for calculating IS weights: - **`token`**: Per-token importance ratios - **`sequence`**: Product of per-token ratios - **`geometric`**: Geometric mean of ratios Two strategies to control weight variance: - **`truncate`** (TIS): Caps weights at upper threshold only, preserving gradients - **`clip`** (CIS): Zeros out weights outside bounds, more aggressive filtering Detailed metrics to monitor distribution mismatch and training health: **Rollout IS Metrics** (automatically prefixed with `mismatch/`): - Health indicators: `rollout_is_eff_sample_size`, `rollout_is_mean` - Distribution statistics: `rollout_is_p25`, `rollout_is_p50`, `rollout_is_p75`, `rollout_is_p95`, `rollout_is_p99`, `rollout_is_max`, `rollout_is_min`, `rollout_is_std` - Diagnostics: `rollout_is_veto_fraction`, `rollout_is_catastrophic_token_fraction`, `rollout_is_clipped_fraction` (clip mode) - Sequence-level statistics (for sequence/geometric modes): `rollout_is_seq_mean`, `rollout_is_seq_std`, `rollout_is_seq_max`, `rollout_is_seq_min`, etc. **Mismatch Metrics** (computed efficiently within IS weight computation): - KL Divergence: `mismatch_kl` (forward KL), `mismatch_k3_kl` (K3 estimator for stability) - Perplexity: `mismatch_training_ppl`, `mismatch_rollout_ppl`, `mismatch_ppl_ratio` - Log perplexity statistics: `mismatch_log_ppl_diff`, `mismatch_log_ppl_abs_diff`, `mismatch_log_ppl_diff_max`, `mismatch_log_ppl_diff_min` - **Veto mechanism**: Automatically discards samples with catastrophic importance weights (per-token ratios below threshold) - Prevents gradient corruption from extreme outliers - Configurable threshold (default: 1e-4) - All core computations in **log-space** to prevent underflow/overflow - Carefully designed clipping and bounding to maintain numerical precision - Safe handling of edge cases (zero probabilities, extreme ratios) - Optimized computation to minimize CUDA memory usage - Efficient metric aggregation without large intermediate tensors - Suitable for large-scale distributed training - Compute and monitor mismatch metrics **without** applying IS weights - Useful for: - Understanding distribution mismatch before intervention - Deciding whether IS correction is needed - A/B testing IS impact - Controlled by `algorithm.rollout_is` flag (independent of weight computation) - Integrated with **all PPO variants**: vanilla, GSPO, GPG, Clip-Cov, KL-Cov, geo_mean - Consistent interface across different policy loss functions - Automatic weight application when enabled --- ```yaml actor_rollout_ref: actor: tis_imp_ratio_cap: 2.0 # Removed from actor config ``` The legacy implementation: - Only supported token-level truncation - No metrics tracking - Lacked numerical stability - Limited configurability Configuration moved to `algorithm` section for better organization: ```yaml algorithm: # Main on/off switch: null = disabled, float = enabled rollout_is_threshold: 2.0 # Control weight application (independent of metrics computation) rollout_is: true # true = apply weights, false = metrics only # Optional: lower threshold (defaults to 1/upper if null) rollout_is_threshold_lower: null # Aggregation level: "token", "sequence", or "geometric" rollout_is_level: token # Bounding mode: "truncate" or "clip" rollout_is_mode: truncate # Veto threshold for catastrophic outliers (null = disabled) rollout_is_veto_threshold: 1e-4 actor_rollout_ref: rollout: calculate_log_probs: true ``` **1. Token-level truncation (recommended starting point)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: token rollout_is_mode: truncate ``` **2. Sequence-level clipping (more aggressive)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: sequence rollout_is_mode: clip ``` **3. Metrics-only mode (monitoring without correction)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: false # Compute metrics but don't apply weights rollout_is_level: token rollout_is_mode: truncate ``` **Example script:** `bash examples/rollout_importance_sampling/run_with_rollout_is.sh` --- 1. **`verl/trainer/ppo/mismatch_helper.py`** (459 lines) - Core implementation of IS weight computation - Three aggregation levels: token, sequence, geometric - Two bounding modes: truncate, clip - Veto mechanism for outlier detection - Comprehensive metrics computation (IS + mismatch) - All computations in log-space for numerical stability - Memory-efficient design 2. **`docs/advance/rollout_is_migration.md`** (642 lines) - Comprehensive migration guide from legacy TIS - Detailed explanation of all configuration options - Recommended threshold ranges for each aggregation level - Troubleshooting guide and best practices - Metrics interpretation guide 3. **`examples/rollout_importance_sampling/README.md`** (242 lines) - Quick start guide with working examples - Configuration templates for common scenarios - Threshold tuning guidelines - Metrics monitoring instructions 4. **`examples/rollout_importance_sampling/run_with_rollout_is.sh`** (99 lines) - Complete working example script - Demonstrates token-level and sequence-level configurations - Ready to run with minimal modifications 1. **`verl/trainer/ppo/core_algos.py`** (~50 lines changed) - Removed legacy TIS logic (`tis_imp_ratio_cap`) - Added `rollout_is_weights` parameter to all policy loss functions - Unified IS weight application interface across all PPO variants: - `compute_policy_loss_vanilla` - `compute_policy_loss_gspo` - `compute_policy_loss_gpg` - `compute_policy_loss_clip_cov` - `compute_policy_loss_kl_cov` - `compute_policy_loss_geo_mean` - Special handling for `geo_mean` (sequence-level aggregation) 2. **`verl/trainer/ppo/ray_trainer.py`** (~52 lines added) - New method: `compute_rollout_importance_weights_and_add_to_batch()` - Centralized IS computation (once per batch, on driver) - Conditional weight distribution to workers based on `algorithm.rollout_is` - Metrics collection and aggregation - Integration with existing training loop 3. **`verl/trainer/config/algorithm.py`** (+18 lines) - Added 6 new Rollout IS parameters: - `rollout_is_threshold` (main on/off switch) - `rollout_is` (weight application control) - `rollout_is_threshold_lower` - `rollout_is_level` - `rollout_is_mode` - `rollout_is_veto_threshold` - Comprehensive docstrings explaining each parameter 4. **`verl/workers/config/actor.py`** (-1 line) - Removed deprecated `tis_imp_ratio_cap` parameter 5. **`verl/workers/actor/dp_actor.py`** (~26 lines changed) - Updated to use new `rollout_is_weights` parameter - Removed legacy TIS logic 6. **`verl/workers/actor/megatron_actor.py`** (~15 lines changed) - Updated to use new `rollout_is_weights` parameter - Removed legacy TIS logic 7. **Configuration Files** (4 files updated) - `verl/trainer/config/ppo_trainer.yaml` - `verl/trainer/config/ppo_megatron_trainer.yaml` - `verl/trainer/config/_generated_ppo_trainer.yaml` - `verl/trainer/config/_generated_ppo_megatron_trainer.yaml` - Added default Rollout IS configuration section with explanatory comments 1. **`tests/trainer/ppo/test_rollout_is.py`** (289 lines) - Unit tests for `mismatch_helper.py` - Coverage for all aggregation levels (token, sequence, geometric) - Coverage for all bounding modes (truncate, clip) - Veto mechanism tests - Edge case handling (zeros, extremes, empty sequences) - Numerical stability verification - Metrics correctness validation 2. **`tests/trainer/ppo/test_rollout_is_integration.py`** (241 lines) - Integration tests with PPO training loop - End-to-end workflow validation - Batch processing tests - Configuration validation - Metrics collection verification - Compatibility with distributed training 1. **`recipe/dapo/dapo_ray_trainer.py`** (+5 lines) - Updated imports to use new framework 2. **`recipe/dapo/run_dapo_qwen2.5_32b_tis.sh`** (~42 lines changed) - Migrated from legacy TIS to new Rollout IS configuration - Updated documentation and comments 1. **`docs/examples/config.rst`** (~22 lines changed) - Updated configuration examples - Added Rollout IS section 2. **`docs/index.rst`** (+1 line) - Added link to Rollout IS migration guide --- The new design follows a clean separation of concerns: ``` ray_trainer.py (driver) └─> compute_rollout_importance_weights_and_add_to_batch() └─> mismatch_helper.compute_rollout_importance_weights() ├─> Computes IS weights (token/sequence/geometric) ├─> Applies bounding (truncate/clip) ├─> Veto mechanism for outliers ├─> Computes IS metrics └─> Computes mismatch metrics (KL, PPL) └─> Conditionally adds weights to batch (if rollout_is=True) └─> Distributes batch to workers actor workers (dp_actor, megatron_actor) └─> Receive batch with rollout_is_weights (if enabled) └─> Pass weights to policy loss function core_algos.py └─> All policy loss functions accept rollout_is_weights └─> Apply weights if provided: pg_losses *= rollout_is_weights ``` 1. **Centralized Computation**: IS weights computed once on driver, not per worker - Reduces redundant computation - Ensures consistency across workers - Simplifies debugging and metrics collection 2. **Configuration in Algorithm**: Moved from actor config to algorithm config - Better conceptual organization (algorithm-level concern, not worker-level) - Easier to manage and validate - Consistent with other algorithm parameters 3. **Two-Level Control**: - `rollout_is_threshold`: Enables/disables entire system (null = off) - `rollout_is`: Controls weight application (true = apply, false = metrics only) - Allows flexible monitoring and gradual rollout 4. **Metrics Consolidation**: Mismatch metrics computed within IS weight computation - Eliminates duplicate computation - Reduces memory overhead - Maintains metric accuracy 5. **Universal PPO Support**: Single interface for all PPO variants - Minimal code changes required - Consistent behavior across algorithms - Easy to add new variants --- **Step 1: Update your configuration file** ```yaml actor_rollout_ref: actor: tis_imp_ratio_cap: 2.0 algorithm: rollout_is_threshold: 2.0 # Use same value as old tis_imp_ratio_cap rollout_is: true rollout_is_level: token rollout_is_mode: truncate actor_rollout_ref: rollout: calculate_log_probs: true ``` **Step 2: Monitor metrics** The first time you run with the new configuration, check these metrics: - `mismatch/rollout_is_eff_sample_size`: Should be > 80% of batch size - `mismatch/rollout_is_veto_fraction`: Should be < 5% - `mismatch/rollout_is_mean`: Should be close to 1.0 **Step 3: Tune if needed** If effective sample size is too low: - Increase `rollout_is_threshold` - Try `rollout_is_mode: clip` with appropriate lower bound - Consider `rollout_is_level: sequence` for more aggressive correction For detailed guidance, see `docs/advance/rollout_is_migration.md`. Start with recommended defaults: ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: token rollout_is_mode: truncate actor_rollout_ref: rollout: calculate_log_probs: true ``` Run the example script to see it in action: ```bash bash examples/rollout_importance_sampling/run_with_rollout_is.sh ``` --- - **289 lines** of comprehensive unit tests in `test_rollout_is.py` - Covers all aggregation levels, bounding modes, and edge cases - Validates numerical stability and correctness - Fast execution (~1-2 seconds) - **241 lines** of integration tests in `test_rollout_is_integration.py` - End-to-end workflow with PPO training loop - Distributed training compatibility - Metrics collection validation - Moderate execution time (~10-20 seconds) ```bash pytest tests/trainer/ppo/test_rollout_is.py -v pytest tests/trainer/ppo/test_rollout_is_integration.py -v pytest tests/trainer/ppo/test_rollout_is.py::test_token_level_truncate -v ``` --- | Metric | Description | Ideal Range | |--------|-------------|-------------| | `rollout_is_eff_sample_size` | Effective number of samples after IS | > 80% of batch | | `rollout_is_mean` | Mean IS weight | ~1.0 | | `rollout_is_std` | Standard deviation of IS weights | Low variance | | `rollout_is_p25` | 25th percentile | ~0.8-1.0 | | `rollout_is_p50` | Median IS weight | ~1.0 | | `rollout_is_p75` | 75th percentile | ~1.0-1.2 | | `rollout_is_p95` | 95th percentile | < threshold | | `rollout_is_p99` | 99th percentile | < threshold | | `rollout_is_max` | Maximum weight | ≤ threshold | | `rollout_is_min` | Minimum weight | ≥ lower threshold (clip mode) | | `rollout_is_veto_fraction` | % sequences vetoed | < 5% | | `rollout_is_catastrophic_token_fraction` | % catastrophic tokens | < 1% | | `rollout_is_clipped_fraction` | % tokens clipped (clip mode) | Variable | | Metric | Description | What It Means | |--------|-------------|---------------| | `mismatch_kl` | Forward KL divergence | Distribution difference (rollout vs training) | | `mismatch_k3_kl` | K3 KL estimator | Stable KL estimate for small divergences | | `mismatch_training_ppl` | Training policy perplexity | Prediction difficulty of training policy | | `mismatch_rollout_ppl` | Rollout policy perplexity | Prediction difficulty of rollout policy | | `mismatch_ppl_ratio` | Ratio of training to rollout PPL | Relative prediction difficulty | | `mismatch_log_ppl_diff` | Log perplexity difference | Sequence-level PPL mismatch | | `mismatch_log_ppl_abs_diff` | Absolute log PPL difference | Magnitude of mismatch | | `mismatch_log_ppl_diff_max` | Max log PPL difference | Worst-case mismatch | | `mismatch_log_ppl_diff_min` | Min log PPL difference | Best-case mismatch | | `mismatch_training_log_ppl` | Log of training PPL | Log-scale training perplexity | | `mismatch_rollout_log_ppl` | Log of rollout PPL | Log-scale rollout perplexity | --- - Minimal overhead: ~1-2% increase in peak memory usage - Efficient log-space computation - No large intermediate tensors - Negligible impact on training speed: < 1% overhead - Centralized computation on driver (no per-worker redundancy) - Optimized tensor operations - Significant improvement in stability when distribution mismatch exists - Faster convergence in many scenarios - Reduced risk of training collapse --- > [!IMPORTANT] > This PR contains **BREAKING CHANGES** to the configuration API. - `actor_rollout_ref.actor.tis_imp_ratio_cap`: No longer supported All users of the legacy TIS implementation must update their configuration files. See the migration guide above or `docs/advance/rollout_is_migration.md` for detailed instructions. - No backward compatibility with legacy TIS - Configuration files with `tis_imp_ratio_cap` will raise validation errors - Affected recipes have been updated in this PR --- - [x] Search for similar PRs: [https://github.com/volcengine/verl/pulls?q=is%3Apr+importance+sampling](https://github.com/volcengine/verl/pulls?q=is%3Apr+importance+sampling) - [x] Format PR title as `[{modules}] {type}: {description}` (checked by CI) - **Suggested title:** `[BREAKING][rollout, trainer, algo] feat: implement comprehensive Rollout Importance Sampling framework` - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md) - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting) - [x] Add/update [documentation](https://github.com/volcengine/verl/tree/main/docs) (3 new docs, 2 updated) - [x] Add unit and integration tests (530 lines of tests) - [x] Once PR is ready for CI, send message in `ci-request` channel --- - **Blog post:** [When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda) - **Migration guide:** `docs/advance/rollout_is_migration.md` - **Examples:** `examples/rollout_importance_sampling/` - **Tests:** `tests/trainer/ppo/test_rollout_is*.py` --------- Co-authored-by: Yan Bai <bayan@nvidia.com>
…ect#3750) # Rollout Importance Sampling Framework related to verl-project#3694 ## Summary This PR introduces a comprehensive **Rollout Importance Sampling (IS)** framework to correct distribution mismatch between data-collecting (rollout) and training policies, a critical factor for ensuring stable and efficient model training in RL fine-tuning. This work is motivated by the analysis in our blog post, [When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda). If you find this implementation useful in your research, please consider citing: ```bibtex @misc{liu-li-2025, title = {When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch}, url = {https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Inference-Training-Mismatch-271211a558b7808d8b12d403fd15edda}, author = {Jiacai Liu and Yingru Li and Yuqian Fu and Jiawei Wang and Qian Liu and Yu Shen}, year = {2025}, month = {September}, } ``` --- ## Problem Statement When using different policies for rollout generation (e.g., vLLM with BFloat16) and training (e.g., FSDP with FP32), distribution mismatch occurs, leading to: - Biased gradient estimates - Training instability and collapse - Reduced sample efficiency - Poor convergence properties This framework addresses these issues through principled importance sampling correction. --- ## Key Features & Improvements ### 1. **Flexible Aggregation Levels** Three methods for calculating IS weights: - **`token`**: Per-token importance ratios - **`sequence`**: Product of per-token ratios - **`geometric`**: Geometric mean of ratios ### 2. **Advanced Bounding Modes** Two strategies to control weight variance: - **`truncate`** (TIS): Caps weights at upper threshold only, preserving gradients - **`mask`** (MIS): Zeros out weights outside bounds, more aggressive filtering ### 3. **Comprehensive Diagnostics** Detailed metrics to monitor distribution mismatch and training health: **Rollout IS Metrics** (automatically prefixed with `mismatch/`): - Health indicators: `rollout_is_eff_sample_size`, `rollout_is_mean` - Distribution statistics: `rollout_is_p25`, `rollout_is_p50`, `rollout_is_p75`, `rollout_is_p95`, `rollout_is_p99`, `rollout_is_max`, `rollout_is_min`, `rollout_is_std` - Diagnostics: `rollout_is_veto_fraction`, `rollout_is_catastrophic_token_fraction`, `rollout_is_masked_fraction` (mask mode) - Sequence-level statistics (for sequence/geometric modes): `rollout_is_seq_mean`, `rollout_is_seq_std`, `rollout_is_seq_max`, `rollout_is_seq_min`, etc. **Mismatch Metrics** (computed efficiently within IS weight computation): - KL Divergence: `mismatch_kl` (forward KL), `mismatch_k3_kl` (K3 estimator for stability) - Perplexity: `mismatch_training_ppl`, `mismatch_rollout_ppl`, `mismatch_ppl_ratio` - Log perplexity statistics: `mismatch_log_ppl_diff`, `mismatch_log_ppl_abs_diff`, `mismatch_log_ppl_diff_max`, `mismatch_log_ppl_diff_min` ### 4. **Outlier Mitigation** - **Veto mechanism**: Automatically discards samples with catastrophic importance weights (per-token ratios below threshold) - Prevents gradient corruption from extreme outliers - Configurable threshold (default: 1e-4) ### 5. **Numerical Stability** - All core computations in **log-space** to prevent underflow/overflow - Carefully designed clamping and bounding to maintain numerical precision - Safe handling of edge cases (zero probabilities, extreme ratios) ### 6. **Memory Efficiency** - Optimized computation to minimize CUDA memory usage - Efficient metric aggregation without large intermediate tensors - Suitable for large-scale distributed training ### 7. **Metrics-Only Mode** - Compute and monitor mismatch metrics **without** applying IS weights - Useful for: - Understanding distribution mismatch before intervention - Deciding whether IS correction is needed - A/B testing IS impact - Controlled by `algorithm.rollout_is` flag (independent of weight computation) ### 8. **Universal PPO Support** - Integrated with **all PPO variants**: vanilla, GSPO, GPG, Clip-Cov, KL-Cov, geo_mean - Consistent interface across different policy loss functions - Automatic weight application when enabled --- ## API and Configuration Changes ### Migration from Legacy TIS #### ❌ **Before (REMOVED)** ```yaml # Old TIS configuration - NO LONGER SUPPORTED actor_rollout_ref: actor: tis_imp_ratio_cap: 2.0 # Removed from actor config ``` The legacy implementation: - Only supported token-level truncation - No metrics tracking - Lacked numerical stability - Limited configurability #### ✅ **After (New Framework)** Configuration moved to `algorithm` section for better organization: ```yaml algorithm: # Main on/off switch: null = disabled, float = enabled rollout_is_threshold: 2.0 # Control weight application (independent of metrics computation) rollout_is: true # true = apply weights, false = metrics only # Optional: lower threshold (defaults to 1/upper if null) rollout_is_threshold_lower: null # Aggregation level: "token", "sequence", or "geometric" rollout_is_level: token # Bounding mode: "truncate" or "mask" rollout_is_mode: truncate # Veto threshold for catastrophic outliers (null = disabled) rollout_is_veto_threshold: 1e-4 # REQUIRED: Enable log probability calculation actor_rollout_ref: rollout: calculate_log_probs: true ``` ### Configuration Examples **1. Token-level truncation (recommended starting point)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: token rollout_is_mode: truncate ``` **2. Sequence-level masking (more aggressive)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: sequence rollout_is_mode: mask ``` **3. Metrics-only mode (monitoring without correction)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: false # Compute metrics but don't apply weights rollout_is_level: token rollout_is_mode: truncate ``` **Example script:** `bash examples/rollout_importance_sampling/run_with_rollout_is.sh` --- ## Code Changes Overview ### New Files (4 files, 1,442 lines) 1. **`verl/trainer/ppo/mismatch_helper.py`** (459 lines) - Core implementation of IS weight computation - Three aggregation levels: token, sequence, geometric - Two bounding modes: truncate, mask - Veto mechanism for outlier detection - Comprehensive metrics computation (IS + mismatch) - All computations in log-space for numerical stability - Memory-efficient design 2. **`docs/advance/rollout_is_migration.md`** (642 lines) - Comprehensive migration guide from legacy TIS - Detailed explanation of all configuration options - Recommended threshold ranges for each aggregation level - Troubleshooting guide and best practices - Metrics interpretation guide 3. **`examples/rollout_importance_sampling/README.md`** (242 lines) - Quick start guide with working examples - Configuration templates for common scenarios - Threshold tuning guidelines - Metrics monitoring instructions 4. **`examples/rollout_importance_sampling/run_with_rollout_is.sh`** (99 lines) - Complete working example script - Demonstrates token-level and sequence-level configurations - Ready to run with minimal modifications ### Modified Core Files (9 files) 1. **`verl/trainer/ppo/core_algos.py`** (~50 lines changed) - Removed legacy TIS logic (`tis_imp_ratio_cap`) - Added `rollout_is_weights` parameter to all policy loss functions - Unified IS weight application interface across all PPO variants: - `compute_policy_loss_vanilla` - `compute_policy_loss_gspo` - `compute_policy_loss_gpg` - `compute_policy_loss_clip_cov` - `compute_policy_loss_kl_cov` - `compute_policy_loss_geo_mean` - Special handling for `geo_mean` (sequence-level aggregation) 2. **`verl/trainer/ppo/ray_trainer.py`** (~52 lines added) - New method: `compute_rollout_importance_weights_and_add_to_batch()` - Centralized IS computation (once per batch, on driver) - Conditional weight distribution to workers based on `algorithm.rollout_is` - Metrics collection and aggregation - Integration with existing training loop 3. **`verl/trainer/config/algorithm.py`** (+18 lines) - Added 6 new Rollout IS parameters: - `rollout_is_threshold` (main on/off switch) - `rollout_is` (weight application control) - `rollout_is_threshold_lower` - `rollout_is_level` - `rollout_is_mode` - `rollout_is_veto_threshold` - Comprehensive docstrings explaining each parameter 4. **`verl/workers/config/actor.py`** (-1 line) - Removed deprecated `tis_imp_ratio_cap` parameter 5. **`verl/workers/actor/dp_actor.py`** (~26 lines changed) - Updated to use new `rollout_is_weights` parameter - Removed legacy TIS logic 6. **`verl/workers/actor/megatron_actor.py`** (~15 lines changed) - Updated to use new `rollout_is_weights` parameter - Removed legacy TIS logic 7. **Configuration Files** (4 files updated) - `verl/trainer/config/ppo_trainer.yaml` - `verl/trainer/config/ppo_megatron_trainer.yaml` - `verl/trainer/config/_generated_ppo_trainer.yaml` - `verl/trainer/config/_generated_ppo_megatron_trainer.yaml` - Added default Rollout IS configuration section with explanatory comments ### Testing (2 files, 530 lines) 1. **`tests/trainer/ppo/test_rollout_is.py`** (289 lines) - Unit tests for `mismatch_helper.py` - Coverage for all aggregation levels (token, sequence, geometric) - Coverage for all bounding modes (truncate, mask) - Veto mechanism tests - Edge case handling (zeros, extremes, empty sequences) - Numerical stability verification - Metrics correctness validation 2. **`tests/trainer/ppo/test_rollout_is_integration.py`** (241 lines) - Integration tests with PPO training loop - End-to-end workflow validation - Batch processing tests - Configuration validation - Metrics collection verification - Compatibility with distributed training ### Updated Recipes (2 files) 1. **`recipe/dapo/dapo_ray_trainer.py`** (+5 lines) - Updated imports to use new framework 2. **`recipe/dapo/run_dapo_qwen2.5_32b_tis.sh`** (~42 lines changed) - Migrated from legacy TIS to new Rollout IS configuration - Updated documentation and comments ### Documentation Updates (2 files) 1. **`docs/examples/config.rst`** (~22 lines changed) - Updated configuration examples - Added Rollout IS section 2. **`docs/index.rst`** (+1 line) - Added link to Rollout IS migration guide --- ## Implementation Highlights ### Centralized Architecture The new design follows a clean separation of concerns: ``` ray_trainer.py (driver) └─> compute_rollout_importance_weights_and_add_to_batch() └─> mismatch_helper.compute_rollout_importance_weights() ├─> Computes IS weights (token/sequence/geometric) ├─> Applies bounding (truncate/mask) ├─> Veto mechanism for outliers ├─> Computes IS metrics └─> Computes mismatch metrics (KL, PPL) └─> Conditionally adds weights to batch (if rollout_is=True) └─> Distributes batch to workers actor workers (dp_actor, megatron_actor) └─> Receive batch with rollout_is_weights (if enabled) └─> Pass weights to policy loss function core_algos.py └─> All policy loss functions accept rollout_is_weights └─> Apply weights if provided: pg_losses *= rollout_is_weights ``` ### Key Design Decisions 1. **Centralized Computation**: IS weights computed once on driver, not per worker - Reduces redundant computation - Ensures consistency across workers - Simplifies debugging and metrics collection 2. **Configuration in Algorithm**: Moved from actor config to algorithm config - Better conceptual organization (algorithm-level concern, not worker-level) - Easier to manage and validate - Consistent with other algorithm parameters 3. **Two-Level Control**: - `rollout_is_threshold`: Enables/disables entire system (null = off) - `rollout_is`: Controls weight application (true = apply, false = metrics only) - Allows flexible monitoring and gradual rollout 4. **Metrics Consolidation**: Mismatch metrics computed within IS weight computation - Eliminates duplicate computation - Reduces memory overhead - Maintains metric accuracy 5. **Universal PPO Support**: Single interface for all PPO variants - Minimal code changes required - Consistent behavior across algorithms - Easy to add new variants --- ## Migration Guide ### For Users of Legacy TIS **Step 1: Update your configuration file** ```yaml # OLD (remove this) actor_rollout_ref: actor: tis_imp_ratio_cap: 2.0 # NEW (add this) algorithm: rollout_is_threshold: 2.0 # Use same value as old tis_imp_ratio_cap rollout_is: true rollout_is_level: token rollout_is_mode: truncate # REQUIRED (add if not present) actor_rollout_ref: rollout: calculate_log_probs: true ``` **Step 2: Monitor metrics** The first time you run with the new configuration, check these metrics: - `mismatch/rollout_is_eff_sample_size`: Should be > 80% of batch size - `mismatch/rollout_is_veto_fraction`: Should be < 5% - `mismatch/rollout_is_mean`: Should be close to 1.0 **Step 3: Tune if needed** If effective sample size is too low: - Increase `rollout_is_threshold` - Try `rollout_is_mode: mask` with appropriate lower bound - Consider `rollout_is_level: sequence` for more aggressive correction For detailed guidance, see `docs/advance/rollout_is_migration.md`. ### For New Users Start with recommended defaults: ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: token rollout_is_mode: truncate actor_rollout_ref: rollout: calculate_log_probs: true ``` Run the example script to see it in action: ```bash bash examples/rollout_importance_sampling/run_with_rollout_is.sh ``` --- ## Testing ### Unit Tests - **289 lines** of comprehensive unit tests in `test_rollout_is.py` - Covers all aggregation levels, bounding modes, and edge cases - Validates numerical stability and correctness - Fast execution (~1-2 seconds) ### Integration Tests - **241 lines** of integration tests in `test_rollout_is_integration.py` - End-to-end workflow with PPO training loop - Distributed training compatibility - Metrics collection validation - Moderate execution time (~10-20 seconds) ### Running Tests ```bash # Run all Rollout IS tests pytest tests/trainer/ppo/test_rollout_is.py -v pytest tests/trainer/ppo/test_rollout_is_integration.py -v # Run specific test pytest tests/trainer/ppo/test_rollout_is.py::test_token_level_truncate -v ``` --- ## Metrics Reference ### Rollout IS Metrics (all prefixed with `mismatch/`) | Metric | Description | Ideal Range | |--------|-------------|-------------| | `rollout_is_eff_sample_size` | Effective number of samples after IS | > 80% of batch | | `rollout_is_mean` | Mean IS weight | ~1.0 | | `rollout_is_std` | Standard deviation of IS weights | Low variance | | `rollout_is_p25` | 25th percentile | ~0.8-1.0 | | `rollout_is_p50` | Median IS weight | ~1.0 | | `rollout_is_p75` | 75th percentile | ~1.0-1.2 | | `rollout_is_p95` | 95th percentile | < threshold | | `rollout_is_p99` | 99th percentile | < threshold | | `rollout_is_max` | Maximum weight | ≤ threshold | | `rollout_is_min` | Minimum weight | ≥ lower threshold (mask mode) | | `rollout_is_veto_fraction` | % sequences vetoed | < 5% | | `rollout_is_catastrophic_token_fraction` | % catastrophic tokens | < 1% | | `rollout_is_masked_fraction` | % tokens masked (mask mode) | Variable | ### Mismatch Metrics (all prefixed with `mismatch/`) | Metric | Description | What It Means | |--------|-------------|---------------| | `mismatch_kl` | Forward KL divergence | Distribution difference (rollout vs training) | | `mismatch_k3_kl` | K3 KL estimator | Stable KL estimate for small divergences | | `mismatch_training_ppl` | Training policy perplexity | Prediction difficulty of training policy | | `mismatch_rollout_ppl` | Rollout policy perplexity | Prediction difficulty of rollout policy | | `mismatch_ppl_ratio` | Ratio of training to rollout PPL | Relative prediction difficulty | | `mismatch_log_ppl_diff` | Log perplexity difference | Sequence-level PPL mismatch | | `mismatch_log_ppl_abs_diff` | Absolute log PPL difference | Magnitude of mismatch | | `mismatch_log_ppl_diff_max` | Max log PPL difference | Worst-case mismatch | | `mismatch_log_ppl_diff_min` | Min log PPL difference | Best-case mismatch | | `mismatch_training_log_ppl` | Log of training PPL | Log-scale training perplexity | | `mismatch_rollout_log_ppl` | Log of rollout PPL | Log-scale rollout perplexity | --- ## Performance Impact ### Memory - Minimal overhead: ~1-2% increase in peak memory usage - Efficient log-space computation - No large intermediate tensors ### Computation - Negligible impact on training speed: < 1% overhead - Centralized computation on driver (no per-worker redundancy) - Optimized tensor operations ### Training Stability - Significant improvement in stability when distribution mismatch exists - Faster convergence in many scenarios - Reduced risk of training collapse --- ## Breaking Changes > [!IMPORTANT] > This PR contains **BREAKING CHANGES** to the configuration API. ### Removed - `actor_rollout_ref.actor.tis_imp_ratio_cap`: No longer supported ### Migration Required All users of the legacy TIS implementation must update their configuration files. See the migration guide above or `docs/advance/rollout_is_migration.md` for detailed instructions. ### Backward Compatibility - No backward compatibility with legacy TIS - Configuration files with `tis_imp_ratio_cap` will raise validation errors - Affected recipes have been updated in this PR --- ## Pre-Submission Checklist - [x] Search for similar PRs: [https://github.com/volcengine/verl/pulls?q=is%3Apr+importance+sampling](https://github.com/volcengine/verl/pulls?q=is%3Apr+importance+sampling) - [x] Format PR title as `[{modules}] {type}: {description}` (checked by CI) - **Suggested title:** `[BREAKING][rollout, trainer, algo] feat: implement comprehensive Rollout Importance Sampling framework` - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md) - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting) - [x] Add/update [documentation](https://github.com/volcengine/verl/tree/main/docs) (3 new docs, 2 updated) - [x] Add unit and integration tests (530 lines of tests) - [x] Once PR is ready for CI, send message in `ci-request` channel --- ## References - **Blog post:** [When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda) - **Migration guide:** `docs/advance/rollout_is_migration.md` - **Examples:** `examples/rollout_importance_sampling/` - **Tests:** `tests/trainer/ppo/test_rollout_is*.py`
…ance sampling implementation (verl-project#3694) # Rollout Importance Sampling Framework ## Summary This PR introduces a comprehensive **Rollout Importance Sampling (IS)** framework to correct distribution mismatch between data-collecting (rollout) and training policies, a critical factor for ensuring stable and efficient model training in RL fine-tuning. This work is motivated by the analysis in our blog post, [When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda). If you find this implementation useful in your research, please consider citing: ```bibtex @misc{liu-li-2025, title = {When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch}, url = {https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Inference-Training-Mismatch-271211a558b7808d8b12d403fd15edda}, author = {Jiacai Liu and Yingru Li and Yuqian Fu and Jiawei Wang and Qian Liu and Yu Shen}, year = {2025}, month = {September}, } ``` --- ## Problem Statement When using different policies for rollout generation (e.g., vLLM with BFloat16) and training (e.g., FSDP with FP32), distribution mismatch occurs, leading to: - Biased gradient estimates - Training instability and collapse - Reduced sample efficiency - Poor convergence properties This framework addresses these issues through principled importance sampling correction. --- ## Key Features & Improvements ### 1. **Flexible Aggregation Levels** Three methods for calculating IS weights: - **`token`**: Per-token importance ratios - **`sequence`**: Product of per-token ratios - **`geometric`**: Geometric mean of ratios ### 2. **Advanced Bounding Modes** Two strategies to control weight variance: - **`truncate`** (TIS): Caps weights at upper threshold only, preserving gradients - **`clip`** (CIS): Zeros out weights outside bounds, more aggressive filtering ### 3. **Comprehensive Diagnostics** Detailed metrics to monitor distribution mismatch and training health: **Rollout IS Metrics** (automatically prefixed with `mismatch/`): - Health indicators: `rollout_is_eff_sample_size`, `rollout_is_mean` - Distribution statistics: `rollout_is_p25`, `rollout_is_p50`, `rollout_is_p75`, `rollout_is_p95`, `rollout_is_p99`, `rollout_is_max`, `rollout_is_min`, `rollout_is_std` - Diagnostics: `rollout_is_veto_fraction`, `rollout_is_catastrophic_token_fraction`, `rollout_is_clipped_fraction` (clip mode) - Sequence-level statistics (for sequence/geometric modes): `rollout_is_seq_mean`, `rollout_is_seq_std`, `rollout_is_seq_max`, `rollout_is_seq_min`, etc. **Mismatch Metrics** (computed efficiently within IS weight computation): - KL Divergence: `mismatch_kl` (forward KL), `mismatch_k3_kl` (K3 estimator for stability) - Perplexity: `mismatch_training_ppl`, `mismatch_rollout_ppl`, `mismatch_ppl_ratio` - Log perplexity statistics: `mismatch_log_ppl_diff`, `mismatch_log_ppl_abs_diff`, `mismatch_log_ppl_diff_max`, `mismatch_log_ppl_diff_min` ### 4. **Outlier Mitigation** - **Veto mechanism**: Automatically discards samples with catastrophic importance weights (per-token ratios below threshold) - Prevents gradient corruption from extreme outliers - Configurable threshold (default: 1e-4) ### 5. **Numerical Stability** - All core computations in **log-space** to prevent underflow/overflow - Carefully designed clipping and bounding to maintain numerical precision - Safe handling of edge cases (zero probabilities, extreme ratios) ### 6. **Memory Efficiency** - Optimized computation to minimize CUDA memory usage - Efficient metric aggregation without large intermediate tensors - Suitable for large-scale distributed training ### 7. **Metrics-Only Mode** - Compute and monitor mismatch metrics **without** applying IS weights - Useful for: - Understanding distribution mismatch before intervention - Deciding whether IS correction is needed - A/B testing IS impact - Controlled by `algorithm.rollout_is` flag (independent of weight computation) ### 8. **Universal PPO Support** - Integrated with **all PPO variants**: vanilla, GSPO, GPG, Clip-Cov, KL-Cov, geo_mean - Consistent interface across different policy loss functions - Automatic weight application when enabled --- ## API and Configuration Changes ### Migration from Legacy TIS #### ❌ **Before (REMOVED)** ```yaml # Old TIS configuration - NO LONGER SUPPORTED actor_rollout_ref: actor: tis_imp_ratio_cap: 2.0 # Removed from actor config ``` The legacy implementation: - Only supported token-level truncation - No metrics tracking - Lacked numerical stability - Limited configurability #### ✅ **After (New Framework)** Configuration moved to `algorithm` section for better organization: ```yaml algorithm: # Main on/off switch: null = disabled, float = enabled rollout_is_threshold: 2.0 # Control weight application (independent of metrics computation) rollout_is: true # true = apply weights, false = metrics only # Optional: lower threshold (defaults to 1/upper if null) rollout_is_threshold_lower: null # Aggregation level: "token", "sequence", or "geometric" rollout_is_level: token # Bounding mode: "truncate" or "clip" rollout_is_mode: truncate # Veto threshold for catastrophic outliers (null = disabled) rollout_is_veto_threshold: 1e-4 # REQUIRED: Enable log probability calculation actor_rollout_ref: rollout: calculate_log_probs: true ``` ### Configuration Examples **1. Token-level truncation (recommended starting point)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: token rollout_is_mode: truncate ``` **2. Sequence-level clipping (more aggressive)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: sequence rollout_is_mode: clip ``` **3. Metrics-only mode (monitoring without correction)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: false # Compute metrics but don't apply weights rollout_is_level: token rollout_is_mode: truncate ``` **Example script:** `bash examples/rollout_importance_sampling/run_with_rollout_is.sh` --- ## Code Changes Overview ### New Files (4 files, 1,442 lines) 1. **`verl/trainer/ppo/mismatch_helper.py`** (459 lines) - Core implementation of IS weight computation - Three aggregation levels: token, sequence, geometric - Two bounding modes: truncate, clip - Veto mechanism for outlier detection - Comprehensive metrics computation (IS + mismatch) - All computations in log-space for numerical stability - Memory-efficient design 2. **`docs/advance/rollout_is_migration.md`** (642 lines) - Comprehensive migration guide from legacy TIS - Detailed explanation of all configuration options - Recommended threshold ranges for each aggregation level - Troubleshooting guide and best practices - Metrics interpretation guide 3. **`examples/rollout_importance_sampling/README.md`** (242 lines) - Quick start guide with working examples - Configuration templates for common scenarios - Threshold tuning guidelines - Metrics monitoring instructions 4. **`examples/rollout_importance_sampling/run_with_rollout_is.sh`** (99 lines) - Complete working example script - Demonstrates token-level and sequence-level configurations - Ready to run with minimal modifications ### Modified Core Files (9 files) 1. **`verl/trainer/ppo/core_algos.py`** (~50 lines changed) - Removed legacy TIS logic (`tis_imp_ratio_cap`) - Added `rollout_is_weights` parameter to all policy loss functions - Unified IS weight application interface across all PPO variants: - `compute_policy_loss_vanilla` - `compute_policy_loss_gspo` - `compute_policy_loss_gpg` - `compute_policy_loss_clip_cov` - `compute_policy_loss_kl_cov` - `compute_policy_loss_geo_mean` - Special handling for `geo_mean` (sequence-level aggregation) 2. **`verl/trainer/ppo/ray_trainer.py`** (~52 lines added) - New method: `compute_rollout_importance_weights_and_add_to_batch()` - Centralized IS computation (once per batch, on driver) - Conditional weight distribution to workers based on `algorithm.rollout_is` - Metrics collection and aggregation - Integration with existing training loop 3. **`verl/trainer/config/algorithm.py`** (+18 lines) - Added 6 new Rollout IS parameters: - `rollout_is_threshold` (main on/off switch) - `rollout_is` (weight application control) - `rollout_is_threshold_lower` - `rollout_is_level` - `rollout_is_mode` - `rollout_is_veto_threshold` - Comprehensive docstrings explaining each parameter 4. **`verl/workers/config/actor.py`** (-1 line) - Removed deprecated `tis_imp_ratio_cap` parameter 5. **`verl/workers/actor/dp_actor.py`** (~26 lines changed) - Updated to use new `rollout_is_weights` parameter - Removed legacy TIS logic 6. **`verl/workers/actor/megatron_actor.py`** (~15 lines changed) - Updated to use new `rollout_is_weights` parameter - Removed legacy TIS logic 7. **Configuration Files** (4 files updated) - `verl/trainer/config/ppo_trainer.yaml` - `verl/trainer/config/ppo_megatron_trainer.yaml` - `verl/trainer/config/_generated_ppo_trainer.yaml` - `verl/trainer/config/_generated_ppo_megatron_trainer.yaml` - Added default Rollout IS configuration section with explanatory comments ### Testing (2 files, 530 lines) 1. **`tests/trainer/ppo/test_rollout_is.py`** (289 lines) - Unit tests for `mismatch_helper.py` - Coverage for all aggregation levels (token, sequence, geometric) - Coverage for all bounding modes (truncate, clip) - Veto mechanism tests - Edge case handling (zeros, extremes, empty sequences) - Numerical stability verification - Metrics correctness validation 2. **`tests/trainer/ppo/test_rollout_is_integration.py`** (241 lines) - Integration tests with PPO training loop - End-to-end workflow validation - Batch processing tests - Configuration validation - Metrics collection verification - Compatibility with distributed training ### Updated Recipes (2 files) 1. **`recipe/dapo/dapo_ray_trainer.py`** (+5 lines) - Updated imports to use new framework 2. **`recipe/dapo/run_dapo_qwen2.5_32b_tis.sh`** (~42 lines changed) - Migrated from legacy TIS to new Rollout IS configuration - Updated documentation and comments ### Documentation Updates (2 files) 1. **`docs/examples/config.rst`** (~22 lines changed) - Updated configuration examples - Added Rollout IS section 2. **`docs/index.rst`** (+1 line) - Added link to Rollout IS migration guide --- ## Implementation Highlights ### Centralized Architecture The new design follows a clean separation of concerns: ``` ray_trainer.py (driver) └─> compute_rollout_importance_weights_and_add_to_batch() └─> mismatch_helper.compute_rollout_importance_weights() ├─> Computes IS weights (token/sequence/geometric) ├─> Applies bounding (truncate/clip) ├─> Veto mechanism for outliers ├─> Computes IS metrics └─> Computes mismatch metrics (KL, PPL) └─> Conditionally adds weights to batch (if rollout_is=True) └─> Distributes batch to workers actor workers (dp_actor, megatron_actor) └─> Receive batch with rollout_is_weights (if enabled) └─> Pass weights to policy loss function core_algos.py └─> All policy loss functions accept rollout_is_weights └─> Apply weights if provided: pg_losses *= rollout_is_weights ``` ### Key Design Decisions 1. **Centralized Computation**: IS weights computed once on driver, not per worker - Reduces redundant computation - Ensures consistency across workers - Simplifies debugging and metrics collection 2. **Configuration in Algorithm**: Moved from actor config to algorithm config - Better conceptual organization (algorithm-level concern, not worker-level) - Easier to manage and validate - Consistent with other algorithm parameters 3. **Two-Level Control**: - `rollout_is_threshold`: Enables/disables entire system (null = off) - `rollout_is`: Controls weight application (true = apply, false = metrics only) - Allows flexible monitoring and gradual rollout 4. **Metrics Consolidation**: Mismatch metrics computed within IS weight computation - Eliminates duplicate computation - Reduces memory overhead - Maintains metric accuracy 5. **Universal PPO Support**: Single interface for all PPO variants - Minimal code changes required - Consistent behavior across algorithms - Easy to add new variants --- ## Migration Guide ### For Users of Legacy TIS **Step 1: Update your configuration file** ```yaml # OLD (remove this) actor_rollout_ref: actor: tis_imp_ratio_cap: 2.0 # NEW (add this) algorithm: rollout_is_threshold: 2.0 # Use same value as old tis_imp_ratio_cap rollout_is: true rollout_is_level: token rollout_is_mode: truncate # REQUIRED (add if not present) actor_rollout_ref: rollout: calculate_log_probs: true ``` **Step 2: Monitor metrics** The first time you run with the new configuration, check these metrics: - `mismatch/rollout_is_eff_sample_size`: Should be > 80% of batch size - `mismatch/rollout_is_veto_fraction`: Should be < 5% - `mismatch/rollout_is_mean`: Should be close to 1.0 **Step 3: Tune if needed** If effective sample size is too low: - Increase `rollout_is_threshold` - Try `rollout_is_mode: clip` with appropriate lower bound - Consider `rollout_is_level: sequence` for more aggressive correction For detailed guidance, see `docs/advance/rollout_is_migration.md`. ### For New Users Start with recommended defaults: ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: token rollout_is_mode: truncate actor_rollout_ref: rollout: calculate_log_probs: true ``` Run the example script to see it in action: ```bash bash examples/rollout_importance_sampling/run_with_rollout_is.sh ``` --- ## Testing ### Unit Tests - **289 lines** of comprehensive unit tests in `test_rollout_is.py` - Covers all aggregation levels, bounding modes, and edge cases - Validates numerical stability and correctness - Fast execution (~1-2 seconds) ### Integration Tests - **241 lines** of integration tests in `test_rollout_is_integration.py` - End-to-end workflow with PPO training loop - Distributed training compatibility - Metrics collection validation - Moderate execution time (~10-20 seconds) ### Running Tests ```bash # Run all Rollout IS tests pytest tests/trainer/ppo/test_rollout_is.py -v pytest tests/trainer/ppo/test_rollout_is_integration.py -v # Run specific test pytest tests/trainer/ppo/test_rollout_is.py::test_token_level_truncate -v ``` --- ## Metrics Reference ### Rollout IS Metrics (all prefixed with `mismatch/`) | Metric | Description | Ideal Range | |--------|-------------|-------------| | `rollout_is_eff_sample_size` | Effective number of samples after IS | > 80% of batch | | `rollout_is_mean` | Mean IS weight | ~1.0 | | `rollout_is_std` | Standard deviation of IS weights | Low variance | | `rollout_is_p25` | 25th percentile | ~0.8-1.0 | | `rollout_is_p50` | Median IS weight | ~1.0 | | `rollout_is_p75` | 75th percentile | ~1.0-1.2 | | `rollout_is_p95` | 95th percentile | < threshold | | `rollout_is_p99` | 99th percentile | < threshold | | `rollout_is_max` | Maximum weight | ≤ threshold | | `rollout_is_min` | Minimum weight | ≥ lower threshold (clip mode) | | `rollout_is_veto_fraction` | % sequences vetoed | < 5% | | `rollout_is_catastrophic_token_fraction` | % catastrophic tokens | < 1% | | `rollout_is_clipped_fraction` | % tokens clipped (clip mode) | Variable | ### Mismatch Metrics (all prefixed with `mismatch/`) | Metric | Description | What It Means | |--------|-------------|---------------| | `mismatch_kl` | Forward KL divergence | Distribution difference (rollout vs training) | | `mismatch_k3_kl` | K3 KL estimator | Stable KL estimate for small divergences | | `mismatch_training_ppl` | Training policy perplexity | Prediction difficulty of training policy | | `mismatch_rollout_ppl` | Rollout policy perplexity | Prediction difficulty of rollout policy | | `mismatch_ppl_ratio` | Ratio of training to rollout PPL | Relative prediction difficulty | | `mismatch_log_ppl_diff` | Log perplexity difference | Sequence-level PPL mismatch | | `mismatch_log_ppl_abs_diff` | Absolute log PPL difference | Magnitude of mismatch | | `mismatch_log_ppl_diff_max` | Max log PPL difference | Worst-case mismatch | | `mismatch_log_ppl_diff_min` | Min log PPL difference | Best-case mismatch | | `mismatch_training_log_ppl` | Log of training PPL | Log-scale training perplexity | | `mismatch_rollout_log_ppl` | Log of rollout PPL | Log-scale rollout perplexity | --- ## Performance Impact ### Memory - Minimal overhead: ~1-2% increase in peak memory usage - Efficient log-space computation - No large intermediate tensors ### Computation - Negligible impact on training speed: < 1% overhead - Centralized computation on driver (no per-worker redundancy) - Optimized tensor operations ### Training Stability - Significant improvement in stability when distribution mismatch exists - Faster convergence in many scenarios - Reduced risk of training collapse --- ## Breaking Changes > [!IMPORTANT] > This PR contains **BREAKING CHANGES** to the configuration API. ### Removed - `actor_rollout_ref.actor.tis_imp_ratio_cap`: No longer supported ### Migration Required All users of the legacy TIS implementation must update their configuration files. See the migration guide above or `docs/advance/rollout_is_migration.md` for detailed instructions. ### Backward Compatibility - No backward compatibility with legacy TIS - Configuration files with `tis_imp_ratio_cap` will raise validation errors - Affected recipes have been updated in this PR --- ## Pre-Submission Checklist - [x] Search for similar PRs: [https://github.com/volcengine/verl/pulls?q=is%3Apr+importance+sampling](https://github.com/volcengine/verl/pulls?q=is%3Apr+importance+sampling) - [x] Format PR title as `[{modules}] {type}: {description}` (checked by CI) - **Suggested title:** `[BREAKING][rollout, trainer, algo] feat: implement comprehensive Rollout Importance Sampling framework` - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md) - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting) - [x] Add/update [documentation](https://github.com/volcengine/verl/tree/main/docs) (3 new docs, 2 updated) - [x] Add unit and integration tests (530 lines of tests) - [x] Once PR is ready for CI, send message in `ci-request` channel --- ## References - **Blog post:** [When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda) - **Migration guide:** `docs/advance/rollout_is_migration.md` - **Examples:** `examples/rollout_importance_sampling/` - **Tests:** `tests/trainer/ppo/test_rollout_is*.py` --------- Co-authored-by: Yan Bai <bayan@nvidia.com>
…ect#3750) # Rollout Importance Sampling Framework related to verl-project#3694 ## Summary This PR introduces a comprehensive **Rollout Importance Sampling (IS)** framework to correct distribution mismatch between data-collecting (rollout) and training policies, a critical factor for ensuring stable and efficient model training in RL fine-tuning. This work is motivated by the analysis in our blog post, [When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda). If you find this implementation useful in your research, please consider citing: ```bibtex @misc{liu-li-2025, title = {When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch}, url = {https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Inference-Training-Mismatch-271211a558b7808d8b12d403fd15edda}, author = {Jiacai Liu and Yingru Li and Yuqian Fu and Jiawei Wang and Qian Liu and Yu Shen}, year = {2025}, month = {September}, } ``` --- ## Problem Statement When using different policies for rollout generation (e.g., vLLM with BFloat16) and training (e.g., FSDP with FP32), distribution mismatch occurs, leading to: - Biased gradient estimates - Training instability and collapse - Reduced sample efficiency - Poor convergence properties This framework addresses these issues through principled importance sampling correction. --- ## Key Features & Improvements ### 1. **Flexible Aggregation Levels** Three methods for calculating IS weights: - **`token`**: Per-token importance ratios - **`sequence`**: Product of per-token ratios - **`geometric`**: Geometric mean of ratios ### 2. **Advanced Bounding Modes** Two strategies to control weight variance: - **`truncate`** (TIS): Caps weights at upper threshold only, preserving gradients - **`mask`** (MIS): Zeros out weights outside bounds, more aggressive filtering ### 3. **Comprehensive Diagnostics** Detailed metrics to monitor distribution mismatch and training health: **Rollout IS Metrics** (automatically prefixed with `mismatch/`): - Health indicators: `rollout_is_eff_sample_size`, `rollout_is_mean` - Distribution statistics: `rollout_is_p25`, `rollout_is_p50`, `rollout_is_p75`, `rollout_is_p95`, `rollout_is_p99`, `rollout_is_max`, `rollout_is_min`, `rollout_is_std` - Diagnostics: `rollout_is_veto_fraction`, `rollout_is_catastrophic_token_fraction`, `rollout_is_masked_fraction` (mask mode) - Sequence-level statistics (for sequence/geometric modes): `rollout_is_seq_mean`, `rollout_is_seq_std`, `rollout_is_seq_max`, `rollout_is_seq_min`, etc. **Mismatch Metrics** (computed efficiently within IS weight computation): - KL Divergence: `mismatch_kl` (forward KL), `mismatch_k3_kl` (K3 estimator for stability) - Perplexity: `mismatch_training_ppl`, `mismatch_rollout_ppl`, `mismatch_ppl_ratio` - Log perplexity statistics: `mismatch_log_ppl_diff`, `mismatch_log_ppl_abs_diff`, `mismatch_log_ppl_diff_max`, `mismatch_log_ppl_diff_min` ### 4. **Outlier Mitigation** - **Veto mechanism**: Automatically discards samples with catastrophic importance weights (per-token ratios below threshold) - Prevents gradient corruption from extreme outliers - Configurable threshold (default: 1e-4) ### 5. **Numerical Stability** - All core computations in **log-space** to prevent underflow/overflow - Carefully designed clamping and bounding to maintain numerical precision - Safe handling of edge cases (zero probabilities, extreme ratios) ### 6. **Memory Efficiency** - Optimized computation to minimize CUDA memory usage - Efficient metric aggregation without large intermediate tensors - Suitable for large-scale distributed training ### 7. **Metrics-Only Mode** - Compute and monitor mismatch metrics **without** applying IS weights - Useful for: - Understanding distribution mismatch before intervention - Deciding whether IS correction is needed - A/B testing IS impact - Controlled by `algorithm.rollout_is` flag (independent of weight computation) ### 8. **Universal PPO Support** - Integrated with **all PPO variants**: vanilla, GSPO, GPG, Clip-Cov, KL-Cov, geo_mean - Consistent interface across different policy loss functions - Automatic weight application when enabled --- ## API and Configuration Changes ### Migration from Legacy TIS #### ❌ **Before (REMOVED)** ```yaml # Old TIS configuration - NO LONGER SUPPORTED actor_rollout_ref: actor: tis_imp_ratio_cap: 2.0 # Removed from actor config ``` The legacy implementation: - Only supported token-level truncation - No metrics tracking - Lacked numerical stability - Limited configurability #### ✅ **After (New Framework)** Configuration moved to `algorithm` section for better organization: ```yaml algorithm: # Main on/off switch: null = disabled, float = enabled rollout_is_threshold: 2.0 # Control weight application (independent of metrics computation) rollout_is: true # true = apply weights, false = metrics only # Optional: lower threshold (defaults to 1/upper if null) rollout_is_threshold_lower: null # Aggregation level: "token", "sequence", or "geometric" rollout_is_level: token # Bounding mode: "truncate" or "mask" rollout_is_mode: truncate # Veto threshold for catastrophic outliers (null = disabled) rollout_is_veto_threshold: 1e-4 # REQUIRED: Enable log probability calculation actor_rollout_ref: rollout: calculate_log_probs: true ``` ### Configuration Examples **1. Token-level truncation (recommended starting point)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: token rollout_is_mode: truncate ``` **2. Sequence-level masking (more aggressive)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: sequence rollout_is_mode: mask ``` **3. Metrics-only mode (monitoring without correction)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: false # Compute metrics but don't apply weights rollout_is_level: token rollout_is_mode: truncate ``` **Example script:** `bash examples/rollout_importance_sampling/run_with_rollout_is.sh` --- ## Code Changes Overview ### New Files (4 files, 1,442 lines) 1. **`verl/trainer/ppo/mismatch_helper.py`** (459 lines) - Core implementation of IS weight computation - Three aggregation levels: token, sequence, geometric - Two bounding modes: truncate, mask - Veto mechanism for outlier detection - Comprehensive metrics computation (IS + mismatch) - All computations in log-space for numerical stability - Memory-efficient design 2. **`docs/advance/rollout_is_migration.md`** (642 lines) - Comprehensive migration guide from legacy TIS - Detailed explanation of all configuration options - Recommended threshold ranges for each aggregation level - Troubleshooting guide and best practices - Metrics interpretation guide 3. **`examples/rollout_importance_sampling/README.md`** (242 lines) - Quick start guide with working examples - Configuration templates for common scenarios - Threshold tuning guidelines - Metrics monitoring instructions 4. **`examples/rollout_importance_sampling/run_with_rollout_is.sh`** (99 lines) - Complete working example script - Demonstrates token-level and sequence-level configurations - Ready to run with minimal modifications ### Modified Core Files (9 files) 1. **`verl/trainer/ppo/core_algos.py`** (~50 lines changed) - Removed legacy TIS logic (`tis_imp_ratio_cap`) - Added `rollout_is_weights` parameter to all policy loss functions - Unified IS weight application interface across all PPO variants: - `compute_policy_loss_vanilla` - `compute_policy_loss_gspo` - `compute_policy_loss_gpg` - `compute_policy_loss_clip_cov` - `compute_policy_loss_kl_cov` - `compute_policy_loss_geo_mean` - Special handling for `geo_mean` (sequence-level aggregation) 2. **`verl/trainer/ppo/ray_trainer.py`** (~52 lines added) - New method: `compute_rollout_importance_weights_and_add_to_batch()` - Centralized IS computation (once per batch, on driver) - Conditional weight distribution to workers based on `algorithm.rollout_is` - Metrics collection and aggregation - Integration with existing training loop 3. **`verl/trainer/config/algorithm.py`** (+18 lines) - Added 6 new Rollout IS parameters: - `rollout_is_threshold` (main on/off switch) - `rollout_is` (weight application control) - `rollout_is_threshold_lower` - `rollout_is_level` - `rollout_is_mode` - `rollout_is_veto_threshold` - Comprehensive docstrings explaining each parameter 4. **`verl/workers/config/actor.py`** (-1 line) - Removed deprecated `tis_imp_ratio_cap` parameter 5. **`verl/workers/actor/dp_actor.py`** (~26 lines changed) - Updated to use new `rollout_is_weights` parameter - Removed legacy TIS logic 6. **`verl/workers/actor/megatron_actor.py`** (~15 lines changed) - Updated to use new `rollout_is_weights` parameter - Removed legacy TIS logic 7. **Configuration Files** (4 files updated) - `verl/trainer/config/ppo_trainer.yaml` - `verl/trainer/config/ppo_megatron_trainer.yaml` - `verl/trainer/config/_generated_ppo_trainer.yaml` - `verl/trainer/config/_generated_ppo_megatron_trainer.yaml` - Added default Rollout IS configuration section with explanatory comments ### Testing (2 files, 530 lines) 1. **`tests/trainer/ppo/test_rollout_is.py`** (289 lines) - Unit tests for `mismatch_helper.py` - Coverage for all aggregation levels (token, sequence, geometric) - Coverage for all bounding modes (truncate, mask) - Veto mechanism tests - Edge case handling (zeros, extremes, empty sequences) - Numerical stability verification - Metrics correctness validation 2. **`tests/trainer/ppo/test_rollout_is_integration.py`** (241 lines) - Integration tests with PPO training loop - End-to-end workflow validation - Batch processing tests - Configuration validation - Metrics collection verification - Compatibility with distributed training ### Updated Recipes (2 files) 1. **`recipe/dapo/dapo_ray_trainer.py`** (+5 lines) - Updated imports to use new framework 2. **`recipe/dapo/run_dapo_qwen2.5_32b_tis.sh`** (~42 lines changed) - Migrated from legacy TIS to new Rollout IS configuration - Updated documentation and comments ### Documentation Updates (2 files) 1. **`docs/examples/config.rst`** (~22 lines changed) - Updated configuration examples - Added Rollout IS section 2. **`docs/index.rst`** (+1 line) - Added link to Rollout IS migration guide --- ## Implementation Highlights ### Centralized Architecture The new design follows a clean separation of concerns: ``` ray_trainer.py (driver) └─> compute_rollout_importance_weights_and_add_to_batch() └─> mismatch_helper.compute_rollout_importance_weights() ├─> Computes IS weights (token/sequence/geometric) ├─> Applies bounding (truncate/mask) ├─> Veto mechanism for outliers ├─> Computes IS metrics └─> Computes mismatch metrics (KL, PPL) └─> Conditionally adds weights to batch (if rollout_is=True) └─> Distributes batch to workers actor workers (dp_actor, megatron_actor) └─> Receive batch with rollout_is_weights (if enabled) └─> Pass weights to policy loss function core_algos.py └─> All policy loss functions accept rollout_is_weights └─> Apply weights if provided: pg_losses *= rollout_is_weights ``` ### Key Design Decisions 1. **Centralized Computation**: IS weights computed once on driver, not per worker - Reduces redundant computation - Ensures consistency across workers - Simplifies debugging and metrics collection 2. **Configuration in Algorithm**: Moved from actor config to algorithm config - Better conceptual organization (algorithm-level concern, not worker-level) - Easier to manage and validate - Consistent with other algorithm parameters 3. **Two-Level Control**: - `rollout_is_threshold`: Enables/disables entire system (null = off) - `rollout_is`: Controls weight application (true = apply, false = metrics only) - Allows flexible monitoring and gradual rollout 4. **Metrics Consolidation**: Mismatch metrics computed within IS weight computation - Eliminates duplicate computation - Reduces memory overhead - Maintains metric accuracy 5. **Universal PPO Support**: Single interface for all PPO variants - Minimal code changes required - Consistent behavior across algorithms - Easy to add new variants --- ## Migration Guide ### For Users of Legacy TIS **Step 1: Update your configuration file** ```yaml # OLD (remove this) actor_rollout_ref: actor: tis_imp_ratio_cap: 2.0 # NEW (add this) algorithm: rollout_is_threshold: 2.0 # Use same value as old tis_imp_ratio_cap rollout_is: true rollout_is_level: token rollout_is_mode: truncate # REQUIRED (add if not present) actor_rollout_ref: rollout: calculate_log_probs: true ``` **Step 2: Monitor metrics** The first time you run with the new configuration, check these metrics: - `mismatch/rollout_is_eff_sample_size`: Should be > 80% of batch size - `mismatch/rollout_is_veto_fraction`: Should be < 5% - `mismatch/rollout_is_mean`: Should be close to 1.0 **Step 3: Tune if needed** If effective sample size is too low: - Increase `rollout_is_threshold` - Try `rollout_is_mode: mask` with appropriate lower bound - Consider `rollout_is_level: sequence` for more aggressive correction For detailed guidance, see `docs/advance/rollout_is_migration.md`. ### For New Users Start with recommended defaults: ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: token rollout_is_mode: truncate actor_rollout_ref: rollout: calculate_log_probs: true ``` Run the example script to see it in action: ```bash bash examples/rollout_importance_sampling/run_with_rollout_is.sh ``` --- ## Testing ### Unit Tests - **289 lines** of comprehensive unit tests in `test_rollout_is.py` - Covers all aggregation levels, bounding modes, and edge cases - Validates numerical stability and correctness - Fast execution (~1-2 seconds) ### Integration Tests - **241 lines** of integration tests in `test_rollout_is_integration.py` - End-to-end workflow with PPO training loop - Distributed training compatibility - Metrics collection validation - Moderate execution time (~10-20 seconds) ### Running Tests ```bash # Run all Rollout IS tests pytest tests/trainer/ppo/test_rollout_is.py -v pytest tests/trainer/ppo/test_rollout_is_integration.py -v # Run specific test pytest tests/trainer/ppo/test_rollout_is.py::test_token_level_truncate -v ``` --- ## Metrics Reference ### Rollout IS Metrics (all prefixed with `mismatch/`) | Metric | Description | Ideal Range | |--------|-------------|-------------| | `rollout_is_eff_sample_size` | Effective number of samples after IS | > 80% of batch | | `rollout_is_mean` | Mean IS weight | ~1.0 | | `rollout_is_std` | Standard deviation of IS weights | Low variance | | `rollout_is_p25` | 25th percentile | ~0.8-1.0 | | `rollout_is_p50` | Median IS weight | ~1.0 | | `rollout_is_p75` | 75th percentile | ~1.0-1.2 | | `rollout_is_p95` | 95th percentile | < threshold | | `rollout_is_p99` | 99th percentile | < threshold | | `rollout_is_max` | Maximum weight | ≤ threshold | | `rollout_is_min` | Minimum weight | ≥ lower threshold (mask mode) | | `rollout_is_veto_fraction` | % sequences vetoed | < 5% | | `rollout_is_catastrophic_token_fraction` | % catastrophic tokens | < 1% | | `rollout_is_masked_fraction` | % tokens masked (mask mode) | Variable | ### Mismatch Metrics (all prefixed with `mismatch/`) | Metric | Description | What It Means | |--------|-------------|---------------| | `mismatch_kl` | Forward KL divergence | Distribution difference (rollout vs training) | | `mismatch_k3_kl` | K3 KL estimator | Stable KL estimate for small divergences | | `mismatch_training_ppl` | Training policy perplexity | Prediction difficulty of training policy | | `mismatch_rollout_ppl` | Rollout policy perplexity | Prediction difficulty of rollout policy | | `mismatch_ppl_ratio` | Ratio of training to rollout PPL | Relative prediction difficulty | | `mismatch_log_ppl_diff` | Log perplexity difference | Sequence-level PPL mismatch | | `mismatch_log_ppl_abs_diff` | Absolute log PPL difference | Magnitude of mismatch | | `mismatch_log_ppl_diff_max` | Max log PPL difference | Worst-case mismatch | | `mismatch_log_ppl_diff_min` | Min log PPL difference | Best-case mismatch | | `mismatch_training_log_ppl` | Log of training PPL | Log-scale training perplexity | | `mismatch_rollout_log_ppl` | Log of rollout PPL | Log-scale rollout perplexity | --- ## Performance Impact ### Memory - Minimal overhead: ~1-2% increase in peak memory usage - Efficient log-space computation - No large intermediate tensors ### Computation - Negligible impact on training speed: < 1% overhead - Centralized computation on driver (no per-worker redundancy) - Optimized tensor operations ### Training Stability - Significant improvement in stability when distribution mismatch exists - Faster convergence in many scenarios - Reduced risk of training collapse --- ## Breaking Changes > [!IMPORTANT] > This PR contains **BREAKING CHANGES** to the configuration API. ### Removed - `actor_rollout_ref.actor.tis_imp_ratio_cap`: No longer supported ### Migration Required All users of the legacy TIS implementation must update their configuration files. See the migration guide above or `docs/advance/rollout_is_migration.md` for detailed instructions. ### Backward Compatibility - No backward compatibility with legacy TIS - Configuration files with `tis_imp_ratio_cap` will raise validation errors - Affected recipes have been updated in this PR --- ## Pre-Submission Checklist - [x] Search for similar PRs: [https://github.com/volcengine/verl/pulls?q=is%3Apr+importance+sampling](https://github.com/volcengine/verl/pulls?q=is%3Apr+importance+sampling) - [x] Format PR title as `[{modules}] {type}: {description}` (checked by CI) - **Suggested title:** `[BREAKING][rollout, trainer, algo] feat: implement comprehensive Rollout Importance Sampling framework` - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md) - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting) - [x] Add/update [documentation](https://github.com/volcengine/verl/tree/main/docs) (3 new docs, 2 updated) - [x] Add unit and integration tests (530 lines of tests) - [x] Once PR is ready for CI, send message in `ci-request` channel --- ## References - **Blog post:** [When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda) - **Migration guide:** `docs/advance/rollout_is_migration.md` - **Examples:** `examples/rollout_importance_sampling/` - **Tests:** `tests/trainer/ppo/test_rollout_is*.py`
…ance sampling implementation (verl-project#3694) # Rollout Importance Sampling Framework ## Summary This PR introduces a comprehensive **Rollout Importance Sampling (IS)** framework to correct distribution mismatch between data-collecting (rollout) and training policies, a critical factor for ensuring stable and efficient model training in RL fine-tuning. This work is motivated by the analysis in our blog post, [When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda). If you find this implementation useful in your research, please consider citing: ```bibtex @misc{liu-li-2025, title = {When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch}, url = {https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Inference-Training-Mismatch-271211a558b7808d8b12d403fd15edda}, author = {Jiacai Liu and Yingru Li and Yuqian Fu and Jiawei Wang and Qian Liu and Yu Shen}, year = {2025}, month = {September}, } ``` --- ## Problem Statement When using different policies for rollout generation (e.g., vLLM with BFloat16) and training (e.g., FSDP with FP32), distribution mismatch occurs, leading to: - Biased gradient estimates - Training instability and collapse - Reduced sample efficiency - Poor convergence properties This framework addresses these issues through principled importance sampling correction. --- ## Key Features & Improvements ### 1. **Flexible Aggregation Levels** Three methods for calculating IS weights: - **`token`**: Per-token importance ratios - **`sequence`**: Product of per-token ratios - **`geometric`**: Geometric mean of ratios ### 2. **Advanced Bounding Modes** Two strategies to control weight variance: - **`truncate`** (TIS): Caps weights at upper threshold only, preserving gradients - **`clip`** (CIS): Zeros out weights outside bounds, more aggressive filtering ### 3. **Comprehensive Diagnostics** Detailed metrics to monitor distribution mismatch and training health: **Rollout IS Metrics** (automatically prefixed with `mismatch/`): - Health indicators: `rollout_is_eff_sample_size`, `rollout_is_mean` - Distribution statistics: `rollout_is_p25`, `rollout_is_p50`, `rollout_is_p75`, `rollout_is_p95`, `rollout_is_p99`, `rollout_is_max`, `rollout_is_min`, `rollout_is_std` - Diagnostics: `rollout_is_veto_fraction`, `rollout_is_catastrophic_token_fraction`, `rollout_is_clipped_fraction` (clip mode) - Sequence-level statistics (for sequence/geometric modes): `rollout_is_seq_mean`, `rollout_is_seq_std`, `rollout_is_seq_max`, `rollout_is_seq_min`, etc. **Mismatch Metrics** (computed efficiently within IS weight computation): - KL Divergence: `mismatch_kl` (forward KL), `mismatch_k3_kl` (K3 estimator for stability) - Perplexity: `mismatch_training_ppl`, `mismatch_rollout_ppl`, `mismatch_ppl_ratio` - Log perplexity statistics: `mismatch_log_ppl_diff`, `mismatch_log_ppl_abs_diff`, `mismatch_log_ppl_diff_max`, `mismatch_log_ppl_diff_min` ### 4. **Outlier Mitigation** - **Veto mechanism**: Automatically discards samples with catastrophic importance weights (per-token ratios below threshold) - Prevents gradient corruption from extreme outliers - Configurable threshold (default: 1e-4) ### 5. **Numerical Stability** - All core computations in **log-space** to prevent underflow/overflow - Carefully designed clipping and bounding to maintain numerical precision - Safe handling of edge cases (zero probabilities, extreme ratios) ### 6. **Memory Efficiency** - Optimized computation to minimize CUDA memory usage - Efficient metric aggregation without large intermediate tensors - Suitable for large-scale distributed training ### 7. **Metrics-Only Mode** - Compute and monitor mismatch metrics **without** applying IS weights - Useful for: - Understanding distribution mismatch before intervention - Deciding whether IS correction is needed - A/B testing IS impact - Controlled by `algorithm.rollout_is` flag (independent of weight computation) ### 8. **Universal PPO Support** - Integrated with **all PPO variants**: vanilla, GSPO, GPG, Clip-Cov, KL-Cov, geo_mean - Consistent interface across different policy loss functions - Automatic weight application when enabled --- ## API and Configuration Changes ### Migration from Legacy TIS #### ❌ **Before (REMOVED)** ```yaml # Old TIS configuration - NO LONGER SUPPORTED actor_rollout_ref: actor: tis_imp_ratio_cap: 2.0 # Removed from actor config ``` The legacy implementation: - Only supported token-level truncation - No metrics tracking - Lacked numerical stability - Limited configurability #### ✅ **After (New Framework)** Configuration moved to `algorithm` section for better organization: ```yaml algorithm: # Main on/off switch: null = disabled, float = enabled rollout_is_threshold: 2.0 # Control weight application (independent of metrics computation) rollout_is: true # true = apply weights, false = metrics only # Optional: lower threshold (defaults to 1/upper if null) rollout_is_threshold_lower: null # Aggregation level: "token", "sequence", or "geometric" rollout_is_level: token # Bounding mode: "truncate" or "clip" rollout_is_mode: truncate # Veto threshold for catastrophic outliers (null = disabled) rollout_is_veto_threshold: 1e-4 # REQUIRED: Enable log probability calculation actor_rollout_ref: rollout: calculate_log_probs: true ``` ### Configuration Examples **1. Token-level truncation (recommended starting point)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: token rollout_is_mode: truncate ``` **2. Sequence-level clipping (more aggressive)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: sequence rollout_is_mode: clip ``` **3. Metrics-only mode (monitoring without correction)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: false # Compute metrics but don't apply weights rollout_is_level: token rollout_is_mode: truncate ``` **Example script:** `bash examples/rollout_importance_sampling/run_with_rollout_is.sh` --- ## Code Changes Overview ### New Files (4 files, 1,442 lines) 1. **`verl/trainer/ppo/mismatch_helper.py`** (459 lines) - Core implementation of IS weight computation - Three aggregation levels: token, sequence, geometric - Two bounding modes: truncate, clip - Veto mechanism for outlier detection - Comprehensive metrics computation (IS + mismatch) - All computations in log-space for numerical stability - Memory-efficient design 2. **`docs/advance/rollout_is_migration.md`** (642 lines) - Comprehensive migration guide from legacy TIS - Detailed explanation of all configuration options - Recommended threshold ranges for each aggregation level - Troubleshooting guide and best practices - Metrics interpretation guide 3. **`examples/rollout_importance_sampling/README.md`** (242 lines) - Quick start guide with working examples - Configuration templates for common scenarios - Threshold tuning guidelines - Metrics monitoring instructions 4. **`examples/rollout_importance_sampling/run_with_rollout_is.sh`** (99 lines) - Complete working example script - Demonstrates token-level and sequence-level configurations - Ready to run with minimal modifications ### Modified Core Files (9 files) 1. **`verl/trainer/ppo/core_algos.py`** (~50 lines changed) - Removed legacy TIS logic (`tis_imp_ratio_cap`) - Added `rollout_is_weights` parameter to all policy loss functions - Unified IS weight application interface across all PPO variants: - `compute_policy_loss_vanilla` - `compute_policy_loss_gspo` - `compute_policy_loss_gpg` - `compute_policy_loss_clip_cov` - `compute_policy_loss_kl_cov` - `compute_policy_loss_geo_mean` - Special handling for `geo_mean` (sequence-level aggregation) 2. **`verl/trainer/ppo/ray_trainer.py`** (~52 lines added) - New method: `compute_rollout_importance_weights_and_add_to_batch()` - Centralized IS computation (once per batch, on driver) - Conditional weight distribution to workers based on `algorithm.rollout_is` - Metrics collection and aggregation - Integration with existing training loop 3. **`verl/trainer/config/algorithm.py`** (+18 lines) - Added 6 new Rollout IS parameters: - `rollout_is_threshold` (main on/off switch) - `rollout_is` (weight application control) - `rollout_is_threshold_lower` - `rollout_is_level` - `rollout_is_mode` - `rollout_is_veto_threshold` - Comprehensive docstrings explaining each parameter 4. **`verl/workers/config/actor.py`** (-1 line) - Removed deprecated `tis_imp_ratio_cap` parameter 5. **`verl/workers/actor/dp_actor.py`** (~26 lines changed) - Updated to use new `rollout_is_weights` parameter - Removed legacy TIS logic 6. **`verl/workers/actor/megatron_actor.py`** (~15 lines changed) - Updated to use new `rollout_is_weights` parameter - Removed legacy TIS logic 7. **Configuration Files** (4 files updated) - `verl/trainer/config/ppo_trainer.yaml` - `verl/trainer/config/ppo_megatron_trainer.yaml` - `verl/trainer/config/_generated_ppo_trainer.yaml` - `verl/trainer/config/_generated_ppo_megatron_trainer.yaml` - Added default Rollout IS configuration section with explanatory comments ### Testing (2 files, 530 lines) 1. **`tests/trainer/ppo/test_rollout_is.py`** (289 lines) - Unit tests for `mismatch_helper.py` - Coverage for all aggregation levels (token, sequence, geometric) - Coverage for all bounding modes (truncate, clip) - Veto mechanism tests - Edge case handling (zeros, extremes, empty sequences) - Numerical stability verification - Metrics correctness validation 2. **`tests/trainer/ppo/test_rollout_is_integration.py`** (241 lines) - Integration tests with PPO training loop - End-to-end workflow validation - Batch processing tests - Configuration validation - Metrics collection verification - Compatibility with distributed training ### Updated Recipes (2 files) 1. **`recipe/dapo/dapo_ray_trainer.py`** (+5 lines) - Updated imports to use new framework 2. **`recipe/dapo/run_dapo_qwen2.5_32b_tis.sh`** (~42 lines changed) - Migrated from legacy TIS to new Rollout IS configuration - Updated documentation and comments ### Documentation Updates (2 files) 1. **`docs/examples/config.rst`** (~22 lines changed) - Updated configuration examples - Added Rollout IS section 2. **`docs/index.rst`** (+1 line) - Added link to Rollout IS migration guide --- ## Implementation Highlights ### Centralized Architecture The new design follows a clean separation of concerns: ``` ray_trainer.py (driver) └─> compute_rollout_importance_weights_and_add_to_batch() └─> mismatch_helper.compute_rollout_importance_weights() ├─> Computes IS weights (token/sequence/geometric) ├─> Applies bounding (truncate/clip) ├─> Veto mechanism for outliers ├─> Computes IS metrics └─> Computes mismatch metrics (KL, PPL) └─> Conditionally adds weights to batch (if rollout_is=True) └─> Distributes batch to workers actor workers (dp_actor, megatron_actor) └─> Receive batch with rollout_is_weights (if enabled) └─> Pass weights to policy loss function core_algos.py └─> All policy loss functions accept rollout_is_weights └─> Apply weights if provided: pg_losses *= rollout_is_weights ``` ### Key Design Decisions 1. **Centralized Computation**: IS weights computed once on driver, not per worker - Reduces redundant computation - Ensures consistency across workers - Simplifies debugging and metrics collection 2. **Configuration in Algorithm**: Moved from actor config to algorithm config - Better conceptual organization (algorithm-level concern, not worker-level) - Easier to manage and validate - Consistent with other algorithm parameters 3. **Two-Level Control**: - `rollout_is_threshold`: Enables/disables entire system (null = off) - `rollout_is`: Controls weight application (true = apply, false = metrics only) - Allows flexible monitoring and gradual rollout 4. **Metrics Consolidation**: Mismatch metrics computed within IS weight computation - Eliminates duplicate computation - Reduces memory overhead - Maintains metric accuracy 5. **Universal PPO Support**: Single interface for all PPO variants - Minimal code changes required - Consistent behavior across algorithms - Easy to add new variants --- ## Migration Guide ### For Users of Legacy TIS **Step 1: Update your configuration file** ```yaml # OLD (remove this) actor_rollout_ref: actor: tis_imp_ratio_cap: 2.0 # NEW (add this) algorithm: rollout_is_threshold: 2.0 # Use same value as old tis_imp_ratio_cap rollout_is: true rollout_is_level: token rollout_is_mode: truncate # REQUIRED (add if not present) actor_rollout_ref: rollout: calculate_log_probs: true ``` **Step 2: Monitor metrics** The first time you run with the new configuration, check these metrics: - `mismatch/rollout_is_eff_sample_size`: Should be > 80% of batch size - `mismatch/rollout_is_veto_fraction`: Should be < 5% - `mismatch/rollout_is_mean`: Should be close to 1.0 **Step 3: Tune if needed** If effective sample size is too low: - Increase `rollout_is_threshold` - Try `rollout_is_mode: clip` with appropriate lower bound - Consider `rollout_is_level: sequence` for more aggressive correction For detailed guidance, see `docs/advance/rollout_is_migration.md`. ### For New Users Start with recommended defaults: ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: token rollout_is_mode: truncate actor_rollout_ref: rollout: calculate_log_probs: true ``` Run the example script to see it in action: ```bash bash examples/rollout_importance_sampling/run_with_rollout_is.sh ``` --- ## Testing ### Unit Tests - **289 lines** of comprehensive unit tests in `test_rollout_is.py` - Covers all aggregation levels, bounding modes, and edge cases - Validates numerical stability and correctness - Fast execution (~1-2 seconds) ### Integration Tests - **241 lines** of integration tests in `test_rollout_is_integration.py` - End-to-end workflow with PPO training loop - Distributed training compatibility - Metrics collection validation - Moderate execution time (~10-20 seconds) ### Running Tests ```bash # Run all Rollout IS tests pytest tests/trainer/ppo/test_rollout_is.py -v pytest tests/trainer/ppo/test_rollout_is_integration.py -v # Run specific test pytest tests/trainer/ppo/test_rollout_is.py::test_token_level_truncate -v ``` --- ## Metrics Reference ### Rollout IS Metrics (all prefixed with `mismatch/`) | Metric | Description | Ideal Range | |--------|-------------|-------------| | `rollout_is_eff_sample_size` | Effective number of samples after IS | > 80% of batch | | `rollout_is_mean` | Mean IS weight | ~1.0 | | `rollout_is_std` | Standard deviation of IS weights | Low variance | | `rollout_is_p25` | 25th percentile | ~0.8-1.0 | | `rollout_is_p50` | Median IS weight | ~1.0 | | `rollout_is_p75` | 75th percentile | ~1.0-1.2 | | `rollout_is_p95` | 95th percentile | < threshold | | `rollout_is_p99` | 99th percentile | < threshold | | `rollout_is_max` | Maximum weight | ≤ threshold | | `rollout_is_min` | Minimum weight | ≥ lower threshold (clip mode) | | `rollout_is_veto_fraction` | % sequences vetoed | < 5% | | `rollout_is_catastrophic_token_fraction` | % catastrophic tokens | < 1% | | `rollout_is_clipped_fraction` | % tokens clipped (clip mode) | Variable | ### Mismatch Metrics (all prefixed with `mismatch/`) | Metric | Description | What It Means | |--------|-------------|---------------| | `mismatch_kl` | Forward KL divergence | Distribution difference (rollout vs training) | | `mismatch_k3_kl` | K3 KL estimator | Stable KL estimate for small divergences | | `mismatch_training_ppl` | Training policy perplexity | Prediction difficulty of training policy | | `mismatch_rollout_ppl` | Rollout policy perplexity | Prediction difficulty of rollout policy | | `mismatch_ppl_ratio` | Ratio of training to rollout PPL | Relative prediction difficulty | | `mismatch_log_ppl_diff` | Log perplexity difference | Sequence-level PPL mismatch | | `mismatch_log_ppl_abs_diff` | Absolute log PPL difference | Magnitude of mismatch | | `mismatch_log_ppl_diff_max` | Max log PPL difference | Worst-case mismatch | | `mismatch_log_ppl_diff_min` | Min log PPL difference | Best-case mismatch | | `mismatch_training_log_ppl` | Log of training PPL | Log-scale training perplexity | | `mismatch_rollout_log_ppl` | Log of rollout PPL | Log-scale rollout perplexity | --- ## Performance Impact ### Memory - Minimal overhead: ~1-2% increase in peak memory usage - Efficient log-space computation - No large intermediate tensors ### Computation - Negligible impact on training speed: < 1% overhead - Centralized computation on driver (no per-worker redundancy) - Optimized tensor operations ### Training Stability - Significant improvement in stability when distribution mismatch exists - Faster convergence in many scenarios - Reduced risk of training collapse --- ## Breaking Changes > [!IMPORTANT] > This PR contains **BREAKING CHANGES** to the configuration API. ### Removed - `actor_rollout_ref.actor.tis_imp_ratio_cap`: No longer supported ### Migration Required All users of the legacy TIS implementation must update their configuration files. See the migration guide above or `docs/advance/rollout_is_migration.md` for detailed instructions. ### Backward Compatibility - No backward compatibility with legacy TIS - Configuration files with `tis_imp_ratio_cap` will raise validation errors - Affected recipes have been updated in this PR --- ## Pre-Submission Checklist - [x] Search for similar PRs: [https://github.com/volcengine/verl/pulls?q=is%3Apr+importance+sampling](https://github.com/volcengine/verl/pulls?q=is%3Apr+importance+sampling) - [x] Format PR title as `[{modules}] {type}: {description}` (checked by CI) - **Suggested title:** `[BREAKING][rollout, trainer, algo] feat: implement comprehensive Rollout Importance Sampling framework` - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md) - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting) - [x] Add/update [documentation](https://github.com/volcengine/verl/tree/main/docs) (3 new docs, 2 updated) - [x] Add unit and integration tests (530 lines of tests) - [x] Once PR is ready for CI, send message in `ci-request` channel --- ## References - **Blog post:** [When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda) - **Migration guide:** `docs/advance/rollout_is_migration.md` - **Examples:** `examples/rollout_importance_sampling/` - **Tests:** `tests/trainer/ppo/test_rollout_is*.py` --------- Co-authored-by: Yan Bai <bayan@nvidia.com>
…ect#3750) # Rollout Importance Sampling Framework related to verl-project#3694 ## Summary This PR introduces a comprehensive **Rollout Importance Sampling (IS)** framework to correct distribution mismatch between data-collecting (rollout) and training policies, a critical factor for ensuring stable and efficient model training in RL fine-tuning. This work is motivated by the analysis in our blog post, [When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda). If you find this implementation useful in your research, please consider citing: ```bibtex @misc{liu-li-2025, title = {When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch}, url = {https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Inference-Training-Mismatch-271211a558b7808d8b12d403fd15edda}, author = {Jiacai Liu and Yingru Li and Yuqian Fu and Jiawei Wang and Qian Liu and Yu Shen}, year = {2025}, month = {September}, } ``` --- ## Problem Statement When using different policies for rollout generation (e.g., vLLM with BFloat16) and training (e.g., FSDP with FP32), distribution mismatch occurs, leading to: - Biased gradient estimates - Training instability and collapse - Reduced sample efficiency - Poor convergence properties This framework addresses these issues through principled importance sampling correction. --- ## Key Features & Improvements ### 1. **Flexible Aggregation Levels** Three methods for calculating IS weights: - **`token`**: Per-token importance ratios - **`sequence`**: Product of per-token ratios - **`geometric`**: Geometric mean of ratios ### 2. **Advanced Bounding Modes** Two strategies to control weight variance: - **`truncate`** (TIS): Caps weights at upper threshold only, preserving gradients - **`mask`** (MIS): Zeros out weights outside bounds, more aggressive filtering ### 3. **Comprehensive Diagnostics** Detailed metrics to monitor distribution mismatch and training health: **Rollout IS Metrics** (automatically prefixed with `mismatch/`): - Health indicators: `rollout_is_eff_sample_size`, `rollout_is_mean` - Distribution statistics: `rollout_is_p25`, `rollout_is_p50`, `rollout_is_p75`, `rollout_is_p95`, `rollout_is_p99`, `rollout_is_max`, `rollout_is_min`, `rollout_is_std` - Diagnostics: `rollout_is_veto_fraction`, `rollout_is_catastrophic_token_fraction`, `rollout_is_masked_fraction` (mask mode) - Sequence-level statistics (for sequence/geometric modes): `rollout_is_seq_mean`, `rollout_is_seq_std`, `rollout_is_seq_max`, `rollout_is_seq_min`, etc. **Mismatch Metrics** (computed efficiently within IS weight computation): - KL Divergence: `mismatch_kl` (forward KL), `mismatch_k3_kl` (K3 estimator for stability) - Perplexity: `mismatch_training_ppl`, `mismatch_rollout_ppl`, `mismatch_ppl_ratio` - Log perplexity statistics: `mismatch_log_ppl_diff`, `mismatch_log_ppl_abs_diff`, `mismatch_log_ppl_diff_max`, `mismatch_log_ppl_diff_min` ### 4. **Outlier Mitigation** - **Veto mechanism**: Automatically discards samples with catastrophic importance weights (per-token ratios below threshold) - Prevents gradient corruption from extreme outliers - Configurable threshold (default: 1e-4) ### 5. **Numerical Stability** - All core computations in **log-space** to prevent underflow/overflow - Carefully designed clamping and bounding to maintain numerical precision - Safe handling of edge cases (zero probabilities, extreme ratios) ### 6. **Memory Efficiency** - Optimized computation to minimize CUDA memory usage - Efficient metric aggregation without large intermediate tensors - Suitable for large-scale distributed training ### 7. **Metrics-Only Mode** - Compute and monitor mismatch metrics **without** applying IS weights - Useful for: - Understanding distribution mismatch before intervention - Deciding whether IS correction is needed - A/B testing IS impact - Controlled by `algorithm.rollout_is` flag (independent of weight computation) ### 8. **Universal PPO Support** - Integrated with **all PPO variants**: vanilla, GSPO, GPG, Clip-Cov, KL-Cov, geo_mean - Consistent interface across different policy loss functions - Automatic weight application when enabled --- ## API and Configuration Changes ### Migration from Legacy TIS #### ❌ **Before (REMOVED)** ```yaml # Old TIS configuration - NO LONGER SUPPORTED actor_rollout_ref: actor: tis_imp_ratio_cap: 2.0 # Removed from actor config ``` The legacy implementation: - Only supported token-level truncation - No metrics tracking - Lacked numerical stability - Limited configurability #### ✅ **After (New Framework)** Configuration moved to `algorithm` section for better organization: ```yaml algorithm: # Main on/off switch: null = disabled, float = enabled rollout_is_threshold: 2.0 # Control weight application (independent of metrics computation) rollout_is: true # true = apply weights, false = metrics only # Optional: lower threshold (defaults to 1/upper if null) rollout_is_threshold_lower: null # Aggregation level: "token", "sequence", or "geometric" rollout_is_level: token # Bounding mode: "truncate" or "mask" rollout_is_mode: truncate # Veto threshold for catastrophic outliers (null = disabled) rollout_is_veto_threshold: 1e-4 # REQUIRED: Enable log probability calculation actor_rollout_ref: rollout: calculate_log_probs: true ``` ### Configuration Examples **1. Token-level truncation (recommended starting point)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: token rollout_is_mode: truncate ``` **2. Sequence-level masking (more aggressive)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: sequence rollout_is_mode: mask ``` **3. Metrics-only mode (monitoring without correction)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: false # Compute metrics but don't apply weights rollout_is_level: token rollout_is_mode: truncate ``` **Example script:** `bash examples/rollout_importance_sampling/run_with_rollout_is.sh` --- ## Code Changes Overview ### New Files (4 files, 1,442 lines) 1. **`verl/trainer/ppo/mismatch_helper.py`** (459 lines) - Core implementation of IS weight computation - Three aggregation levels: token, sequence, geometric - Two bounding modes: truncate, mask - Veto mechanism for outlier detection - Comprehensive metrics computation (IS + mismatch) - All computations in log-space for numerical stability - Memory-efficient design 2. **`docs/advance/rollout_is_migration.md`** (642 lines) - Comprehensive migration guide from legacy TIS - Detailed explanation of all configuration options - Recommended threshold ranges for each aggregation level - Troubleshooting guide and best practices - Metrics interpretation guide 3. **`examples/rollout_importance_sampling/README.md`** (242 lines) - Quick start guide with working examples - Configuration templates for common scenarios - Threshold tuning guidelines - Metrics monitoring instructions 4. **`examples/rollout_importance_sampling/run_with_rollout_is.sh`** (99 lines) - Complete working example script - Demonstrates token-level and sequence-level configurations - Ready to run with minimal modifications ### Modified Core Files (9 files) 1. **`verl/trainer/ppo/core_algos.py`** (~50 lines changed) - Removed legacy TIS logic (`tis_imp_ratio_cap`) - Added `rollout_is_weights` parameter to all policy loss functions - Unified IS weight application interface across all PPO variants: - `compute_policy_loss_vanilla` - `compute_policy_loss_gspo` - `compute_policy_loss_gpg` - `compute_policy_loss_clip_cov` - `compute_policy_loss_kl_cov` - `compute_policy_loss_geo_mean` - Special handling for `geo_mean` (sequence-level aggregation) 2. **`verl/trainer/ppo/ray_trainer.py`** (~52 lines added) - New method: `compute_rollout_importance_weights_and_add_to_batch()` - Centralized IS computation (once per batch, on driver) - Conditional weight distribution to workers based on `algorithm.rollout_is` - Metrics collection and aggregation - Integration with existing training loop 3. **`verl/trainer/config/algorithm.py`** (+18 lines) - Added 6 new Rollout IS parameters: - `rollout_is_threshold` (main on/off switch) - `rollout_is` (weight application control) - `rollout_is_threshold_lower` - `rollout_is_level` - `rollout_is_mode` - `rollout_is_veto_threshold` - Comprehensive docstrings explaining each parameter 4. **`verl/workers/config/actor.py`** (-1 line) - Removed deprecated `tis_imp_ratio_cap` parameter 5. **`verl/workers/actor/dp_actor.py`** (~26 lines changed) - Updated to use new `rollout_is_weights` parameter - Removed legacy TIS logic 6. **`verl/workers/actor/megatron_actor.py`** (~15 lines changed) - Updated to use new `rollout_is_weights` parameter - Removed legacy TIS logic 7. **Configuration Files** (4 files updated) - `verl/trainer/config/ppo_trainer.yaml` - `verl/trainer/config/ppo_megatron_trainer.yaml` - `verl/trainer/config/_generated_ppo_trainer.yaml` - `verl/trainer/config/_generated_ppo_megatron_trainer.yaml` - Added default Rollout IS configuration section with explanatory comments ### Testing (2 files, 530 lines) 1. **`tests/trainer/ppo/test_rollout_is.py`** (289 lines) - Unit tests for `mismatch_helper.py` - Coverage for all aggregation levels (token, sequence, geometric) - Coverage for all bounding modes (truncate, mask) - Veto mechanism tests - Edge case handling (zeros, extremes, empty sequences) - Numerical stability verification - Metrics correctness validation 2. **`tests/trainer/ppo/test_rollout_is_integration.py`** (241 lines) - Integration tests with PPO training loop - End-to-end workflow validation - Batch processing tests - Configuration validation - Metrics collection verification - Compatibility with distributed training ### Updated Recipes (2 files) 1. **`recipe/dapo/dapo_ray_trainer.py`** (+5 lines) - Updated imports to use new framework 2. **`recipe/dapo/run_dapo_qwen2.5_32b_tis.sh`** (~42 lines changed) - Migrated from legacy TIS to new Rollout IS configuration - Updated documentation and comments ### Documentation Updates (2 files) 1. **`docs/examples/config.rst`** (~22 lines changed) - Updated configuration examples - Added Rollout IS section 2. **`docs/index.rst`** (+1 line) - Added link to Rollout IS migration guide --- ## Implementation Highlights ### Centralized Architecture The new design follows a clean separation of concerns: ``` ray_trainer.py (driver) └─> compute_rollout_importance_weights_and_add_to_batch() └─> mismatch_helper.compute_rollout_importance_weights() ├─> Computes IS weights (token/sequence/geometric) ├─> Applies bounding (truncate/mask) ├─> Veto mechanism for outliers ├─> Computes IS metrics └─> Computes mismatch metrics (KL, PPL) └─> Conditionally adds weights to batch (if rollout_is=True) └─> Distributes batch to workers actor workers (dp_actor, megatron_actor) └─> Receive batch with rollout_is_weights (if enabled) └─> Pass weights to policy loss function core_algos.py └─> All policy loss functions accept rollout_is_weights └─> Apply weights if provided: pg_losses *= rollout_is_weights ``` ### Key Design Decisions 1. **Centralized Computation**: IS weights computed once on driver, not per worker - Reduces redundant computation - Ensures consistency across workers - Simplifies debugging and metrics collection 2. **Configuration in Algorithm**: Moved from actor config to algorithm config - Better conceptual organization (algorithm-level concern, not worker-level) - Easier to manage and validate - Consistent with other algorithm parameters 3. **Two-Level Control**: - `rollout_is_threshold`: Enables/disables entire system (null = off) - `rollout_is`: Controls weight application (true = apply, false = metrics only) - Allows flexible monitoring and gradual rollout 4. **Metrics Consolidation**: Mismatch metrics computed within IS weight computation - Eliminates duplicate computation - Reduces memory overhead - Maintains metric accuracy 5. **Universal PPO Support**: Single interface for all PPO variants - Minimal code changes required - Consistent behavior across algorithms - Easy to add new variants --- ## Migration Guide ### For Users of Legacy TIS **Step 1: Update your configuration file** ```yaml # OLD (remove this) actor_rollout_ref: actor: tis_imp_ratio_cap: 2.0 # NEW (add this) algorithm: rollout_is_threshold: 2.0 # Use same value as old tis_imp_ratio_cap rollout_is: true rollout_is_level: token rollout_is_mode: truncate # REQUIRED (add if not present) actor_rollout_ref: rollout: calculate_log_probs: true ``` **Step 2: Monitor metrics** The first time you run with the new configuration, check these metrics: - `mismatch/rollout_is_eff_sample_size`: Should be > 80% of batch size - `mismatch/rollout_is_veto_fraction`: Should be < 5% - `mismatch/rollout_is_mean`: Should be close to 1.0 **Step 3: Tune if needed** If effective sample size is too low: - Increase `rollout_is_threshold` - Try `rollout_is_mode: mask` with appropriate lower bound - Consider `rollout_is_level: sequence` for more aggressive correction For detailed guidance, see `docs/advance/rollout_is_migration.md`. ### For New Users Start with recommended defaults: ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: token rollout_is_mode: truncate actor_rollout_ref: rollout: calculate_log_probs: true ``` Run the example script to see it in action: ```bash bash examples/rollout_importance_sampling/run_with_rollout_is.sh ``` --- ## Testing ### Unit Tests - **289 lines** of comprehensive unit tests in `test_rollout_is.py` - Covers all aggregation levels, bounding modes, and edge cases - Validates numerical stability and correctness - Fast execution (~1-2 seconds) ### Integration Tests - **241 lines** of integration tests in `test_rollout_is_integration.py` - End-to-end workflow with PPO training loop - Distributed training compatibility - Metrics collection validation - Moderate execution time (~10-20 seconds) ### Running Tests ```bash # Run all Rollout IS tests pytest tests/trainer/ppo/test_rollout_is.py -v pytest tests/trainer/ppo/test_rollout_is_integration.py -v # Run specific test pytest tests/trainer/ppo/test_rollout_is.py::test_token_level_truncate -v ``` --- ## Metrics Reference ### Rollout IS Metrics (all prefixed with `mismatch/`) | Metric | Description | Ideal Range | |--------|-------------|-------------| | `rollout_is_eff_sample_size` | Effective number of samples after IS | > 80% of batch | | `rollout_is_mean` | Mean IS weight | ~1.0 | | `rollout_is_std` | Standard deviation of IS weights | Low variance | | `rollout_is_p25` | 25th percentile | ~0.8-1.0 | | `rollout_is_p50` | Median IS weight | ~1.0 | | `rollout_is_p75` | 75th percentile | ~1.0-1.2 | | `rollout_is_p95` | 95th percentile | < threshold | | `rollout_is_p99` | 99th percentile | < threshold | | `rollout_is_max` | Maximum weight | ≤ threshold | | `rollout_is_min` | Minimum weight | ≥ lower threshold (mask mode) | | `rollout_is_veto_fraction` | % sequences vetoed | < 5% | | `rollout_is_catastrophic_token_fraction` | % catastrophic tokens | < 1% | | `rollout_is_masked_fraction` | % tokens masked (mask mode) | Variable | ### Mismatch Metrics (all prefixed with `mismatch/`) | Metric | Description | What It Means | |--------|-------------|---------------| | `mismatch_kl` | Forward KL divergence | Distribution difference (rollout vs training) | | `mismatch_k3_kl` | K3 KL estimator | Stable KL estimate for small divergences | | `mismatch_training_ppl` | Training policy perplexity | Prediction difficulty of training policy | | `mismatch_rollout_ppl` | Rollout policy perplexity | Prediction difficulty of rollout policy | | `mismatch_ppl_ratio` | Ratio of training to rollout PPL | Relative prediction difficulty | | `mismatch_log_ppl_diff` | Log perplexity difference | Sequence-level PPL mismatch | | `mismatch_log_ppl_abs_diff` | Absolute log PPL difference | Magnitude of mismatch | | `mismatch_log_ppl_diff_max` | Max log PPL difference | Worst-case mismatch | | `mismatch_log_ppl_diff_min` | Min log PPL difference | Best-case mismatch | | `mismatch_training_log_ppl` | Log of training PPL | Log-scale training perplexity | | `mismatch_rollout_log_ppl` | Log of rollout PPL | Log-scale rollout perplexity | --- ## Performance Impact ### Memory - Minimal overhead: ~1-2% increase in peak memory usage - Efficient log-space computation - No large intermediate tensors ### Computation - Negligible impact on training speed: < 1% overhead - Centralized computation on driver (no per-worker redundancy) - Optimized tensor operations ### Training Stability - Significant improvement in stability when distribution mismatch exists - Faster convergence in many scenarios - Reduced risk of training collapse --- ## Breaking Changes > [!IMPORTANT] > This PR contains **BREAKING CHANGES** to the configuration API. ### Removed - `actor_rollout_ref.actor.tis_imp_ratio_cap`: No longer supported ### Migration Required All users of the legacy TIS implementation must update their configuration files. See the migration guide above or `docs/advance/rollout_is_migration.md` for detailed instructions. ### Backward Compatibility - No backward compatibility with legacy TIS - Configuration files with `tis_imp_ratio_cap` will raise validation errors - Affected recipes have been updated in this PR --- ## Pre-Submission Checklist - [x] Search for similar PRs: [https://github.com/volcengine/verl/pulls?q=is%3Apr+importance+sampling](https://github.com/volcengine/verl/pulls?q=is%3Apr+importance+sampling) - [x] Format PR title as `[{modules}] {type}: {description}` (checked by CI) - **Suggested title:** `[BREAKING][rollout, trainer, algo] feat: implement comprehensive Rollout Importance Sampling framework` - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md) - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting) - [x] Add/update [documentation](https://github.com/volcengine/verl/tree/main/docs) (3 new docs, 2 updated) - [x] Add unit and integration tests (530 lines of tests) - [x] Once PR is ready for CI, send message in `ci-request` channel --- ## References - **Blog post:** [When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda) - **Migration guide:** `docs/advance/rollout_is_migration.md` - **Examples:** `examples/rollout_importance_sampling/` - **Tests:** `tests/trainer/ppo/test_rollout_is*.py`
…ance sampling implementation (verl-project#3694) # Rollout Importance Sampling Framework ## Summary This PR introduces a comprehensive **Rollout Importance Sampling (IS)** framework to correct distribution mismatch between data-collecting (rollout) and training policies, a critical factor for ensuring stable and efficient model training in RL fine-tuning. This work is motivated by the analysis in our blog post, [When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda). If you find this implementation useful in your research, please consider citing: ```bibtex @misc{liu-li-2025, title = {When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch}, url = {https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Inference-Training-Mismatch-271211a558b7808d8b12d403fd15edda}, author = {Jiacai Liu and Yingru Li and Yuqian Fu and Jiawei Wang and Qian Liu and Yu Shen}, year = {2025}, month = {September}, } ``` --- ## Problem Statement When using different policies for rollout generation (e.g., vLLM with BFloat16) and training (e.g., FSDP with FP32), distribution mismatch occurs, leading to: - Biased gradient estimates - Training instability and collapse - Reduced sample efficiency - Poor convergence properties This framework addresses these issues through principled importance sampling correction. --- ## Key Features & Improvements ### 1. **Flexible Aggregation Levels** Three methods for calculating IS weights: - **`token`**: Per-token importance ratios - **`sequence`**: Product of per-token ratios - **`geometric`**: Geometric mean of ratios ### 2. **Advanced Bounding Modes** Two strategies to control weight variance: - **`truncate`** (TIS): Caps weights at upper threshold only, preserving gradients - **`clip`** (CIS): Zeros out weights outside bounds, more aggressive filtering ### 3. **Comprehensive Diagnostics** Detailed metrics to monitor distribution mismatch and training health: **Rollout IS Metrics** (automatically prefixed with `mismatch/`): - Health indicators: `rollout_is_eff_sample_size`, `rollout_is_mean` - Distribution statistics: `rollout_is_p25`, `rollout_is_p50`, `rollout_is_p75`, `rollout_is_p95`, `rollout_is_p99`, `rollout_is_max`, `rollout_is_min`, `rollout_is_std` - Diagnostics: `rollout_is_veto_fraction`, `rollout_is_catastrophic_token_fraction`, `rollout_is_clipped_fraction` (clip mode) - Sequence-level statistics (for sequence/geometric modes): `rollout_is_seq_mean`, `rollout_is_seq_std`, `rollout_is_seq_max`, `rollout_is_seq_min`, etc. **Mismatch Metrics** (computed efficiently within IS weight computation): - KL Divergence: `mismatch_kl` (forward KL), `mismatch_k3_kl` (K3 estimator for stability) - Perplexity: `mismatch_training_ppl`, `mismatch_rollout_ppl`, `mismatch_ppl_ratio` - Log perplexity statistics: `mismatch_log_ppl_diff`, `mismatch_log_ppl_abs_diff`, `mismatch_log_ppl_diff_max`, `mismatch_log_ppl_diff_min` ### 4. **Outlier Mitigation** - **Veto mechanism**: Automatically discards samples with catastrophic importance weights (per-token ratios below threshold) - Prevents gradient corruption from extreme outliers - Configurable threshold (default: 1e-4) ### 5. **Numerical Stability** - All core computations in **log-space** to prevent underflow/overflow - Carefully designed clipping and bounding to maintain numerical precision - Safe handling of edge cases (zero probabilities, extreme ratios) ### 6. **Memory Efficiency** - Optimized computation to minimize CUDA memory usage - Efficient metric aggregation without large intermediate tensors - Suitable for large-scale distributed training ### 7. **Metrics-Only Mode** - Compute and monitor mismatch metrics **without** applying IS weights - Useful for: - Understanding distribution mismatch before intervention - Deciding whether IS correction is needed - A/B testing IS impact - Controlled by `algorithm.rollout_is` flag (independent of weight computation) ### 8. **Universal PPO Support** - Integrated with **all PPO variants**: vanilla, GSPO, GPG, Clip-Cov, KL-Cov, geo_mean - Consistent interface across different policy loss functions - Automatic weight application when enabled --- ## API and Configuration Changes ### Migration from Legacy TIS #### ❌ **Before (REMOVED)** ```yaml # Old TIS configuration - NO LONGER SUPPORTED actor_rollout_ref: actor: tis_imp_ratio_cap: 2.0 # Removed from actor config ``` The legacy implementation: - Only supported token-level truncation - No metrics tracking - Lacked numerical stability - Limited configurability #### ✅ **After (New Framework)** Configuration moved to `algorithm` section for better organization: ```yaml algorithm: # Main on/off switch: null = disabled, float = enabled rollout_is_threshold: 2.0 # Control weight application (independent of metrics computation) rollout_is: true # true = apply weights, false = metrics only # Optional: lower threshold (defaults to 1/upper if null) rollout_is_threshold_lower: null # Aggregation level: "token", "sequence", or "geometric" rollout_is_level: token # Bounding mode: "truncate" or "clip" rollout_is_mode: truncate # Veto threshold for catastrophic outliers (null = disabled) rollout_is_veto_threshold: 1e-4 # REQUIRED: Enable log probability calculation actor_rollout_ref: rollout: calculate_log_probs: true ``` ### Configuration Examples **1. Token-level truncation (recommended starting point)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: token rollout_is_mode: truncate ``` **2. Sequence-level clipping (more aggressive)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: sequence rollout_is_mode: clip ``` **3. Metrics-only mode (monitoring without correction)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: false # Compute metrics but don't apply weights rollout_is_level: token rollout_is_mode: truncate ``` **Example script:** `bash examples/rollout_importance_sampling/run_with_rollout_is.sh` --- ## Code Changes Overview ### New Files (4 files, 1,442 lines) 1. **`verl/trainer/ppo/mismatch_helper.py`** (459 lines) - Core implementation of IS weight computation - Three aggregation levels: token, sequence, geometric - Two bounding modes: truncate, clip - Veto mechanism for outlier detection - Comprehensive metrics computation (IS + mismatch) - All computations in log-space for numerical stability - Memory-efficient design 2. **`docs/advance/rollout_is_migration.md`** (642 lines) - Comprehensive migration guide from legacy TIS - Detailed explanation of all configuration options - Recommended threshold ranges for each aggregation level - Troubleshooting guide and best practices - Metrics interpretation guide 3. **`examples/rollout_importance_sampling/README.md`** (242 lines) - Quick start guide with working examples - Configuration templates for common scenarios - Threshold tuning guidelines - Metrics monitoring instructions 4. **`examples/rollout_importance_sampling/run_with_rollout_is.sh`** (99 lines) - Complete working example script - Demonstrates token-level and sequence-level configurations - Ready to run with minimal modifications ### Modified Core Files (9 files) 1. **`verl/trainer/ppo/core_algos.py`** (~50 lines changed) - Removed legacy TIS logic (`tis_imp_ratio_cap`) - Added `rollout_is_weights` parameter to all policy loss functions - Unified IS weight application interface across all PPO variants: - `compute_policy_loss_vanilla` - `compute_policy_loss_gspo` - `compute_policy_loss_gpg` - `compute_policy_loss_clip_cov` - `compute_policy_loss_kl_cov` - `compute_policy_loss_geo_mean` - Special handling for `geo_mean` (sequence-level aggregation) 2. **`verl/trainer/ppo/ray_trainer.py`** (~52 lines added) - New method: `compute_rollout_importance_weights_and_add_to_batch()` - Centralized IS computation (once per batch, on driver) - Conditional weight distribution to workers based on `algorithm.rollout_is` - Metrics collection and aggregation - Integration with existing training loop 3. **`verl/trainer/config/algorithm.py`** (+18 lines) - Added 6 new Rollout IS parameters: - `rollout_is_threshold` (main on/off switch) - `rollout_is` (weight application control) - `rollout_is_threshold_lower` - `rollout_is_level` - `rollout_is_mode` - `rollout_is_veto_threshold` - Comprehensive docstrings explaining each parameter 4. **`verl/workers/config/actor.py`** (-1 line) - Removed deprecated `tis_imp_ratio_cap` parameter 5. **`verl/workers/actor/dp_actor.py`** (~26 lines changed) - Updated to use new `rollout_is_weights` parameter - Removed legacy TIS logic 6. **`verl/workers/actor/megatron_actor.py`** (~15 lines changed) - Updated to use new `rollout_is_weights` parameter - Removed legacy TIS logic 7. **Configuration Files** (4 files updated) - `verl/trainer/config/ppo_trainer.yaml` - `verl/trainer/config/ppo_megatron_trainer.yaml` - `verl/trainer/config/_generated_ppo_trainer.yaml` - `verl/trainer/config/_generated_ppo_megatron_trainer.yaml` - Added default Rollout IS configuration section with explanatory comments ### Testing (2 files, 530 lines) 1. **`tests/trainer/ppo/test_rollout_is.py`** (289 lines) - Unit tests for `mismatch_helper.py` - Coverage for all aggregation levels (token, sequence, geometric) - Coverage for all bounding modes (truncate, clip) - Veto mechanism tests - Edge case handling (zeros, extremes, empty sequences) - Numerical stability verification - Metrics correctness validation 2. **`tests/trainer/ppo/test_rollout_is_integration.py`** (241 lines) - Integration tests with PPO training loop - End-to-end workflow validation - Batch processing tests - Configuration validation - Metrics collection verification - Compatibility with distributed training ### Updated Recipes (2 files) 1. **`recipe/dapo/dapo_ray_trainer.py`** (+5 lines) - Updated imports to use new framework 2. **`recipe/dapo/run_dapo_qwen2.5_32b_tis.sh`** (~42 lines changed) - Migrated from legacy TIS to new Rollout IS configuration - Updated documentation and comments ### Documentation Updates (2 files) 1. **`docs/examples/config.rst`** (~22 lines changed) - Updated configuration examples - Added Rollout IS section 2. **`docs/index.rst`** (+1 line) - Added link to Rollout IS migration guide --- ## Implementation Highlights ### Centralized Architecture The new design follows a clean separation of concerns: ``` ray_trainer.py (driver) └─> compute_rollout_importance_weights_and_add_to_batch() └─> mismatch_helper.compute_rollout_importance_weights() ├─> Computes IS weights (token/sequence/geometric) ├─> Applies bounding (truncate/clip) ├─> Veto mechanism for outliers ├─> Computes IS metrics └─> Computes mismatch metrics (KL, PPL) └─> Conditionally adds weights to batch (if rollout_is=True) └─> Distributes batch to workers actor workers (dp_actor, megatron_actor) └─> Receive batch with rollout_is_weights (if enabled) └─> Pass weights to policy loss function core_algos.py └─> All policy loss functions accept rollout_is_weights └─> Apply weights if provided: pg_losses *= rollout_is_weights ``` ### Key Design Decisions 1. **Centralized Computation**: IS weights computed once on driver, not per worker - Reduces redundant computation - Ensures consistency across workers - Simplifies debugging and metrics collection 2. **Configuration in Algorithm**: Moved from actor config to algorithm config - Better conceptual organization (algorithm-level concern, not worker-level) - Easier to manage and validate - Consistent with other algorithm parameters 3. **Two-Level Control**: - `rollout_is_threshold`: Enables/disables entire system (null = off) - `rollout_is`: Controls weight application (true = apply, false = metrics only) - Allows flexible monitoring and gradual rollout 4. **Metrics Consolidation**: Mismatch metrics computed within IS weight computation - Eliminates duplicate computation - Reduces memory overhead - Maintains metric accuracy 5. **Universal PPO Support**: Single interface for all PPO variants - Minimal code changes required - Consistent behavior across algorithms - Easy to add new variants --- ## Migration Guide ### For Users of Legacy TIS **Step 1: Update your configuration file** ```yaml # OLD (remove this) actor_rollout_ref: actor: tis_imp_ratio_cap: 2.0 # NEW (add this) algorithm: rollout_is_threshold: 2.0 # Use same value as old tis_imp_ratio_cap rollout_is: true rollout_is_level: token rollout_is_mode: truncate # REQUIRED (add if not present) actor_rollout_ref: rollout: calculate_log_probs: true ``` **Step 2: Monitor metrics** The first time you run with the new configuration, check these metrics: - `mismatch/rollout_is_eff_sample_size`: Should be > 80% of batch size - `mismatch/rollout_is_veto_fraction`: Should be < 5% - `mismatch/rollout_is_mean`: Should be close to 1.0 **Step 3: Tune if needed** If effective sample size is too low: - Increase `rollout_is_threshold` - Try `rollout_is_mode: clip` with appropriate lower bound - Consider `rollout_is_level: sequence` for more aggressive correction For detailed guidance, see `docs/advance/rollout_is_migration.md`. ### For New Users Start with recommended defaults: ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: token rollout_is_mode: truncate actor_rollout_ref: rollout: calculate_log_probs: true ``` Run the example script to see it in action: ```bash bash examples/rollout_importance_sampling/run_with_rollout_is.sh ``` --- ## Testing ### Unit Tests - **289 lines** of comprehensive unit tests in `test_rollout_is.py` - Covers all aggregation levels, bounding modes, and edge cases - Validates numerical stability and correctness - Fast execution (~1-2 seconds) ### Integration Tests - **241 lines** of integration tests in `test_rollout_is_integration.py` - End-to-end workflow with PPO training loop - Distributed training compatibility - Metrics collection validation - Moderate execution time (~10-20 seconds) ### Running Tests ```bash # Run all Rollout IS tests pytest tests/trainer/ppo/test_rollout_is.py -v pytest tests/trainer/ppo/test_rollout_is_integration.py -v # Run specific test pytest tests/trainer/ppo/test_rollout_is.py::test_token_level_truncate -v ``` --- ## Metrics Reference ### Rollout IS Metrics (all prefixed with `mismatch/`) | Metric | Description | Ideal Range | |--------|-------------|-------------| | `rollout_is_eff_sample_size` | Effective number of samples after IS | > 80% of batch | | `rollout_is_mean` | Mean IS weight | ~1.0 | | `rollout_is_std` | Standard deviation of IS weights | Low variance | | `rollout_is_p25` | 25th percentile | ~0.8-1.0 | | `rollout_is_p50` | Median IS weight | ~1.0 | | `rollout_is_p75` | 75th percentile | ~1.0-1.2 | | `rollout_is_p95` | 95th percentile | < threshold | | `rollout_is_p99` | 99th percentile | < threshold | | `rollout_is_max` | Maximum weight | ≤ threshold | | `rollout_is_min` | Minimum weight | ≥ lower threshold (clip mode) | | `rollout_is_veto_fraction` | % sequences vetoed | < 5% | | `rollout_is_catastrophic_token_fraction` | % catastrophic tokens | < 1% | | `rollout_is_clipped_fraction` | % tokens clipped (clip mode) | Variable | ### Mismatch Metrics (all prefixed with `mismatch/`) | Metric | Description | What It Means | |--------|-------------|---------------| | `mismatch_kl` | Forward KL divergence | Distribution difference (rollout vs training) | | `mismatch_k3_kl` | K3 KL estimator | Stable KL estimate for small divergences | | `mismatch_training_ppl` | Training policy perplexity | Prediction difficulty of training policy | | `mismatch_rollout_ppl` | Rollout policy perplexity | Prediction difficulty of rollout policy | | `mismatch_ppl_ratio` | Ratio of training to rollout PPL | Relative prediction difficulty | | `mismatch_log_ppl_diff` | Log perplexity difference | Sequence-level PPL mismatch | | `mismatch_log_ppl_abs_diff` | Absolute log PPL difference | Magnitude of mismatch | | `mismatch_log_ppl_diff_max` | Max log PPL difference | Worst-case mismatch | | `mismatch_log_ppl_diff_min` | Min log PPL difference | Best-case mismatch | | `mismatch_training_log_ppl` | Log of training PPL | Log-scale training perplexity | | `mismatch_rollout_log_ppl` | Log of rollout PPL | Log-scale rollout perplexity | --- ## Performance Impact ### Memory - Minimal overhead: ~1-2% increase in peak memory usage - Efficient log-space computation - No large intermediate tensors ### Computation - Negligible impact on training speed: < 1% overhead - Centralized computation on driver (no per-worker redundancy) - Optimized tensor operations ### Training Stability - Significant improvement in stability when distribution mismatch exists - Faster convergence in many scenarios - Reduced risk of training collapse --- ## Breaking Changes > [!IMPORTANT] > This PR contains **BREAKING CHANGES** to the configuration API. ### Removed - `actor_rollout_ref.actor.tis_imp_ratio_cap`: No longer supported ### Migration Required All users of the legacy TIS implementation must update their configuration files. See the migration guide above or `docs/advance/rollout_is_migration.md` for detailed instructions. ### Backward Compatibility - No backward compatibility with legacy TIS - Configuration files with `tis_imp_ratio_cap` will raise validation errors - Affected recipes have been updated in this PR --- ## Pre-Submission Checklist - [x] Search for similar PRs: [https://github.com/volcengine/verl/pulls?q=is%3Apr+importance+sampling](https://github.com/volcengine/verl/pulls?q=is%3Apr+importance+sampling) - [x] Format PR title as `[{modules}] {type}: {description}` (checked by CI) - **Suggested title:** `[BREAKING][rollout, trainer, algo] feat: implement comprehensive Rollout Importance Sampling framework` - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md) - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting) - [x] Add/update [documentation](https://github.com/volcengine/verl/tree/main/docs) (3 new docs, 2 updated) - [x] Add unit and integration tests (530 lines of tests) - [x] Once PR is ready for CI, send message in `ci-request` channel --- ## References - **Blog post:** [When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda) - **Migration guide:** `docs/advance/rollout_is_migration.md` - **Examples:** `examples/rollout_importance_sampling/` - **Tests:** `tests/trainer/ppo/test_rollout_is*.py` --------- Co-authored-by: Yan Bai <bayan@nvidia.com>
…ect#3750) # Rollout Importance Sampling Framework related to verl-project#3694 ## Summary This PR introduces a comprehensive **Rollout Importance Sampling (IS)** framework to correct distribution mismatch between data-collecting (rollout) and training policies, a critical factor for ensuring stable and efficient model training in RL fine-tuning. This work is motivated by the analysis in our blog post, [When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda). If you find this implementation useful in your research, please consider citing: ```bibtex @misc{liu-li-2025, title = {When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch}, url = {https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Inference-Training-Mismatch-271211a558b7808d8b12d403fd15edda}, author = {Jiacai Liu and Yingru Li and Yuqian Fu and Jiawei Wang and Qian Liu and Yu Shen}, year = {2025}, month = {September}, } ``` --- ## Problem Statement When using different policies for rollout generation (e.g., vLLM with BFloat16) and training (e.g., FSDP with FP32), distribution mismatch occurs, leading to: - Biased gradient estimates - Training instability and collapse - Reduced sample efficiency - Poor convergence properties This framework addresses these issues through principled importance sampling correction. --- ## Key Features & Improvements ### 1. **Flexible Aggregation Levels** Three methods for calculating IS weights: - **`token`**: Per-token importance ratios - **`sequence`**: Product of per-token ratios - **`geometric`**: Geometric mean of ratios ### 2. **Advanced Bounding Modes** Two strategies to control weight variance: - **`truncate`** (TIS): Caps weights at upper threshold only, preserving gradients - **`mask`** (MIS): Zeros out weights outside bounds, more aggressive filtering ### 3. **Comprehensive Diagnostics** Detailed metrics to monitor distribution mismatch and training health: **Rollout IS Metrics** (automatically prefixed with `mismatch/`): - Health indicators: `rollout_is_eff_sample_size`, `rollout_is_mean` - Distribution statistics: `rollout_is_p25`, `rollout_is_p50`, `rollout_is_p75`, `rollout_is_p95`, `rollout_is_p99`, `rollout_is_max`, `rollout_is_min`, `rollout_is_std` - Diagnostics: `rollout_is_veto_fraction`, `rollout_is_catastrophic_token_fraction`, `rollout_is_masked_fraction` (mask mode) - Sequence-level statistics (for sequence/geometric modes): `rollout_is_seq_mean`, `rollout_is_seq_std`, `rollout_is_seq_max`, `rollout_is_seq_min`, etc. **Mismatch Metrics** (computed efficiently within IS weight computation): - KL Divergence: `mismatch_kl` (forward KL), `mismatch_k3_kl` (K3 estimator for stability) - Perplexity: `mismatch_training_ppl`, `mismatch_rollout_ppl`, `mismatch_ppl_ratio` - Log perplexity statistics: `mismatch_log_ppl_diff`, `mismatch_log_ppl_abs_diff`, `mismatch_log_ppl_diff_max`, `mismatch_log_ppl_diff_min` ### 4. **Outlier Mitigation** - **Veto mechanism**: Automatically discards samples with catastrophic importance weights (per-token ratios below threshold) - Prevents gradient corruption from extreme outliers - Configurable threshold (default: 1e-4) ### 5. **Numerical Stability** - All core computations in **log-space** to prevent underflow/overflow - Carefully designed clamping and bounding to maintain numerical precision - Safe handling of edge cases (zero probabilities, extreme ratios) ### 6. **Memory Efficiency** - Optimized computation to minimize CUDA memory usage - Efficient metric aggregation without large intermediate tensors - Suitable for large-scale distributed training ### 7. **Metrics-Only Mode** - Compute and monitor mismatch metrics **without** applying IS weights - Useful for: - Understanding distribution mismatch before intervention - Deciding whether IS correction is needed - A/B testing IS impact - Controlled by `algorithm.rollout_is` flag (independent of weight computation) ### 8. **Universal PPO Support** - Integrated with **all PPO variants**: vanilla, GSPO, GPG, Clip-Cov, KL-Cov, geo_mean - Consistent interface across different policy loss functions - Automatic weight application when enabled --- ## API and Configuration Changes ### Migration from Legacy TIS #### ❌ **Before (REMOVED)** ```yaml # Old TIS configuration - NO LONGER SUPPORTED actor_rollout_ref: actor: tis_imp_ratio_cap: 2.0 # Removed from actor config ``` The legacy implementation: - Only supported token-level truncation - No metrics tracking - Lacked numerical stability - Limited configurability #### ✅ **After (New Framework)** Configuration moved to `algorithm` section for better organization: ```yaml algorithm: # Main on/off switch: null = disabled, float = enabled rollout_is_threshold: 2.0 # Control weight application (independent of metrics computation) rollout_is: true # true = apply weights, false = metrics only # Optional: lower threshold (defaults to 1/upper if null) rollout_is_threshold_lower: null # Aggregation level: "token", "sequence", or "geometric" rollout_is_level: token # Bounding mode: "truncate" or "mask" rollout_is_mode: truncate # Veto threshold for catastrophic outliers (null = disabled) rollout_is_veto_threshold: 1e-4 # REQUIRED: Enable log probability calculation actor_rollout_ref: rollout: calculate_log_probs: true ``` ### Configuration Examples **1. Token-level truncation (recommended starting point)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: token rollout_is_mode: truncate ``` **2. Sequence-level masking (more aggressive)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: sequence rollout_is_mode: mask ``` **3. Metrics-only mode (monitoring without correction)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: false # Compute metrics but don't apply weights rollout_is_level: token rollout_is_mode: truncate ``` **Example script:** `bash examples/rollout_importance_sampling/run_with_rollout_is.sh` --- ## Code Changes Overview ### New Files (4 files, 1,442 lines) 1. **`verl/trainer/ppo/mismatch_helper.py`** (459 lines) - Core implementation of IS weight computation - Three aggregation levels: token, sequence, geometric - Two bounding modes: truncate, mask - Veto mechanism for outlier detection - Comprehensive metrics computation (IS + mismatch) - All computations in log-space for numerical stability - Memory-efficient design 2. **`docs/advance/rollout_is_migration.md`** (642 lines) - Comprehensive migration guide from legacy TIS - Detailed explanation of all configuration options - Recommended threshold ranges for each aggregation level - Troubleshooting guide and best practices - Metrics interpretation guide 3. **`examples/rollout_importance_sampling/README.md`** (242 lines) - Quick start guide with working examples - Configuration templates for common scenarios - Threshold tuning guidelines - Metrics monitoring instructions 4. **`examples/rollout_importance_sampling/run_with_rollout_is.sh`** (99 lines) - Complete working example script - Demonstrates token-level and sequence-level configurations - Ready to run with minimal modifications ### Modified Core Files (9 files) 1. **`verl/trainer/ppo/core_algos.py`** (~50 lines changed) - Removed legacy TIS logic (`tis_imp_ratio_cap`) - Added `rollout_is_weights` parameter to all policy loss functions - Unified IS weight application interface across all PPO variants: - `compute_policy_loss_vanilla` - `compute_policy_loss_gspo` - `compute_policy_loss_gpg` - `compute_policy_loss_clip_cov` - `compute_policy_loss_kl_cov` - `compute_policy_loss_geo_mean` - Special handling for `geo_mean` (sequence-level aggregation) 2. **`verl/trainer/ppo/ray_trainer.py`** (~52 lines added) - New method: `compute_rollout_importance_weights_and_add_to_batch()` - Centralized IS computation (once per batch, on driver) - Conditional weight distribution to workers based on `algorithm.rollout_is` - Metrics collection and aggregation - Integration with existing training loop 3. **`verl/trainer/config/algorithm.py`** (+18 lines) - Added 6 new Rollout IS parameters: - `rollout_is_threshold` (main on/off switch) - `rollout_is` (weight application control) - `rollout_is_threshold_lower` - `rollout_is_level` - `rollout_is_mode` - `rollout_is_veto_threshold` - Comprehensive docstrings explaining each parameter 4. **`verl/workers/config/actor.py`** (-1 line) - Removed deprecated `tis_imp_ratio_cap` parameter 5. **`verl/workers/actor/dp_actor.py`** (~26 lines changed) - Updated to use new `rollout_is_weights` parameter - Removed legacy TIS logic 6. **`verl/workers/actor/megatron_actor.py`** (~15 lines changed) - Updated to use new `rollout_is_weights` parameter - Removed legacy TIS logic 7. **Configuration Files** (4 files updated) - `verl/trainer/config/ppo_trainer.yaml` - `verl/trainer/config/ppo_megatron_trainer.yaml` - `verl/trainer/config/_generated_ppo_trainer.yaml` - `verl/trainer/config/_generated_ppo_megatron_trainer.yaml` - Added default Rollout IS configuration section with explanatory comments ### Testing (2 files, 530 lines) 1. **`tests/trainer/ppo/test_rollout_is.py`** (289 lines) - Unit tests for `mismatch_helper.py` - Coverage for all aggregation levels (token, sequence, geometric) - Coverage for all bounding modes (truncate, mask) - Veto mechanism tests - Edge case handling (zeros, extremes, empty sequences) - Numerical stability verification - Metrics correctness validation 2. **`tests/trainer/ppo/test_rollout_is_integration.py`** (241 lines) - Integration tests with PPO training loop - End-to-end workflow validation - Batch processing tests - Configuration validation - Metrics collection verification - Compatibility with distributed training ### Updated Recipes (2 files) 1. **`recipe/dapo/dapo_ray_trainer.py`** (+5 lines) - Updated imports to use new framework 2. **`recipe/dapo/run_dapo_qwen2.5_32b_tis.sh`** (~42 lines changed) - Migrated from legacy TIS to new Rollout IS configuration - Updated documentation and comments ### Documentation Updates (2 files) 1. **`docs/examples/config.rst`** (~22 lines changed) - Updated configuration examples - Added Rollout IS section 2. **`docs/index.rst`** (+1 line) - Added link to Rollout IS migration guide --- ## Implementation Highlights ### Centralized Architecture The new design follows a clean separation of concerns: ``` ray_trainer.py (driver) └─> compute_rollout_importance_weights_and_add_to_batch() └─> mismatch_helper.compute_rollout_importance_weights() ├─> Computes IS weights (token/sequence/geometric) ├─> Applies bounding (truncate/mask) ├─> Veto mechanism for outliers ├─> Computes IS metrics └─> Computes mismatch metrics (KL, PPL) └─> Conditionally adds weights to batch (if rollout_is=True) └─> Distributes batch to workers actor workers (dp_actor, megatron_actor) └─> Receive batch with rollout_is_weights (if enabled) └─> Pass weights to policy loss function core_algos.py └─> All policy loss functions accept rollout_is_weights └─> Apply weights if provided: pg_losses *= rollout_is_weights ``` ### Key Design Decisions 1. **Centralized Computation**: IS weights computed once on driver, not per worker - Reduces redundant computation - Ensures consistency across workers - Simplifies debugging and metrics collection 2. **Configuration in Algorithm**: Moved from actor config to algorithm config - Better conceptual organization (algorithm-level concern, not worker-level) - Easier to manage and validate - Consistent with other algorithm parameters 3. **Two-Level Control**: - `rollout_is_threshold`: Enables/disables entire system (null = off) - `rollout_is`: Controls weight application (true = apply, false = metrics only) - Allows flexible monitoring and gradual rollout 4. **Metrics Consolidation**: Mismatch metrics computed within IS weight computation - Eliminates duplicate computation - Reduces memory overhead - Maintains metric accuracy 5. **Universal PPO Support**: Single interface for all PPO variants - Minimal code changes required - Consistent behavior across algorithms - Easy to add new variants --- ## Migration Guide ### For Users of Legacy TIS **Step 1: Update your configuration file** ```yaml # OLD (remove this) actor_rollout_ref: actor: tis_imp_ratio_cap: 2.0 # NEW (add this) algorithm: rollout_is_threshold: 2.0 # Use same value as old tis_imp_ratio_cap rollout_is: true rollout_is_level: token rollout_is_mode: truncate # REQUIRED (add if not present) actor_rollout_ref: rollout: calculate_log_probs: true ``` **Step 2: Monitor metrics** The first time you run with the new configuration, check these metrics: - `mismatch/rollout_is_eff_sample_size`: Should be > 80% of batch size - `mismatch/rollout_is_veto_fraction`: Should be < 5% - `mismatch/rollout_is_mean`: Should be close to 1.0 **Step 3: Tune if needed** If effective sample size is too low: - Increase `rollout_is_threshold` - Try `rollout_is_mode: mask` with appropriate lower bound - Consider `rollout_is_level: sequence` for more aggressive correction For detailed guidance, see `docs/advance/rollout_is_migration.md`. ### For New Users Start with recommended defaults: ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: token rollout_is_mode: truncate actor_rollout_ref: rollout: calculate_log_probs: true ``` Run the example script to see it in action: ```bash bash examples/rollout_importance_sampling/run_with_rollout_is.sh ``` --- ## Testing ### Unit Tests - **289 lines** of comprehensive unit tests in `test_rollout_is.py` - Covers all aggregation levels, bounding modes, and edge cases - Validates numerical stability and correctness - Fast execution (~1-2 seconds) ### Integration Tests - **241 lines** of integration tests in `test_rollout_is_integration.py` - End-to-end workflow with PPO training loop - Distributed training compatibility - Metrics collection validation - Moderate execution time (~10-20 seconds) ### Running Tests ```bash # Run all Rollout IS tests pytest tests/trainer/ppo/test_rollout_is.py -v pytest tests/trainer/ppo/test_rollout_is_integration.py -v # Run specific test pytest tests/trainer/ppo/test_rollout_is.py::test_token_level_truncate -v ``` --- ## Metrics Reference ### Rollout IS Metrics (all prefixed with `mismatch/`) | Metric | Description | Ideal Range | |--------|-------------|-------------| | `rollout_is_eff_sample_size` | Effective number of samples after IS | > 80% of batch | | `rollout_is_mean` | Mean IS weight | ~1.0 | | `rollout_is_std` | Standard deviation of IS weights | Low variance | | `rollout_is_p25` | 25th percentile | ~0.8-1.0 | | `rollout_is_p50` | Median IS weight | ~1.0 | | `rollout_is_p75` | 75th percentile | ~1.0-1.2 | | `rollout_is_p95` | 95th percentile | < threshold | | `rollout_is_p99` | 99th percentile | < threshold | | `rollout_is_max` | Maximum weight | ≤ threshold | | `rollout_is_min` | Minimum weight | ≥ lower threshold (mask mode) | | `rollout_is_veto_fraction` | % sequences vetoed | < 5% | | `rollout_is_catastrophic_token_fraction` | % catastrophic tokens | < 1% | | `rollout_is_masked_fraction` | % tokens masked (mask mode) | Variable | ### Mismatch Metrics (all prefixed with `mismatch/`) | Metric | Description | What It Means | |--------|-------------|---------------| | `mismatch_kl` | Forward KL divergence | Distribution difference (rollout vs training) | | `mismatch_k3_kl` | K3 KL estimator | Stable KL estimate for small divergences | | `mismatch_training_ppl` | Training policy perplexity | Prediction difficulty of training policy | | `mismatch_rollout_ppl` | Rollout policy perplexity | Prediction difficulty of rollout policy | | `mismatch_ppl_ratio` | Ratio of training to rollout PPL | Relative prediction difficulty | | `mismatch_log_ppl_diff` | Log perplexity difference | Sequence-level PPL mismatch | | `mismatch_log_ppl_abs_diff` | Absolute log PPL difference | Magnitude of mismatch | | `mismatch_log_ppl_diff_max` | Max log PPL difference | Worst-case mismatch | | `mismatch_log_ppl_diff_min` | Min log PPL difference | Best-case mismatch | | `mismatch_training_log_ppl` | Log of training PPL | Log-scale training perplexity | | `mismatch_rollout_log_ppl` | Log of rollout PPL | Log-scale rollout perplexity | --- ## Performance Impact ### Memory - Minimal overhead: ~1-2% increase in peak memory usage - Efficient log-space computation - No large intermediate tensors ### Computation - Negligible impact on training speed: < 1% overhead - Centralized computation on driver (no per-worker redundancy) - Optimized tensor operations ### Training Stability - Significant improvement in stability when distribution mismatch exists - Faster convergence in many scenarios - Reduced risk of training collapse --- ## Breaking Changes > [!IMPORTANT] > This PR contains **BREAKING CHANGES** to the configuration API. ### Removed - `actor_rollout_ref.actor.tis_imp_ratio_cap`: No longer supported ### Migration Required All users of the legacy TIS implementation must update their configuration files. See the migration guide above or `docs/advance/rollout_is_migration.md` for detailed instructions. ### Backward Compatibility - No backward compatibility with legacy TIS - Configuration files with `tis_imp_ratio_cap` will raise validation errors - Affected recipes have been updated in this PR --- ## Pre-Submission Checklist - [x] Search for similar PRs: [https://github.com/volcengine/verl/pulls?q=is%3Apr+importance+sampling](https://github.com/volcengine/verl/pulls?q=is%3Apr+importance+sampling) - [x] Format PR title as `[{modules}] {type}: {description}` (checked by CI) - **Suggested title:** `[BREAKING][rollout, trainer, algo] feat: implement comprehensive Rollout Importance Sampling framework` - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md) - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting) - [x] Add/update [documentation](https://github.com/volcengine/verl/tree/main/docs) (3 new docs, 2 updated) - [x] Add unit and integration tests (530 lines of tests) - [x] Once PR is ready for CI, send message in `ci-request` channel --- ## References - **Blog post:** [When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda) - **Migration guide:** `docs/advance/rollout_is_migration.md` - **Examples:** `examples/rollout_importance_sampling/` - **Tests:** `tests/trainer/ppo/test_rollout_is*.py`
…ance sampling implementation (verl-project#3694) # Rollout Importance Sampling Framework ## Summary This PR introduces a comprehensive **Rollout Importance Sampling (IS)** framework to correct distribution mismatch between data-collecting (rollout) and training policies, a critical factor for ensuring stable and efficient model training in RL fine-tuning. This work is motivated by the analysis in our blog post, [When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda). If you find this implementation useful in your research, please consider citing: ```bibtex @misc{liu-li-2025, title = {When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch}, url = {https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Inference-Training-Mismatch-271211a558b7808d8b12d403fd15edda}, author = {Jiacai Liu and Yingru Li and Yuqian Fu and Jiawei Wang and Qian Liu and Yu Shen}, year = {2025}, month = {September}, } ``` --- ## Problem Statement When using different policies for rollout generation (e.g., vLLM with BFloat16) and training (e.g., FSDP with FP32), distribution mismatch occurs, leading to: - Biased gradient estimates - Training instability and collapse - Reduced sample efficiency - Poor convergence properties This framework addresses these issues through principled importance sampling correction. --- ## Key Features & Improvements ### 1. **Flexible Aggregation Levels** Three methods for calculating IS weights: - **`token`**: Per-token importance ratios - **`sequence`**: Product of per-token ratios - **`geometric`**: Geometric mean of ratios ### 2. **Advanced Bounding Modes** Two strategies to control weight variance: - **`truncate`** (TIS): Caps weights at upper threshold only, preserving gradients - **`clip`** (CIS): Zeros out weights outside bounds, more aggressive filtering ### 3. **Comprehensive Diagnostics** Detailed metrics to monitor distribution mismatch and training health: **Rollout IS Metrics** (automatically prefixed with `mismatch/`): - Health indicators: `rollout_is_eff_sample_size`, `rollout_is_mean` - Distribution statistics: `rollout_is_p25`, `rollout_is_p50`, `rollout_is_p75`, `rollout_is_p95`, `rollout_is_p99`, `rollout_is_max`, `rollout_is_min`, `rollout_is_std` - Diagnostics: `rollout_is_veto_fraction`, `rollout_is_catastrophic_token_fraction`, `rollout_is_clipped_fraction` (clip mode) - Sequence-level statistics (for sequence/geometric modes): `rollout_is_seq_mean`, `rollout_is_seq_std`, `rollout_is_seq_max`, `rollout_is_seq_min`, etc. **Mismatch Metrics** (computed efficiently within IS weight computation): - KL Divergence: `mismatch_kl` (forward KL), `mismatch_k3_kl` (K3 estimator for stability) - Perplexity: `mismatch_training_ppl`, `mismatch_rollout_ppl`, `mismatch_ppl_ratio` - Log perplexity statistics: `mismatch_log_ppl_diff`, `mismatch_log_ppl_abs_diff`, `mismatch_log_ppl_diff_max`, `mismatch_log_ppl_diff_min` ### 4. **Outlier Mitigation** - **Veto mechanism**: Automatically discards samples with catastrophic importance weights (per-token ratios below threshold) - Prevents gradient corruption from extreme outliers - Configurable threshold (default: 1e-4) ### 5. **Numerical Stability** - All core computations in **log-space** to prevent underflow/overflow - Carefully designed clipping and bounding to maintain numerical precision - Safe handling of edge cases (zero probabilities, extreme ratios) ### 6. **Memory Efficiency** - Optimized computation to minimize CUDA memory usage - Efficient metric aggregation without large intermediate tensors - Suitable for large-scale distributed training ### 7. **Metrics-Only Mode** - Compute and monitor mismatch metrics **without** applying IS weights - Useful for: - Understanding distribution mismatch before intervention - Deciding whether IS correction is needed - A/B testing IS impact - Controlled by `algorithm.rollout_is` flag (independent of weight computation) ### 8. **Universal PPO Support** - Integrated with **all PPO variants**: vanilla, GSPO, GPG, Clip-Cov, KL-Cov, geo_mean - Consistent interface across different policy loss functions - Automatic weight application when enabled --- ## API and Configuration Changes ### Migration from Legacy TIS #### ❌ **Before (REMOVED)** ```yaml # Old TIS configuration - NO LONGER SUPPORTED actor_rollout_ref: actor: tis_imp_ratio_cap: 2.0 # Removed from actor config ``` The legacy implementation: - Only supported token-level truncation - No metrics tracking - Lacked numerical stability - Limited configurability #### ✅ **After (New Framework)** Configuration moved to `algorithm` section for better organization: ```yaml algorithm: # Main on/off switch: null = disabled, float = enabled rollout_is_threshold: 2.0 # Control weight application (independent of metrics computation) rollout_is: true # true = apply weights, false = metrics only # Optional: lower threshold (defaults to 1/upper if null) rollout_is_threshold_lower: null # Aggregation level: "token", "sequence", or "geometric" rollout_is_level: token # Bounding mode: "truncate" or "clip" rollout_is_mode: truncate # Veto threshold for catastrophic outliers (null = disabled) rollout_is_veto_threshold: 1e-4 # REQUIRED: Enable log probability calculation actor_rollout_ref: rollout: calculate_log_probs: true ``` ### Configuration Examples **1. Token-level truncation (recommended starting point)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: token rollout_is_mode: truncate ``` **2. Sequence-level clipping (more aggressive)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: sequence rollout_is_mode: clip ``` **3. Metrics-only mode (monitoring without correction)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: false # Compute metrics but don't apply weights rollout_is_level: token rollout_is_mode: truncate ``` **Example script:** `bash examples/rollout_importance_sampling/run_with_rollout_is.sh` --- ## Code Changes Overview ### New Files (4 files, 1,442 lines) 1. **`verl/trainer/ppo/mismatch_helper.py`** (459 lines) - Core implementation of IS weight computation - Three aggregation levels: token, sequence, geometric - Two bounding modes: truncate, clip - Veto mechanism for outlier detection - Comprehensive metrics computation (IS + mismatch) - All computations in log-space for numerical stability - Memory-efficient design 2. **`docs/advance/rollout_is_migration.md`** (642 lines) - Comprehensive migration guide from legacy TIS - Detailed explanation of all configuration options - Recommended threshold ranges for each aggregation level - Troubleshooting guide and best practices - Metrics interpretation guide 3. **`examples/rollout_importance_sampling/README.md`** (242 lines) - Quick start guide with working examples - Configuration templates for common scenarios - Threshold tuning guidelines - Metrics monitoring instructions 4. **`examples/rollout_importance_sampling/run_with_rollout_is.sh`** (99 lines) - Complete working example script - Demonstrates token-level and sequence-level configurations - Ready to run with minimal modifications ### Modified Core Files (9 files) 1. **`verl/trainer/ppo/core_algos.py`** (~50 lines changed) - Removed legacy TIS logic (`tis_imp_ratio_cap`) - Added `rollout_is_weights` parameter to all policy loss functions - Unified IS weight application interface across all PPO variants: - `compute_policy_loss_vanilla` - `compute_policy_loss_gspo` - `compute_policy_loss_gpg` - `compute_policy_loss_clip_cov` - `compute_policy_loss_kl_cov` - `compute_policy_loss_geo_mean` - Special handling for `geo_mean` (sequence-level aggregation) 2. **`verl/trainer/ppo/ray_trainer.py`** (~52 lines added) - New method: `compute_rollout_importance_weights_and_add_to_batch()` - Centralized IS computation (once per batch, on driver) - Conditional weight distribution to workers based on `algorithm.rollout_is` - Metrics collection and aggregation - Integration with existing training loop 3. **`verl/trainer/config/algorithm.py`** (+18 lines) - Added 6 new Rollout IS parameters: - `rollout_is_threshold` (main on/off switch) - `rollout_is` (weight application control) - `rollout_is_threshold_lower` - `rollout_is_level` - `rollout_is_mode` - `rollout_is_veto_threshold` - Comprehensive docstrings explaining each parameter 4. **`verl/workers/config/actor.py`** (-1 line) - Removed deprecated `tis_imp_ratio_cap` parameter 5. **`verl/workers/actor/dp_actor.py`** (~26 lines changed) - Updated to use new `rollout_is_weights` parameter - Removed legacy TIS logic 6. **`verl/workers/actor/megatron_actor.py`** (~15 lines changed) - Updated to use new `rollout_is_weights` parameter - Removed legacy TIS logic 7. **Configuration Files** (4 files updated) - `verl/trainer/config/ppo_trainer.yaml` - `verl/trainer/config/ppo_megatron_trainer.yaml` - `verl/trainer/config/_generated_ppo_trainer.yaml` - `verl/trainer/config/_generated_ppo_megatron_trainer.yaml` - Added default Rollout IS configuration section with explanatory comments ### Testing (2 files, 530 lines) 1. **`tests/trainer/ppo/test_rollout_is.py`** (289 lines) - Unit tests for `mismatch_helper.py` - Coverage for all aggregation levels (token, sequence, geometric) - Coverage for all bounding modes (truncate, clip) - Veto mechanism tests - Edge case handling (zeros, extremes, empty sequences) - Numerical stability verification - Metrics correctness validation 2. **`tests/trainer/ppo/test_rollout_is_integration.py`** (241 lines) - Integration tests with PPO training loop - End-to-end workflow validation - Batch processing tests - Configuration validation - Metrics collection verification - Compatibility with distributed training ### Updated Recipes (2 files) 1. **`recipe/dapo/dapo_ray_trainer.py`** (+5 lines) - Updated imports to use new framework 2. **`recipe/dapo/run_dapo_qwen2.5_32b_tis.sh`** (~42 lines changed) - Migrated from legacy TIS to new Rollout IS configuration - Updated documentation and comments ### Documentation Updates (2 files) 1. **`docs/examples/config.rst`** (~22 lines changed) - Updated configuration examples - Added Rollout IS section 2. **`docs/index.rst`** (+1 line) - Added link to Rollout IS migration guide --- ## Implementation Highlights ### Centralized Architecture The new design follows a clean separation of concerns: ``` ray_trainer.py (driver) └─> compute_rollout_importance_weights_and_add_to_batch() └─> mismatch_helper.compute_rollout_importance_weights() ├─> Computes IS weights (token/sequence/geometric) ├─> Applies bounding (truncate/clip) ├─> Veto mechanism for outliers ├─> Computes IS metrics └─> Computes mismatch metrics (KL, PPL) └─> Conditionally adds weights to batch (if rollout_is=True) └─> Distributes batch to workers actor workers (dp_actor, megatron_actor) └─> Receive batch with rollout_is_weights (if enabled) └─> Pass weights to policy loss function core_algos.py └─> All policy loss functions accept rollout_is_weights └─> Apply weights if provided: pg_losses *= rollout_is_weights ``` ### Key Design Decisions 1. **Centralized Computation**: IS weights computed once on driver, not per worker - Reduces redundant computation - Ensures consistency across workers - Simplifies debugging and metrics collection 2. **Configuration in Algorithm**: Moved from actor config to algorithm config - Better conceptual organization (algorithm-level concern, not worker-level) - Easier to manage and validate - Consistent with other algorithm parameters 3. **Two-Level Control**: - `rollout_is_threshold`: Enables/disables entire system (null = off) - `rollout_is`: Controls weight application (true = apply, false = metrics only) - Allows flexible monitoring and gradual rollout 4. **Metrics Consolidation**: Mismatch metrics computed within IS weight computation - Eliminates duplicate computation - Reduces memory overhead - Maintains metric accuracy 5. **Universal PPO Support**: Single interface for all PPO variants - Minimal code changes required - Consistent behavior across algorithms - Easy to add new variants --- ## Migration Guide ### For Users of Legacy TIS **Step 1: Update your configuration file** ```yaml # OLD (remove this) actor_rollout_ref: actor: tis_imp_ratio_cap: 2.0 # NEW (add this) algorithm: rollout_is_threshold: 2.0 # Use same value as old tis_imp_ratio_cap rollout_is: true rollout_is_level: token rollout_is_mode: truncate # REQUIRED (add if not present) actor_rollout_ref: rollout: calculate_log_probs: true ``` **Step 2: Monitor metrics** The first time you run with the new configuration, check these metrics: - `mismatch/rollout_is_eff_sample_size`: Should be > 80% of batch size - `mismatch/rollout_is_veto_fraction`: Should be < 5% - `mismatch/rollout_is_mean`: Should be close to 1.0 **Step 3: Tune if needed** If effective sample size is too low: - Increase `rollout_is_threshold` - Try `rollout_is_mode: clip` with appropriate lower bound - Consider `rollout_is_level: sequence` for more aggressive correction For detailed guidance, see `docs/advance/rollout_is_migration.md`. ### For New Users Start with recommended defaults: ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: token rollout_is_mode: truncate actor_rollout_ref: rollout: calculate_log_probs: true ``` Run the example script to see it in action: ```bash bash examples/rollout_importance_sampling/run_with_rollout_is.sh ``` --- ## Testing ### Unit Tests - **289 lines** of comprehensive unit tests in `test_rollout_is.py` - Covers all aggregation levels, bounding modes, and edge cases - Validates numerical stability and correctness - Fast execution (~1-2 seconds) ### Integration Tests - **241 lines** of integration tests in `test_rollout_is_integration.py` - End-to-end workflow with PPO training loop - Distributed training compatibility - Metrics collection validation - Moderate execution time (~10-20 seconds) ### Running Tests ```bash # Run all Rollout IS tests pytest tests/trainer/ppo/test_rollout_is.py -v pytest tests/trainer/ppo/test_rollout_is_integration.py -v # Run specific test pytest tests/trainer/ppo/test_rollout_is.py::test_token_level_truncate -v ``` --- ## Metrics Reference ### Rollout IS Metrics (all prefixed with `mismatch/`) | Metric | Description | Ideal Range | |--------|-------------|-------------| | `rollout_is_eff_sample_size` | Effective number of samples after IS | > 80% of batch | | `rollout_is_mean` | Mean IS weight | ~1.0 | | `rollout_is_std` | Standard deviation of IS weights | Low variance | | `rollout_is_p25` | 25th percentile | ~0.8-1.0 | | `rollout_is_p50` | Median IS weight | ~1.0 | | `rollout_is_p75` | 75th percentile | ~1.0-1.2 | | `rollout_is_p95` | 95th percentile | < threshold | | `rollout_is_p99` | 99th percentile | < threshold | | `rollout_is_max` | Maximum weight | ≤ threshold | | `rollout_is_min` | Minimum weight | ≥ lower threshold (clip mode) | | `rollout_is_veto_fraction` | % sequences vetoed | < 5% | | `rollout_is_catastrophic_token_fraction` | % catastrophic tokens | < 1% | | `rollout_is_clipped_fraction` | % tokens clipped (clip mode) | Variable | ### Mismatch Metrics (all prefixed with `mismatch/`) | Metric | Description | What It Means | |--------|-------------|---------------| | `mismatch_kl` | Forward KL divergence | Distribution difference (rollout vs training) | | `mismatch_k3_kl` | K3 KL estimator | Stable KL estimate for small divergences | | `mismatch_training_ppl` | Training policy perplexity | Prediction difficulty of training policy | | `mismatch_rollout_ppl` | Rollout policy perplexity | Prediction difficulty of rollout policy | | `mismatch_ppl_ratio` | Ratio of training to rollout PPL | Relative prediction difficulty | | `mismatch_log_ppl_diff` | Log perplexity difference | Sequence-level PPL mismatch | | `mismatch_log_ppl_abs_diff` | Absolute log PPL difference | Magnitude of mismatch | | `mismatch_log_ppl_diff_max` | Max log PPL difference | Worst-case mismatch | | `mismatch_log_ppl_diff_min` | Min log PPL difference | Best-case mismatch | | `mismatch_training_log_ppl` | Log of training PPL | Log-scale training perplexity | | `mismatch_rollout_log_ppl` | Log of rollout PPL | Log-scale rollout perplexity | --- ## Performance Impact ### Memory - Minimal overhead: ~1-2% increase in peak memory usage - Efficient log-space computation - No large intermediate tensors ### Computation - Negligible impact on training speed: < 1% overhead - Centralized computation on driver (no per-worker redundancy) - Optimized tensor operations ### Training Stability - Significant improvement in stability when distribution mismatch exists - Faster convergence in many scenarios - Reduced risk of training collapse --- ## Breaking Changes > [!IMPORTANT] > This PR contains **BREAKING CHANGES** to the configuration API. ### Removed - `actor_rollout_ref.actor.tis_imp_ratio_cap`: No longer supported ### Migration Required All users of the legacy TIS implementation must update their configuration files. See the migration guide above or `docs/advance/rollout_is_migration.md` for detailed instructions. ### Backward Compatibility - No backward compatibility with legacy TIS - Configuration files with `tis_imp_ratio_cap` will raise validation errors - Affected recipes have been updated in this PR --- ## Pre-Submission Checklist - [x] Search for similar PRs: [https://github.com/volcengine/verl/pulls?q=is%3Apr+importance+sampling](https://github.com/volcengine/verl/pulls?q=is%3Apr+importance+sampling) - [x] Format PR title as `[{modules}] {type}: {description}` (checked by CI) - **Suggested title:** `[BREAKING][rollout, trainer, algo] feat: implement comprehensive Rollout Importance Sampling framework` - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md) - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting) - [x] Add/update [documentation](https://github.com/volcengine/verl/tree/main/docs) (3 new docs, 2 updated) - [x] Add unit and integration tests (530 lines of tests) - [x] Once PR is ready for CI, send message in `ci-request` channel --- ## References - **Blog post:** [When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda) - **Migration guide:** `docs/advance/rollout_is_migration.md` - **Examples:** `examples/rollout_importance_sampling/` - **Tests:** `tests/trainer/ppo/test_rollout_is*.py` --------- Co-authored-by: Yan Bai <bayan@nvidia.com>
…ect#3750) # Rollout Importance Sampling Framework related to verl-project#3694 ## Summary This PR introduces a comprehensive **Rollout Importance Sampling (IS)** framework to correct distribution mismatch between data-collecting (rollout) and training policies, a critical factor for ensuring stable and efficient model training in RL fine-tuning. This work is motivated by the analysis in our blog post, [When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda). If you find this implementation useful in your research, please consider citing: ```bibtex @misc{liu-li-2025, title = {When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch}, url = {https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Inference-Training-Mismatch-271211a558b7808d8b12d403fd15edda}, author = {Jiacai Liu and Yingru Li and Yuqian Fu and Jiawei Wang and Qian Liu and Yu Shen}, year = {2025}, month = {September}, } ``` --- ## Problem Statement When using different policies for rollout generation (e.g., vLLM with BFloat16) and training (e.g., FSDP with FP32), distribution mismatch occurs, leading to: - Biased gradient estimates - Training instability and collapse - Reduced sample efficiency - Poor convergence properties This framework addresses these issues through principled importance sampling correction. --- ## Key Features & Improvements ### 1. **Flexible Aggregation Levels** Three methods for calculating IS weights: - **`token`**: Per-token importance ratios - **`sequence`**: Product of per-token ratios - **`geometric`**: Geometric mean of ratios ### 2. **Advanced Bounding Modes** Two strategies to control weight variance: - **`truncate`** (TIS): Caps weights at upper threshold only, preserving gradients - **`mask`** (MIS): Zeros out weights outside bounds, more aggressive filtering ### 3. **Comprehensive Diagnostics** Detailed metrics to monitor distribution mismatch and training health: **Rollout IS Metrics** (automatically prefixed with `mismatch/`): - Health indicators: `rollout_is_eff_sample_size`, `rollout_is_mean` - Distribution statistics: `rollout_is_p25`, `rollout_is_p50`, `rollout_is_p75`, `rollout_is_p95`, `rollout_is_p99`, `rollout_is_max`, `rollout_is_min`, `rollout_is_std` - Diagnostics: `rollout_is_veto_fraction`, `rollout_is_catastrophic_token_fraction`, `rollout_is_masked_fraction` (mask mode) - Sequence-level statistics (for sequence/geometric modes): `rollout_is_seq_mean`, `rollout_is_seq_std`, `rollout_is_seq_max`, `rollout_is_seq_min`, etc. **Mismatch Metrics** (computed efficiently within IS weight computation): - KL Divergence: `mismatch_kl` (forward KL), `mismatch_k3_kl` (K3 estimator for stability) - Perplexity: `mismatch_training_ppl`, `mismatch_rollout_ppl`, `mismatch_ppl_ratio` - Log perplexity statistics: `mismatch_log_ppl_diff`, `mismatch_log_ppl_abs_diff`, `mismatch_log_ppl_diff_max`, `mismatch_log_ppl_diff_min` ### 4. **Outlier Mitigation** - **Veto mechanism**: Automatically discards samples with catastrophic importance weights (per-token ratios below threshold) - Prevents gradient corruption from extreme outliers - Configurable threshold (default: 1e-4) ### 5. **Numerical Stability** - All core computations in **log-space** to prevent underflow/overflow - Carefully designed clamping and bounding to maintain numerical precision - Safe handling of edge cases (zero probabilities, extreme ratios) ### 6. **Memory Efficiency** - Optimized computation to minimize CUDA memory usage - Efficient metric aggregation without large intermediate tensors - Suitable for large-scale distributed training ### 7. **Metrics-Only Mode** - Compute and monitor mismatch metrics **without** applying IS weights - Useful for: - Understanding distribution mismatch before intervention - Deciding whether IS correction is needed - A/B testing IS impact - Controlled by `algorithm.rollout_is` flag (independent of weight computation) ### 8. **Universal PPO Support** - Integrated with **all PPO variants**: vanilla, GSPO, GPG, Clip-Cov, KL-Cov, geo_mean - Consistent interface across different policy loss functions - Automatic weight application when enabled --- ## API and Configuration Changes ### Migration from Legacy TIS #### ❌ **Before (REMOVED)** ```yaml # Old TIS configuration - NO LONGER SUPPORTED actor_rollout_ref: actor: tis_imp_ratio_cap: 2.0 # Removed from actor config ``` The legacy implementation: - Only supported token-level truncation - No metrics tracking - Lacked numerical stability - Limited configurability #### ✅ **After (New Framework)** Configuration moved to `algorithm` section for better organization: ```yaml algorithm: # Main on/off switch: null = disabled, float = enabled rollout_is_threshold: 2.0 # Control weight application (independent of metrics computation) rollout_is: true # true = apply weights, false = metrics only # Optional: lower threshold (defaults to 1/upper if null) rollout_is_threshold_lower: null # Aggregation level: "token", "sequence", or "geometric" rollout_is_level: token # Bounding mode: "truncate" or "mask" rollout_is_mode: truncate # Veto threshold for catastrophic outliers (null = disabled) rollout_is_veto_threshold: 1e-4 # REQUIRED: Enable log probability calculation actor_rollout_ref: rollout: calculate_log_probs: true ``` ### Configuration Examples **1. Token-level truncation (recommended starting point)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: token rollout_is_mode: truncate ``` **2. Sequence-level masking (more aggressive)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: sequence rollout_is_mode: mask ``` **3. Metrics-only mode (monitoring without correction)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: false # Compute metrics but don't apply weights rollout_is_level: token rollout_is_mode: truncate ``` **Example script:** `bash examples/rollout_importance_sampling/run_with_rollout_is.sh` --- ## Code Changes Overview ### New Files (4 files, 1,442 lines) 1. **`verl/trainer/ppo/mismatch_helper.py`** (459 lines) - Core implementation of IS weight computation - Three aggregation levels: token, sequence, geometric - Two bounding modes: truncate, mask - Veto mechanism for outlier detection - Comprehensive metrics computation (IS + mismatch) - All computations in log-space for numerical stability - Memory-efficient design 2. **`docs/advance/rollout_is_migration.md`** (642 lines) - Comprehensive migration guide from legacy TIS - Detailed explanation of all configuration options - Recommended threshold ranges for each aggregation level - Troubleshooting guide and best practices - Metrics interpretation guide 3. **`examples/rollout_importance_sampling/README.md`** (242 lines) - Quick start guide with working examples - Configuration templates for common scenarios - Threshold tuning guidelines - Metrics monitoring instructions 4. **`examples/rollout_importance_sampling/run_with_rollout_is.sh`** (99 lines) - Complete working example script - Demonstrates token-level and sequence-level configurations - Ready to run with minimal modifications ### Modified Core Files (9 files) 1. **`verl/trainer/ppo/core_algos.py`** (~50 lines changed) - Removed legacy TIS logic (`tis_imp_ratio_cap`) - Added `rollout_is_weights` parameter to all policy loss functions - Unified IS weight application interface across all PPO variants: - `compute_policy_loss_vanilla` - `compute_policy_loss_gspo` - `compute_policy_loss_gpg` - `compute_policy_loss_clip_cov` - `compute_policy_loss_kl_cov` - `compute_policy_loss_geo_mean` - Special handling for `geo_mean` (sequence-level aggregation) 2. **`verl/trainer/ppo/ray_trainer.py`** (~52 lines added) - New method: `compute_rollout_importance_weights_and_add_to_batch()` - Centralized IS computation (once per batch, on driver) - Conditional weight distribution to workers based on `algorithm.rollout_is` - Metrics collection and aggregation - Integration with existing training loop 3. **`verl/trainer/config/algorithm.py`** (+18 lines) - Added 6 new Rollout IS parameters: - `rollout_is_threshold` (main on/off switch) - `rollout_is` (weight application control) - `rollout_is_threshold_lower` - `rollout_is_level` - `rollout_is_mode` - `rollout_is_veto_threshold` - Comprehensive docstrings explaining each parameter 4. **`verl/workers/config/actor.py`** (-1 line) - Removed deprecated `tis_imp_ratio_cap` parameter 5. **`verl/workers/actor/dp_actor.py`** (~26 lines changed) - Updated to use new `rollout_is_weights` parameter - Removed legacy TIS logic 6. **`verl/workers/actor/megatron_actor.py`** (~15 lines changed) - Updated to use new `rollout_is_weights` parameter - Removed legacy TIS logic 7. **Configuration Files** (4 files updated) - `verl/trainer/config/ppo_trainer.yaml` - `verl/trainer/config/ppo_megatron_trainer.yaml` - `verl/trainer/config/_generated_ppo_trainer.yaml` - `verl/trainer/config/_generated_ppo_megatron_trainer.yaml` - Added default Rollout IS configuration section with explanatory comments ### Testing (2 files, 530 lines) 1. **`tests/trainer/ppo/test_rollout_is.py`** (289 lines) - Unit tests for `mismatch_helper.py` - Coverage for all aggregation levels (token, sequence, geometric) - Coverage for all bounding modes (truncate, mask) - Veto mechanism tests - Edge case handling (zeros, extremes, empty sequences) - Numerical stability verification - Metrics correctness validation 2. **`tests/trainer/ppo/test_rollout_is_integration.py`** (241 lines) - Integration tests with PPO training loop - End-to-end workflow validation - Batch processing tests - Configuration validation - Metrics collection verification - Compatibility with distributed training ### Updated Recipes (2 files) 1. **`recipe/dapo/dapo_ray_trainer.py`** (+5 lines) - Updated imports to use new framework 2. **`recipe/dapo/run_dapo_qwen2.5_32b_tis.sh`** (~42 lines changed) - Migrated from legacy TIS to new Rollout IS configuration - Updated documentation and comments ### Documentation Updates (2 files) 1. **`docs/examples/config.rst`** (~22 lines changed) - Updated configuration examples - Added Rollout IS section 2. **`docs/index.rst`** (+1 line) - Added link to Rollout IS migration guide --- ## Implementation Highlights ### Centralized Architecture The new design follows a clean separation of concerns: ``` ray_trainer.py (driver) └─> compute_rollout_importance_weights_and_add_to_batch() └─> mismatch_helper.compute_rollout_importance_weights() ├─> Computes IS weights (token/sequence/geometric) ├─> Applies bounding (truncate/mask) ├─> Veto mechanism for outliers ├─> Computes IS metrics └─> Computes mismatch metrics (KL, PPL) └─> Conditionally adds weights to batch (if rollout_is=True) └─> Distributes batch to workers actor workers (dp_actor, megatron_actor) └─> Receive batch with rollout_is_weights (if enabled) └─> Pass weights to policy loss function core_algos.py └─> All policy loss functions accept rollout_is_weights └─> Apply weights if provided: pg_losses *= rollout_is_weights ``` ### Key Design Decisions 1. **Centralized Computation**: IS weights computed once on driver, not per worker - Reduces redundant computation - Ensures consistency across workers - Simplifies debugging and metrics collection 2. **Configuration in Algorithm**: Moved from actor config to algorithm config - Better conceptual organization (algorithm-level concern, not worker-level) - Easier to manage and validate - Consistent with other algorithm parameters 3. **Two-Level Control**: - `rollout_is_threshold`: Enables/disables entire system (null = off) - `rollout_is`: Controls weight application (true = apply, false = metrics only) - Allows flexible monitoring and gradual rollout 4. **Metrics Consolidation**: Mismatch metrics computed within IS weight computation - Eliminates duplicate computation - Reduces memory overhead - Maintains metric accuracy 5. **Universal PPO Support**: Single interface for all PPO variants - Minimal code changes required - Consistent behavior across algorithms - Easy to add new variants --- ## Migration Guide ### For Users of Legacy TIS **Step 1: Update your configuration file** ```yaml # OLD (remove this) actor_rollout_ref: actor: tis_imp_ratio_cap: 2.0 # NEW (add this) algorithm: rollout_is_threshold: 2.0 # Use same value as old tis_imp_ratio_cap rollout_is: true rollout_is_level: token rollout_is_mode: truncate # REQUIRED (add if not present) actor_rollout_ref: rollout: calculate_log_probs: true ``` **Step 2: Monitor metrics** The first time you run with the new configuration, check these metrics: - `mismatch/rollout_is_eff_sample_size`: Should be > 80% of batch size - `mismatch/rollout_is_veto_fraction`: Should be < 5% - `mismatch/rollout_is_mean`: Should be close to 1.0 **Step 3: Tune if needed** If effective sample size is too low: - Increase `rollout_is_threshold` - Try `rollout_is_mode: mask` with appropriate lower bound - Consider `rollout_is_level: sequence` for more aggressive correction For detailed guidance, see `docs/advance/rollout_is_migration.md`. ### For New Users Start with recommended defaults: ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: token rollout_is_mode: truncate actor_rollout_ref: rollout: calculate_log_probs: true ``` Run the example script to see it in action: ```bash bash examples/rollout_importance_sampling/run_with_rollout_is.sh ``` --- ## Testing ### Unit Tests - **289 lines** of comprehensive unit tests in `test_rollout_is.py` - Covers all aggregation levels, bounding modes, and edge cases - Validates numerical stability and correctness - Fast execution (~1-2 seconds) ### Integration Tests - **241 lines** of integration tests in `test_rollout_is_integration.py` - End-to-end workflow with PPO training loop - Distributed training compatibility - Metrics collection validation - Moderate execution time (~10-20 seconds) ### Running Tests ```bash # Run all Rollout IS tests pytest tests/trainer/ppo/test_rollout_is.py -v pytest tests/trainer/ppo/test_rollout_is_integration.py -v # Run specific test pytest tests/trainer/ppo/test_rollout_is.py::test_token_level_truncate -v ``` --- ## Metrics Reference ### Rollout IS Metrics (all prefixed with `mismatch/`) | Metric | Description | Ideal Range | |--------|-------------|-------------| | `rollout_is_eff_sample_size` | Effective number of samples after IS | > 80% of batch | | `rollout_is_mean` | Mean IS weight | ~1.0 | | `rollout_is_std` | Standard deviation of IS weights | Low variance | | `rollout_is_p25` | 25th percentile | ~0.8-1.0 | | `rollout_is_p50` | Median IS weight | ~1.0 | | `rollout_is_p75` | 75th percentile | ~1.0-1.2 | | `rollout_is_p95` | 95th percentile | < threshold | | `rollout_is_p99` | 99th percentile | < threshold | | `rollout_is_max` | Maximum weight | ≤ threshold | | `rollout_is_min` | Minimum weight | ≥ lower threshold (mask mode) | | `rollout_is_veto_fraction` | % sequences vetoed | < 5% | | `rollout_is_catastrophic_token_fraction` | % catastrophic tokens | < 1% | | `rollout_is_masked_fraction` | % tokens masked (mask mode) | Variable | ### Mismatch Metrics (all prefixed with `mismatch/`) | Metric | Description | What It Means | |--------|-------------|---------------| | `mismatch_kl` | Forward KL divergence | Distribution difference (rollout vs training) | | `mismatch_k3_kl` | K3 KL estimator | Stable KL estimate for small divergences | | `mismatch_training_ppl` | Training policy perplexity | Prediction difficulty of training policy | | `mismatch_rollout_ppl` | Rollout policy perplexity | Prediction difficulty of rollout policy | | `mismatch_ppl_ratio` | Ratio of training to rollout PPL | Relative prediction difficulty | | `mismatch_log_ppl_diff` | Log perplexity difference | Sequence-level PPL mismatch | | `mismatch_log_ppl_abs_diff` | Absolute log PPL difference | Magnitude of mismatch | | `mismatch_log_ppl_diff_max` | Max log PPL difference | Worst-case mismatch | | `mismatch_log_ppl_diff_min` | Min log PPL difference | Best-case mismatch | | `mismatch_training_log_ppl` | Log of training PPL | Log-scale training perplexity | | `mismatch_rollout_log_ppl` | Log of rollout PPL | Log-scale rollout perplexity | --- ## Performance Impact ### Memory - Minimal overhead: ~1-2% increase in peak memory usage - Efficient log-space computation - No large intermediate tensors ### Computation - Negligible impact on training speed: < 1% overhead - Centralized computation on driver (no per-worker redundancy) - Optimized tensor operations ### Training Stability - Significant improvement in stability when distribution mismatch exists - Faster convergence in many scenarios - Reduced risk of training collapse --- ## Breaking Changes > [!IMPORTANT] > This PR contains **BREAKING CHANGES** to the configuration API. ### Removed - `actor_rollout_ref.actor.tis_imp_ratio_cap`: No longer supported ### Migration Required All users of the legacy TIS implementation must update their configuration files. See the migration guide above or `docs/advance/rollout_is_migration.md` for detailed instructions. ### Backward Compatibility - No backward compatibility with legacy TIS - Configuration files with `tis_imp_ratio_cap` will raise validation errors - Affected recipes have been updated in this PR --- ## Pre-Submission Checklist - [x] Search for similar PRs: [https://github.com/volcengine/verl/pulls?q=is%3Apr+importance+sampling](https://github.com/volcengine/verl/pulls?q=is%3Apr+importance+sampling) - [x] Format PR title as `[{modules}] {type}: {description}` (checked by CI) - **Suggested title:** `[BREAKING][rollout, trainer, algo] feat: implement comprehensive Rollout Importance Sampling framework` - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md) - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting) - [x] Add/update [documentation](https://github.com/volcengine/verl/tree/main/docs) (3 new docs, 2 updated) - [x] Add unit and integration tests (530 lines of tests) - [x] Once PR is ready for CI, send message in `ci-request` channel --- ## References - **Blog post:** [When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda) - **Migration guide:** `docs/advance/rollout_is_migration.md` - **Examples:** `examples/rollout_importance_sampling/` - **Tests:** `tests/trainer/ppo/test_rollout_is*.py`
…ance sampling implementation (verl-project#3694) # Rollout Importance Sampling Framework ## Summary This PR introduces a comprehensive **Rollout Importance Sampling (IS)** framework to correct distribution mismatch between data-collecting (rollout) and training policies, a critical factor for ensuring stable and efficient model training in RL fine-tuning. This work is motivated by the analysis in our blog post, [When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda). If you find this implementation useful in your research, please consider citing: ```bibtex @misc{liu-li-2025, title = {When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch}, url = {https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Inference-Training-Mismatch-271211a558b7808d8b12d403fd15edda}, author = {Jiacai Liu and Yingru Li and Yuqian Fu and Jiawei Wang and Qian Liu and Yu Shen}, year = {2025}, month = {September}, } ``` --- ## Problem Statement When using different policies for rollout generation (e.g., vLLM with BFloat16) and training (e.g., FSDP with FP32), distribution mismatch occurs, leading to: - Biased gradient estimates - Training instability and collapse - Reduced sample efficiency - Poor convergence properties This framework addresses these issues through principled importance sampling correction. --- ## Key Features & Improvements ### 1. **Flexible Aggregation Levels** Three methods for calculating IS weights: - **`token`**: Per-token importance ratios - **`sequence`**: Product of per-token ratios - **`geometric`**: Geometric mean of ratios ### 2. **Advanced Bounding Modes** Two strategies to control weight variance: - **`truncate`** (TIS): Caps weights at upper threshold only, preserving gradients - **`clip`** (CIS): Zeros out weights outside bounds, more aggressive filtering ### 3. **Comprehensive Diagnostics** Detailed metrics to monitor distribution mismatch and training health: **Rollout IS Metrics** (automatically prefixed with `mismatch/`): - Health indicators: `rollout_is_eff_sample_size`, `rollout_is_mean` - Distribution statistics: `rollout_is_p25`, `rollout_is_p50`, `rollout_is_p75`, `rollout_is_p95`, `rollout_is_p99`, `rollout_is_max`, `rollout_is_min`, `rollout_is_std` - Diagnostics: `rollout_is_veto_fraction`, `rollout_is_catastrophic_token_fraction`, `rollout_is_clipped_fraction` (clip mode) - Sequence-level statistics (for sequence/geometric modes): `rollout_is_seq_mean`, `rollout_is_seq_std`, `rollout_is_seq_max`, `rollout_is_seq_min`, etc. **Mismatch Metrics** (computed efficiently within IS weight computation): - KL Divergence: `mismatch_kl` (forward KL), `mismatch_k3_kl` (K3 estimator for stability) - Perplexity: `mismatch_training_ppl`, `mismatch_rollout_ppl`, `mismatch_ppl_ratio` - Log perplexity statistics: `mismatch_log_ppl_diff`, `mismatch_log_ppl_abs_diff`, `mismatch_log_ppl_diff_max`, `mismatch_log_ppl_diff_min` ### 4. **Outlier Mitigation** - **Veto mechanism**: Automatically discards samples with catastrophic importance weights (per-token ratios below threshold) - Prevents gradient corruption from extreme outliers - Configurable threshold (default: 1e-4) ### 5. **Numerical Stability** - All core computations in **log-space** to prevent underflow/overflow - Carefully designed clipping and bounding to maintain numerical precision - Safe handling of edge cases (zero probabilities, extreme ratios) ### 6. **Memory Efficiency** - Optimized computation to minimize CUDA memory usage - Efficient metric aggregation without large intermediate tensors - Suitable for large-scale distributed training ### 7. **Metrics-Only Mode** - Compute and monitor mismatch metrics **without** applying IS weights - Useful for: - Understanding distribution mismatch before intervention - Deciding whether IS correction is needed - A/B testing IS impact - Controlled by `algorithm.rollout_is` flag (independent of weight computation) ### 8. **Universal PPO Support** - Integrated with **all PPO variants**: vanilla, GSPO, GPG, Clip-Cov, KL-Cov, geo_mean - Consistent interface across different policy loss functions - Automatic weight application when enabled --- ## API and Configuration Changes ### Migration from Legacy TIS #### ❌ **Before (REMOVED)** ```yaml # Old TIS configuration - NO LONGER SUPPORTED actor_rollout_ref: actor: tis_imp_ratio_cap: 2.0 # Removed from actor config ``` The legacy implementation: - Only supported token-level truncation - No metrics tracking - Lacked numerical stability - Limited configurability #### ✅ **After (New Framework)** Configuration moved to `algorithm` section for better organization: ```yaml algorithm: # Main on/off switch: null = disabled, float = enabled rollout_is_threshold: 2.0 # Control weight application (independent of metrics computation) rollout_is: true # true = apply weights, false = metrics only # Optional: lower threshold (defaults to 1/upper if null) rollout_is_threshold_lower: null # Aggregation level: "token", "sequence", or "geometric" rollout_is_level: token # Bounding mode: "truncate" or "clip" rollout_is_mode: truncate # Veto threshold for catastrophic outliers (null = disabled) rollout_is_veto_threshold: 1e-4 # REQUIRED: Enable log probability calculation actor_rollout_ref: rollout: calculate_log_probs: true ``` ### Configuration Examples **1. Token-level truncation (recommended starting point)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: token rollout_is_mode: truncate ``` **2. Sequence-level clipping (more aggressive)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: sequence rollout_is_mode: clip ``` **3. Metrics-only mode (monitoring without correction)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: false # Compute metrics but don't apply weights rollout_is_level: token rollout_is_mode: truncate ``` **Example script:** `bash examples/rollout_importance_sampling/run_with_rollout_is.sh` --- ## Code Changes Overview ### New Files (4 files, 1,442 lines) 1. **`verl/trainer/ppo/mismatch_helper.py`** (459 lines) - Core implementation of IS weight computation - Three aggregation levels: token, sequence, geometric - Two bounding modes: truncate, clip - Veto mechanism for outlier detection - Comprehensive metrics computation (IS + mismatch) - All computations in log-space for numerical stability - Memory-efficient design 2. **`docs/advance/rollout_is_migration.md`** (642 lines) - Comprehensive migration guide from legacy TIS - Detailed explanation of all configuration options - Recommended threshold ranges for each aggregation level - Troubleshooting guide and best practices - Metrics interpretation guide 3. **`examples/rollout_importance_sampling/README.md`** (242 lines) - Quick start guide with working examples - Configuration templates for common scenarios - Threshold tuning guidelines - Metrics monitoring instructions 4. **`examples/rollout_importance_sampling/run_with_rollout_is.sh`** (99 lines) - Complete working example script - Demonstrates token-level and sequence-level configurations - Ready to run with minimal modifications ### Modified Core Files (9 files) 1. **`verl/trainer/ppo/core_algos.py`** (~50 lines changed) - Removed legacy TIS logic (`tis_imp_ratio_cap`) - Added `rollout_is_weights` parameter to all policy loss functions - Unified IS weight application interface across all PPO variants: - `compute_policy_loss_vanilla` - `compute_policy_loss_gspo` - `compute_policy_loss_gpg` - `compute_policy_loss_clip_cov` - `compute_policy_loss_kl_cov` - `compute_policy_loss_geo_mean` - Special handling for `geo_mean` (sequence-level aggregation) 2. **`verl/trainer/ppo/ray_trainer.py`** (~52 lines added) - New method: `compute_rollout_importance_weights_and_add_to_batch()` - Centralized IS computation (once per batch, on driver) - Conditional weight distribution to workers based on `algorithm.rollout_is` - Metrics collection and aggregation - Integration with existing training loop 3. **`verl/trainer/config/algorithm.py`** (+18 lines) - Added 6 new Rollout IS parameters: - `rollout_is_threshold` (main on/off switch) - `rollout_is` (weight application control) - `rollout_is_threshold_lower` - `rollout_is_level` - `rollout_is_mode` - `rollout_is_veto_threshold` - Comprehensive docstrings explaining each parameter 4. **`verl/workers/config/actor.py`** (-1 line) - Removed deprecated `tis_imp_ratio_cap` parameter 5. **`verl/workers/actor/dp_actor.py`** (~26 lines changed) - Updated to use new `rollout_is_weights` parameter - Removed legacy TIS logic 6. **`verl/workers/actor/megatron_actor.py`** (~15 lines changed) - Updated to use new `rollout_is_weights` parameter - Removed legacy TIS logic 7. **Configuration Files** (4 files updated) - `verl/trainer/config/ppo_trainer.yaml` - `verl/trainer/config/ppo_megatron_trainer.yaml` - `verl/trainer/config/_generated_ppo_trainer.yaml` - `verl/trainer/config/_generated_ppo_megatron_trainer.yaml` - Added default Rollout IS configuration section with explanatory comments ### Testing (2 files, 530 lines) 1. **`tests/trainer/ppo/test_rollout_is.py`** (289 lines) - Unit tests for `mismatch_helper.py` - Coverage for all aggregation levels (token, sequence, geometric) - Coverage for all bounding modes (truncate, clip) - Veto mechanism tests - Edge case handling (zeros, extremes, empty sequences) - Numerical stability verification - Metrics correctness validation 2. **`tests/trainer/ppo/test_rollout_is_integration.py`** (241 lines) - Integration tests with PPO training loop - End-to-end workflow validation - Batch processing tests - Configuration validation - Metrics collection verification - Compatibility with distributed training ### Updated Recipes (2 files) 1. **`recipe/dapo/dapo_ray_trainer.py`** (+5 lines) - Updated imports to use new framework 2. **`recipe/dapo/run_dapo_qwen2.5_32b_tis.sh`** (~42 lines changed) - Migrated from legacy TIS to new Rollout IS configuration - Updated documentation and comments ### Documentation Updates (2 files) 1. **`docs/examples/config.rst`** (~22 lines changed) - Updated configuration examples - Added Rollout IS section 2. **`docs/index.rst`** (+1 line) - Added link to Rollout IS migration guide --- ## Implementation Highlights ### Centralized Architecture The new design follows a clean separation of concerns: ``` ray_trainer.py (driver) └─> compute_rollout_importance_weights_and_add_to_batch() └─> mismatch_helper.compute_rollout_importance_weights() ├─> Computes IS weights (token/sequence/geometric) ├─> Applies bounding (truncate/clip) ├─> Veto mechanism for outliers ├─> Computes IS metrics └─> Computes mismatch metrics (KL, PPL) └─> Conditionally adds weights to batch (if rollout_is=True) └─> Distributes batch to workers actor workers (dp_actor, megatron_actor) └─> Receive batch with rollout_is_weights (if enabled) └─> Pass weights to policy loss function core_algos.py └─> All policy loss functions accept rollout_is_weights └─> Apply weights if provided: pg_losses *= rollout_is_weights ``` ### Key Design Decisions 1. **Centralized Computation**: IS weights computed once on driver, not per worker - Reduces redundant computation - Ensures consistency across workers - Simplifies debugging and metrics collection 2. **Configuration in Algorithm**: Moved from actor config to algorithm config - Better conceptual organization (algorithm-level concern, not worker-level) - Easier to manage and validate - Consistent with other algorithm parameters 3. **Two-Level Control**: - `rollout_is_threshold`: Enables/disables entire system (null = off) - `rollout_is`: Controls weight application (true = apply, false = metrics only) - Allows flexible monitoring and gradual rollout 4. **Metrics Consolidation**: Mismatch metrics computed within IS weight computation - Eliminates duplicate computation - Reduces memory overhead - Maintains metric accuracy 5. **Universal PPO Support**: Single interface for all PPO variants - Minimal code changes required - Consistent behavior across algorithms - Easy to add new variants --- ## Migration Guide ### For Users of Legacy TIS **Step 1: Update your configuration file** ```yaml # OLD (remove this) actor_rollout_ref: actor: tis_imp_ratio_cap: 2.0 # NEW (add this) algorithm: rollout_is_threshold: 2.0 # Use same value as old tis_imp_ratio_cap rollout_is: true rollout_is_level: token rollout_is_mode: truncate # REQUIRED (add if not present) actor_rollout_ref: rollout: calculate_log_probs: true ``` **Step 2: Monitor metrics** The first time you run with the new configuration, check these metrics: - `mismatch/rollout_is_eff_sample_size`: Should be > 80% of batch size - `mismatch/rollout_is_veto_fraction`: Should be < 5% - `mismatch/rollout_is_mean`: Should be close to 1.0 **Step 3: Tune if needed** If effective sample size is too low: - Increase `rollout_is_threshold` - Try `rollout_is_mode: clip` with appropriate lower bound - Consider `rollout_is_level: sequence` for more aggressive correction For detailed guidance, see `docs/advance/rollout_is_migration.md`. ### For New Users Start with recommended defaults: ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: token rollout_is_mode: truncate actor_rollout_ref: rollout: calculate_log_probs: true ``` Run the example script to see it in action: ```bash bash examples/rollout_importance_sampling/run_with_rollout_is.sh ``` --- ## Testing ### Unit Tests - **289 lines** of comprehensive unit tests in `test_rollout_is.py` - Covers all aggregation levels, bounding modes, and edge cases - Validates numerical stability and correctness - Fast execution (~1-2 seconds) ### Integration Tests - **241 lines** of integration tests in `test_rollout_is_integration.py` - End-to-end workflow with PPO training loop - Distributed training compatibility - Metrics collection validation - Moderate execution time (~10-20 seconds) ### Running Tests ```bash # Run all Rollout IS tests pytest tests/trainer/ppo/test_rollout_is.py -v pytest tests/trainer/ppo/test_rollout_is_integration.py -v # Run specific test pytest tests/trainer/ppo/test_rollout_is.py::test_token_level_truncate -v ``` --- ## Metrics Reference ### Rollout IS Metrics (all prefixed with `mismatch/`) | Metric | Description | Ideal Range | |--------|-------------|-------------| | `rollout_is_eff_sample_size` | Effective number of samples after IS | > 80% of batch | | `rollout_is_mean` | Mean IS weight | ~1.0 | | `rollout_is_std` | Standard deviation of IS weights | Low variance | | `rollout_is_p25` | 25th percentile | ~0.8-1.0 | | `rollout_is_p50` | Median IS weight | ~1.0 | | `rollout_is_p75` | 75th percentile | ~1.0-1.2 | | `rollout_is_p95` | 95th percentile | < threshold | | `rollout_is_p99` | 99th percentile | < threshold | | `rollout_is_max` | Maximum weight | ≤ threshold | | `rollout_is_min` | Minimum weight | ≥ lower threshold (clip mode) | | `rollout_is_veto_fraction` | % sequences vetoed | < 5% | | `rollout_is_catastrophic_token_fraction` | % catastrophic tokens | < 1% | | `rollout_is_clipped_fraction` | % tokens clipped (clip mode) | Variable | ### Mismatch Metrics (all prefixed with `mismatch/`) | Metric | Description | What It Means | |--------|-------------|---------------| | `mismatch_kl` | Forward KL divergence | Distribution difference (rollout vs training) | | `mismatch_k3_kl` | K3 KL estimator | Stable KL estimate for small divergences | | `mismatch_training_ppl` | Training policy perplexity | Prediction difficulty of training policy | | `mismatch_rollout_ppl` | Rollout policy perplexity | Prediction difficulty of rollout policy | | `mismatch_ppl_ratio` | Ratio of training to rollout PPL | Relative prediction difficulty | | `mismatch_log_ppl_diff` | Log perplexity difference | Sequence-level PPL mismatch | | `mismatch_log_ppl_abs_diff` | Absolute log PPL difference | Magnitude of mismatch | | `mismatch_log_ppl_diff_max` | Max log PPL difference | Worst-case mismatch | | `mismatch_log_ppl_diff_min` | Min log PPL difference | Best-case mismatch | | `mismatch_training_log_ppl` | Log of training PPL | Log-scale training perplexity | | `mismatch_rollout_log_ppl` | Log of rollout PPL | Log-scale rollout perplexity | --- ## Performance Impact ### Memory - Minimal overhead: ~1-2% increase in peak memory usage - Efficient log-space computation - No large intermediate tensors ### Computation - Negligible impact on training speed: < 1% overhead - Centralized computation on driver (no per-worker redundancy) - Optimized tensor operations ### Training Stability - Significant improvement in stability when distribution mismatch exists - Faster convergence in many scenarios - Reduced risk of training collapse --- ## Breaking Changes > [!IMPORTANT] > This PR contains **BREAKING CHANGES** to the configuration API. ### Removed - `actor_rollout_ref.actor.tis_imp_ratio_cap`: No longer supported ### Migration Required All users of the legacy TIS implementation must update their configuration files. See the migration guide above or `docs/advance/rollout_is_migration.md` for detailed instructions. ### Backward Compatibility - No backward compatibility with legacy TIS - Configuration files with `tis_imp_ratio_cap` will raise validation errors - Affected recipes have been updated in this PR --- ## Pre-Submission Checklist - [x] Search for similar PRs: [https://github.com/volcengine/verl/pulls?q=is%3Apr+importance+sampling](https://github.com/volcengine/verl/pulls?q=is%3Apr+importance+sampling) - [x] Format PR title as `[{modules}] {type}: {description}` (checked by CI) - **Suggested title:** `[BREAKING][rollout, trainer, algo] feat: implement comprehensive Rollout Importance Sampling framework` - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md) - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting) - [x] Add/update [documentation](https://github.com/volcengine/verl/tree/main/docs) (3 new docs, 2 updated) - [x] Add unit and integration tests (530 lines of tests) - [x] Once PR is ready for CI, send message in `ci-request` channel --- ## References - **Blog post:** [When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda) - **Migration guide:** `docs/advance/rollout_is_migration.md` - **Examples:** `examples/rollout_importance_sampling/` - **Tests:** `tests/trainer/ppo/test_rollout_is*.py` --------- Co-authored-by: Yan Bai <bayan@nvidia.com>
…ect#3750) # Rollout Importance Sampling Framework related to verl-project#3694 ## Summary This PR introduces a comprehensive **Rollout Importance Sampling (IS)** framework to correct distribution mismatch between data-collecting (rollout) and training policies, a critical factor for ensuring stable and efficient model training in RL fine-tuning. This work is motivated by the analysis in our blog post, [When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda). If you find this implementation useful in your research, please consider citing: ```bibtex @misc{liu-li-2025, title = {When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch}, url = {https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Inference-Training-Mismatch-271211a558b7808d8b12d403fd15edda}, author = {Jiacai Liu and Yingru Li and Yuqian Fu and Jiawei Wang and Qian Liu and Yu Shen}, year = {2025}, month = {September}, } ``` --- ## Problem Statement When using different policies for rollout generation (e.g., vLLM with BFloat16) and training (e.g., FSDP with FP32), distribution mismatch occurs, leading to: - Biased gradient estimates - Training instability and collapse - Reduced sample efficiency - Poor convergence properties This framework addresses these issues through principled importance sampling correction. --- ## Key Features & Improvements ### 1. **Flexible Aggregation Levels** Three methods for calculating IS weights: - **`token`**: Per-token importance ratios - **`sequence`**: Product of per-token ratios - **`geometric`**: Geometric mean of ratios ### 2. **Advanced Bounding Modes** Two strategies to control weight variance: - **`truncate`** (TIS): Caps weights at upper threshold only, preserving gradients - **`mask`** (MIS): Zeros out weights outside bounds, more aggressive filtering ### 3. **Comprehensive Diagnostics** Detailed metrics to monitor distribution mismatch and training health: **Rollout IS Metrics** (automatically prefixed with `mismatch/`): - Health indicators: `rollout_is_eff_sample_size`, `rollout_is_mean` - Distribution statistics: `rollout_is_p25`, `rollout_is_p50`, `rollout_is_p75`, `rollout_is_p95`, `rollout_is_p99`, `rollout_is_max`, `rollout_is_min`, `rollout_is_std` - Diagnostics: `rollout_is_veto_fraction`, `rollout_is_catastrophic_token_fraction`, `rollout_is_masked_fraction` (mask mode) - Sequence-level statistics (for sequence/geometric modes): `rollout_is_seq_mean`, `rollout_is_seq_std`, `rollout_is_seq_max`, `rollout_is_seq_min`, etc. **Mismatch Metrics** (computed efficiently within IS weight computation): - KL Divergence: `mismatch_kl` (forward KL), `mismatch_k3_kl` (K3 estimator for stability) - Perplexity: `mismatch_training_ppl`, `mismatch_rollout_ppl`, `mismatch_ppl_ratio` - Log perplexity statistics: `mismatch_log_ppl_diff`, `mismatch_log_ppl_abs_diff`, `mismatch_log_ppl_diff_max`, `mismatch_log_ppl_diff_min` ### 4. **Outlier Mitigation** - **Veto mechanism**: Automatically discards samples with catastrophic importance weights (per-token ratios below threshold) - Prevents gradient corruption from extreme outliers - Configurable threshold (default: 1e-4) ### 5. **Numerical Stability** - All core computations in **log-space** to prevent underflow/overflow - Carefully designed clamping and bounding to maintain numerical precision - Safe handling of edge cases (zero probabilities, extreme ratios) ### 6. **Memory Efficiency** - Optimized computation to minimize CUDA memory usage - Efficient metric aggregation without large intermediate tensors - Suitable for large-scale distributed training ### 7. **Metrics-Only Mode** - Compute and monitor mismatch metrics **without** applying IS weights - Useful for: - Understanding distribution mismatch before intervention - Deciding whether IS correction is needed - A/B testing IS impact - Controlled by `algorithm.rollout_is` flag (independent of weight computation) ### 8. **Universal PPO Support** - Integrated with **all PPO variants**: vanilla, GSPO, GPG, Clip-Cov, KL-Cov, geo_mean - Consistent interface across different policy loss functions - Automatic weight application when enabled --- ## API and Configuration Changes ### Migration from Legacy TIS #### ❌ **Before (REMOVED)** ```yaml # Old TIS configuration - NO LONGER SUPPORTED actor_rollout_ref: actor: tis_imp_ratio_cap: 2.0 # Removed from actor config ``` The legacy implementation: - Only supported token-level truncation - No metrics tracking - Lacked numerical stability - Limited configurability #### ✅ **After (New Framework)** Configuration moved to `algorithm` section for better organization: ```yaml algorithm: # Main on/off switch: null = disabled, float = enabled rollout_is_threshold: 2.0 # Control weight application (independent of metrics computation) rollout_is: true # true = apply weights, false = metrics only # Optional: lower threshold (defaults to 1/upper if null) rollout_is_threshold_lower: null # Aggregation level: "token", "sequence", or "geometric" rollout_is_level: token # Bounding mode: "truncate" or "mask" rollout_is_mode: truncate # Veto threshold for catastrophic outliers (null = disabled) rollout_is_veto_threshold: 1e-4 # REQUIRED: Enable log probability calculation actor_rollout_ref: rollout: calculate_log_probs: true ``` ### Configuration Examples **1. Token-level truncation (recommended starting point)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: token rollout_is_mode: truncate ``` **2. Sequence-level masking (more aggressive)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: sequence rollout_is_mode: mask ``` **3. Metrics-only mode (monitoring without correction)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: false # Compute metrics but don't apply weights rollout_is_level: token rollout_is_mode: truncate ``` **Example script:** `bash examples/rollout_importance_sampling/run_with_rollout_is.sh` --- ## Code Changes Overview ### New Files (4 files, 1,442 lines) 1. **`verl/trainer/ppo/mismatch_helper.py`** (459 lines) - Core implementation of IS weight computation - Three aggregation levels: token, sequence, geometric - Two bounding modes: truncate, mask - Veto mechanism for outlier detection - Comprehensive metrics computation (IS + mismatch) - All computations in log-space for numerical stability - Memory-efficient design 2. **`docs/advance/rollout_is_migration.md`** (642 lines) - Comprehensive migration guide from legacy TIS - Detailed explanation of all configuration options - Recommended threshold ranges for each aggregation level - Troubleshooting guide and best practices - Metrics interpretation guide 3. **`examples/rollout_importance_sampling/README.md`** (242 lines) - Quick start guide with working examples - Configuration templates for common scenarios - Threshold tuning guidelines - Metrics monitoring instructions 4. **`examples/rollout_importance_sampling/run_with_rollout_is.sh`** (99 lines) - Complete working example script - Demonstrates token-level and sequence-level configurations - Ready to run with minimal modifications ### Modified Core Files (9 files) 1. **`verl/trainer/ppo/core_algos.py`** (~50 lines changed) - Removed legacy TIS logic (`tis_imp_ratio_cap`) - Added `rollout_is_weights` parameter to all policy loss functions - Unified IS weight application interface across all PPO variants: - `compute_policy_loss_vanilla` - `compute_policy_loss_gspo` - `compute_policy_loss_gpg` - `compute_policy_loss_clip_cov` - `compute_policy_loss_kl_cov` - `compute_policy_loss_geo_mean` - Special handling for `geo_mean` (sequence-level aggregation) 2. **`verl/trainer/ppo/ray_trainer.py`** (~52 lines added) - New method: `compute_rollout_importance_weights_and_add_to_batch()` - Centralized IS computation (once per batch, on driver) - Conditional weight distribution to workers based on `algorithm.rollout_is` - Metrics collection and aggregation - Integration with existing training loop 3. **`verl/trainer/config/algorithm.py`** (+18 lines) - Added 6 new Rollout IS parameters: - `rollout_is_threshold` (main on/off switch) - `rollout_is` (weight application control) - `rollout_is_threshold_lower` - `rollout_is_level` - `rollout_is_mode` - `rollout_is_veto_threshold` - Comprehensive docstrings explaining each parameter 4. **`verl/workers/config/actor.py`** (-1 line) - Removed deprecated `tis_imp_ratio_cap` parameter 5. **`verl/workers/actor/dp_actor.py`** (~26 lines changed) - Updated to use new `rollout_is_weights` parameter - Removed legacy TIS logic 6. **`verl/workers/actor/megatron_actor.py`** (~15 lines changed) - Updated to use new `rollout_is_weights` parameter - Removed legacy TIS logic 7. **Configuration Files** (4 files updated) - `verl/trainer/config/ppo_trainer.yaml` - `verl/trainer/config/ppo_megatron_trainer.yaml` - `verl/trainer/config/_generated_ppo_trainer.yaml` - `verl/trainer/config/_generated_ppo_megatron_trainer.yaml` - Added default Rollout IS configuration section with explanatory comments ### Testing (2 files, 530 lines) 1. **`tests/trainer/ppo/test_rollout_is.py`** (289 lines) - Unit tests for `mismatch_helper.py` - Coverage for all aggregation levels (token, sequence, geometric) - Coverage for all bounding modes (truncate, mask) - Veto mechanism tests - Edge case handling (zeros, extremes, empty sequences) - Numerical stability verification - Metrics correctness validation 2. **`tests/trainer/ppo/test_rollout_is_integration.py`** (241 lines) - Integration tests with PPO training loop - End-to-end workflow validation - Batch processing tests - Configuration validation - Metrics collection verification - Compatibility with distributed training ### Updated Recipes (2 files) 1. **`recipe/dapo/dapo_ray_trainer.py`** (+5 lines) - Updated imports to use new framework 2. **`recipe/dapo/run_dapo_qwen2.5_32b_tis.sh`** (~42 lines changed) - Migrated from legacy TIS to new Rollout IS configuration - Updated documentation and comments ### Documentation Updates (2 files) 1. **`docs/examples/config.rst`** (~22 lines changed) - Updated configuration examples - Added Rollout IS section 2. **`docs/index.rst`** (+1 line) - Added link to Rollout IS migration guide --- ## Implementation Highlights ### Centralized Architecture The new design follows a clean separation of concerns: ``` ray_trainer.py (driver) └─> compute_rollout_importance_weights_and_add_to_batch() └─> mismatch_helper.compute_rollout_importance_weights() ├─> Computes IS weights (token/sequence/geometric) ├─> Applies bounding (truncate/mask) ├─> Veto mechanism for outliers ├─> Computes IS metrics └─> Computes mismatch metrics (KL, PPL) └─> Conditionally adds weights to batch (if rollout_is=True) └─> Distributes batch to workers actor workers (dp_actor, megatron_actor) └─> Receive batch with rollout_is_weights (if enabled) └─> Pass weights to policy loss function core_algos.py └─> All policy loss functions accept rollout_is_weights └─> Apply weights if provided: pg_losses *= rollout_is_weights ``` ### Key Design Decisions 1. **Centralized Computation**: IS weights computed once on driver, not per worker - Reduces redundant computation - Ensures consistency across workers - Simplifies debugging and metrics collection 2. **Configuration in Algorithm**: Moved from actor config to algorithm config - Better conceptual organization (algorithm-level concern, not worker-level) - Easier to manage and validate - Consistent with other algorithm parameters 3. **Two-Level Control**: - `rollout_is_threshold`: Enables/disables entire system (null = off) - `rollout_is`: Controls weight application (true = apply, false = metrics only) - Allows flexible monitoring and gradual rollout 4. **Metrics Consolidation**: Mismatch metrics computed within IS weight computation - Eliminates duplicate computation - Reduces memory overhead - Maintains metric accuracy 5. **Universal PPO Support**: Single interface for all PPO variants - Minimal code changes required - Consistent behavior across algorithms - Easy to add new variants --- ## Migration Guide ### For Users of Legacy TIS **Step 1: Update your configuration file** ```yaml # OLD (remove this) actor_rollout_ref: actor: tis_imp_ratio_cap: 2.0 # NEW (add this) algorithm: rollout_is_threshold: 2.0 # Use same value as old tis_imp_ratio_cap rollout_is: true rollout_is_level: token rollout_is_mode: truncate # REQUIRED (add if not present) actor_rollout_ref: rollout: calculate_log_probs: true ``` **Step 2: Monitor metrics** The first time you run with the new configuration, check these metrics: - `mismatch/rollout_is_eff_sample_size`: Should be > 80% of batch size - `mismatch/rollout_is_veto_fraction`: Should be < 5% - `mismatch/rollout_is_mean`: Should be close to 1.0 **Step 3: Tune if needed** If effective sample size is too low: - Increase `rollout_is_threshold` - Try `rollout_is_mode: mask` with appropriate lower bound - Consider `rollout_is_level: sequence` for more aggressive correction For detailed guidance, see `docs/advance/rollout_is_migration.md`. ### For New Users Start with recommended defaults: ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: token rollout_is_mode: truncate actor_rollout_ref: rollout: calculate_log_probs: true ``` Run the example script to see it in action: ```bash bash examples/rollout_importance_sampling/run_with_rollout_is.sh ``` --- ## Testing ### Unit Tests - **289 lines** of comprehensive unit tests in `test_rollout_is.py` - Covers all aggregation levels, bounding modes, and edge cases - Validates numerical stability and correctness - Fast execution (~1-2 seconds) ### Integration Tests - **241 lines** of integration tests in `test_rollout_is_integration.py` - End-to-end workflow with PPO training loop - Distributed training compatibility - Metrics collection validation - Moderate execution time (~10-20 seconds) ### Running Tests ```bash # Run all Rollout IS tests pytest tests/trainer/ppo/test_rollout_is.py -v pytest tests/trainer/ppo/test_rollout_is_integration.py -v # Run specific test pytest tests/trainer/ppo/test_rollout_is.py::test_token_level_truncate -v ``` --- ## Metrics Reference ### Rollout IS Metrics (all prefixed with `mismatch/`) | Metric | Description | Ideal Range | |--------|-------------|-------------| | `rollout_is_eff_sample_size` | Effective number of samples after IS | > 80% of batch | | `rollout_is_mean` | Mean IS weight | ~1.0 | | `rollout_is_std` | Standard deviation of IS weights | Low variance | | `rollout_is_p25` | 25th percentile | ~0.8-1.0 | | `rollout_is_p50` | Median IS weight | ~1.0 | | `rollout_is_p75` | 75th percentile | ~1.0-1.2 | | `rollout_is_p95` | 95th percentile | < threshold | | `rollout_is_p99` | 99th percentile | < threshold | | `rollout_is_max` | Maximum weight | ≤ threshold | | `rollout_is_min` | Minimum weight | ≥ lower threshold (mask mode) | | `rollout_is_veto_fraction` | % sequences vetoed | < 5% | | `rollout_is_catastrophic_token_fraction` | % catastrophic tokens | < 1% | | `rollout_is_masked_fraction` | % tokens masked (mask mode) | Variable | ### Mismatch Metrics (all prefixed with `mismatch/`) | Metric | Description | What It Means | |--------|-------------|---------------| | `mismatch_kl` | Forward KL divergence | Distribution difference (rollout vs training) | | `mismatch_k3_kl` | K3 KL estimator | Stable KL estimate for small divergences | | `mismatch_training_ppl` | Training policy perplexity | Prediction difficulty of training policy | | `mismatch_rollout_ppl` | Rollout policy perplexity | Prediction difficulty of rollout policy | | `mismatch_ppl_ratio` | Ratio of training to rollout PPL | Relative prediction difficulty | | `mismatch_log_ppl_diff` | Log perplexity difference | Sequence-level PPL mismatch | | `mismatch_log_ppl_abs_diff` | Absolute log PPL difference | Magnitude of mismatch | | `mismatch_log_ppl_diff_max` | Max log PPL difference | Worst-case mismatch | | `mismatch_log_ppl_diff_min` | Min log PPL difference | Best-case mismatch | | `mismatch_training_log_ppl` | Log of training PPL | Log-scale training perplexity | | `mismatch_rollout_log_ppl` | Log of rollout PPL | Log-scale rollout perplexity | --- ## Performance Impact ### Memory - Minimal overhead: ~1-2% increase in peak memory usage - Efficient log-space computation - No large intermediate tensors ### Computation - Negligible impact on training speed: < 1% overhead - Centralized computation on driver (no per-worker redundancy) - Optimized tensor operations ### Training Stability - Significant improvement in stability when distribution mismatch exists - Faster convergence in many scenarios - Reduced risk of training collapse --- ## Breaking Changes > [!IMPORTANT] > This PR contains **BREAKING CHANGES** to the configuration API. ### Removed - `actor_rollout_ref.actor.tis_imp_ratio_cap`: No longer supported ### Migration Required All users of the legacy TIS implementation must update their configuration files. See the migration guide above or `docs/advance/rollout_is_migration.md` for detailed instructions. ### Backward Compatibility - No backward compatibility with legacy TIS - Configuration files with `tis_imp_ratio_cap` will raise validation errors - Affected recipes have been updated in this PR --- ## Pre-Submission Checklist - [x] Search for similar PRs: [https://github.com/volcengine/verl/pulls?q=is%3Apr+importance+sampling](https://github.com/volcengine/verl/pulls?q=is%3Apr+importance+sampling) - [x] Format PR title as `[{modules}] {type}: {description}` (checked by CI) - **Suggested title:** `[BREAKING][rollout, trainer, algo] feat: implement comprehensive Rollout Importance Sampling framework` - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md) - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting) - [x] Add/update [documentation](https://github.com/volcengine/verl/tree/main/docs) (3 new docs, 2 updated) - [x] Add unit and integration tests (530 lines of tests) - [x] Once PR is ready for CI, send message in `ci-request` channel --- ## References - **Blog post:** [When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda) - **Migration guide:** `docs/advance/rollout_is_migration.md` - **Examples:** `examples/rollout_importance_sampling/` - **Tests:** `tests/trainer/ppo/test_rollout_is*.py`
# Rollout Importance Sampling Framework related to verl-project/verl#3694 ## Summary This PR introduces a comprehensive **Rollout Importance Sampling (IS)** framework to correct distribution mismatch between data-collecting (rollout) and training policies, a critical factor for ensuring stable and efficient model training in RL fine-tuning. This work is motivated by the analysis in our blog post, [When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda). If you find this implementation useful in your research, please consider citing: ```bibtex @misc{liu-li-2025, title = {When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch}, url = {https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Inference-Training-Mismatch-271211a558b7808d8b12d403fd15edda}, author = {Jiacai Liu and Yingru Li and Yuqian Fu and Jiawei Wang and Qian Liu and Yu Shen}, year = {2025}, month = {September}, } ``` --- ## Problem Statement When using different policies for rollout generation (e.g., vLLM with BFloat16) and training (e.g., FSDP with FP32), distribution mismatch occurs, leading to: - Biased gradient estimates - Training instability and collapse - Reduced sample efficiency - Poor convergence properties This framework addresses these issues through principled importance sampling correction. --- ## Key Features & Improvements ### 1. **Flexible Aggregation Levels** Three methods for calculating IS weights: - **`token`**: Per-token importance ratios - **`sequence`**: Product of per-token ratios - **`geometric`**: Geometric mean of ratios ### 2. **Advanced Bounding Modes** Two strategies to control weight variance: - **`truncate`** (TIS): Caps weights at upper threshold only, preserving gradients - **`mask`** (MIS): Zeros out weights outside bounds, more aggressive filtering ### 3. **Comprehensive Diagnostics** Detailed metrics to monitor distribution mismatch and training health: **Rollout IS Metrics** (automatically prefixed with `mismatch/`): - Health indicators: `rollout_is_eff_sample_size`, `rollout_is_mean` - Distribution statistics: `rollout_is_p25`, `rollout_is_p50`, `rollout_is_p75`, `rollout_is_p95`, `rollout_is_p99`, `rollout_is_max`, `rollout_is_min`, `rollout_is_std` - Diagnostics: `rollout_is_veto_fraction`, `rollout_is_catastrophic_token_fraction`, `rollout_is_masked_fraction` (mask mode) - Sequence-level statistics (for sequence/geometric modes): `rollout_is_seq_mean`, `rollout_is_seq_std`, `rollout_is_seq_max`, `rollout_is_seq_min`, etc. **Mismatch Metrics** (computed efficiently within IS weight computation): - KL Divergence: `mismatch_kl` (forward KL), `mismatch_k3_kl` (K3 estimator for stability) - Perplexity: `mismatch_training_ppl`, `mismatch_rollout_ppl`, `mismatch_ppl_ratio` - Log perplexity statistics: `mismatch_log_ppl_diff`, `mismatch_log_ppl_abs_diff`, `mismatch_log_ppl_diff_max`, `mismatch_log_ppl_diff_min` ### 4. **Outlier Mitigation** - **Veto mechanism**: Automatically discards samples with catastrophic importance weights (per-token ratios below threshold) - Prevents gradient corruption from extreme outliers - Configurable threshold (default: 1e-4) ### 5. **Numerical Stability** - All core computations in **log-space** to prevent underflow/overflow - Carefully designed clamping and bounding to maintain numerical precision - Safe handling of edge cases (zero probabilities, extreme ratios) ### 6. **Memory Efficiency** - Optimized computation to minimize CUDA memory usage - Efficient metric aggregation without large intermediate tensors - Suitable for large-scale distributed training ### 7. **Metrics-Only Mode** - Compute and monitor mismatch metrics **without** applying IS weights - Useful for: - Understanding distribution mismatch before intervention - Deciding whether IS correction is needed - A/B testing IS impact - Controlled by `algorithm.rollout_is` flag (independent of weight computation) ### 8. **Universal PPO Support** - Integrated with **all PPO variants**: vanilla, GSPO, GPG, Clip-Cov, KL-Cov, geo_mean - Consistent interface across different policy loss functions - Automatic weight application when enabled --- ## API and Configuration Changes ### Migration from Legacy TIS #### ❌ **Before (REMOVED)** ```yaml # Old TIS configuration - NO LONGER SUPPORTED actor_rollout_ref: actor: tis_imp_ratio_cap: 2.0 # Removed from actor config ``` The legacy implementation: - Only supported token-level truncation - No metrics tracking - Lacked numerical stability - Limited configurability #### ✅ **After (New Framework)** Configuration moved to `algorithm` section for better organization: ```yaml algorithm: # Main on/off switch: null = disabled, float = enabled rollout_is_threshold: 2.0 # Control weight application (independent of metrics computation) rollout_is: true # true = apply weights, false = metrics only # Optional: lower threshold (defaults to 1/upper if null) rollout_is_threshold_lower: null # Aggregation level: "token", "sequence", or "geometric" rollout_is_level: token # Bounding mode: "truncate" or "mask" rollout_is_mode: truncate # Veto threshold for catastrophic outliers (null = disabled) rollout_is_veto_threshold: 1e-4 # REQUIRED: Enable log probability calculation actor_rollout_ref: rollout: calculate_log_probs: true ``` ### Configuration Examples **1. Token-level truncation (recommended starting point)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: token rollout_is_mode: truncate ``` **2. Sequence-level masking (more aggressive)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: sequence rollout_is_mode: mask ``` **3. Metrics-only mode (monitoring without correction)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: false # Compute metrics but don't apply weights rollout_is_level: token rollout_is_mode: truncate ``` **Example script:** `bash examples/rollout_importance_sampling/run_with_rollout_is.sh` --- ## Code Changes Overview ### New Files (4 files, 1,442 lines) 1. **`verl/trainer/ppo/mismatch_helper.py`** (459 lines) - Core implementation of IS weight computation - Three aggregation levels: token, sequence, geometric - Two bounding modes: truncate, mask - Veto mechanism for outlier detection - Comprehensive metrics computation (IS + mismatch) - All computations in log-space for numerical stability - Memory-efficient design 2. **`docs/advance/rollout_is_migration.md`** (642 lines) - Comprehensive migration guide from legacy TIS - Detailed explanation of all configuration options - Recommended threshold ranges for each aggregation level - Troubleshooting guide and best practices - Metrics interpretation guide 3. **`examples/rollout_importance_sampling/README.md`** (242 lines) - Quick start guide with working examples - Configuration templates for common scenarios - Threshold tuning guidelines - Metrics monitoring instructions 4. **`examples/rollout_importance_sampling/run_with_rollout_is.sh`** (99 lines) - Complete working example script - Demonstrates token-level and sequence-level configurations - Ready to run with minimal modifications ### Modified Core Files (9 files) 1. **`verl/trainer/ppo/core_algos.py`** (~50 lines changed) - Removed legacy TIS logic (`tis_imp_ratio_cap`) - Added `rollout_is_weights` parameter to all policy loss functions - Unified IS weight application interface across all PPO variants: - `compute_policy_loss_vanilla` - `compute_policy_loss_gspo` - `compute_policy_loss_gpg` - `compute_policy_loss_clip_cov` - `compute_policy_loss_kl_cov` - `compute_policy_loss_geo_mean` - Special handling for `geo_mean` (sequence-level aggregation) 2. **`verl/trainer/ppo/ray_trainer.py`** (~52 lines added) - New method: `compute_rollout_importance_weights_and_add_to_batch()` - Centralized IS computation (once per batch, on driver) - Conditional weight distribution to workers based on `algorithm.rollout_is` - Metrics collection and aggregation - Integration with existing training loop 3. **`verl/trainer/config/algorithm.py`** (+18 lines) - Added 6 new Rollout IS parameters: - `rollout_is_threshold` (main on/off switch) - `rollout_is` (weight application control) - `rollout_is_threshold_lower` - `rollout_is_level` - `rollout_is_mode` - `rollout_is_veto_threshold` - Comprehensive docstrings explaining each parameter 4. **`verl/workers/config/actor.py`** (-1 line) - Removed deprecated `tis_imp_ratio_cap` parameter 5. **`verl/workers/actor/dp_actor.py`** (~26 lines changed) - Updated to use new `rollout_is_weights` parameter - Removed legacy TIS logic 6. **`verl/workers/actor/megatron_actor.py`** (~15 lines changed) - Updated to use new `rollout_is_weights` parameter - Removed legacy TIS logic 7. **Configuration Files** (4 files updated) - `verl/trainer/config/ppo_trainer.yaml` - `verl/trainer/config/ppo_megatron_trainer.yaml` - `verl/trainer/config/_generated_ppo_trainer.yaml` - `verl/trainer/config/_generated_ppo_megatron_trainer.yaml` - Added default Rollout IS configuration section with explanatory comments ### Testing (2 files, 530 lines) 1. **`tests/trainer/ppo/test_rollout_is.py`** (289 lines) - Unit tests for `mismatch_helper.py` - Coverage for all aggregation levels (token, sequence, geometric) - Coverage for all bounding modes (truncate, mask) - Veto mechanism tests - Edge case handling (zeros, extremes, empty sequences) - Numerical stability verification - Metrics correctness validation 2. **`tests/trainer/ppo/test_rollout_is_integration.py`** (241 lines) - Integration tests with PPO training loop - End-to-end workflow validation - Batch processing tests - Configuration validation - Metrics collection verification - Compatibility with distributed training ### Updated Recipes (2 files) 1. **`recipe/dapo/dapo_ray_trainer.py`** (+5 lines) - Updated imports to use new framework 2. **`recipe/dapo/run_dapo_qwen2.5_32b_tis.sh`** (~42 lines changed) - Migrated from legacy TIS to new Rollout IS configuration - Updated documentation and comments ### Documentation Updates (2 files) 1. **`docs/examples/config.rst`** (~22 lines changed) - Updated configuration examples - Added Rollout IS section 2. **`docs/index.rst`** (+1 line) - Added link to Rollout IS migration guide --- ## Implementation Highlights ### Centralized Architecture The new design follows a clean separation of concerns: ``` ray_trainer.py (driver) └─> compute_rollout_importance_weights_and_add_to_batch() └─> mismatch_helper.compute_rollout_importance_weights() ├─> Computes IS weights (token/sequence/geometric) ├─> Applies bounding (truncate/mask) ├─> Veto mechanism for outliers ├─> Computes IS metrics └─> Computes mismatch metrics (KL, PPL) └─> Conditionally adds weights to batch (if rollout_is=True) └─> Distributes batch to workers actor workers (dp_actor, megatron_actor) └─> Receive batch with rollout_is_weights (if enabled) └─> Pass weights to policy loss function core_algos.py └─> All policy loss functions accept rollout_is_weights └─> Apply weights if provided: pg_losses *= rollout_is_weights ``` ### Key Design Decisions 1. **Centralized Computation**: IS weights computed once on driver, not per worker - Reduces redundant computation - Ensures consistency across workers - Simplifies debugging and metrics collection 2. **Configuration in Algorithm**: Moved from actor config to algorithm config - Better conceptual organization (algorithm-level concern, not worker-level) - Easier to manage and validate - Consistent with other algorithm parameters 3. **Two-Level Control**: - `rollout_is_threshold`: Enables/disables entire system (null = off) - `rollout_is`: Controls weight application (true = apply, false = metrics only) - Allows flexible monitoring and gradual rollout 4. **Metrics Consolidation**: Mismatch metrics computed within IS weight computation - Eliminates duplicate computation - Reduces memory overhead - Maintains metric accuracy 5. **Universal PPO Support**: Single interface for all PPO variants - Minimal code changes required - Consistent behavior across algorithms - Easy to add new variants --- ## Migration Guide ### For Users of Legacy TIS **Step 1: Update your configuration file** ```yaml # OLD (remove this) actor_rollout_ref: actor: tis_imp_ratio_cap: 2.0 # NEW (add this) algorithm: rollout_is_threshold: 2.0 # Use same value as old tis_imp_ratio_cap rollout_is: true rollout_is_level: token rollout_is_mode: truncate # REQUIRED (add if not present) actor_rollout_ref: rollout: calculate_log_probs: true ``` **Step 2: Monitor metrics** The first time you run with the new configuration, check these metrics: - `mismatch/rollout_is_eff_sample_size`: Should be > 80% of batch size - `mismatch/rollout_is_veto_fraction`: Should be < 5% - `mismatch/rollout_is_mean`: Should be close to 1.0 **Step 3: Tune if needed** If effective sample size is too low: - Increase `rollout_is_threshold` - Try `rollout_is_mode: mask` with appropriate lower bound - Consider `rollout_is_level: sequence` for more aggressive correction For detailed guidance, see `docs/advance/rollout_is_migration.md`. ### For New Users Start with recommended defaults: ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: token rollout_is_mode: truncate actor_rollout_ref: rollout: calculate_log_probs: true ``` Run the example script to see it in action: ```bash bash examples/rollout_importance_sampling/run_with_rollout_is.sh ``` --- ## Testing ### Unit Tests - **289 lines** of comprehensive unit tests in `test_rollout_is.py` - Covers all aggregation levels, bounding modes, and edge cases - Validates numerical stability and correctness - Fast execution (~1-2 seconds) ### Integration Tests - **241 lines** of integration tests in `test_rollout_is_integration.py` - End-to-end workflow with PPO training loop - Distributed training compatibility - Metrics collection validation - Moderate execution time (~10-20 seconds) ### Running Tests ```bash # Run all Rollout IS tests pytest tests/trainer/ppo/test_rollout_is.py -v pytest tests/trainer/ppo/test_rollout_is_integration.py -v # Run specific test pytest tests/trainer/ppo/test_rollout_is.py::test_token_level_truncate -v ``` --- ## Metrics Reference ### Rollout IS Metrics (all prefixed with `mismatch/`) | Metric | Description | Ideal Range | |--------|-------------|-------------| | `rollout_is_eff_sample_size` | Effective number of samples after IS | > 80% of batch | | `rollout_is_mean` | Mean IS weight | ~1.0 | | `rollout_is_std` | Standard deviation of IS weights | Low variance | | `rollout_is_p25` | 25th percentile | ~0.8-1.0 | | `rollout_is_p50` | Median IS weight | ~1.0 | | `rollout_is_p75` | 75th percentile | ~1.0-1.2 | | `rollout_is_p95` | 95th percentile | < threshold | | `rollout_is_p99` | 99th percentile | < threshold | | `rollout_is_max` | Maximum weight | ≤ threshold | | `rollout_is_min` | Minimum weight | ≥ lower threshold (mask mode) | | `rollout_is_veto_fraction` | % sequences vetoed | < 5% | | `rollout_is_catastrophic_token_fraction` | % catastrophic tokens | < 1% | | `rollout_is_masked_fraction` | % tokens masked (mask mode) | Variable | ### Mismatch Metrics (all prefixed with `mismatch/`) | Metric | Description | What It Means | |--------|-------------|---------------| | `mismatch_kl` | Forward KL divergence | Distribution difference (rollout vs training) | | `mismatch_k3_kl` | K3 KL estimator | Stable KL estimate for small divergences | | `mismatch_training_ppl` | Training policy perplexity | Prediction difficulty of training policy | | `mismatch_rollout_ppl` | Rollout policy perplexity | Prediction difficulty of rollout policy | | `mismatch_ppl_ratio` | Ratio of training to rollout PPL | Relative prediction difficulty | | `mismatch_log_ppl_diff` | Log perplexity difference | Sequence-level PPL mismatch | | `mismatch_log_ppl_abs_diff` | Absolute log PPL difference | Magnitude of mismatch | | `mismatch_log_ppl_diff_max` | Max log PPL difference | Worst-case mismatch | | `mismatch_log_ppl_diff_min` | Min log PPL difference | Best-case mismatch | | `mismatch_training_log_ppl` | Log of training PPL | Log-scale training perplexity | | `mismatch_rollout_log_ppl` | Log of rollout PPL | Log-scale rollout perplexity | --- ## Performance Impact ### Memory - Minimal overhead: ~1-2% increase in peak memory usage - Efficient log-space computation - No large intermediate tensors ### Computation - Negligible impact on training speed: < 1% overhead - Centralized computation on driver (no per-worker redundancy) - Optimized tensor operations ### Training Stability - Significant improvement in stability when distribution mismatch exists - Faster convergence in many scenarios - Reduced risk of training collapse --- ## Breaking Changes > [!IMPORTANT] > This PR contains **BREAKING CHANGES** to the configuration API. ### Removed - `actor_rollout_ref.actor.tis_imp_ratio_cap`: No longer supported ### Migration Required All users of the legacy TIS implementation must update their configuration files. See the migration guide above or `docs/advance/rollout_is_migration.md` for detailed instructions. ### Backward Compatibility - No backward compatibility with legacy TIS - Configuration files with `tis_imp_ratio_cap` will raise validation errors - Affected recipes have been updated in this PR --- ## Pre-Submission Checklist - [x] Search for similar PRs: [https://github.com/volcengine/verl/pulls?q=is%3Apr+importance+sampling](https://github.com/volcengine/verl/pulls?q=is%3Apr+importance+sampling) - [x] Format PR title as `[{modules}] {type}: {description}` (checked by CI) - **Suggested title:** `[BREAKING][rollout, trainer, algo] feat: implement comprehensive Rollout Importance Sampling framework` - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md) - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting) - [x] Add/update [documentation](https://github.com/volcengine/verl/tree/main/docs) (3 new docs, 2 updated) - [x] Add unit and integration tests (530 lines of tests) - [x] Once PR is ready for CI, send message in `ci-request` channel --- ## References - **Blog post:** [When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda) - **Migration guide:** `docs/advance/rollout_is_migration.md` - **Examples:** `examples/rollout_importance_sampling/` - **Tests:** `tests/trainer/ppo/test_rollout_is*.py`
…ance sampling implementation (verl-project#3694) # Rollout Importance Sampling Framework ## Summary This PR introduces a comprehensive **Rollout Importance Sampling (IS)** framework to correct distribution mismatch between data-collecting (rollout) and training policies, a critical factor for ensuring stable and efficient model training in RL fine-tuning. This work is motivated by the analysis in our blog post, [When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda). If you find this implementation useful in your research, please consider citing: ```bibtex @misc{liu-li-2025, title = {When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch}, url = {https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Inference-Training-Mismatch-271211a558b7808d8b12d403fd15edda}, author = {Jiacai Liu and Yingru Li and Yuqian Fu and Jiawei Wang and Qian Liu and Yu Shen}, year = {2025}, month = {September}, } ``` --- ## Problem Statement When using different policies for rollout generation (e.g., vLLM with BFloat16) and training (e.g., FSDP with FP32), distribution mismatch occurs, leading to: - Biased gradient estimates - Training instability and collapse - Reduced sample efficiency - Poor convergence properties This framework addresses these issues through principled importance sampling correction. --- ## Key Features & Improvements ### 1. **Flexible Aggregation Levels** Three methods for calculating IS weights: - **`token`**: Per-token importance ratios - **`sequence`**: Product of per-token ratios - **`geometric`**: Geometric mean of ratios ### 2. **Advanced Bounding Modes** Two strategies to control weight variance: - **`truncate`** (TIS): Caps weights at upper threshold only, preserving gradients - **`clip`** (CIS): Zeros out weights outside bounds, more aggressive filtering ### 3. **Comprehensive Diagnostics** Detailed metrics to monitor distribution mismatch and training health: **Rollout IS Metrics** (automatically prefixed with `mismatch/`): - Health indicators: `rollout_is_eff_sample_size`, `rollout_is_mean` - Distribution statistics: `rollout_is_p25`, `rollout_is_p50`, `rollout_is_p75`, `rollout_is_p95`, `rollout_is_p99`, `rollout_is_max`, `rollout_is_min`, `rollout_is_std` - Diagnostics: `rollout_is_veto_fraction`, `rollout_is_catastrophic_token_fraction`, `rollout_is_clipped_fraction` (clip mode) - Sequence-level statistics (for sequence/geometric modes): `rollout_is_seq_mean`, `rollout_is_seq_std`, `rollout_is_seq_max`, `rollout_is_seq_min`, etc. **Mismatch Metrics** (computed efficiently within IS weight computation): - KL Divergence: `mismatch_kl` (forward KL), `mismatch_k3_kl` (K3 estimator for stability) - Perplexity: `mismatch_training_ppl`, `mismatch_rollout_ppl`, `mismatch_ppl_ratio` - Log perplexity statistics: `mismatch_log_ppl_diff`, `mismatch_log_ppl_abs_diff`, `mismatch_log_ppl_diff_max`, `mismatch_log_ppl_diff_min` ### 4. **Outlier Mitigation** - **Veto mechanism**: Automatically discards samples with catastrophic importance weights (per-token ratios below threshold) - Prevents gradient corruption from extreme outliers - Configurable threshold (default: 1e-4) ### 5. **Numerical Stability** - All core computations in **log-space** to prevent underflow/overflow - Carefully designed clipping and bounding to maintain numerical precision - Safe handling of edge cases (zero probabilities, extreme ratios) ### 6. **Memory Efficiency** - Optimized computation to minimize CUDA memory usage - Efficient metric aggregation without large intermediate tensors - Suitable for large-scale distributed training ### 7. **Metrics-Only Mode** - Compute and monitor mismatch metrics **without** applying IS weights - Useful for: - Understanding distribution mismatch before intervention - Deciding whether IS correction is needed - A/B testing IS impact - Controlled by `algorithm.rollout_is` flag (independent of weight computation) ### 8. **Universal PPO Support** - Integrated with **all PPO variants**: vanilla, GSPO, GPG, Clip-Cov, KL-Cov, geo_mean - Consistent interface across different policy loss functions - Automatic weight application when enabled --- ## API and Configuration Changes ### Migration from Legacy TIS #### ❌ **Before (REMOVED)** ```yaml # Old TIS configuration - NO LONGER SUPPORTED actor_rollout_ref: actor: tis_imp_ratio_cap: 2.0 # Removed from actor config ``` The legacy implementation: - Only supported token-level truncation - No metrics tracking - Lacked numerical stability - Limited configurability #### ✅ **After (New Framework)** Configuration moved to `algorithm` section for better organization: ```yaml algorithm: # Main on/off switch: null = disabled, float = enabled rollout_is_threshold: 2.0 # Control weight application (independent of metrics computation) rollout_is: true # true = apply weights, false = metrics only # Optional: lower threshold (defaults to 1/upper if null) rollout_is_threshold_lower: null # Aggregation level: "token", "sequence", or "geometric" rollout_is_level: token # Bounding mode: "truncate" or "clip" rollout_is_mode: truncate # Veto threshold for catastrophic outliers (null = disabled) rollout_is_veto_threshold: 1e-4 # REQUIRED: Enable log probability calculation actor_rollout_ref: rollout: calculate_log_probs: true ``` ### Configuration Examples **1. Token-level truncation (recommended starting point)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: token rollout_is_mode: truncate ``` **2. Sequence-level clipping (more aggressive)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: sequence rollout_is_mode: clip ``` **3. Metrics-only mode (monitoring without correction)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: false # Compute metrics but don't apply weights rollout_is_level: token rollout_is_mode: truncate ``` **Example script:** `bash examples/rollout_importance_sampling/run_with_rollout_is.sh` --- ## Code Changes Overview ### New Files (4 files, 1,442 lines) 1. **`verl/trainer/ppo/mismatch_helper.py`** (459 lines) - Core implementation of IS weight computation - Three aggregation levels: token, sequence, geometric - Two bounding modes: truncate, clip - Veto mechanism for outlier detection - Comprehensive metrics computation (IS + mismatch) - All computations in log-space for numerical stability - Memory-efficient design 2. **`docs/advance/rollout_is_migration.md`** (642 lines) - Comprehensive migration guide from legacy TIS - Detailed explanation of all configuration options - Recommended threshold ranges for each aggregation level - Troubleshooting guide and best practices - Metrics interpretation guide 3. **`examples/rollout_importance_sampling/README.md`** (242 lines) - Quick start guide with working examples - Configuration templates for common scenarios - Threshold tuning guidelines - Metrics monitoring instructions 4. **`examples/rollout_importance_sampling/run_with_rollout_is.sh`** (99 lines) - Complete working example script - Demonstrates token-level and sequence-level configurations - Ready to run with minimal modifications ### Modified Core Files (9 files) 1. **`verl/trainer/ppo/core_algos.py`** (~50 lines changed) - Removed legacy TIS logic (`tis_imp_ratio_cap`) - Added `rollout_is_weights` parameter to all policy loss functions - Unified IS weight application interface across all PPO variants: - `compute_policy_loss_vanilla` - `compute_policy_loss_gspo` - `compute_policy_loss_gpg` - `compute_policy_loss_clip_cov` - `compute_policy_loss_kl_cov` - `compute_policy_loss_geo_mean` - Special handling for `geo_mean` (sequence-level aggregation) 2. **`verl/trainer/ppo/ray_trainer.py`** (~52 lines added) - New method: `compute_rollout_importance_weights_and_add_to_batch()` - Centralized IS computation (once per batch, on driver) - Conditional weight distribution to workers based on `algorithm.rollout_is` - Metrics collection and aggregation - Integration with existing training loop 3. **`verl/trainer/config/algorithm.py`** (+18 lines) - Added 6 new Rollout IS parameters: - `rollout_is_threshold` (main on/off switch) - `rollout_is` (weight application control) - `rollout_is_threshold_lower` - `rollout_is_level` - `rollout_is_mode` - `rollout_is_veto_threshold` - Comprehensive docstrings explaining each parameter 4. **`verl/workers/config/actor.py`** (-1 line) - Removed deprecated `tis_imp_ratio_cap` parameter 5. **`verl/workers/actor/dp_actor.py`** (~26 lines changed) - Updated to use new `rollout_is_weights` parameter - Removed legacy TIS logic 6. **`verl/workers/actor/megatron_actor.py`** (~15 lines changed) - Updated to use new `rollout_is_weights` parameter - Removed legacy TIS logic 7. **Configuration Files** (4 files updated) - `verl/trainer/config/ppo_trainer.yaml` - `verl/trainer/config/ppo_megatron_trainer.yaml` - `verl/trainer/config/_generated_ppo_trainer.yaml` - `verl/trainer/config/_generated_ppo_megatron_trainer.yaml` - Added default Rollout IS configuration section with explanatory comments ### Testing (2 files, 530 lines) 1. **`tests/trainer/ppo/test_rollout_is.py`** (289 lines) - Unit tests for `mismatch_helper.py` - Coverage for all aggregation levels (token, sequence, geometric) - Coverage for all bounding modes (truncate, clip) - Veto mechanism tests - Edge case handling (zeros, extremes, empty sequences) - Numerical stability verification - Metrics correctness validation 2. **`tests/trainer/ppo/test_rollout_is_integration.py`** (241 lines) - Integration tests with PPO training loop - End-to-end workflow validation - Batch processing tests - Configuration validation - Metrics collection verification - Compatibility with distributed training ### Updated Recipes (2 files) 1. **`recipe/dapo/dapo_ray_trainer.py`** (+5 lines) - Updated imports to use new framework 2. **`recipe/dapo/run_dapo_qwen2.5_32b_tis.sh`** (~42 lines changed) - Migrated from legacy TIS to new Rollout IS configuration - Updated documentation and comments ### Documentation Updates (2 files) 1. **`docs/examples/config.rst`** (~22 lines changed) - Updated configuration examples - Added Rollout IS section 2. **`docs/index.rst`** (+1 line) - Added link to Rollout IS migration guide --- ## Implementation Highlights ### Centralized Architecture The new design follows a clean separation of concerns: ``` ray_trainer.py (driver) └─> compute_rollout_importance_weights_and_add_to_batch() └─> mismatch_helper.compute_rollout_importance_weights() ├─> Computes IS weights (token/sequence/geometric) ├─> Applies bounding (truncate/clip) ├─> Veto mechanism for outliers ├─> Computes IS metrics └─> Computes mismatch metrics (KL, PPL) └─> Conditionally adds weights to batch (if rollout_is=True) └─> Distributes batch to workers actor workers (dp_actor, megatron_actor) └─> Receive batch with rollout_is_weights (if enabled) └─> Pass weights to policy loss function core_algos.py └─> All policy loss functions accept rollout_is_weights └─> Apply weights if provided: pg_losses *= rollout_is_weights ``` ### Key Design Decisions 1. **Centralized Computation**: IS weights computed once on driver, not per worker - Reduces redundant computation - Ensures consistency across workers - Simplifies debugging and metrics collection 2. **Configuration in Algorithm**: Moved from actor config to algorithm config - Better conceptual organization (algorithm-level concern, not worker-level) - Easier to manage and validate - Consistent with other algorithm parameters 3. **Two-Level Control**: - `rollout_is_threshold`: Enables/disables entire system (null = off) - `rollout_is`: Controls weight application (true = apply, false = metrics only) - Allows flexible monitoring and gradual rollout 4. **Metrics Consolidation**: Mismatch metrics computed within IS weight computation - Eliminates duplicate computation - Reduces memory overhead - Maintains metric accuracy 5. **Universal PPO Support**: Single interface for all PPO variants - Minimal code changes required - Consistent behavior across algorithms - Easy to add new variants --- ## Migration Guide ### For Users of Legacy TIS **Step 1: Update your configuration file** ```yaml # OLD (remove this) actor_rollout_ref: actor: tis_imp_ratio_cap: 2.0 # NEW (add this) algorithm: rollout_is_threshold: 2.0 # Use same value as old tis_imp_ratio_cap rollout_is: true rollout_is_level: token rollout_is_mode: truncate # REQUIRED (add if not present) actor_rollout_ref: rollout: calculate_log_probs: true ``` **Step 2: Monitor metrics** The first time you run with the new configuration, check these metrics: - `mismatch/rollout_is_eff_sample_size`: Should be > 80% of batch size - `mismatch/rollout_is_veto_fraction`: Should be < 5% - `mismatch/rollout_is_mean`: Should be close to 1.0 **Step 3: Tune if needed** If effective sample size is too low: - Increase `rollout_is_threshold` - Try `rollout_is_mode: clip` with appropriate lower bound - Consider `rollout_is_level: sequence` for more aggressive correction For detailed guidance, see `docs/advance/rollout_is_migration.md`. ### For New Users Start with recommended defaults: ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: token rollout_is_mode: truncate actor_rollout_ref: rollout: calculate_log_probs: true ``` Run the example script to see it in action: ```bash bash examples/rollout_importance_sampling/run_with_rollout_is.sh ``` --- ## Testing ### Unit Tests - **289 lines** of comprehensive unit tests in `test_rollout_is.py` - Covers all aggregation levels, bounding modes, and edge cases - Validates numerical stability and correctness - Fast execution (~1-2 seconds) ### Integration Tests - **241 lines** of integration tests in `test_rollout_is_integration.py` - End-to-end workflow with PPO training loop - Distributed training compatibility - Metrics collection validation - Moderate execution time (~10-20 seconds) ### Running Tests ```bash # Run all Rollout IS tests pytest tests/trainer/ppo/test_rollout_is.py -v pytest tests/trainer/ppo/test_rollout_is_integration.py -v # Run specific test pytest tests/trainer/ppo/test_rollout_is.py::test_token_level_truncate -v ``` --- ## Metrics Reference ### Rollout IS Metrics (all prefixed with `mismatch/`) | Metric | Description | Ideal Range | |--------|-------------|-------------| | `rollout_is_eff_sample_size` | Effective number of samples after IS | > 80% of batch | | `rollout_is_mean` | Mean IS weight | ~1.0 | | `rollout_is_std` | Standard deviation of IS weights | Low variance | | `rollout_is_p25` | 25th percentile | ~0.8-1.0 | | `rollout_is_p50` | Median IS weight | ~1.0 | | `rollout_is_p75` | 75th percentile | ~1.0-1.2 | | `rollout_is_p95` | 95th percentile | < threshold | | `rollout_is_p99` | 99th percentile | < threshold | | `rollout_is_max` | Maximum weight | ≤ threshold | | `rollout_is_min` | Minimum weight | ≥ lower threshold (clip mode) | | `rollout_is_veto_fraction` | % sequences vetoed | < 5% | | `rollout_is_catastrophic_token_fraction` | % catastrophic tokens | < 1% | | `rollout_is_clipped_fraction` | % tokens clipped (clip mode) | Variable | ### Mismatch Metrics (all prefixed with `mismatch/`) | Metric | Description | What It Means | |--------|-------------|---------------| | `mismatch_kl` | Forward KL divergence | Distribution difference (rollout vs training) | | `mismatch_k3_kl` | K3 KL estimator | Stable KL estimate for small divergences | | `mismatch_training_ppl` | Training policy perplexity | Prediction difficulty of training policy | | `mismatch_rollout_ppl` | Rollout policy perplexity | Prediction difficulty of rollout policy | | `mismatch_ppl_ratio` | Ratio of training to rollout PPL | Relative prediction difficulty | | `mismatch_log_ppl_diff` | Log perplexity difference | Sequence-level PPL mismatch | | `mismatch_log_ppl_abs_diff` | Absolute log PPL difference | Magnitude of mismatch | | `mismatch_log_ppl_diff_max` | Max log PPL difference | Worst-case mismatch | | `mismatch_log_ppl_diff_min` | Min log PPL difference | Best-case mismatch | | `mismatch_training_log_ppl` | Log of training PPL | Log-scale training perplexity | | `mismatch_rollout_log_ppl` | Log of rollout PPL | Log-scale rollout perplexity | --- ## Performance Impact ### Memory - Minimal overhead: ~1-2% increase in peak memory usage - Efficient log-space computation - No large intermediate tensors ### Computation - Negligible impact on training speed: < 1% overhead - Centralized computation on driver (no per-worker redundancy) - Optimized tensor operations ### Training Stability - Significant improvement in stability when distribution mismatch exists - Faster convergence in many scenarios - Reduced risk of training collapse --- ## Breaking Changes > [!IMPORTANT] > This PR contains **BREAKING CHANGES** to the configuration API. ### Removed - `actor_rollout_ref.actor.tis_imp_ratio_cap`: No longer supported ### Migration Required All users of the legacy TIS implementation must update their configuration files. See the migration guide above or `docs/advance/rollout_is_migration.md` for detailed instructions. ### Backward Compatibility - No backward compatibility with legacy TIS - Configuration files with `tis_imp_ratio_cap` will raise validation errors - Affected recipes have been updated in this PR --- ## Pre-Submission Checklist - [x] Search for similar PRs: [https://github.com/volcengine/verl/pulls?q=is%3Apr+importance+sampling](https://github.com/volcengine/verl/pulls?q=is%3Apr+importance+sampling) - [x] Format PR title as `[{modules}] {type}: {description}` (checked by CI) - **Suggested title:** `[BREAKING][rollout, trainer, algo] feat: implement comprehensive Rollout Importance Sampling framework` - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md) - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting) - [x] Add/update [documentation](https://github.com/volcengine/verl/tree/main/docs) (3 new docs, 2 updated) - [x] Add unit and integration tests (530 lines of tests) - [x] Once PR is ready for CI, send message in `ci-request` channel --- ## References - **Blog post:** [When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda) - **Migration guide:** `docs/advance/rollout_is_migration.md` - **Examples:** `examples/rollout_importance_sampling/` - **Tests:** `tests/trainer/ppo/test_rollout_is*.py` --------- Co-authored-by: Yan Bai <bayan@nvidia.com>
…ect#3750) # Rollout Importance Sampling Framework related to verl-project#3694 ## Summary This PR introduces a comprehensive **Rollout Importance Sampling (IS)** framework to correct distribution mismatch between data-collecting (rollout) and training policies, a critical factor for ensuring stable and efficient model training in RL fine-tuning. This work is motivated by the analysis in our blog post, [When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda). If you find this implementation useful in your research, please consider citing: ```bibtex @misc{liu-li-2025, title = {When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch}, url = {https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Inference-Training-Mismatch-271211a558b7808d8b12d403fd15edda}, author = {Jiacai Liu and Yingru Li and Yuqian Fu and Jiawei Wang and Qian Liu and Yu Shen}, year = {2025}, month = {September}, } ``` --- ## Problem Statement When using different policies for rollout generation (e.g., vLLM with BFloat16) and training (e.g., FSDP with FP32), distribution mismatch occurs, leading to: - Biased gradient estimates - Training instability and collapse - Reduced sample efficiency - Poor convergence properties This framework addresses these issues through principled importance sampling correction. --- ## Key Features & Improvements ### 1. **Flexible Aggregation Levels** Three methods for calculating IS weights: - **`token`**: Per-token importance ratios - **`sequence`**: Product of per-token ratios - **`geometric`**: Geometric mean of ratios ### 2. **Advanced Bounding Modes** Two strategies to control weight variance: - **`truncate`** (TIS): Caps weights at upper threshold only, preserving gradients - **`mask`** (MIS): Zeros out weights outside bounds, more aggressive filtering ### 3. **Comprehensive Diagnostics** Detailed metrics to monitor distribution mismatch and training health: **Rollout IS Metrics** (automatically prefixed with `mismatch/`): - Health indicators: `rollout_is_eff_sample_size`, `rollout_is_mean` - Distribution statistics: `rollout_is_p25`, `rollout_is_p50`, `rollout_is_p75`, `rollout_is_p95`, `rollout_is_p99`, `rollout_is_max`, `rollout_is_min`, `rollout_is_std` - Diagnostics: `rollout_is_veto_fraction`, `rollout_is_catastrophic_token_fraction`, `rollout_is_masked_fraction` (mask mode) - Sequence-level statistics (for sequence/geometric modes): `rollout_is_seq_mean`, `rollout_is_seq_std`, `rollout_is_seq_max`, `rollout_is_seq_min`, etc. **Mismatch Metrics** (computed efficiently within IS weight computation): - KL Divergence: `mismatch_kl` (forward KL), `mismatch_k3_kl` (K3 estimator for stability) - Perplexity: `mismatch_training_ppl`, `mismatch_rollout_ppl`, `mismatch_ppl_ratio` - Log perplexity statistics: `mismatch_log_ppl_diff`, `mismatch_log_ppl_abs_diff`, `mismatch_log_ppl_diff_max`, `mismatch_log_ppl_diff_min` ### 4. **Outlier Mitigation** - **Veto mechanism**: Automatically discards samples with catastrophic importance weights (per-token ratios below threshold) - Prevents gradient corruption from extreme outliers - Configurable threshold (default: 1e-4) ### 5. **Numerical Stability** - All core computations in **log-space** to prevent underflow/overflow - Carefully designed clamping and bounding to maintain numerical precision - Safe handling of edge cases (zero probabilities, extreme ratios) ### 6. **Memory Efficiency** - Optimized computation to minimize CUDA memory usage - Efficient metric aggregation without large intermediate tensors - Suitable for large-scale distributed training ### 7. **Metrics-Only Mode** - Compute and monitor mismatch metrics **without** applying IS weights - Useful for: - Understanding distribution mismatch before intervention - Deciding whether IS correction is needed - A/B testing IS impact - Controlled by `algorithm.rollout_is` flag (independent of weight computation) ### 8. **Universal PPO Support** - Integrated with **all PPO variants**: vanilla, GSPO, GPG, Clip-Cov, KL-Cov, geo_mean - Consistent interface across different policy loss functions - Automatic weight application when enabled --- ## API and Configuration Changes ### Migration from Legacy TIS #### ❌ **Before (REMOVED)** ```yaml # Old TIS configuration - NO LONGER SUPPORTED actor_rollout_ref: actor: tis_imp_ratio_cap: 2.0 # Removed from actor config ``` The legacy implementation: - Only supported token-level truncation - No metrics tracking - Lacked numerical stability - Limited configurability #### ✅ **After (New Framework)** Configuration moved to `algorithm` section for better organization: ```yaml algorithm: # Main on/off switch: null = disabled, float = enabled rollout_is_threshold: 2.0 # Control weight application (independent of metrics computation) rollout_is: true # true = apply weights, false = metrics only # Optional: lower threshold (defaults to 1/upper if null) rollout_is_threshold_lower: null # Aggregation level: "token", "sequence", or "geometric" rollout_is_level: token # Bounding mode: "truncate" or "mask" rollout_is_mode: truncate # Veto threshold for catastrophic outliers (null = disabled) rollout_is_veto_threshold: 1e-4 # REQUIRED: Enable log probability calculation actor_rollout_ref: rollout: calculate_log_probs: true ``` ### Configuration Examples **1. Token-level truncation (recommended starting point)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: token rollout_is_mode: truncate ``` **2. Sequence-level masking (more aggressive)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: sequence rollout_is_mode: mask ``` **3. Metrics-only mode (monitoring without correction)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: false # Compute metrics but don't apply weights rollout_is_level: token rollout_is_mode: truncate ``` **Example script:** `bash examples/rollout_importance_sampling/run_with_rollout_is.sh` --- ## Code Changes Overview ### New Files (4 files, 1,442 lines) 1. **`verl/trainer/ppo/mismatch_helper.py`** (459 lines) - Core implementation of IS weight computation - Three aggregation levels: token, sequence, geometric - Two bounding modes: truncate, mask - Veto mechanism for outlier detection - Comprehensive metrics computation (IS + mismatch) - All computations in log-space for numerical stability - Memory-efficient design 2. **`docs/advance/rollout_is_migration.md`** (642 lines) - Comprehensive migration guide from legacy TIS - Detailed explanation of all configuration options - Recommended threshold ranges for each aggregation level - Troubleshooting guide and best practices - Metrics interpretation guide 3. **`examples/rollout_importance_sampling/README.md`** (242 lines) - Quick start guide with working examples - Configuration templates for common scenarios - Threshold tuning guidelines - Metrics monitoring instructions 4. **`examples/rollout_importance_sampling/run_with_rollout_is.sh`** (99 lines) - Complete working example script - Demonstrates token-level and sequence-level configurations - Ready to run with minimal modifications ### Modified Core Files (9 files) 1. **`verl/trainer/ppo/core_algos.py`** (~50 lines changed) - Removed legacy TIS logic (`tis_imp_ratio_cap`) - Added `rollout_is_weights` parameter to all policy loss functions - Unified IS weight application interface across all PPO variants: - `compute_policy_loss_vanilla` - `compute_policy_loss_gspo` - `compute_policy_loss_gpg` - `compute_policy_loss_clip_cov` - `compute_policy_loss_kl_cov` - `compute_policy_loss_geo_mean` - Special handling for `geo_mean` (sequence-level aggregation) 2. **`verl/trainer/ppo/ray_trainer.py`** (~52 lines added) - New method: `compute_rollout_importance_weights_and_add_to_batch()` - Centralized IS computation (once per batch, on driver) - Conditional weight distribution to workers based on `algorithm.rollout_is` - Metrics collection and aggregation - Integration with existing training loop 3. **`verl/trainer/config/algorithm.py`** (+18 lines) - Added 6 new Rollout IS parameters: - `rollout_is_threshold` (main on/off switch) - `rollout_is` (weight application control) - `rollout_is_threshold_lower` - `rollout_is_level` - `rollout_is_mode` - `rollout_is_veto_threshold` - Comprehensive docstrings explaining each parameter 4. **`verl/workers/config/actor.py`** (-1 line) - Removed deprecated `tis_imp_ratio_cap` parameter 5. **`verl/workers/actor/dp_actor.py`** (~26 lines changed) - Updated to use new `rollout_is_weights` parameter - Removed legacy TIS logic 6. **`verl/workers/actor/megatron_actor.py`** (~15 lines changed) - Updated to use new `rollout_is_weights` parameter - Removed legacy TIS logic 7. **Configuration Files** (4 files updated) - `verl/trainer/config/ppo_trainer.yaml` - `verl/trainer/config/ppo_megatron_trainer.yaml` - `verl/trainer/config/_generated_ppo_trainer.yaml` - `verl/trainer/config/_generated_ppo_megatron_trainer.yaml` - Added default Rollout IS configuration section with explanatory comments ### Testing (2 files, 530 lines) 1. **`tests/trainer/ppo/test_rollout_is.py`** (289 lines) - Unit tests for `mismatch_helper.py` - Coverage for all aggregation levels (token, sequence, geometric) - Coverage for all bounding modes (truncate, mask) - Veto mechanism tests - Edge case handling (zeros, extremes, empty sequences) - Numerical stability verification - Metrics correctness validation 2. **`tests/trainer/ppo/test_rollout_is_integration.py`** (241 lines) - Integration tests with PPO training loop - End-to-end workflow validation - Batch processing tests - Configuration validation - Metrics collection verification - Compatibility with distributed training ### Updated Recipes (2 files) 1. **`recipe/dapo/dapo_ray_trainer.py`** (+5 lines) - Updated imports to use new framework 2. **`recipe/dapo/run_dapo_qwen2.5_32b_tis.sh`** (~42 lines changed) - Migrated from legacy TIS to new Rollout IS configuration - Updated documentation and comments ### Documentation Updates (2 files) 1. **`docs/examples/config.rst`** (~22 lines changed) - Updated configuration examples - Added Rollout IS section 2. **`docs/index.rst`** (+1 line) - Added link to Rollout IS migration guide --- ## Implementation Highlights ### Centralized Architecture The new design follows a clean separation of concerns: ``` ray_trainer.py (driver) └─> compute_rollout_importance_weights_and_add_to_batch() └─> mismatch_helper.compute_rollout_importance_weights() ├─> Computes IS weights (token/sequence/geometric) ├─> Applies bounding (truncate/mask) ├─> Veto mechanism for outliers ├─> Computes IS metrics └─> Computes mismatch metrics (KL, PPL) └─> Conditionally adds weights to batch (if rollout_is=True) └─> Distributes batch to workers actor workers (dp_actor, megatron_actor) └─> Receive batch with rollout_is_weights (if enabled) └─> Pass weights to policy loss function core_algos.py └─> All policy loss functions accept rollout_is_weights └─> Apply weights if provided: pg_losses *= rollout_is_weights ``` ### Key Design Decisions 1. **Centralized Computation**: IS weights computed once on driver, not per worker - Reduces redundant computation - Ensures consistency across workers - Simplifies debugging and metrics collection 2. **Configuration in Algorithm**: Moved from actor config to algorithm config - Better conceptual organization (algorithm-level concern, not worker-level) - Easier to manage and validate - Consistent with other algorithm parameters 3. **Two-Level Control**: - `rollout_is_threshold`: Enables/disables entire system (null = off) - `rollout_is`: Controls weight application (true = apply, false = metrics only) - Allows flexible monitoring and gradual rollout 4. **Metrics Consolidation**: Mismatch metrics computed within IS weight computation - Eliminates duplicate computation - Reduces memory overhead - Maintains metric accuracy 5. **Universal PPO Support**: Single interface for all PPO variants - Minimal code changes required - Consistent behavior across algorithms - Easy to add new variants --- ## Migration Guide ### For Users of Legacy TIS **Step 1: Update your configuration file** ```yaml # OLD (remove this) actor_rollout_ref: actor: tis_imp_ratio_cap: 2.0 # NEW (add this) algorithm: rollout_is_threshold: 2.0 # Use same value as old tis_imp_ratio_cap rollout_is: true rollout_is_level: token rollout_is_mode: truncate # REQUIRED (add if not present) actor_rollout_ref: rollout: calculate_log_probs: true ``` **Step 2: Monitor metrics** The first time you run with the new configuration, check these metrics: - `mismatch/rollout_is_eff_sample_size`: Should be > 80% of batch size - `mismatch/rollout_is_veto_fraction`: Should be < 5% - `mismatch/rollout_is_mean`: Should be close to 1.0 **Step 3: Tune if needed** If effective sample size is too low: - Increase `rollout_is_threshold` - Try `rollout_is_mode: mask` with appropriate lower bound - Consider `rollout_is_level: sequence` for more aggressive correction For detailed guidance, see `docs/advance/rollout_is_migration.md`. ### For New Users Start with recommended defaults: ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: token rollout_is_mode: truncate actor_rollout_ref: rollout: calculate_log_probs: true ``` Run the example script to see it in action: ```bash bash examples/rollout_importance_sampling/run_with_rollout_is.sh ``` --- ## Testing ### Unit Tests - **289 lines** of comprehensive unit tests in `test_rollout_is.py` - Covers all aggregation levels, bounding modes, and edge cases - Validates numerical stability and correctness - Fast execution (~1-2 seconds) ### Integration Tests - **241 lines** of integration tests in `test_rollout_is_integration.py` - End-to-end workflow with PPO training loop - Distributed training compatibility - Metrics collection validation - Moderate execution time (~10-20 seconds) ### Running Tests ```bash # Run all Rollout IS tests pytest tests/trainer/ppo/test_rollout_is.py -v pytest tests/trainer/ppo/test_rollout_is_integration.py -v # Run specific test pytest tests/trainer/ppo/test_rollout_is.py::test_token_level_truncate -v ``` --- ## Metrics Reference ### Rollout IS Metrics (all prefixed with `mismatch/`) | Metric | Description | Ideal Range | |--------|-------------|-------------| | `rollout_is_eff_sample_size` | Effective number of samples after IS | > 80% of batch | | `rollout_is_mean` | Mean IS weight | ~1.0 | | `rollout_is_std` | Standard deviation of IS weights | Low variance | | `rollout_is_p25` | 25th percentile | ~0.8-1.0 | | `rollout_is_p50` | Median IS weight | ~1.0 | | `rollout_is_p75` | 75th percentile | ~1.0-1.2 | | `rollout_is_p95` | 95th percentile | < threshold | | `rollout_is_p99` | 99th percentile | < threshold | | `rollout_is_max` | Maximum weight | ≤ threshold | | `rollout_is_min` | Minimum weight | ≥ lower threshold (mask mode) | | `rollout_is_veto_fraction` | % sequences vetoed | < 5% | | `rollout_is_catastrophic_token_fraction` | % catastrophic tokens | < 1% | | `rollout_is_masked_fraction` | % tokens masked (mask mode) | Variable | ### Mismatch Metrics (all prefixed with `mismatch/`) | Metric | Description | What It Means | |--------|-------------|---------------| | `mismatch_kl` | Forward KL divergence | Distribution difference (rollout vs training) | | `mismatch_k3_kl` | K3 KL estimator | Stable KL estimate for small divergences | | `mismatch_training_ppl` | Training policy perplexity | Prediction difficulty of training policy | | `mismatch_rollout_ppl` | Rollout policy perplexity | Prediction difficulty of rollout policy | | `mismatch_ppl_ratio` | Ratio of training to rollout PPL | Relative prediction difficulty | | `mismatch_log_ppl_diff` | Log perplexity difference | Sequence-level PPL mismatch | | `mismatch_log_ppl_abs_diff` | Absolute log PPL difference | Magnitude of mismatch | | `mismatch_log_ppl_diff_max` | Max log PPL difference | Worst-case mismatch | | `mismatch_log_ppl_diff_min` | Min log PPL difference | Best-case mismatch | | `mismatch_training_log_ppl` | Log of training PPL | Log-scale training perplexity | | `mismatch_rollout_log_ppl` | Log of rollout PPL | Log-scale rollout perplexity | --- ## Performance Impact ### Memory - Minimal overhead: ~1-2% increase in peak memory usage - Efficient log-space computation - No large intermediate tensors ### Computation - Negligible impact on training speed: < 1% overhead - Centralized computation on driver (no per-worker redundancy) - Optimized tensor operations ### Training Stability - Significant improvement in stability when distribution mismatch exists - Faster convergence in many scenarios - Reduced risk of training collapse --- ## Breaking Changes > [!IMPORTANT] > This PR contains **BREAKING CHANGES** to the configuration API. ### Removed - `actor_rollout_ref.actor.tis_imp_ratio_cap`: No longer supported ### Migration Required All users of the legacy TIS implementation must update their configuration files. See the migration guide above or `docs/advance/rollout_is_migration.md` for detailed instructions. ### Backward Compatibility - No backward compatibility with legacy TIS - Configuration files with `tis_imp_ratio_cap` will raise validation errors - Affected recipes have been updated in this PR --- ## Pre-Submission Checklist - [x] Search for similar PRs: [https://github.com/volcengine/verl/pulls?q=is%3Apr+importance+sampling](https://github.com/volcengine/verl/pulls?q=is%3Apr+importance+sampling) - [x] Format PR title as `[{modules}] {type}: {description}` (checked by CI) - **Suggested title:** `[BREAKING][rollout, trainer, algo] feat: implement comprehensive Rollout Importance Sampling framework` - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md) - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting) - [x] Add/update [documentation](https://github.com/volcengine/verl/tree/main/docs) (3 new docs, 2 updated) - [x] Add unit and integration tests (530 lines of tests) - [x] Once PR is ready for CI, send message in `ci-request` channel --- ## References - **Blog post:** [When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda) - **Migration guide:** `docs/advance/rollout_is_migration.md` - **Examples:** `examples/rollout_importance_sampling/` - **Tests:** `tests/trainer/ppo/test_rollout_is*.py`
…ance sampling implementation (verl-project#3694) # Rollout Importance Sampling Framework ## Summary This PR introduces a comprehensive **Rollout Importance Sampling (IS)** framework to correct distribution mismatch between data-collecting (rollout) and training policies, a critical factor for ensuring stable and efficient model training in RL fine-tuning. This work is motivated by the analysis in our blog post, [When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda). If you find this implementation useful in your research, please consider citing: ```bibtex @misc{liu-li-2025, title = {When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch}, url = {https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Inference-Training-Mismatch-271211a558b7808d8b12d403fd15edda}, author = {Jiacai Liu and Yingru Li and Yuqian Fu and Jiawei Wang and Qian Liu and Yu Shen}, year = {2025}, month = {September}, } ``` --- ## Problem Statement When using different policies for rollout generation (e.g., vLLM with BFloat16) and training (e.g., FSDP with FP32), distribution mismatch occurs, leading to: - Biased gradient estimates - Training instability and collapse - Reduced sample efficiency - Poor convergence properties This framework addresses these issues through principled importance sampling correction. --- ## Key Features & Improvements ### 1. **Flexible Aggregation Levels** Three methods for calculating IS weights: - **`token`**: Per-token importance ratios - **`sequence`**: Product of per-token ratios - **`geometric`**: Geometric mean of ratios ### 2. **Advanced Bounding Modes** Two strategies to control weight variance: - **`truncate`** (TIS): Caps weights at upper threshold only, preserving gradients - **`clip`** (CIS): Zeros out weights outside bounds, more aggressive filtering ### 3. **Comprehensive Diagnostics** Detailed metrics to monitor distribution mismatch and training health: **Rollout IS Metrics** (automatically prefixed with `mismatch/`): - Health indicators: `rollout_is_eff_sample_size`, `rollout_is_mean` - Distribution statistics: `rollout_is_p25`, `rollout_is_p50`, `rollout_is_p75`, `rollout_is_p95`, `rollout_is_p99`, `rollout_is_max`, `rollout_is_min`, `rollout_is_std` - Diagnostics: `rollout_is_veto_fraction`, `rollout_is_catastrophic_token_fraction`, `rollout_is_clipped_fraction` (clip mode) - Sequence-level statistics (for sequence/geometric modes): `rollout_is_seq_mean`, `rollout_is_seq_std`, `rollout_is_seq_max`, `rollout_is_seq_min`, etc. **Mismatch Metrics** (computed efficiently within IS weight computation): - KL Divergence: `mismatch_kl` (forward KL), `mismatch_k3_kl` (K3 estimator for stability) - Perplexity: `mismatch_training_ppl`, `mismatch_rollout_ppl`, `mismatch_ppl_ratio` - Log perplexity statistics: `mismatch_log_ppl_diff`, `mismatch_log_ppl_abs_diff`, `mismatch_log_ppl_diff_max`, `mismatch_log_ppl_diff_min` ### 4. **Outlier Mitigation** - **Veto mechanism**: Automatically discards samples with catastrophic importance weights (per-token ratios below threshold) - Prevents gradient corruption from extreme outliers - Configurable threshold (default: 1e-4) ### 5. **Numerical Stability** - All core computations in **log-space** to prevent underflow/overflow - Carefully designed clipping and bounding to maintain numerical precision - Safe handling of edge cases (zero probabilities, extreme ratios) ### 6. **Memory Efficiency** - Optimized computation to minimize CUDA memory usage - Efficient metric aggregation without large intermediate tensors - Suitable for large-scale distributed training ### 7. **Metrics-Only Mode** - Compute and monitor mismatch metrics **without** applying IS weights - Useful for: - Understanding distribution mismatch before intervention - Deciding whether IS correction is needed - A/B testing IS impact - Controlled by `algorithm.rollout_is` flag (independent of weight computation) ### 8. **Universal PPO Support** - Integrated with **all PPO variants**: vanilla, GSPO, GPG, Clip-Cov, KL-Cov, geo_mean - Consistent interface across different policy loss functions - Automatic weight application when enabled --- ## API and Configuration Changes ### Migration from Legacy TIS #### ❌ **Before (REMOVED)** ```yaml # Old TIS configuration - NO LONGER SUPPORTED actor_rollout_ref: actor: tis_imp_ratio_cap: 2.0 # Removed from actor config ``` The legacy implementation: - Only supported token-level truncation - No metrics tracking - Lacked numerical stability - Limited configurability #### ✅ **After (New Framework)** Configuration moved to `algorithm` section for better organization: ```yaml algorithm: # Main on/off switch: null = disabled, float = enabled rollout_is_threshold: 2.0 # Control weight application (independent of metrics computation) rollout_is: true # true = apply weights, false = metrics only # Optional: lower threshold (defaults to 1/upper if null) rollout_is_threshold_lower: null # Aggregation level: "token", "sequence", or "geometric" rollout_is_level: token # Bounding mode: "truncate" or "clip" rollout_is_mode: truncate # Veto threshold for catastrophic outliers (null = disabled) rollout_is_veto_threshold: 1e-4 # REQUIRED: Enable log probability calculation actor_rollout_ref: rollout: calculate_log_probs: true ``` ### Configuration Examples **1. Token-level truncation (recommended starting point)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: token rollout_is_mode: truncate ``` **2. Sequence-level clipping (more aggressive)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: sequence rollout_is_mode: clip ``` **3. Metrics-only mode (monitoring without correction)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: false # Compute metrics but don't apply weights rollout_is_level: token rollout_is_mode: truncate ``` **Example script:** `bash examples/rollout_importance_sampling/run_with_rollout_is.sh` --- ## Code Changes Overview ### New Files (4 files, 1,442 lines) 1. **`verl/trainer/ppo/mismatch_helper.py`** (459 lines) - Core implementation of IS weight computation - Three aggregation levels: token, sequence, geometric - Two bounding modes: truncate, clip - Veto mechanism for outlier detection - Comprehensive metrics computation (IS + mismatch) - All computations in log-space for numerical stability - Memory-efficient design 2. **`docs/advance/rollout_is_migration.md`** (642 lines) - Comprehensive migration guide from legacy TIS - Detailed explanation of all configuration options - Recommended threshold ranges for each aggregation level - Troubleshooting guide and best practices - Metrics interpretation guide 3. **`examples/rollout_importance_sampling/README.md`** (242 lines) - Quick start guide with working examples - Configuration templates for common scenarios - Threshold tuning guidelines - Metrics monitoring instructions 4. **`examples/rollout_importance_sampling/run_with_rollout_is.sh`** (99 lines) - Complete working example script - Demonstrates token-level and sequence-level configurations - Ready to run with minimal modifications ### Modified Core Files (9 files) 1. **`verl/trainer/ppo/core_algos.py`** (~50 lines changed) - Removed legacy TIS logic (`tis_imp_ratio_cap`) - Added `rollout_is_weights` parameter to all policy loss functions - Unified IS weight application interface across all PPO variants: - `compute_policy_loss_vanilla` - `compute_policy_loss_gspo` - `compute_policy_loss_gpg` - `compute_policy_loss_clip_cov` - `compute_policy_loss_kl_cov` - `compute_policy_loss_geo_mean` - Special handling for `geo_mean` (sequence-level aggregation) 2. **`verl/trainer/ppo/ray_trainer.py`** (~52 lines added) - New method: `compute_rollout_importance_weights_and_add_to_batch()` - Centralized IS computation (once per batch, on driver) - Conditional weight distribution to workers based on `algorithm.rollout_is` - Metrics collection and aggregation - Integration with existing training loop 3. **`verl/trainer/config/algorithm.py`** (+18 lines) - Added 6 new Rollout IS parameters: - `rollout_is_threshold` (main on/off switch) - `rollout_is` (weight application control) - `rollout_is_threshold_lower` - `rollout_is_level` - `rollout_is_mode` - `rollout_is_veto_threshold` - Comprehensive docstrings explaining each parameter 4. **`verl/workers/config/actor.py`** (-1 line) - Removed deprecated `tis_imp_ratio_cap` parameter 5. **`verl/workers/actor/dp_actor.py`** (~26 lines changed) - Updated to use new `rollout_is_weights` parameter - Removed legacy TIS logic 6. **`verl/workers/actor/megatron_actor.py`** (~15 lines changed) - Updated to use new `rollout_is_weights` parameter - Removed legacy TIS logic 7. **Configuration Files** (4 files updated) - `verl/trainer/config/ppo_trainer.yaml` - `verl/trainer/config/ppo_megatron_trainer.yaml` - `verl/trainer/config/_generated_ppo_trainer.yaml` - `verl/trainer/config/_generated_ppo_megatron_trainer.yaml` - Added default Rollout IS configuration section with explanatory comments ### Testing (2 files, 530 lines) 1. **`tests/trainer/ppo/test_rollout_is.py`** (289 lines) - Unit tests for `mismatch_helper.py` - Coverage for all aggregation levels (token, sequence, geometric) - Coverage for all bounding modes (truncate, clip) - Veto mechanism tests - Edge case handling (zeros, extremes, empty sequences) - Numerical stability verification - Metrics correctness validation 2. **`tests/trainer/ppo/test_rollout_is_integration.py`** (241 lines) - Integration tests with PPO training loop - End-to-end workflow validation - Batch processing tests - Configuration validation - Metrics collection verification - Compatibility with distributed training ### Updated Recipes (2 files) 1. **`recipe/dapo/dapo_ray_trainer.py`** (+5 lines) - Updated imports to use new framework 2. **`recipe/dapo/run_dapo_qwen2.5_32b_tis.sh`** (~42 lines changed) - Migrated from legacy TIS to new Rollout IS configuration - Updated documentation and comments ### Documentation Updates (2 files) 1. **`docs/examples/config.rst`** (~22 lines changed) - Updated configuration examples - Added Rollout IS section 2. **`docs/index.rst`** (+1 line) - Added link to Rollout IS migration guide --- ## Implementation Highlights ### Centralized Architecture The new design follows a clean separation of concerns: ``` ray_trainer.py (driver) └─> compute_rollout_importance_weights_and_add_to_batch() └─> mismatch_helper.compute_rollout_importance_weights() ├─> Computes IS weights (token/sequence/geometric) ├─> Applies bounding (truncate/clip) ├─> Veto mechanism for outliers ├─> Computes IS metrics └─> Computes mismatch metrics (KL, PPL) └─> Conditionally adds weights to batch (if rollout_is=True) └─> Distributes batch to workers actor workers (dp_actor, megatron_actor) └─> Receive batch with rollout_is_weights (if enabled) └─> Pass weights to policy loss function core_algos.py └─> All policy loss functions accept rollout_is_weights └─> Apply weights if provided: pg_losses *= rollout_is_weights ``` ### Key Design Decisions 1. **Centralized Computation**: IS weights computed once on driver, not per worker - Reduces redundant computation - Ensures consistency across workers - Simplifies debugging and metrics collection 2. **Configuration in Algorithm**: Moved from actor config to algorithm config - Better conceptual organization (algorithm-level concern, not worker-level) - Easier to manage and validate - Consistent with other algorithm parameters 3. **Two-Level Control**: - `rollout_is_threshold`: Enables/disables entire system (null = off) - `rollout_is`: Controls weight application (true = apply, false = metrics only) - Allows flexible monitoring and gradual rollout 4. **Metrics Consolidation**: Mismatch metrics computed within IS weight computation - Eliminates duplicate computation - Reduces memory overhead - Maintains metric accuracy 5. **Universal PPO Support**: Single interface for all PPO variants - Minimal code changes required - Consistent behavior across algorithms - Easy to add new variants --- ## Migration Guide ### For Users of Legacy TIS **Step 1: Update your configuration file** ```yaml # OLD (remove this) actor_rollout_ref: actor: tis_imp_ratio_cap: 2.0 # NEW (add this) algorithm: rollout_is_threshold: 2.0 # Use same value as old tis_imp_ratio_cap rollout_is: true rollout_is_level: token rollout_is_mode: truncate # REQUIRED (add if not present) actor_rollout_ref: rollout: calculate_log_probs: true ``` **Step 2: Monitor metrics** The first time you run with the new configuration, check these metrics: - `mismatch/rollout_is_eff_sample_size`: Should be > 80% of batch size - `mismatch/rollout_is_veto_fraction`: Should be < 5% - `mismatch/rollout_is_mean`: Should be close to 1.0 **Step 3: Tune if needed** If effective sample size is too low: - Increase `rollout_is_threshold` - Try `rollout_is_mode: clip` with appropriate lower bound - Consider `rollout_is_level: sequence` for more aggressive correction For detailed guidance, see `docs/advance/rollout_is_migration.md`. ### For New Users Start with recommended defaults: ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: token rollout_is_mode: truncate actor_rollout_ref: rollout: calculate_log_probs: true ``` Run the example script to see it in action: ```bash bash examples/rollout_importance_sampling/run_with_rollout_is.sh ``` --- ## Testing ### Unit Tests - **289 lines** of comprehensive unit tests in `test_rollout_is.py` - Covers all aggregation levels, bounding modes, and edge cases - Validates numerical stability and correctness - Fast execution (~1-2 seconds) ### Integration Tests - **241 lines** of integration tests in `test_rollout_is_integration.py` - End-to-end workflow with PPO training loop - Distributed training compatibility - Metrics collection validation - Moderate execution time (~10-20 seconds) ### Running Tests ```bash # Run all Rollout IS tests pytest tests/trainer/ppo/test_rollout_is.py -v pytest tests/trainer/ppo/test_rollout_is_integration.py -v # Run specific test pytest tests/trainer/ppo/test_rollout_is.py::test_token_level_truncate -v ``` --- ## Metrics Reference ### Rollout IS Metrics (all prefixed with `mismatch/`) | Metric | Description | Ideal Range | |--------|-------------|-------------| | `rollout_is_eff_sample_size` | Effective number of samples after IS | > 80% of batch | | `rollout_is_mean` | Mean IS weight | ~1.0 | | `rollout_is_std` | Standard deviation of IS weights | Low variance | | `rollout_is_p25` | 25th percentile | ~0.8-1.0 | | `rollout_is_p50` | Median IS weight | ~1.0 | | `rollout_is_p75` | 75th percentile | ~1.0-1.2 | | `rollout_is_p95` | 95th percentile | < threshold | | `rollout_is_p99` | 99th percentile | < threshold | | `rollout_is_max` | Maximum weight | ≤ threshold | | `rollout_is_min` | Minimum weight | ≥ lower threshold (clip mode) | | `rollout_is_veto_fraction` | % sequences vetoed | < 5% | | `rollout_is_catastrophic_token_fraction` | % catastrophic tokens | < 1% | | `rollout_is_clipped_fraction` | % tokens clipped (clip mode) | Variable | ### Mismatch Metrics (all prefixed with `mismatch/`) | Metric | Description | What It Means | |--------|-------------|---------------| | `mismatch_kl` | Forward KL divergence | Distribution difference (rollout vs training) | | `mismatch_k3_kl` | K3 KL estimator | Stable KL estimate for small divergences | | `mismatch_training_ppl` | Training policy perplexity | Prediction difficulty of training policy | | `mismatch_rollout_ppl` | Rollout policy perplexity | Prediction difficulty of rollout policy | | `mismatch_ppl_ratio` | Ratio of training to rollout PPL | Relative prediction difficulty | | `mismatch_log_ppl_diff` | Log perplexity difference | Sequence-level PPL mismatch | | `mismatch_log_ppl_abs_diff` | Absolute log PPL difference | Magnitude of mismatch | | `mismatch_log_ppl_diff_max` | Max log PPL difference | Worst-case mismatch | | `mismatch_log_ppl_diff_min` | Min log PPL difference | Best-case mismatch | | `mismatch_training_log_ppl` | Log of training PPL | Log-scale training perplexity | | `mismatch_rollout_log_ppl` | Log of rollout PPL | Log-scale rollout perplexity | --- ## Performance Impact ### Memory - Minimal overhead: ~1-2% increase in peak memory usage - Efficient log-space computation - No large intermediate tensors ### Computation - Negligible impact on training speed: < 1% overhead - Centralized computation on driver (no per-worker redundancy) - Optimized tensor operations ### Training Stability - Significant improvement in stability when distribution mismatch exists - Faster convergence in many scenarios - Reduced risk of training collapse --- ## Breaking Changes > [!IMPORTANT] > This PR contains **BREAKING CHANGES** to the configuration API. ### Removed - `actor_rollout_ref.actor.tis_imp_ratio_cap`: No longer supported ### Migration Required All users of the legacy TIS implementation must update their configuration files. See the migration guide above or `docs/advance/rollout_is_migration.md` for detailed instructions. ### Backward Compatibility - No backward compatibility with legacy TIS - Configuration files with `tis_imp_ratio_cap` will raise validation errors - Affected recipes have been updated in this PR --- ## Pre-Submission Checklist - [x] Search for similar PRs: [https://github.com/volcengine/verl/pulls?q=is%3Apr+importance+sampling](https://github.com/volcengine/verl/pulls?q=is%3Apr+importance+sampling) - [x] Format PR title as `[{modules}] {type}: {description}` (checked by CI) - **Suggested title:** `[BREAKING][rollout, trainer, algo] feat: implement comprehensive Rollout Importance Sampling framework` - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md) - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting) - [x] Add/update [documentation](https://github.com/volcengine/verl/tree/main/docs) (3 new docs, 2 updated) - [x] Add unit and integration tests (530 lines of tests) - [x] Once PR is ready for CI, send message in `ci-request` channel --- ## References - **Blog post:** [When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda) - **Migration guide:** `docs/advance/rollout_is_migration.md` - **Examples:** `examples/rollout_importance_sampling/` - **Tests:** `tests/trainer/ppo/test_rollout_is*.py` --------- Co-authored-by: Yan Bai <bayan@nvidia.com>
…ect#3750) # Rollout Importance Sampling Framework related to verl-project#3694 ## Summary This PR introduces a comprehensive **Rollout Importance Sampling (IS)** framework to correct distribution mismatch between data-collecting (rollout) and training policies, a critical factor for ensuring stable and efficient model training in RL fine-tuning. This work is motivated by the analysis in our blog post, [When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda). If you find this implementation useful in your research, please consider citing: ```bibtex @misc{liu-li-2025, title = {When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch}, url = {https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Inference-Training-Mismatch-271211a558b7808d8b12d403fd15edda}, author = {Jiacai Liu and Yingru Li and Yuqian Fu and Jiawei Wang and Qian Liu and Yu Shen}, year = {2025}, month = {September}, } ``` --- ## Problem Statement When using different policies for rollout generation (e.g., vLLM with BFloat16) and training (e.g., FSDP with FP32), distribution mismatch occurs, leading to: - Biased gradient estimates - Training instability and collapse - Reduced sample efficiency - Poor convergence properties This framework addresses these issues through principled importance sampling correction. --- ## Key Features & Improvements ### 1. **Flexible Aggregation Levels** Three methods for calculating IS weights: - **`token`**: Per-token importance ratios - **`sequence`**: Product of per-token ratios - **`geometric`**: Geometric mean of ratios ### 2. **Advanced Bounding Modes** Two strategies to control weight variance: - **`truncate`** (TIS): Caps weights at upper threshold only, preserving gradients - **`mask`** (MIS): Zeros out weights outside bounds, more aggressive filtering ### 3. **Comprehensive Diagnostics** Detailed metrics to monitor distribution mismatch and training health: **Rollout IS Metrics** (automatically prefixed with `mismatch/`): - Health indicators: `rollout_is_eff_sample_size`, `rollout_is_mean` - Distribution statistics: `rollout_is_p25`, `rollout_is_p50`, `rollout_is_p75`, `rollout_is_p95`, `rollout_is_p99`, `rollout_is_max`, `rollout_is_min`, `rollout_is_std` - Diagnostics: `rollout_is_veto_fraction`, `rollout_is_catastrophic_token_fraction`, `rollout_is_masked_fraction` (mask mode) - Sequence-level statistics (for sequence/geometric modes): `rollout_is_seq_mean`, `rollout_is_seq_std`, `rollout_is_seq_max`, `rollout_is_seq_min`, etc. **Mismatch Metrics** (computed efficiently within IS weight computation): - KL Divergence: `mismatch_kl` (forward KL), `mismatch_k3_kl` (K3 estimator for stability) - Perplexity: `mismatch_training_ppl`, `mismatch_rollout_ppl`, `mismatch_ppl_ratio` - Log perplexity statistics: `mismatch_log_ppl_diff`, `mismatch_log_ppl_abs_diff`, `mismatch_log_ppl_diff_max`, `mismatch_log_ppl_diff_min` ### 4. **Outlier Mitigation** - **Veto mechanism**: Automatically discards samples with catastrophic importance weights (per-token ratios below threshold) - Prevents gradient corruption from extreme outliers - Configurable threshold (default: 1e-4) ### 5. **Numerical Stability** - All core computations in **log-space** to prevent underflow/overflow - Carefully designed clamping and bounding to maintain numerical precision - Safe handling of edge cases (zero probabilities, extreme ratios) ### 6. **Memory Efficiency** - Optimized computation to minimize CUDA memory usage - Efficient metric aggregation without large intermediate tensors - Suitable for large-scale distributed training ### 7. **Metrics-Only Mode** - Compute and monitor mismatch metrics **without** applying IS weights - Useful for: - Understanding distribution mismatch before intervention - Deciding whether IS correction is needed - A/B testing IS impact - Controlled by `algorithm.rollout_is` flag (independent of weight computation) ### 8. **Universal PPO Support** - Integrated with **all PPO variants**: vanilla, GSPO, GPG, Clip-Cov, KL-Cov, geo_mean - Consistent interface across different policy loss functions - Automatic weight application when enabled --- ## API and Configuration Changes ### Migration from Legacy TIS #### ❌ **Before (REMOVED)** ```yaml # Old TIS configuration - NO LONGER SUPPORTED actor_rollout_ref: actor: tis_imp_ratio_cap: 2.0 # Removed from actor config ``` The legacy implementation: - Only supported token-level truncation - No metrics tracking - Lacked numerical stability - Limited configurability #### ✅ **After (New Framework)** Configuration moved to `algorithm` section for better organization: ```yaml algorithm: # Main on/off switch: null = disabled, float = enabled rollout_is_threshold: 2.0 # Control weight application (independent of metrics computation) rollout_is: true # true = apply weights, false = metrics only # Optional: lower threshold (defaults to 1/upper if null) rollout_is_threshold_lower: null # Aggregation level: "token", "sequence", or "geometric" rollout_is_level: token # Bounding mode: "truncate" or "mask" rollout_is_mode: truncate # Veto threshold for catastrophic outliers (null = disabled) rollout_is_veto_threshold: 1e-4 # REQUIRED: Enable log probability calculation actor_rollout_ref: rollout: calculate_log_probs: true ``` ### Configuration Examples **1. Token-level truncation (recommended starting point)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: token rollout_is_mode: truncate ``` **2. Sequence-level masking (more aggressive)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: sequence rollout_is_mode: mask ``` **3. Metrics-only mode (monitoring without correction)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: false # Compute metrics but don't apply weights rollout_is_level: token rollout_is_mode: truncate ``` **Example script:** `bash examples/rollout_importance_sampling/run_with_rollout_is.sh` --- ## Code Changes Overview ### New Files (4 files, 1,442 lines) 1. **`verl/trainer/ppo/mismatch_helper.py`** (459 lines) - Core implementation of IS weight computation - Three aggregation levels: token, sequence, geometric - Two bounding modes: truncate, mask - Veto mechanism for outlier detection - Comprehensive metrics computation (IS + mismatch) - All computations in log-space for numerical stability - Memory-efficient design 2. **`docs/advance/rollout_is_migration.md`** (642 lines) - Comprehensive migration guide from legacy TIS - Detailed explanation of all configuration options - Recommended threshold ranges for each aggregation level - Troubleshooting guide and best practices - Metrics interpretation guide 3. **`examples/rollout_importance_sampling/README.md`** (242 lines) - Quick start guide with working examples - Configuration templates for common scenarios - Threshold tuning guidelines - Metrics monitoring instructions 4. **`examples/rollout_importance_sampling/run_with_rollout_is.sh`** (99 lines) - Complete working example script - Demonstrates token-level and sequence-level configurations - Ready to run with minimal modifications ### Modified Core Files (9 files) 1. **`verl/trainer/ppo/core_algos.py`** (~50 lines changed) - Removed legacy TIS logic (`tis_imp_ratio_cap`) - Added `rollout_is_weights` parameter to all policy loss functions - Unified IS weight application interface across all PPO variants: - `compute_policy_loss_vanilla` - `compute_policy_loss_gspo` - `compute_policy_loss_gpg` - `compute_policy_loss_clip_cov` - `compute_policy_loss_kl_cov` - `compute_policy_loss_geo_mean` - Special handling for `geo_mean` (sequence-level aggregation) 2. **`verl/trainer/ppo/ray_trainer.py`** (~52 lines added) - New method: `compute_rollout_importance_weights_and_add_to_batch()` - Centralized IS computation (once per batch, on driver) - Conditional weight distribution to workers based on `algorithm.rollout_is` - Metrics collection and aggregation - Integration with existing training loop 3. **`verl/trainer/config/algorithm.py`** (+18 lines) - Added 6 new Rollout IS parameters: - `rollout_is_threshold` (main on/off switch) - `rollout_is` (weight application control) - `rollout_is_threshold_lower` - `rollout_is_level` - `rollout_is_mode` - `rollout_is_veto_threshold` - Comprehensive docstrings explaining each parameter 4. **`verl/workers/config/actor.py`** (-1 line) - Removed deprecated `tis_imp_ratio_cap` parameter 5. **`verl/workers/actor/dp_actor.py`** (~26 lines changed) - Updated to use new `rollout_is_weights` parameter - Removed legacy TIS logic 6. **`verl/workers/actor/megatron_actor.py`** (~15 lines changed) - Updated to use new `rollout_is_weights` parameter - Removed legacy TIS logic 7. **Configuration Files** (4 files updated) - `verl/trainer/config/ppo_trainer.yaml` - `verl/trainer/config/ppo_megatron_trainer.yaml` - `verl/trainer/config/_generated_ppo_trainer.yaml` - `verl/trainer/config/_generated_ppo_megatron_trainer.yaml` - Added default Rollout IS configuration section with explanatory comments ### Testing (2 files, 530 lines) 1. **`tests/trainer/ppo/test_rollout_is.py`** (289 lines) - Unit tests for `mismatch_helper.py` - Coverage for all aggregation levels (token, sequence, geometric) - Coverage for all bounding modes (truncate, mask) - Veto mechanism tests - Edge case handling (zeros, extremes, empty sequences) - Numerical stability verification - Metrics correctness validation 2. **`tests/trainer/ppo/test_rollout_is_integration.py`** (241 lines) - Integration tests with PPO training loop - End-to-end workflow validation - Batch processing tests - Configuration validation - Metrics collection verification - Compatibility with distributed training ### Updated Recipes (2 files) 1. **`recipe/dapo/dapo_ray_trainer.py`** (+5 lines) - Updated imports to use new framework 2. **`recipe/dapo/run_dapo_qwen2.5_32b_tis.sh`** (~42 lines changed) - Migrated from legacy TIS to new Rollout IS configuration - Updated documentation and comments ### Documentation Updates (2 files) 1. **`docs/examples/config.rst`** (~22 lines changed) - Updated configuration examples - Added Rollout IS section 2. **`docs/index.rst`** (+1 line) - Added link to Rollout IS migration guide --- ## Implementation Highlights ### Centralized Architecture The new design follows a clean separation of concerns: ``` ray_trainer.py (driver) └─> compute_rollout_importance_weights_and_add_to_batch() └─> mismatch_helper.compute_rollout_importance_weights() ├─> Computes IS weights (token/sequence/geometric) ├─> Applies bounding (truncate/mask) ├─> Veto mechanism for outliers ├─> Computes IS metrics └─> Computes mismatch metrics (KL, PPL) └─> Conditionally adds weights to batch (if rollout_is=True) └─> Distributes batch to workers actor workers (dp_actor, megatron_actor) └─> Receive batch with rollout_is_weights (if enabled) └─> Pass weights to policy loss function core_algos.py └─> All policy loss functions accept rollout_is_weights └─> Apply weights if provided: pg_losses *= rollout_is_weights ``` ### Key Design Decisions 1. **Centralized Computation**: IS weights computed once on driver, not per worker - Reduces redundant computation - Ensures consistency across workers - Simplifies debugging and metrics collection 2. **Configuration in Algorithm**: Moved from actor config to algorithm config - Better conceptual organization (algorithm-level concern, not worker-level) - Easier to manage and validate - Consistent with other algorithm parameters 3. **Two-Level Control**: - `rollout_is_threshold`: Enables/disables entire system (null = off) - `rollout_is`: Controls weight application (true = apply, false = metrics only) - Allows flexible monitoring and gradual rollout 4. **Metrics Consolidation**: Mismatch metrics computed within IS weight computation - Eliminates duplicate computation - Reduces memory overhead - Maintains metric accuracy 5. **Universal PPO Support**: Single interface for all PPO variants - Minimal code changes required - Consistent behavior across algorithms - Easy to add new variants --- ## Migration Guide ### For Users of Legacy TIS **Step 1: Update your configuration file** ```yaml # OLD (remove this) actor_rollout_ref: actor: tis_imp_ratio_cap: 2.0 # NEW (add this) algorithm: rollout_is_threshold: 2.0 # Use same value as old tis_imp_ratio_cap rollout_is: true rollout_is_level: token rollout_is_mode: truncate # REQUIRED (add if not present) actor_rollout_ref: rollout: calculate_log_probs: true ``` **Step 2: Monitor metrics** The first time you run with the new configuration, check these metrics: - `mismatch/rollout_is_eff_sample_size`: Should be > 80% of batch size - `mismatch/rollout_is_veto_fraction`: Should be < 5% - `mismatch/rollout_is_mean`: Should be close to 1.0 **Step 3: Tune if needed** If effective sample size is too low: - Increase `rollout_is_threshold` - Try `rollout_is_mode: mask` with appropriate lower bound - Consider `rollout_is_level: sequence` for more aggressive correction For detailed guidance, see `docs/advance/rollout_is_migration.md`. ### For New Users Start with recommended defaults: ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: token rollout_is_mode: truncate actor_rollout_ref: rollout: calculate_log_probs: true ``` Run the example script to see it in action: ```bash bash examples/rollout_importance_sampling/run_with_rollout_is.sh ``` --- ## Testing ### Unit Tests - **289 lines** of comprehensive unit tests in `test_rollout_is.py` - Covers all aggregation levels, bounding modes, and edge cases - Validates numerical stability and correctness - Fast execution (~1-2 seconds) ### Integration Tests - **241 lines** of integration tests in `test_rollout_is_integration.py` - End-to-end workflow with PPO training loop - Distributed training compatibility - Metrics collection validation - Moderate execution time (~10-20 seconds) ### Running Tests ```bash # Run all Rollout IS tests pytest tests/trainer/ppo/test_rollout_is.py -v pytest tests/trainer/ppo/test_rollout_is_integration.py -v # Run specific test pytest tests/trainer/ppo/test_rollout_is.py::test_token_level_truncate -v ``` --- ## Metrics Reference ### Rollout IS Metrics (all prefixed with `mismatch/`) | Metric | Description | Ideal Range | |--------|-------------|-------------| | `rollout_is_eff_sample_size` | Effective number of samples after IS | > 80% of batch | | `rollout_is_mean` | Mean IS weight | ~1.0 | | `rollout_is_std` | Standard deviation of IS weights | Low variance | | `rollout_is_p25` | 25th percentile | ~0.8-1.0 | | `rollout_is_p50` | Median IS weight | ~1.0 | | `rollout_is_p75` | 75th percentile | ~1.0-1.2 | | `rollout_is_p95` | 95th percentile | < threshold | | `rollout_is_p99` | 99th percentile | < threshold | | `rollout_is_max` | Maximum weight | ≤ threshold | | `rollout_is_min` | Minimum weight | ≥ lower threshold (mask mode) | | `rollout_is_veto_fraction` | % sequences vetoed | < 5% | | `rollout_is_catastrophic_token_fraction` | % catastrophic tokens | < 1% | | `rollout_is_masked_fraction` | % tokens masked (mask mode) | Variable | ### Mismatch Metrics (all prefixed with `mismatch/`) | Metric | Description | What It Means | |--------|-------------|---------------| | `mismatch_kl` | Forward KL divergence | Distribution difference (rollout vs training) | | `mismatch_k3_kl` | K3 KL estimator | Stable KL estimate for small divergences | | `mismatch_training_ppl` | Training policy perplexity | Prediction difficulty of training policy | | `mismatch_rollout_ppl` | Rollout policy perplexity | Prediction difficulty of rollout policy | | `mismatch_ppl_ratio` | Ratio of training to rollout PPL | Relative prediction difficulty | | `mismatch_log_ppl_diff` | Log perplexity difference | Sequence-level PPL mismatch | | `mismatch_log_ppl_abs_diff` | Absolute log PPL difference | Magnitude of mismatch | | `mismatch_log_ppl_diff_max` | Max log PPL difference | Worst-case mismatch | | `mismatch_log_ppl_diff_min` | Min log PPL difference | Best-case mismatch | | `mismatch_training_log_ppl` | Log of training PPL | Log-scale training perplexity | | `mismatch_rollout_log_ppl` | Log of rollout PPL | Log-scale rollout perplexity | --- ## Performance Impact ### Memory - Minimal overhead: ~1-2% increase in peak memory usage - Efficient log-space computation - No large intermediate tensors ### Computation - Negligible impact on training speed: < 1% overhead - Centralized computation on driver (no per-worker redundancy) - Optimized tensor operations ### Training Stability - Significant improvement in stability when distribution mismatch exists - Faster convergence in many scenarios - Reduced risk of training collapse --- ## Breaking Changes > [!IMPORTANT] > This PR contains **BREAKING CHANGES** to the configuration API. ### Removed - `actor_rollout_ref.actor.tis_imp_ratio_cap`: No longer supported ### Migration Required All users of the legacy TIS implementation must update their configuration files. See the migration guide above or `docs/advance/rollout_is_migration.md` for detailed instructions. ### Backward Compatibility - No backward compatibility with legacy TIS - Configuration files with `tis_imp_ratio_cap` will raise validation errors - Affected recipes have been updated in this PR --- ## Pre-Submission Checklist - [x] Search for similar PRs: [https://github.com/volcengine/verl/pulls?q=is%3Apr+importance+sampling](https://github.com/volcengine/verl/pulls?q=is%3Apr+importance+sampling) - [x] Format PR title as `[{modules}] {type}: {description}` (checked by CI) - **Suggested title:** `[BREAKING][rollout, trainer, algo] feat: implement comprehensive Rollout Importance Sampling framework` - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md) - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting) - [x] Add/update [documentation](https://github.com/volcengine/verl/tree/main/docs) (3 new docs, 2 updated) - [x] Add unit and integration tests (530 lines of tests) - [x] Once PR is ready for CI, send message in `ci-request` channel --- ## References - **Blog post:** [When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda) - **Migration guide:** `docs/advance/rollout_is_migration.md` - **Examples:** `examples/rollout_importance_sampling/` - **Tests:** `tests/trainer/ppo/test_rollout_is*.py`
UPDATED at Oct. 30: Now migrated to rollout correction framework in https://verl.readthedocs.io/en/latest/algo/rollout_corr.html Related to #4070
Rollout Importance Sampling Framework
Related to #3750
Summary
This PR introduces a comprehensive Rollout Importance Sampling (IS) framework to correct distribution mismatch between data-collecting (rollout) and training policies, a critical factor for ensuring stable and efficient model training in RL fine-tuning.
This work is motivated by the analysis in our blog post, When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch. If you find this implementation useful in your research, please consider citing:
Problem Statement
When using different policies for rollout generation (e.g., vLLM with BFloat16) and training (e.g., FSDP with FP32), distribution mismatch occurs, leading to:
This framework addresses these issues through principled importance sampling correction.
Key Features & Improvements
1. Flexible Aggregation Levels
Three methods for calculating IS weights:
token: Per-token importance ratiossequence: Product of per-token ratiosgeometric: Geometric mean of ratios2. Advanced Bounding Modes
Two strategies to control weight variance:
truncate(TIS): Caps weights at upper threshold only, preserving gradientsmask(MIS): Zeros out weights outside bounds, more aggressive filtering3. Comprehensive Diagnostics
Detailed metrics to monitor distribution mismatch and training health:
Rollout IS Metrics (automatically prefixed with
mismatch/):rollout_is_eff_sample_size,rollout_is_meanrollout_is_p25,rollout_is_p50,rollout_is_p75,rollout_is_p95,rollout_is_p99,rollout_is_max,rollout_is_min,rollout_is_stdrollout_is_veto_fraction,rollout_is_catastrophic_token_fraction,rollout_is_masked_fraction(mask mode)rollout_is_seq_mean,rollout_is_seq_std,rollout_is_seq_max,rollout_is_seq_min, etc.Mismatch Metrics (computed efficiently within IS weight computation):
mismatch_kl(forward KL),mismatch_k3_kl(K3 estimator for stability)mismatch_training_ppl,mismatch_rollout_ppl,mismatch_ppl_ratiomismatch_log_ppl_diff,mismatch_log_ppl_abs_diff,mismatch_log_ppl_diff_max,mismatch_log_ppl_diff_min4. Outlier Mitigation
5. Numerical Stability
6. Memory Efficiency
7. Metrics-Only Mode
algorithm.rollout_isflag (independent of weight computation)8. Universal PPO Support
API and Configuration Changes
Migration from Legacy TIS
❌ Before (REMOVED)
The legacy implementation:
✅ After (New Framework)
Configuration moved to
algorithmsection for better organization:Configuration Examples
1. Token-level truncation (recommended starting point)
2. Sequence-level masking (more aggressive)
3. Metrics-only mode (monitoring without correction)
Example script:
bash examples/rollout_importance_sampling/run_with_rollout_is.shCode Changes Overview
New Files (4 files, 1,442 lines)
verl/trainer/ppo/mismatch_helper.py(459 lines)docs/advance/rollout_is_migration.md(642 lines)examples/rollout_importance_sampling/README.md(242 lines)examples/rollout_importance_sampling/run_with_rollout_is.sh(99 lines)Modified Core Files (9 files)
verl/trainer/ppo/core_algos.py(~50 lines changed)tis_imp_ratio_cap)rollout_is_weightsparameter to all policy loss functionscompute_policy_loss_vanillacompute_policy_loss_gspocompute_policy_loss_gpgcompute_policy_loss_clip_covcompute_policy_loss_kl_covcompute_policy_loss_geo_meangeo_mean(sequence-level aggregation)verl/trainer/ppo/ray_trainer.py(~52 lines added)compute_rollout_importance_weights_and_add_to_batch()algorithm.rollout_isverl/trainer/config/algorithm.py(+18 lines)rollout_is_threshold(main on/off switch)rollout_is(weight application control)rollout_is_threshold_lowerrollout_is_levelrollout_is_moderollout_is_veto_thresholdverl/workers/config/actor.py(-1 line)tis_imp_ratio_capparameterverl/workers/actor/dp_actor.py(~26 lines changed)rollout_is_weightsparameterverl/workers/actor/megatron_actor.py(~15 lines changed)rollout_is_weightsparameterConfiguration Files (4 files updated)
verl/trainer/config/ppo_trainer.yamlverl/trainer/config/ppo_megatron_trainer.yamlverl/trainer/config/_generated_ppo_trainer.yamlverl/trainer/config/_generated_ppo_megatron_trainer.yamlTesting (2 files, 530 lines)
tests/trainer/ppo/test_rollout_is.py(289 lines)mismatch_helper.pytests/trainer/ppo/test_rollout_is_integration.py(241 lines)Updated Recipes (2 files)
recipe/dapo/dapo_ray_trainer.py(+5 lines)recipe/dapo/run_dapo_qwen2.5_32b_tis.sh(~42 lines changed)Documentation Updates (2 files)
docs/examples/config.rst(~22 lines changed)docs/index.rst(+1 line)Implementation Highlights
Centralized Architecture
The new design follows a clean separation of concerns:
Key Design Decisions
Centralized Computation: IS weights computed once on driver, not per worker
Configuration in Algorithm: Moved from actor config to algorithm config
Two-Level Control:
rollout_is_threshold: Enables/disables entire system (null = off)rollout_is: Controls weight application (true = apply, false = metrics only)Metrics Consolidation: Mismatch metrics computed within IS weight computation
Universal PPO Support: Single interface for all PPO variants
Migration Guide
For Users of Legacy TIS
Step 1: Update your configuration file
Step 2: Monitor metrics
The first time you run with the new configuration, check these metrics:
mismatch/rollout_is_eff_sample_size: Should be > 80% of batch sizemismatch/rollout_is_veto_fraction: Should be < 5%mismatch/rollout_is_mean: Should be close to 1.0Step 3: Tune if needed
If effective sample size is too low:
rollout_is_thresholdrollout_is_mode: maskwith appropriate lower boundrollout_is_level: sequencefor more aggressive correctionFor detailed guidance, see
docs/advance/rollout_is_migration.md.For New Users
Start with recommended defaults:
Run the example script to see it in action:
Testing
Unit Tests
test_rollout_is.pyIntegration Tests
test_rollout_is_integration.pyRunning Tests
Metrics Reference
Rollout IS Metrics (all prefixed with
mismatch/)rollout_is_eff_sample_sizerollout_is_meanrollout_is_stdrollout_is_p25rollout_is_p50rollout_is_p75rollout_is_p95rollout_is_p99rollout_is_maxrollout_is_minrollout_is_veto_fractionrollout_is_catastrophic_token_fractionrollout_is_masked_fractionMismatch Metrics (all prefixed with
mismatch/)mismatch_klmismatch_k3_klmismatch_training_pplmismatch_rollout_pplmismatch_ppl_ratiomismatch_log_ppl_diffmismatch_log_ppl_abs_diffmismatch_log_ppl_diff_maxmismatch_log_ppl_diff_minmismatch_training_log_pplmismatch_rollout_log_pplPerformance Impact
Memory
Computation
Training Stability
Breaking Changes
Important
This PR contains BREAKING CHANGES to the configuration API.
Removed
actor_rollout_ref.actor.tis_imp_ratio_cap: No longer supportedMigration Required
All users of the legacy TIS implementation must update their configuration files. See the migration guide above or
docs/advance/rollout_is_migration.mdfor detailed instructions.Backward Compatibility
tis_imp_ratio_capwill raise validation errorsPre-Submission Checklist
[{modules}] {type}: {description}(checked by CI)[BREAKING][rollout, trainer, algo] feat: implement comprehensive Rollout Importance Sampling frameworkci-requestchannelReferences
docs/advance/rollout_is_migration.mdexamples/rollout_importance_sampling/tests/trainer/ppo/test_rollout_is*.py