Skip to content

Commit d651fce

Browse files
committed
refactor(rollout_is): move config to algorithm, add metrics-only mode
Move rollout_is configuration from actor to algorithm for better separation of concerns. Add rollout_is flag to enable metrics-only mode (default: false) for monitoring before enabling weight application. Changes: - Move 6 rollout_is params from ActorConfig to AlgoConfig - Update ray_trainer to read from algorithm config - Simplify dp_actor and megatron_actor (data checks only) - Update docs, examples, and tests - Net -44 lines of code Migration: actor_rollout_ref.actor.rollout_is_* → algorithm.rollout_is_*
1 parent 6720e98 commit d651fce

File tree

16 files changed

+204
-184
lines changed

16 files changed

+204
-184
lines changed

docs/advance/rollout_is_migration.md

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,12 @@ The old implementation:
3636
### **Added (New Implementation)**
3737
3838
```yaml
39-
# New Rollout IS configuration
39+
# New Rollout IS configuration (all in algorithm config)
4040
algorithm:
41-
rollout_is: true
41+
# Main control: set threshold to enable (null = disabled)
4242
rollout_is_threshold: 2.0
43+
# Whether to apply weights to loss (default: false = metrics only)
44+
rollout_is: true
4345
rollout_is_threshold_lower: null # Auto-reciprocal
4446
rollout_is_level: token
4547
rollout_is_mode: truncate
@@ -121,11 +123,17 @@ The new implementation:
121123

122124
## Configuration Parameters
123125

126+
### `algorithm.rollout_is_threshold` (float or null)
127+
**Main on/off switch.** Upper threshold for IS weights.
128+
- `null` = disabled (no computation, no metrics)
129+
- `float` value (e.g., 2.0) = enabled (compute weights and metrics)
130+
124131
### `algorithm.rollout_is` (bool)
125-
Enable/disable IS correction. Default: `False`
132+
Whether to apply IS weights to policy loss. Default: `False`
133+
- `true` = apply weights to loss (full IS correction)
134+
- `false` = compute metrics only (useful for monitoring before enabling)
126135

127-
### `algorithm.rollout_is_threshold` (float or null)
128-
Upper threshold for IS weights. Set to `null` to disable IS completely.
136+
**Recommended threshold ranges:**
129137
- Token level: 1.5 - 5.0
130138
- Sequence level: 2.0 - 10.0
131139
- Geometric level: 1.0002 - 1.001
@@ -164,8 +172,8 @@ actor_rollout_ref:
164172
**After (New):**
165173
```yaml
166174
algorithm:
167-
rollout_is: true
168-
rollout_is_threshold: 2.0
175+
rollout_is_threshold: 2.0 # Main control
176+
rollout_is: true # Apply to loss (default: false)
169177
rollout_is_level: token
170178
rollout_is_mode: truncate
171179
@@ -430,41 +438,39 @@ Monitor metrics for 1-2 epochs before adjusting parameters.
430438

431439
## Configuration Examples
432440

433-
### Example 1: Token-level with Truncate
441+
### Example 1: Full IS Correction
434442
```yaml
435443
algorithm:
436-
rollout_is: true
437444
rollout_is_threshold: 2.0
445+
rollout_is: true # Apply weights to loss
438446
rollout_is_level: token
439447
rollout_is_mode: truncate
440448
```
441449

442-
### Example 2: Geometric Mean with Clip
450+
### Example 2: Metrics Only (Monitoring Mode)
443451
```yaml
444452
algorithm:
445-
rollout_is: true
446-
rollout_is_threshold: 1.0002
447-
rollout_is_threshold_lower: 0.9998
448-
rollout_is_level: geometric
449-
rollout_is_mode: clip
453+
rollout_is_threshold: 2.0
454+
rollout_is: false # Compute metrics, don't apply weights
455+
rollout_is_level: token
456+
rollout_is_mode: truncate
450457
```
451458

452-
### Example 3: Wider Threshold with Clip
459+
### Example 3: Geometric Mean with Clip
453460
```yaml
454461
algorithm:
462+
rollout_is_threshold: 1.0002
455463
rollout_is: true
456-
rollout_is_threshold: 3.0
457-
rollout_is_threshold_lower: 0.33
458-
rollout_is_level: token
464+
rollout_is_threshold_lower: 0.9998
465+
rollout_is_level: geometric
459466
rollout_is_mode: clip
460-
rollout_is_veto_threshold: 1e-5
461467
```
462468

463469
### Example 4: Asymmetric Thresholds
464470
```yaml
465471
algorithm:
466-
rollout_is: true
467472
rollout_is_threshold: 5.0
473+
rollout_is: true
468474
rollout_is_threshold_lower: 0.8
469475
rollout_is_level: token
470476
rollout_is_mode: clip

examples/rollout_importance_sampling/README.md

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@ Rollout Importance Sampling corrects for distribution mismatch when:
1919

2020
```yaml
2121
algorithm:
22-
rollout_is: true
22+
# Main control: set threshold to enable (null = disabled)
2323
rollout_is_threshold: 2.0
24+
# Whether to apply weights to policy loss (true) or just compute metrics (false)
25+
rollout_is: true
2426
rollout_is_level: token
2527
rollout_is_mode: truncate
2628

@@ -56,66 +58,64 @@ bash examples/rollout_importance_sampling/run_with_rollout_is.sh
5658

5759
### Key Parameters
5860

59-
- `rollout_is`: Enable/disable IS correction (boolean)
60-
- `rollout_is_threshold`: Upper threshold for IS weights (float or null to disable)
61+
- `rollout_is_threshold`: Upper threshold for IS weights (null = disabled, float = enabled). **Main on/off switch.**
62+
- `rollout_is`: Whether to apply weights to loss (true) or just compute metrics (false). Default: false.
6163
- `rollout_is_threshold_lower`: Lower threshold (null = auto 1/upper)
6264
- `rollout_is_veto_threshold`: Catastrophic outlier threshold (default: 1e-4)
6365

6466
## Configuration Examples
6567

66-
### Example 1: Token-level with Truncate
68+
### Example 1: Full IS Correction (Apply Weights)
6769

6870
```yaml
6971
algorithm:
70-
rollout_is: true
7172
rollout_is_threshold: 2.0
73+
rollout_is: true # Apply to loss
7274
rollout_is_level: token
7375
rollout_is_mode: truncate
7476
rollout_is_veto_threshold: 1e-4
7577
```
7678
77-
### Example 2: Geometric Mean with Clip
79+
### Example 2: Metrics Only (No Weight Application)
80+
81+
```yaml
82+
algorithm:
83+
rollout_is_threshold: 2.0
84+
rollout_is: false # Compute metrics only, don't apply to loss
85+
rollout_is_level: token
86+
rollout_is_mode: truncate
87+
```
88+
89+
### Example 3: Geometric Mean with Clip
7890
7991
```yaml
8092
algorithm:
81-
rollout_is: true
8293
rollout_is_threshold: 1.0002
94+
rollout_is: true
8395
rollout_is_threshold_lower: 0.9998
8496
rollout_is_level: geometric
8597
rollout_is_mode: clip
8698
rollout_is_veto_threshold: 1e-4
8799
```
88100
89-
### Example 3: Sequence-level with Truncate
101+
### Example 4: Sequence-level with Truncate
90102
91103
```yaml
92104
algorithm:
93-
rollout_is: true
94105
rollout_is_threshold: 5.0
106+
rollout_is: true
95107
rollout_is_threshold_lower: null # Auto-reciprocal: 0.2
96108
rollout_is_level: sequence
97109
rollout_is_mode: truncate
98110
rollout_is_veto_threshold: 1e-4
99111
```
100112
101-
### Example 4: Wider Threshold with Clip
102-
103-
```yaml
104-
algorithm:
105-
rollout_is: true
106-
rollout_is_threshold: 3.0
107-
rollout_is_threshold_lower: 0.33
108-
rollout_is_level: token
109-
rollout_is_mode: clip
110-
rollout_is_veto_threshold: 1e-5
111-
```
112-
113113
### Example 5: Asymmetric Thresholds
114114
115115
```yaml
116116
algorithm:
117-
rollout_is: true
118117
rollout_is_threshold: 5.0
118+
rollout_is: true
119119
rollout_is_threshold_lower: 0.8
120120
rollout_is_level: token
121121
rollout_is_mode: clip

examples/rollout_importance_sampling/run_with_rollout_is.sh

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@ set -xeuo pipefail
88
# Rollout Importance Sampling Configuration
99
# ==============================================================================
1010

11-
# Enable rollout IS
12-
rollout_is=True
13-
14-
# Upper threshold for IS weights
11+
# Main control: Upper threshold for IS weights (null = disabled, float = enabled)
1512
rollout_is_threshold=2.0
1613

14+
# Whether to apply IS weights to policy loss
15+
# true = apply weights to loss, false = compute metrics only
16+
rollout_is=true
17+
1718
# Lower threshold (null = auto-reciprocal, i.e., 1/upper = 0.5)
1819
rollout_is_threshold_lower=null
1920

@@ -87,9 +88,10 @@ python3 -m verl.trainer.main_ppo \
8788
echo "Training completed!"
8889
echo ""
8990
echo "Rollout IS Configuration:"
91+
echo " - Threshold: ${rollout_is_threshold}"
92+
echo " - Apply to loss: ${rollout_is}"
9093
echo " - Level: ${rollout_is_level}"
9194
echo " - Mode: ${rollout_is_mode}"
92-
echo " - Threshold: ${rollout_is_threshold}"
9395
echo ""
9496
echo "Monitor these key metrics in wandb:"
9597
echo " - mismatch/rollout_is_mean (should be ~1.0)"

recipe/dapo/dapo_ray_trainer.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,11 @@ def fit(self):
304304
values = self.critic_wg.compute_values(batch)
305305
batch = batch.union(values)
306306

307+
# Compute rollout IS weights and mismatch metrics (inherited from RayPPOTrainer)
308+
batch, is_metrics = self.compute_rollout_importance_weights_and_add_to_batch(batch)
309+
# IS and mismatch metrics already have mismatch/ prefix
310+
metrics.update(is_metrics)
311+
307312
with marked_timer("adv", timing_raw, "brown"):
308313
# compute advantages, executed on the driver process
309314
norm_adv_by_std_in_grpo = self.config.algorithm.get("norm_adv_by_std_in_grpo", True)
@@ -316,11 +321,6 @@ def fit(self):
316321
norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
317322
)
318323

319-
# Compute rollout IS weights and mismatch metrics (inherited from RayPPOTrainer)
320-
batch, is_metrics = self.compute_rollout_importance_weights_and_add_to_batch(batch)
321-
# IS and mismatch metrics already have mismatch/ prefix
322-
metrics.update(is_metrics)
323-
324324
# update critic
325325
if self.use_critic:
326326
with marked_timer("update_critic", timing_raw, "pink"):

recipe/one_step_off_policy/ray_trainer.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -577,6 +577,11 @@ def fit(self):
577577
else:
578578
batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]
579579

580+
# Compute rollout IS weights and mismatch metrics (inherited from RayPPOTrainer)
581+
batch, is_metrics = self.compute_rollout_importance_weights_and_add_to_batch(batch)
582+
# IS and mismatch metrics already have mismatch/ prefix
583+
metrics.update(is_metrics)
584+
580585
# compute advantages, executed on the driver process
581586

582587
norm_adv_by_std_in_grpo = self.config.algorithm.get(
@@ -593,11 +598,6 @@ def fit(self):
593598
config=self.config.algorithm,
594599
)
595600

596-
# Compute rollout IS weights and mismatch metrics (inherited from RayPPOTrainer)
597-
batch, is_metrics = self.compute_rollout_importance_weights_and_add_to_batch(batch)
598-
# IS and mismatch metrics already have mismatch/ prefix
599-
metrics.update(is_metrics)
600-
601601
# update critic
602602
if self.use_critic:
603603
with marked_timer("update_critic", timing_raw, color="pink"):

tests/trainer/ppo/test_rollout_is_integration.py

Lines changed: 55 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,25 +40,25 @@ def sample_data(self):
4040

4141
@pytest.fixture
4242
def config_with_rollout_is(self):
43-
"""Create config with rollout IS enabled."""
43+
"""Create config for policy loss computation.
44+
45+
Note: rollout_is config has been moved to algorithm config.
46+
This config only needs fields used by policy loss (clip_ratio, etc).
47+
"""
4448
config = ActorConfig(
4549
strategy="fsdp",
4650
rollout_n=1,
4751
ppo_micro_batch_size=2,
48-
rollout_is=True,
49-
rollout_is_threshold=2.0,
50-
rollout_is_level="token",
51-
rollout_is_mode="truncate",
52-
rollout_is_veto_threshold=1e-4,
5352
clip_ratio=0.2,
5453
)
5554
return config
5655

5756
def test_policy_loss_with_rollout_is(self, sample_data, config_with_rollout_is):
5857
"""Test that policy loss computation works with rollout IS weights.
5958
60-
Note: In production, IS weights are computed centrally in the trainer.
61-
This test simulates that by computing weights before passing to policy loss.
59+
Note: In production, IS weights are computed centrally in the trainer
60+
(before advantage computation) and passed to policy loss.
61+
This test simulates that workflow.
6262
"""
6363
# First compute IS weights (as trainer would do centrally)
6464
rollout_is_weights_proto, _ = compute_rollout_importance_weights(
@@ -189,6 +189,53 @@ def test_veto_mechanism(self):
189189
assert metrics["mismatch/rollout_is_veto_fraction"] > 0
190190
assert metrics["mismatch/rollout_is_veto_fraction"] <= 1.0
191191

192+
def test_metrics_only_mode(self, sample_data, config_with_rollout_is):
193+
"""Test metrics-only mode: compute IS weights/metrics but don't apply to loss.
194+
195+
This tests the use case where rollout_is_threshold is set (enables computation)
196+
but rollout_is=False (disables weight application to policy loss).
197+
"""
198+
# Compute IS weights (as trainer would do)
199+
rollout_is_weights_proto, is_metrics = compute_rollout_importance_weights(
200+
old_log_prob=sample_data["old_log_prob"],
201+
rollout_log_prob=sample_data["rollout_log_prob"],
202+
response_mask=sample_data["response_mask"],
203+
rollout_is_level="token",
204+
rollout_is_mode="truncate",
205+
rollout_is_threshold=2.0,
206+
)
207+
208+
# Metrics should be computed
209+
assert len(is_metrics) > 0
210+
assert "mismatch/rollout_is_mean" in is_metrics
211+
212+
# In metrics-only mode, we compute loss WITHOUT applying weights
213+
# (simulating rollout_is=False)
214+
pg_loss_no_weights, _, _, _ = compute_policy_loss_vanilla(
215+
old_log_prob=sample_data["old_log_prob"],
216+
log_prob=sample_data["log_prob"],
217+
advantages=sample_data["advantages"],
218+
response_mask=sample_data["response_mask"],
219+
loss_agg_mode="token-mean",
220+
config=config_with_rollout_is,
221+
rollout_is_weights=None, # Don't apply weights
222+
)
223+
224+
# Compare to loss WITH weights (rollout_is=True)
225+
rollout_is_weights = rollout_is_weights_proto.batch["rollout_is_weights"]
226+
pg_loss_with_weights, _, _, _ = compute_policy_loss_vanilla(
227+
old_log_prob=sample_data["old_log_prob"],
228+
log_prob=sample_data["log_prob"],
229+
advantages=sample_data["advantages"],
230+
response_mask=sample_data["response_mask"],
231+
loss_agg_mode="token-mean",
232+
config=config_with_rollout_is,
233+
rollout_is_weights=rollout_is_weights,
234+
)
235+
236+
# Losses should be different (weights have an effect)
237+
assert not torch.allclose(pg_loss_no_weights, pg_loss_with_weights)
238+
192239

193240
if __name__ == "__main__":
194241
pytest.main([__file__, "-v", "-s"])

verl/trainer/config/_generated_ppo_megatron_trainer.yaml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,6 @@ actor_rollout_ref:
7575
clip_ratio_c: 3.0
7676
loss_agg_mode: token-mean
7777
entropy_coeff: 0
78-
rollout_is: false
79-
rollout_is_threshold: null
80-
rollout_is_threshold_lower: null
81-
rollout_is_level: token
82-
rollout_is_mode: truncate
83-
rollout_is_veto_threshold: 0.0001
8478
use_kl_loss: false
8579
use_torch_compile: true
8680
kl_loss_coef: 0.001
@@ -487,6 +481,12 @@ algorithm:
487481
pf_ppo:
488482
reweight_method: pow
489483
weight_pow: 2.0
484+
rollout_is_threshold: null
485+
rollout_is_threshold_lower: null
486+
rollout_is_level: token
487+
rollout_is_mode: truncate
488+
rollout_is_veto_threshold: 0.0001
489+
rollout_is: false
490490
trainer:
491491
balance_batch: true
492492
total_epochs: 30

0 commit comments

Comments
 (0)