diff --git a/.github/workflows/e2e_one_step_off_policy.yml b/.github/workflows/e2e_one_step_off_policy.yml
new file mode 100644
index 00000000000..2ac76f869c7
--- /dev/null
+++ b/.github/workflows/e2e_one_step_off_policy.yml
@@ -0,0 +1,144 @@
+# # Tests layout
+
+# Each folder under tests/ corresponds to a test category for a sub-namespace in verl. For instance:
+# - `tests/trainer` for testing functionality related to `verl/trainer`
+# - `tests/models` for testing functionality related to `verl/models`
+# - ...
+
+# There are a few folders with `special_` prefix, created for special purposes:
+# - `special_distributed`: unit tests that must run with multiple GPUs
+# - `special_e2e`: end-to-end tests with training/generation scripts
+# - `special_npu`: tests for NPUs
+# - `special_sanity`: a suite of quick sanity tests
+# - `special_standalone`: a set of test that are designed to run in dedicated environments
+
+# Accelerators for tests
+# - By default tests are run with GPU available, except for the ones under `special_npu`, and any test script whose name ends with `on_cpu.py`.
+# - For test scripts with `on_cpu.py` name suffix would be tested on CPU resources in linux environment.
+
+# # Workflow layout
+
+# All CI tests are configured by yaml files in `.github/workflows/`. Here's an overview of all test configs:
+# 1. A list of always triggered CPU sanity tests: `check-pr-title.yml`, `secrets_scan.yml`, `check-pr-title,yml`, `pre-commit.yml`, `doc.yml`
+# 2. Some heavy multi-GPU unit tests, such as `model.yml`, `vllm.yml`, `sgl.yml`
+# 3. End-to-end tests: `e2e_*.yml`
+# 4. Unit tests
+# - `cpu_unit_tests.yml`, run pytest on all scripts with file name pattern `tests/**/test_*_on_cpu.py`
+# - `gpu_unit_tests.yml`, run pytest on all scripts with file without the `on_cpu.py` suffix.
+# - Since cpu/gpu unit tests by default runs all tests under `tests`, please make sure tests are manually excluded in them when
+# - new workflow yaml is added to `.github/workflows`
+# - new tests are added to workflow mentioned in 2.
+
+
+name: e2e_one_step_off_policy
+
+on:
+ # Trigger the workflow on push or pull request,
+ # but only for the main branch
+ # For push, for now only anti-patterns are specified so it is more conservative
+ # and achieves higher coverage.
+ push:
+ branches:
+ - main
+ - v0.*
+ paths:
+ - "**/*.py"
+ - "!**/*.md"
+ - "!**/*.sh"
+ # Other entrypoints
+ - "!examples/*trainer*"
+ - "!tests/**"
+ - "!verl/trainer/main_*.py"
+ - "!verl/trainer/fsdp_sft_trainer.py"
+ - "!recipe/**"
+ - "recipe/one_step_off_policy"
+ pull_request:
+ branches:
+ - main
+ - v0.*
+ paths:
+ - "**/*.py"
+ - "!**/*.md"
+ - "!**/*.sh"
+ # Other entrypoints
+ - "!examples/**"
+ - "!tests/**"
+ - "!verl/trainer/main_*.py"
+ - "!verl/trainer/fsdp_sft_trainer.py"
+ # Other recipes
+ - "!recipe/**"
+ # Home
+ - "recipe/one_step_off_policy"
+ # Entrypoints
+ - ".github/workflows/e2e_one_step_off_policy.yml"
+ - "examples/data_preprocess/gsm8k.py"
+ - "tests/special_e2e/run_one_step_off_policy.sh"
+
+# Cancel jobs on the same ref if a new one is triggered
+concurrency:
+ group: ${{ github.workflow }}-${{ github.ref }}
+ cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
+
+# Declare permissions just read content.
+permissions:
+ contents: read
+
+jobs:
+ # Test FSDP2 strategy
+ e2e_one_step_off_policy_fsdp2:
+ runs-on: [L20x8]
+ timeout-minutes: 50 # Increase timeout for async training
+ env:
+ HTTP_PROXY: ${{ secrets.PROXY_HTTP }}
+ HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }}
+ NO_PROXY: "localhost,127.0.0.1,hf-mirror.com"
+ HF_ENDPOINT: "https://hf-mirror.com"
+ HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable
+ ACTOR_STRATEGY: "fsdp2"
+ container:
+ image: verlai/verl:app-verl0.4-vllm0.8.5-mcore0.12.1
+ options: --gpus all --shm-size=10g
+ steps:
+ - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
+ with:
+ fetch-depth: 0
+ - name: Install the current repository
+ run: |
+ pip3 install --no-deps -e .[test,gpu]
+ - name: Prepare GSM8K dataset
+ run: |
+ python3 examples/data_preprocess/gsm8k.py
+ - name: Running the E2E test with one_step_off_policy algorithm (FSDP2)
+ run: |
+ ray stop --force
+ bash tests/special_e2e/run_one_step_off_policy.sh
+
+ # Test Megatron strategy
+ e2e_one_step_off_policy_megatron:
+ runs-on: [L20x8]
+ timeout-minutes: 50 # Increase timeout for async training
+ env:
+ HTTP_PROXY: ${{ secrets.PROXY_HTTP }}
+ HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }}
+ NO_PROXY: "localhost,127.0.0.1,hf-mirror.com"
+ HF_ENDPOINT: "https://hf-mirror.com"
+ HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable
+ ACTOR_STRATEGY: "megatron"
+ container:
+ image: verlai/verl:app-verl0.4-vllm0.8.5-mcore0.12.1
+ options: --gpus all --shm-size=10g
+ steps:
+ - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
+ with:
+ fetch-depth: 0
+ - name: Install the current repository
+ run: |
+ pip3 install --no-deps -e .[test,gpu]
+ - name: Prepare GSM8K dataset
+ run: |
+ python3 examples/data_preprocess/gsm8k.py
+ - name: Running the E2E test with one_step_off_policy algorithm (Megatron)
+ run: |
+ ray stop --force
+ bash tests/special_e2e/run_one_step_off_policy.sh
+
diff --git a/docs/advance/one_step_off.md b/docs/advance/one_step_off.md
new file mode 100644
index 00000000000..43edb199947
--- /dev/null
+++ b/docs/advance/one_step_off.md
@@ -0,0 +1,297 @@
+# Recipe: One Step Off Policy Async Trainer
+
+**Author:** `https://github.com/meituan-search>`
+
+Last updated: 07/16/2025.
+
+## Introduction
+
+### Background
+
+The current reinforcement learning training process implemented by verl is synchronous, adhering to the algorithmic
+workflows of established methods like PPO, GRPO, and DAPO. In each step, training samples are generated by the latest
+model, and the model is updated after training completes. While this approach aligns with off-policy reinforcement
+learning and stabilizes RL training, but it suffers from severe efficiency issues.
+Model updates must wait for the longest output in the generation phase to complete.
+During the generation of long-tail samples, GPUs remain idle, resulting in significant underutilization.
+The more severe the long-tail problem in sample generation, the lower the overall training efficiency.
+For example, in DAPO 32B training, the Rollout phase accounts for approximately 70% of the total time,
+and increasing resources does not reduce the Rollout duration.
+
+
+> source data: https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/workspace?nw=nwusertongyuxuan361
+
+### Solution
+
+We have implemented the **One Step Off Async Trainer** to help alleviate this issue. This approach parallelizes the
+generation and training processes, utilizing samples generated in the previous step for current training.
+It also involves appropriately partitioning resources, allocating dedicated resources for generation while automatically
+assigning the remainder to training. By reducing resources allocated to the generation phase, we mitigate GPU idle time
+during long-tail sample generation. Throughout this process, generation and training parameters maintain a one-step off
+policy.
+
+
+> reference: [AReaL: A Large-Scale Asynchronous Reinforcement Learning System for Language Reasoning](
+> https://arxiv.org/abs/2505.24298)
+
+Our core contributions include:
+
+1. **Parallel Generation and Training**:
+ Samples for the next batch are asynchronously generated while the current batch is being trained.
+
+2. **Resource Isolation**:
+ Unlike `hybrid_engine`, this method requires explicit resource allocation for rollout, with remaining resources
+ automatically assigned to training.
+
+3. **NCCL Parameter Synchronization**:
+ Employs NCCL communication primitives for seamless parameter transfer between generation and training modules.
+
+### Experimental Results
+
+- **Machine Configuration**: 2 nodes with 16 H20 GPUs each
+ - Generation: 4 GPUs
+ - Training: 12 GPUs
+- **Model**: Qwen2.5-Math-7B
+- **Rollout Configuration**:
+- **Max Response Length**: FSDP2: 20,480 tokens; Megatron: 8,192 tokens
+- **Algorithm**: DAPO
+- **Rollout Engine**: vLLM
+
+| training mode | engine | step | gen | wait_prev_gen | generate_sequences | old_log_prob | update_actor | total time | acc/best@32/mean | acc/maj@32/mean |
+|------------------------|---------------|------|-----|---------------|--------------------|--------------|--------------|---------------|------------------|-----------------|
+| colocate sync | VLLM+FSDP2 | 749 | 321 | - | 247 | 88 | 286 | 19h18m | 0.5948 | 0.417 |
+| one-step-overlap async | VLLM+FSDP2 | 520 | - | 45 | 458 | 108 | 337 | 15h34m(+23%) | 0.6165 | 0.494 |
+| colocate sync | VLLM+Megatron | 699 | 207 | - | 162 | 119 | 344 | 18h21m | 0.605 | 0.4217 |
+| one-step-overlap async | VLLM+Megatron | 566 | - | 59 | 501 | 120 | 347 | 13h06m (+40%) | 0.6569 | 0.4038 |
+
+* colocate sync: step = gen + old_log_prob + update_actor
+* one-step-overlap async: step = max(wait_prev_gen + generate_sequences, old_log_prob + update_actor)
+
+
+
+> source data: https://wandb.ai/hou-zg-meituan/one-step-off-policy?nw=nwuserhouzg
+
+## Implementation
+
+### One Step Off Policy Async Pipline
+
+Our implemented **One Step Off Policy Async Pipeline** integrates seamlessly into existing training logic at minimal
+cost,
+eliminating the need for additional sample storage management. The core mechanism uses `async_gen_next_batch`
+for asynchronous rollout generation while maintaining continuous operation during epoch transitions
+via `create_continuous_iterator`.
+
+```python
+# iterator generator, simplify one-step integration of the training process
+def _create_continuous_iterator(self):
+ for epoch in range(self.config.trainer.total_epochs):
+ iterator = iter(self.train_dataloader)
+ for batch_dict in iterator:
+ yield epoch, batch_dict
+
+
+# read next batch samples, parameters sync and launch asyn gen_seq
+def _async_gen_next_batch(self, continuous_iterator):
+ # read train_data
+ try:
+ epoch, batch_dict = next(continuous_iterator)
+ except StopIteration:
+ return None
+ batch = DataProto.from_single_dict(batch_dict)
+ gen_batch = batch_pocess(batch)
+ # sync weights from actor to rollout
+ self.sync_rollout_weights()
+ # async generation
+ gen_batch_output = self.rollout_wg.async_generate_sequences(gen_batch)
+ # future encapsulated
+ return GenerationBatchFuture(epoch, batch, gen_batch_output)
+
+
+continuous_iterator = self._create_continuous_iterator()
+# run rollout first to achieve one-step-off
+batch_data_future = self._async_gen_next_batch(continuous_iterator)
+
+while batch_data_future is not None:
+ # wait for the gen_seq result from the previous step
+ batch = batch_data_future.get()
+ # launch the next async call to generate sequences
+ batch_data_future = self._async_gen_next_batch(continuous_iterator)
+
+ # compute advantages
+ batch = critic.compute_values(batch)
+ batch = reference.compute_log_prob(batch)
+ batch = reward.compute_reward(batch)
+ batch = compute_advantages(batch)
+
+ # model update
+ critic_metrics = critic.update_critic(batch)
+ actor_metrics = actor.update_actor(batch)
+```
+
+### Parameter Synchronization
+
+The exciting point is that our nccl based weights updating for rollout model has great performance.
+At most of time, the latency is under 300ms, which is negligible for RLHF.
+Although it is only implemented with fsdp and vllm now, we think it is not complex to extend it to the other backend.
+
+> **sync_rollout_weights**:The time for synchronizing parameters from actor to rollout is extremely fast and can almost
+> be ignored because it is implemented with nccl.
+
+```python
+class ActorRolloutRefWorker:
+ # actor acquires the meta-info of model parameters for parameter sync
+ @register(dispatch_mode=Dispatch.ONE_TO_ALL)
+ def get_actor_weights_info(self):
+ params = self._get_actor_params()
+ ret = []
+ for key, tensor in params.items():
+ ret.append((key, tensor.size(), tensor.dtype))
+ self._weights_info = ret
+ return ret
+
+ # rollout sets the meta-info of model parameters for parameter sync
+ @register(dispatch_mode=Dispatch.ONE_TO_ALL)
+ def set_actor_weights_info(self, weights_info):
+ self._weights_info = weights_info
+
+
+class AsyncRayPPOTrainer(RayPPOTrainer):
+ def init_workers(self):
+
+
+...
+# rollout obtains the meta-info of model parameters from the actor for parameter sync
+weights_info = self.actor_wg.get_actor_weights_info()[0]
+self.rollout_wg.set_actor_weights_info(weights_info)
+
+# Create an actor-rollout communication group for parameter sync
+actor_rollout_workers = self.actor_wg.workers + self.rollout_wg.workers
+collective.create_collective_group(
+ actor_rollout_workers,
+ len(actor_rollout_workers),
+ list(range(0, len(actor_rollout_workers))),
+ backend="nccl",
+ group_name="actor_rollout"
+)
+```
+
+```python
+# drive process call the actor and rollout respectively to sync parameters by nccl
+def sync_rollout_weights(self):
+ self.actor_wg.sync_rollout_weights()
+ ray.get(self.rollout_wg.sync_rollout_weights())
+
+
+# fsdp model parameter sync
+@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
+def sync_rollout_weights(self):
+ params = self._get_actor_params() if self._is_actor else None
+ if self._is_rollout:
+ inference_model = (
+ self.rollout.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model
+ )
+ patch_vllm_moe_model_weight_loader(inference_model)
+ # Model parameters are broadcast tensor-by-tensor from actor to rollout
+ for key, shape, dtype in self._weights_info:
+ tensor = torch.empty(shape, dtype=dtype, device=get_torch_device().current_device())
+ if self._is_actor:
+ assert key in params
+ origin_data = params[key]
+ if hasattr(origin_data, "full_tensor"):
+ origin_data = origin_data.full_tensor()
+ if torch.distributed.get_rank() == 0:
+ tensor.copy_(origin_data)
+ from ray.util.collective import collective
+
+ collective.broadcast(tensor, src_rank=0, group_name="actor_rollout")
+ if self._is_rollout:
+ inference_model.load_weights([(key, tensor)])
+```
+
+## Usage
+
+### FSDP2 Configuration Example
+
+```shell
+python3 -m recipe.one_step_off_policy.async_main_ppo \
+ --config-path=config \
+ --config-name='one_step_off_ppo_trainer.yaml' \
+ actor_rollout_ref.actor.strategy=fsdp2 \
+ # actor and rollout are placed separately
+ actor_rollout_ref.hybrid_engine=False \
+ # actor and rollout resource
+ trainer.nnodes=1 \
+ trainer.n_gpus_per_node=6 \
+ rollout.nnodes=1 \
+ rollout.n_gpus_per_node=2
+```
+
+### Megatron Configuration Example
+
+```shell
+python3 -m recipe.one_step_off_policy.async_main_ppo \
+ --config-path=config \
+ --config-name='one_step_off_ppo_megatron_trainer.yaml' \
+ actor_rollout_ref.actor.strategy=megatron \
+ # actor and rollout are placed separately
+ actor_rollout_ref.hybrid_engine=False \
+ # actor and rollout resource
+ trainer.nnodes=1 \
+ trainer.n_gpus_per_node=6 \
+ rollout.nnodes=1 \
+ rollout.n_gpus_per_node=2
+```
+
+### Configuration Guidelines
+
+1. **Card Number Relationships**
+ Maintain either of these relationships for optimal batch distribution:
+ - `actor_rollout_ref.rollout.n` should be an integer divisor of:
+ `trainer.n_gpus_per_node * trainer.nnodes`
+ - `actor_rollout_ref.rollout.n * data.train_batch_size` should be evenly divisible by:
+ `trainer.n_gpus_per_node * trainer.nnodes`
+
+ > Rationale: Ensures training samples can be evenly distributed across training GPUs when using partial resources for
+ generation.
+
+2. **Dynamic Resource Tuning**
+ Adjust `trainer.nnodes` `trainer.n_gpus_per_node` `rollout.nnodes` `rollout.n_gpus_per_node` based on phase
+ durations:
+ - **Ideal state**: Rollout and training phases have comparable durations
+ - **Diagnostic metrics**:
+ - Monitor `wait_prev_gen` duration
+ - Analyze `sequence_length` distribution
+ - **Adjustment strategy**:
+ - High `wait_prev_gen` + uniform sequence lengths → Increase rollout resources
+ - High `wait_prev_gen` + long-tail sequences → Optimize stopping criteria (resource increase won't help)
+ > **wait_prev_gen**:The time consumed waiting for the previous rollout to end (the part that is not fully
+ overlapped).
+ **Resource Configuration Strategies:**
+ - **Resource-constrained scenario**: Optimize resource utilization by adjusting GPU allocation ratios,
+ keeping the number of nodes equal to allow training and rollout to share nodes;
+ - Configure `trainer.nnodes = rollout.nnodes` with
+ `trainer.n_gpus_per_node + rollout.n_gpus_per_node = physical_gpus_per_node`. Control rollout resource
+ allocation by adjusting `n_gpus_per_node`.
+ - **Resource-abundant scenario**: Optimize performance by adjusting the number of nodes,
+ keeping the number of GPUs per node equal to enable independent scaling of training and rollout
+ parallelism.
+ - Configure `trainer.n_gpus_per_node = rollout.n_gpus_per_node` and control rollout resource allocation by
+ adjusting `trainer.nnodes` and `rollout.nnodes`to achieve optimal performance.
+ > **Note**: The total number of nodes required by the system is not simply `trainer.nnodes + rollout.nnodes`. The
+ > actual calculation depends on GPU capacity:
+ > - When `trainer.n_gpus_per_node + rollout.n_gpus_per_node <= physical_gpus_per_node`,
+ > the required node count is `max(trainer.nnodes, rollout.nnodes)`
+ > - When `trainer.n_gpus_per_node + rollout.n_gpus_per_node > physical_gpus_per_node`,
+ > the required node count is `trainer.nnodes + rollout.nnodes`
+
+## Functional Support
+
+| Category | Support Situation |
+|--------------------|-----------------------------------------------------------------------------------------------------------------|
+| train engine | FSDP2
Megatron |
+| rollout engine | vLLM |
+| AdvantageEstimator | GRPO
GRPO_PASSK
REINFORCE_PLUS_PLUS
RLOO
OPO
REINFORCE_PLUS_PLUS_BASELINE
GPG |
+| Reward | all |
diff --git a/docs/index.rst b/docs/index.rst
index 980066a7fbc..a486586945c 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -116,6 +116,7 @@ verl is fast with:
advance/dpo_extension
examples/sandbox_fusion_example
advance/rollout_trace.rst
+ advance/one_step_off
.. toctree::
:maxdepth: 1
diff --git a/recipe/one_step_off_policy/README.md b/recipe/one_step_off_policy/README.md
new file mode 100644
index 00000000000..2eef6997f8c
--- /dev/null
+++ b/recipe/one_step_off_policy/README.md
@@ -0,0 +1,297 @@
+# Recipe: One Step Off Policy Async Trainer
+
+**Author:** `https://github.com/meituan-search>`
+
+Last updated: 07/16/2025.
+
+## Introduction
+
+### Background
+
+The current reinforcement learning training process implemented by verl is synchronous, adhering to the algorithmic
+workflows of established methods like PPO, GRPO, and DAPO. In each step, training samples are generated by the latest
+model, and the model is updated after training completes. While this approach aligns with off-policy reinforcement
+learning and stabilizes RL training, but it suffers from severe efficiency issues.
+Model updates must wait for the longest output in the generation phase to complete.
+During the generation of long-tail samples, GPUs remain idle, resulting in significant underutilization.
+The more severe the long-tail problem in sample generation, the lower the overall training efficiency.
+For example, in DAPO 32B training, the Rollout phase accounts for approximately 70% of the total time,
+and increasing resources does not reduce the Rollout duration.
+
+
+> source data: https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/workspace?nw=nwusertongyuxuan361
+
+### Solution
+
+We have implemented the **One Step Off Async Trainer** to help alleviate this issue. This approach parallelizes the
+generation and training processes, utilizing samples generated in the previous step for current training.
+It also involves appropriately partitioning resources, allocating dedicated resources for generation while automatically
+assigning the remainder to training. By reducing resources allocated to the generation phase, we mitigate GPU idle time
+during long-tail sample generation. Throughout this process, generation and training parameters maintain a one-step off
+policy.
+
+
+> reference: [AReaL: A Large-Scale Asynchronous Reinforcement Learning System for Language Reasoning](
+> https://arxiv.org/abs/2505.24298)
+
+Our core contributions include:
+
+1. **Parallel Generation and Training**:
+ Samples for the next batch are asynchronously generated while the current batch is being trained.
+
+2. **Resource Isolation**:
+ Unlike `hybrid_engine`, this method requires explicit resource allocation for rollout, with remaining resources
+ automatically assigned to training.
+
+3. **NCCL Parameter Synchronization**:
+ Employs NCCL communication primitives for seamless parameter transfer between generation and training modules.
+
+### Experimental Results
+
+- **Machine Configuration**: 2 nodes with 16 H20 GPUs each
+ - Generation: 4 GPUs
+ - Training: 12 GPUs
+- **Model**: Qwen2.5-Math-7B
+- **Rollout Configuration**:
+- **Max Response Length**: FSDP2: 20,480 tokens; Megatron: 8,192 tokens
+- **Algorithm**: DAPO
+- **Rollout Engine**: vLLM
+
+| training mode | engine | step | gen | wait_prev_gen | generate_sequences | old_log_prob | update_actor | total time | acc/best@32/mean | acc/maj@32/mean |
+|------------------------|---------------|------|-----|---------------|--------------------|--------------|--------------|---------------|------------------|-----------------|
+| colocate sync | VLLM+FSDP2 | 749 | 321 | - | 247 | 88 | 286 | 19h18m | 0.5948 | 0.417 |
+| one-step-overlap async | VLLM+FSDP2 | 520 | - | 45 | 458 | 108 | 337 | 15h34m(+23%) | 0.6165 | 0.494 |
+| colocate sync | VLLM+Megatron | 699 | 207 | - | 162 | 119 | 344 | 18h21m | 0.605 | 0.4217 |
+| one-step-overlap async | VLLM+Megatron | 566 | - | 59 | 501 | 120 | 347 | 13h06m (+40%) | 0.6569 | 0.4038 |
+
+* colocate sync: step = gen + old_log_prob + update_actor
+* one-step-overlap async: step = max(wait_prev_gen + generate_sequences, old_log_prob + update_actor)
+
+
+
+> source data: https://wandb.ai/hou-zg-meituan/one-step-off-policy?nw=nwuserhouzg
+
+## Implementation
+
+### One Step Off Policy Async Pipline
+
+Our implemented **One Step Off Policy Async Pipeline** integrates seamlessly into existing training logic at minimal
+cost,
+eliminating the need for additional sample storage management. The core mechanism uses `async_gen_next_batch`
+for asynchronous rollout generation while maintaining continuous operation during epoch transitions
+via `create_continuous_iterator`.
+
+```python
+# iterator generator, simplify one-step integration of the training process
+def _create_continuous_iterator(self):
+ for epoch in range(self.config.trainer.total_epochs):
+ iterator = iter(self.train_dataloader)
+ for batch_dict in iterator:
+ yield epoch, batch_dict
+
+
+# read next batch samples, parameters sync and launch asyn gen_seq
+def _async_gen_next_batch(self, continuous_iterator):
+ # read train_data
+ try:
+ epoch, batch_dict = next(continuous_iterator)
+ except StopIteration:
+ return None
+ batch = DataProto.from_single_dict(batch_dict)
+ gen_batch = batch_pocess(batch)
+ # sync weights from actor to rollout
+ self.sync_rollout_weights()
+ # async generation
+ gen_batch_output = self.rollout_wg.async_generate_sequences(gen_batch)
+ # future encapsulated
+ return GenerationBatchFuture(epoch, batch, gen_batch_output)
+
+
+continuous_iterator = self._create_continuous_iterator()
+# run rollout first to achieve one-step-off
+batch_data_future = self._async_gen_next_batch(continuous_iterator)
+
+while batch_data_future is not None:
+ # wait for the gen_seq result from the previous step
+ batch = batch_data_future.get()
+ # launch the next async call to generate sequences
+ batch_data_future = self._async_gen_next_batch(continuous_iterator)
+
+ # compute advantages
+ batch = critic.compute_values(batch)
+ batch = reference.compute_log_prob(batch)
+ batch = reward.compute_reward(batch)
+ batch = compute_advantages(batch)
+
+ # model update
+ critic_metrics = critic.update_critic(batch)
+ actor_metrics = actor.update_actor(batch)
+```
+
+### Parameter Synchronization
+
+The exciting point is that our nccl based weights updating for rollout model has great performance.
+At most of time, the latency is under 300ms, which is negligible for RLHF.
+Although it is only implemented with fsdp and vllm now, we think it is not complex to extend it to the other backend.
+
+> **sync_rollout_weights**:The time for synchronizing parameters from actor to rollout is extremely fast and can almost
+> be ignored because it is implemented with nccl.
+
+```python
+class ActorRolloutRefWorker:
+ # actor acquires the meta-info of model parameters for parameter sync
+ @register(dispatch_mode=Dispatch.ONE_TO_ALL)
+ def get_actor_weights_info(self):
+ params = self._get_actor_params()
+ ret = []
+ for key, tensor in params.items():
+ ret.append((key, tensor.size(), tensor.dtype))
+ self._weights_info = ret
+ return ret
+
+ # rollout sets the meta-info of model parameters for parameter sync
+ @register(dispatch_mode=Dispatch.ONE_TO_ALL)
+ def set_actor_weights_info(self, weights_info):
+ self._weights_info = weights_info
+
+
+class AsyncRayPPOTrainer(RayPPOTrainer):
+ def init_workers(self):
+
+
+...
+# rollout obtains the meta-info of model parameters from the actor for parameter sync
+weights_info = self.actor_wg.get_actor_weights_info()[0]
+self.rollout_wg.set_actor_weights_info(weights_info)
+
+# Create an actor-rollout communication group for parameter sync
+actor_rollout_workers = self.actor_wg.workers + self.rollout_wg.workers
+collective.create_collective_group(
+ actor_rollout_workers,
+ len(actor_rollout_workers),
+ list(range(0, len(actor_rollout_workers))),
+ backend="nccl",
+ group_name="actor_rollout"
+)
+```
+
+```python
+# drive process call the actor and rollout respectively to sync parameters by nccl
+def sync_rollout_weights(self):
+ self.actor_wg.sync_rollout_weights()
+ ray.get(self.rollout_wg.sync_rollout_weights())
+
+
+# fsdp model parameter sync
+@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
+def sync_rollout_weights(self):
+ params = self._get_actor_params() if self._is_actor else None
+ if self._is_rollout:
+ inference_model = (
+ self.rollout.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model
+ )
+ patch_vllm_moe_model_weight_loader(inference_model)
+ # Model parameters are broadcast tensor-by-tensor from actor to rollout
+ for key, shape, dtype in self._weights_info:
+ tensor = torch.empty(shape, dtype=dtype, device=get_torch_device().current_device())
+ if self._is_actor:
+ assert key in params
+ origin_data = params[key]
+ if hasattr(origin_data, "full_tensor"):
+ origin_data = origin_data.full_tensor()
+ if torch.distributed.get_rank() == 0:
+ tensor.copy_(origin_data)
+ from ray.util.collective import collective
+
+ collective.broadcast(tensor, src_rank=0, group_name="actor_rollout")
+ if self._is_rollout:
+ inference_model.load_weights([(key, tensor)])
+```
+
+## Usage
+
+### FSDP2 Configuration Example
+
+```shell
+python3 -m recipe.one_step_off_policy.async_main_ppo \
+ --config-path=config \
+ --config-name='one_step_off_ppo_trainer.yaml' \
+ actor_rollout_ref.actor.strategy=fsdp2 \
+ # actor and rollout are placed separately
+ actor_rollout_ref.hybrid_engine=False \
+ # actor and rollout resource
+ trainer.nnodes=1 \
+ trainer.n_gpus_per_node=6 \
+ rollout.nnodes=1 \
+ rollout.n_gpus_per_node=2
+```
+
+### Megatron Configuration Example
+
+```shell
+python3 -m recipe.one_step_off_policy.async_main_ppo \
+ --config-path=config \
+ --config-name='one_step_off_ppo_megatron_trainer.yaml' \
+ actor_rollout_ref.actor.strategy=megatron \
+ # actor and rollout are placed separately
+ actor_rollout_ref.hybrid_engine=False \
+ # actor and rollout resource
+ trainer.nnodes=1 \
+ trainer.n_gpus_per_node=6 \
+ rollout.nnodes=1 \
+ rollout.n_gpus_per_node=2
+```
+
+### Configuration Guidelines
+
+1. **Card Number Relationships**
+ Maintain either of these relationships for optimal batch distribution:
+ - `actor_rollout_ref.rollout.n` should be an integer divisor of:
+ `trainer.n_gpus_per_node * trainer.nnodes`
+ - `actor_rollout_ref.rollout.n * data.train_batch_size` should be evenly divisible by:
+ `trainer.n_gpus_per_node * trainer.nnodes`
+
+ > Rationale: Ensures training samples can be evenly distributed across training GPUs when using partial resources for
+ generation.
+
+2. **Dynamic Resource Tuning**
+ Adjust `trainer.nnodes` `trainer.n_gpus_per_node` `rollout.nnodes` `rollout.n_gpus_per_node` based on phase
+ durations:
+ - **Ideal state**: Rollout and training phases have comparable durations
+ - **Diagnostic metrics**:
+ - Monitor `wait_prev_gen` duration
+ - Analyze `sequence_length` distribution
+ - **Adjustment strategy**:
+ - High `wait_prev_gen` + uniform sequence lengths → Increase rollout resources
+ - High `wait_prev_gen` + long-tail sequences → Optimize stopping criteria (resource increase won't help)
+ > **wait_prev_gen**:The time consumed waiting for the previous rollout to end (the part that is not fully
+ overlapped).
+ **Resource Configuration Strategies:**
+ - **Resource-constrained scenario**: Optimize resource utilization by adjusting GPU allocation ratios,
+ keeping the number of nodes equal to allow training and rollout to share nodes;
+ - Configure `trainer.nnodes = rollout.nnodes` with
+ `trainer.n_gpus_per_node + rollout.n_gpus_per_node = physical_gpus_per_node`. Control rollout resource
+ allocation by adjusting `n_gpus_per_node`.
+ - **Resource-abundant scenario**: Optimize performance by adjusting the number of nodes,
+ keeping the number of GPUs per node equal to enable independent scaling of training and rollout
+ parallelism.
+ - Configure `trainer.n_gpus_per_node = rollout.n_gpus_per_node` and control rollout resource allocation by
+ adjusting `trainer.nnodes` and `rollout.nnodes`to achieve optimal performance.
+ > **Note**: The total number of nodes required by the system is not simply `trainer.nnodes + rollout.nnodes`. The
+ > actual calculation depends on GPU capacity:
+ > - When `trainer.n_gpus_per_node + rollout.n_gpus_per_node <= physical_gpus_per_node`,
+ > the required node count is `max(trainer.nnodes, rollout.nnodes)`
+ > - When `trainer.n_gpus_per_node + rollout.n_gpus_per_node > physical_gpus_per_node`,
+ > the required node count is `trainer.nnodes + rollout.nnodes`
+
+## Functional Support
+
+| Category | Support Situation |
+|--------------------|-----------------------------------------------------------------------------------------------------------------|
+| train engine | FSDP2
Megatron |
+| rollout engine | vLLM |
+| AdvantageEstimator | GRPO
GRPO_PASSK
REINFORCE_PLUS_PLUS
RLOO
OPO
REINFORCE_PLUS_PLUS_BASELINE
GPG |
+| Reward | all |
diff --git a/recipe/one_step_off_policy/config/one_step_off_ppo_megatron_trainer.yaml b/recipe/one_step_off_policy/config/one_step_off_ppo_megatron_trainer.yaml
new file mode 100644
index 00000000000..f5a3c6b58e6
--- /dev/null
+++ b/recipe/one_step_off_policy/config/one_step_off_ppo_megatron_trainer.yaml
@@ -0,0 +1,14 @@
+hydra:
+ searchpath:
+ - file://verl/trainer/config
+
+defaults:
+ - ppo_megatron_trainer
+ - _self_
+
+# config for the rollout (only for resource isolation)
+rollout:
+ # Number of nodes used in the rollout
+ nnodes: 1
+ # Number of GPUs per node
+ n_gpus_per_node: 8
\ No newline at end of file
diff --git a/recipe/one_step_off_policy/config/one_step_off_ppo_trainer.yaml b/recipe/one_step_off_policy/config/one_step_off_ppo_trainer.yaml
new file mode 100644
index 00000000000..05890e9b5a2
--- /dev/null
+++ b/recipe/one_step_off_policy/config/one_step_off_ppo_trainer.yaml
@@ -0,0 +1,14 @@
+hydra:
+ searchpath:
+ - file://verl/trainer/config
+
+defaults:
+ - ppo_trainer
+ - _self_
+
+# config for the rollout (only for resource isolation)
+rollout:
+ # Number of nodes used in the rollout
+ nnodes: 1
+ # Number of GPUs per node
+ n_gpus_per_node: 8
\ No newline at end of file
diff --git a/recipe/one_step_off_policy/dapo_7b_math_fsdp2_4_12.sh b/recipe/one_step_off_policy/dapo_7b_math_fsdp2_4_12.sh
new file mode 100644
index 00000000000..7f4ca919a4c
--- /dev/null
+++ b/recipe/one_step_off_policy/dapo_7b_math_fsdp2_4_12.sh
@@ -0,0 +1,139 @@
+#!/usr/bin/env bash
+set -xeuo pipefail
+
+project_name='DAPO'
+exp_name='DAPO-Qwen2.5-7b-MATH-0527a1-fsdp2-one-step-off-4-12'
+
+adv_estimator=grpo
+
+use_kl_in_reward=False
+kl_coef=0.0
+use_kl_loss=False
+kl_loss_coef=0.0
+
+clip_ratio_low=0.2
+clip_ratio_high=0.28
+
+max_prompt_length=$((1024 * 2))
+max_response_length=$((1024 * 8))
+enable_overlong_buffer=True
+overlong_buffer_len=$((1024 * 4))
+overlong_penalty_factor=1.0
+
+loss_agg_mode="token-mean"
+
+train_prompt_bsz=512
+n_resp_per_prompt=12
+train_prompt_mini_bsz=32
+
+# Ray
+# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
+# WORKING_DIR=${WORKING_DIR:-"${PWD}"}
+# RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
+NNODES=${NNODES:-2}
+NGPUS_PER_NODE=${NGPUS_PER_NODE:-8}
+
+n_gpus_rollout=2
+n_gpus_training=$((NGPUS_PER_NODE - n_gpus_rollout))
+
+# Paths
+RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
+# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface
+MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"}
+CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
+TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"}
+TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"}
+
+
+# Algorithm
+temperature=1.0
+top_p=1.0
+top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
+val_top_p=0.7
+
+# Performance Related Parameter
+use_dynamic_bsz=True
+actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))
+infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))
+ref_offload=True
+actor_offload=False
+gen_tp=2
+sp_size=4
+fsdp_size=2
+
+python3 -m recipe.one_step_off_policy.main_ppo \
+ data.train_files="${TRAIN_FILE}" \
+ data.val_files="${TEST_FILE}" \
+ data.prompt_key=prompt \
+ data.truncation='left' \
+ data.max_prompt_length=${max_prompt_length} \
+ data.max_response_length=${max_response_length} \
+ data.train_batch_size=${train_prompt_bsz} \
+ actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
+ algorithm.adv_estimator=${adv_estimator} \
+ algorithm.use_kl_in_reward=${use_kl_in_reward} \
+ algorithm.kl_ctrl.kl_coef=${kl_coef} \
+ actor_rollout_ref.actor.strategy=fsdp2 \
+ critic.strategy=fsdp2 \
+ actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
+ actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
+ actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
+ actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
+ actor_rollout_ref.actor.clip_ratio_c=10.0 \
+ actor_rollout_ref.model.use_remove_padding=True \
+ actor_rollout_ref.hybrid_engine=False \
+ +actor_rollout_ref.model.override_config.max_position_embeddings=32768 \
+ actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
+ actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
+ actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
+ actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
+ actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
+ actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
+ actor_rollout_ref.model.path="${MODEL_PATH}" \
+ actor_rollout_ref.model.enable_gradient_checkpointing=True \
+ actor_rollout_ref.actor.optim.lr=1e-6 \
+ actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
+ actor_rollout_ref.actor.optim.weight_decay=0.1 \
+ actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
+ actor_rollout_ref.actor.fsdp_config.param_offload=${actor_offload} \
+ actor_rollout_ref.actor.fsdp_config.optimizer_offload=${actor_offload} \
+ actor_rollout_ref.actor.entropy_coeff=0 \
+ actor_rollout_ref.actor.grad_clip=1.0 \
+ actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
+ actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
+ actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \
+ actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
+ actor_rollout_ref.rollout.enable_chunked_prefill=True \
+ actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
+ actor_rollout_ref.rollout.temperature=${temperature} \
+ actor_rollout_ref.rollout.top_p=${top_p} \
+ actor_rollout_ref.rollout.top_k=${top_k} \
+ actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
+ actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \
+ actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
+ actor_rollout_ref.rollout.val_kwargs.do_sample=True \
+ actor_rollout_ref.rollout.val_kwargs.n=1 \
+ actor_rollout_ref.ref.fsdp_config.param_offload=${ref_offload} \
+ actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
+ actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \
+ reward_model.reward_manager=dapo \
+ +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \
+ +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \
+ +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \
+ +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \
+ +reward_model.reward_kwargs.max_resp_len=${max_response_length} \
+ trainer.logger=['console','tensorboard'] \
+ trainer.project_name="${project_name}" \
+ trainer.experiment_name="${exp_name}" \
+ trainer.val_before_train=True \
+ trainer.test_freq=10 \
+ trainer.save_freq=-1 \
+ trainer.total_epochs=10 \
+ trainer.total_training_steps=100 \
+ trainer.default_local_dir="${CKPTS_DIR}" \
+ trainer.resume_mode=auto \
+ trainer.log_val_generations=10 \
+ trainer.nnodes="${NNODES}" \
+ trainer.n_gpus_per_node="${n_gpus_training}" \
+ rollout.nnodes="${NNODES}" \
+ rollout.n_gpus_per_node="${n_gpus_rollout}"
\ No newline at end of file
diff --git a/recipe/one_step_off_policy/dapo_7b_math_fsdp2_colocate.sh b/recipe/one_step_off_policy/dapo_7b_math_fsdp2_colocate.sh
new file mode 100644
index 00000000000..33b65b06080
--- /dev/null
+++ b/recipe/one_step_off_policy/dapo_7b_math_fsdp2_colocate.sh
@@ -0,0 +1,131 @@
+#!/usr/bin/env bash
+set -xeuo pipefail
+
+project_name='DAPO'
+exp_name='DAPO-Qwen2.5-7b-MATH-0527a1-fsdp2-colocate'
+
+adv_estimator=grpo
+
+use_kl_in_reward=False
+kl_coef=0.0
+use_kl_loss=False
+kl_loss_coef=0.0
+
+clip_ratio_low=0.2
+clip_ratio_high=0.28
+
+max_prompt_length=$((1024 * 2))
+max_response_length=$((1024 * 8))
+enable_overlong_buffer=True
+overlong_buffer_len=$((1024 * 4))
+overlong_penalty_factor=1.0
+
+loss_agg_mode="token-mean"
+
+train_prompt_bsz=512
+n_resp_per_prompt=12
+train_prompt_mini_bsz=32
+
+# Ray
+# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
+# WORKING_DIR=${WORKING_DIR:-"${PWD}"}
+# RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
+NNODES=${NNODES:-2}
+NGPUS_PER_NODE=${NGPUS_PER_NODE:-8}
+# Paths
+RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
+# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface
+MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"}
+CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
+TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"}
+TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"}
+# Algorithm
+temperature=1.0
+top_p=1.0
+top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
+val_top_p=0.7
+
+# Performance Related Parameter
+use_dynamic_bsz=True
+actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))
+infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))
+offload=True
+gen_tp=2
+sp_size=4
+fsdp_size=2
+
+# reference run wandb: https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/runs/ow47vvon?nw=nwusertongyuxuan361
+
+python3 -m verl.trainer.main_ppo \
+ data.train_files="${TRAIN_FILE}" \
+ data.val_files="${TEST_FILE}" \
+ data.prompt_key=prompt \
+ data.truncation='left' \
+ data.max_prompt_length=${max_prompt_length} \
+ data.max_response_length=${max_response_length} \
+ data.train_batch_size=${train_prompt_bsz} \
+ actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
+ algorithm.adv_estimator=${adv_estimator} \
+ algorithm.use_kl_in_reward=${use_kl_in_reward} \
+ algorithm.kl_ctrl.kl_coef=${kl_coef} \
+ actor_rollout_ref.actor.strategy=fsdp2 \
+ critic.strategy=fsdp2 \
+ actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
+ actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
+ actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
+ actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
+ actor_rollout_ref.actor.clip_ratio_c=10.0 \
+ actor_rollout_ref.model.use_remove_padding=True \
+ +actor_rollout_ref.model.override_config.max_position_embeddings=32768 \
+ actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
+ actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
+ actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
+ actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
+ actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
+ actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
+ actor_rollout_ref.model.path="${MODEL_PATH}" \
+ actor_rollout_ref.model.enable_gradient_checkpointing=True \
+ actor_rollout_ref.actor.optim.lr=1e-6 \
+ actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
+ actor_rollout_ref.actor.optim.weight_decay=0.1 \
+ actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
+ actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
+ actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \
+ actor_rollout_ref.actor.entropy_coeff=0 \
+ actor_rollout_ref.actor.grad_clip=1.0 \
+ actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
+ actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
+ actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \
+ actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
+ actor_rollout_ref.rollout.enable_chunked_prefill=True \
+ actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
+ actor_rollout_ref.rollout.temperature=${temperature} \
+ actor_rollout_ref.rollout.top_p=${top_p} \
+ actor_rollout_ref.rollout.top_k=${top_k} \
+ actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
+ actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \
+ actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
+ actor_rollout_ref.rollout.val_kwargs.do_sample=True \
+ actor_rollout_ref.rollout.val_kwargs.n=1 \
+ actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \
+ actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
+ actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \
+ reward_model.reward_manager=dapo \
+ +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \
+ +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \
+ +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \
+ +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \
+ +reward_model.reward_kwargs.max_resp_len=${max_response_length} \
+ trainer.logger=['console','tensorboard'] \
+ trainer.project_name="${project_name}" \
+ trainer.experiment_name="${exp_name}" \
+ trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \
+ trainer.nnodes="${NNODES}" \
+ trainer.val_before_train=True \
+ trainer.test_freq=10 \
+ trainer.save_freq=-1 \
+ trainer.total_epochs=10 \
+ trainer.total_training_steps=100 \
+ trainer.default_local_dir="${CKPTS_DIR}" \
+ trainer.resume_mode=auto \
+ trainer.log_val_generations=10
diff --git a/recipe/one_step_off_policy/dapo_7b_math_megatron_4_12.sh b/recipe/one_step_off_policy/dapo_7b_math_megatron_4_12.sh
new file mode 100644
index 00000000000..e246f01fbf5
--- /dev/null
+++ b/recipe/one_step_off_policy/dapo_7b_math_megatron_4_12.sh
@@ -0,0 +1,146 @@
+#!/usr/bin/env bash
+set -xeuo pipefail
+
+project_name='DAPO'
+exp_name='DAPO-Qwen2.5-7b-MATH-0527a1-megatron-one-step-off-4-12'
+
+adv_estimator=grpo
+
+use_kl_in_reward=False
+kl_coef=0.0
+use_kl_loss=False
+kl_loss_coef=0.0
+
+clip_ratio_low=0.2
+clip_ratio_high=0.28
+
+max_prompt_length=$((1024 * 2))
+max_response_length=$((1024 * 8))
+enable_overlong_buffer=True
+overlong_buffer_len=$((1024 * 4))
+overlong_penalty_factor=1.0
+
+loss_agg_mode="token-mean"
+
+train_prompt_bsz=512
+n_resp_per_prompt=12
+train_prompt_mini_bsz=32
+
+
+# Ray
+# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
+# WORKING_DIR=${WORKING_DIR:-"${PWD}"}
+# RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
+NNODES=${NNODES:-2}
+NGPUS_PER_NODE=${NGPUS_PER_NODE:-8}
+
+n_gpus_rollout=2
+n_gpus_training=$((NGPUS_PER_NODE - n_gpus_rollout))
+
+# Paths
+RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
+# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface
+MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"}
+CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
+TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"}
+TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"}
+# Algorithm
+temperature=1.0
+top_p=1.0
+top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
+val_top_p=0.7
+
+# Performance Related Parameter
+use_dynamic_bsz=True
+actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))
+infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))
+ref_offload=True
+actor_offload=False
+gen_tp=2
+train_tp=2
+train_pp=2
+
+# TODO: support dynamic_bsz for megatron
+# actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
+# actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
+# actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
+# actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
+# actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
+# actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
+
+python3 -m recipe.one_step_off_policy.main_ppo \
+ --config-path=config \
+ --config-name='one_step_off_ppo_megatron_trainer.yaml' \
+ data.train_files="${TRAIN_FILE}" \
+ data.val_files="${TEST_FILE}" \
+ data.prompt_key=prompt \
+ data.truncation='left' \
+ data.max_prompt_length=${max_prompt_length} \
+ data.max_response_length=${max_response_length} \
+ data.train_batch_size=${train_prompt_bsz} \
+ actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
+ algorithm.adv_estimator=${adv_estimator} \
+ algorithm.use_kl_in_reward=${use_kl_in_reward} \
+ algorithm.kl_ctrl.kl_coef=${kl_coef} \
+ actor_rollout_ref.actor.strategy=megatron \
+ critic.strategy=megatron \
+ actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
+ actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
+ actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
+ actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
+ actor_rollout_ref.actor.clip_ratio_c=10.0 \
+ actor_rollout_ref.hybrid_engine=False \
+ +actor_rollout_ref.model.override_config.max_position_embeddings=32768 \
+ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \
+ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \
+ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \
+ actor_rollout_ref.model.path="${MODEL_PATH}" \
+ actor_rollout_ref.model.enable_gradient_checkpointing=True \
+ actor_rollout_ref.actor.optim.lr=1e-6 \
+ actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
+ actor_rollout_ref.actor.optim.weight_decay=0.1 \
+ actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
+ actor_rollout_ref.actor.megatron.param_offload=${actor_offload} \
+ actor_rollout_ref.actor.megatron.optimizer_offload=${actor_offload} \
+ actor_rollout_ref.actor.megatron.grad_offload=${actor_offload} \
+ actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} \
+ actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} \
+ actor_rollout_ref.actor.entropy_coeff=0 \
+ actor_rollout_ref.actor.optim.clip_grad=1.0 \
+ actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
+ actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \
+ actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
+ actor_rollout_ref.rollout.enable_chunked_prefill=True \
+ actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
+ actor_rollout_ref.rollout.temperature=${temperature} \
+ actor_rollout_ref.rollout.top_p=${top_p} \
+ actor_rollout_ref.rollout.top_k=${top_k} \
+ actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
+ actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \
+ actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
+ actor_rollout_ref.rollout.val_kwargs.do_sample=True \
+ actor_rollout_ref.rollout.val_kwargs.n=1 \
+ actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${train_pp} \
+ actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${train_tp} \
+ actor_rollout_ref.ref.megatron.param_offload=${ref_offload} \
+ reward_model.reward_manager=dapo \
+ +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \
+ +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \
+ +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \
+ +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \
+ +reward_model.reward_kwargs.max_resp_len=${max_response_length} \
+ trainer.logger=['console','tensorboard'] \
+ trainer.project_name="${project_name}" \
+ trainer.experiment_name="${exp_name}" \
+ trainer.val_before_train=True \
+ trainer.test_freq=10 \
+ trainer.save_freq=-1 \
+ trainer.total_epochs=10 \
+ trainer.total_training_steps=100 \
+ trainer.default_local_dir="${CKPTS_DIR}" \
+ trainer.resume_mode=auto \
+ trainer.log_val_generations=10 \
+ trainer.nnodes="${NNODES}" \
+ trainer.n_gpus_per_node="${n_gpus_training}" \
+ rollout.nnodes="${NNODES}" \
+ rollout.n_gpus_per_node="${n_gpus_rollout}"
diff --git a/recipe/one_step_off_policy/dapo_7b_math_megatron_colocate.sh b/recipe/one_step_off_policy/dapo_7b_math_megatron_colocate.sh
new file mode 100644
index 00000000000..c22da0fd440
--- /dev/null
+++ b/recipe/one_step_off_policy/dapo_7b_math_megatron_colocate.sh
@@ -0,0 +1,138 @@
+#!/usr/bin/env bash
+set -xeuo pipefail
+
+project_name='DAPO'
+exp_name='DAPO-Qwen2.5-7b-MATH-0519a1-megatron-colocate'
+
+adv_estimator=grpo
+
+use_kl_in_reward=False
+kl_coef=0.0
+use_kl_loss=False
+kl_loss_coef=0.0
+
+clip_ratio_low=0.2
+clip_ratio_high=0.28
+
+max_prompt_length=$((1024 * 2))
+max_response_length=$((1024 * 8))
+enable_overlong_buffer=True
+overlong_buffer_len=$((1024 * 4))
+overlong_penalty_factor=1.0
+
+loss_agg_mode="token-mean"
+
+train_prompt_bsz=512
+n_resp_per_prompt=16
+train_prompt_mini_bsz=32
+
+# Ray
+# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
+# WORKING_DIR=${WORKING_DIR:-"${PWD}"}
+# RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
+NNODES=${NNODES:-2}
+NGPUS_PER_NODE=${NGPUS_PER_NODE:-8}
+# Paths
+RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
+# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface
+MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"}
+CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
+TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"}
+TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"}
+
+# Algorithm
+temperature=1.0
+top_p=1.0
+top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
+val_top_p=0.7
+
+# Performance Related Parameter
+use_dynamic_bsz=True
+actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))
+infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))
+offload=True
+gen_tp=2
+train_tp=2
+train_pp=2
+
+# TODO: support dynamic_bsz for megatron
+# actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
+# actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
+# actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
+# actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
+# actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
+# actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
+
+python3 -m verl.trainer.main_ppo \
+ --config-path=config \
+ --config-name='ppo_megatron_trainer.yaml' \
+ data.train_files="${TRAIN_FILE}" \
+ data.val_files="${TEST_FILE}" \
+ data.prompt_key=prompt \
+ data.truncation='left' \
+ data.max_prompt_length=${max_prompt_length} \
+ data.max_response_length=${max_response_length} \
+ data.train_batch_size=${train_prompt_bsz} \
+ actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
+ algorithm.adv_estimator=${adv_estimator} \
+ algorithm.use_kl_in_reward=${use_kl_in_reward} \
+ algorithm.kl_ctrl.kl_coef=${kl_coef} \
+ actor_rollout_ref.actor.strategy=megatron \
+ critic.strategy=megatron \
+ actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
+ actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
+ actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
+ actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
+ actor_rollout_ref.actor.clip_ratio_c=10.0 \
+ +actor_rollout_ref.model.override_config.max_position_embeddings=32768 \
+ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \
+ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \
+ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \
+ actor_rollout_ref.model.path="${MODEL_PATH}" \
+ actor_rollout_ref.model.enable_gradient_checkpointing=True \
+ actor_rollout_ref.actor.optim.lr=1e-6 \
+ actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
+ actor_rollout_ref.actor.optim.weight_decay=0.1 \
+ actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
+ actor_rollout_ref.actor.megatron.param_offload=${offload} \
+ actor_rollout_ref.actor.megatron.optimizer_offload=${offload} \
+ actor_rollout_ref.actor.megatron.grad_offload=${offload} \
+ actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} \
+ actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} \
+ actor_rollout_ref.actor.entropy_coeff=0 \
+ actor_rollout_ref.actor.optim.clip_grad=1.0 \
+ actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
+ actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \
+ actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
+ actor_rollout_ref.rollout.enable_chunked_prefill=True \
+ actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
+ actor_rollout_ref.rollout.temperature=${temperature} \
+ actor_rollout_ref.rollout.top_p=${top_p} \
+ actor_rollout_ref.rollout.top_k=${top_k} \
+ actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
+ actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \
+ actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
+ actor_rollout_ref.rollout.val_kwargs.do_sample=True \
+ actor_rollout_ref.rollout.val_kwargs.n=1 \
+ actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${train_pp} \
+ actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${train_tp} \
+ actor_rollout_ref.ref.megatron.param_offload=${offload} \
+ reward_model.reward_manager=dapo \
+ +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \
+ +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \
+ +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \
+ +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \
+ +reward_model.reward_kwargs.max_resp_len=${max_response_length} \
+ trainer.logger=['console','tensorboard'] \
+ trainer.project_name="${project_name}" \
+ trainer.experiment_name="${exp_name}" \
+ trainer.n_gpus_per_node=8 \
+ trainer.nnodes="${NNODES}" \
+ trainer.val_before_train=True \
+ trainer.test_freq=10 \
+ trainer.save_freq=-1 \
+ trainer.total_epochs=10 \
+ trainer.total_training_steps=100 \
+ trainer.default_local_dir="${CKPTS_DIR}" \
+ trainer.resume_mode=auto \
+ trainer.log_val_generations=10
diff --git a/recipe/one_step_off_policy/fsdp_workers.py b/recipe/one_step_off_policy/fsdp_workers.py
new file mode 100644
index 00000000000..c2600f7ebd7
--- /dev/null
+++ b/recipe/one_step_off_policy/fsdp_workers.py
@@ -0,0 +1,228 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+# Copyright 2025 Meituan Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+
+import torch
+import torch.distributed
+from omegaconf import DictConfig, OmegaConf
+from torch.distributed.device_mesh import init_device_mesh
+from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+from transformers import AutoConfig
+
+from verl.single_controller.base import Worker
+from verl.single_controller.base.decorator import Dispatch, register
+from verl.utils import hf_processor, hf_tokenizer, omega_conf_to_dataclass
+from verl.utils.debug import DistProfiler, DistProfilerExtension, log_gpu_memory_usage
+from verl.utils.device import (
+ get_device_name,
+ get_nccl_backend,
+ get_torch_device,
+)
+from verl.utils.fs import copy_to_local
+from verl.utils.fsdp_utils import (
+ fsdp_version,
+)
+from verl.utils.import_utils import import_external_libs
+from verl.utils.model import get_generation_config, update_model_config
+from verl.utils.vllm_utils import patch_vllm_moe_model_weight_loader
+from verl.workers.fsdp_workers import ActorRolloutRefWorker as ARRWorker
+from verl.workers.fsdp_workers import CriticWorker
+
+logger = logging.getLogger(__file__)
+logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
+
+device_name = get_device_name()
+
+__all__ = ["ActorRolloutRefWorker", "AsyncActorRolloutRefWorker", "CriticWorker", "RolloutWorker"]
+
+
+class ActorRolloutRefWorker(ARRWorker):
+ def _get_actor_params(self):
+ assert self._is_actor
+ params = self.actor_module_fsdp.state_dict()
+ from verl.utils.model import convert_weight_keys
+
+ params = convert_weight_keys(
+ params, getattr(self.actor_module_fsdp, "_fsdp_wrapped_module", self.actor_module_fsdp)
+ )
+ return params
+
+ @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
+ def sync_rollout_weights(self):
+ assert (self._is_actor or self._is_rollout) and not self.config.hybrid_engine
+ assert hasattr(self, "_weights_info") and self._weights_info is not None
+
+ params = self._get_actor_params() if self._is_actor else None
+ if self._is_rollout:
+ inference_model = (
+ self.rollout.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model
+ )
+ patch_vllm_moe_model_weight_loader(inference_model)
+ for key, shape, dtype in self._weights_info:
+ tensor = torch.empty(shape, dtype=dtype, device=get_torch_device().current_device())
+ if self._is_actor:
+ assert key in params
+ origin_data = params[key]
+ if hasattr(origin_data, "full_tensor"):
+ origin_data = origin_data.full_tensor()
+ if torch.distributed.get_rank() == 0:
+ tensor.copy_(origin_data)
+ from ray.util.collective import collective
+
+ collective.broadcast(tensor, src_rank=0, group_name="actor_rollout")
+ if self._is_rollout:
+ inference_model.load_weights([(key, tensor)])
+
+ @register(dispatch_mode=Dispatch.ONE_TO_ALL)
+ def get_actor_weights_info(self):
+ assert self._is_actor
+ if hasattr(self, "_weights_info"):
+ return self._weights_info
+ if fsdp_version(self.actor_module_fsdp) == 1:
+ from torch.distributed.fsdp.api import ShardedStateDictConfig, StateDictType
+
+ FSDP.set_state_dict_type(
+ self.actor_module_fsdp,
+ state_dict_type=StateDictType.SHARDED_STATE_DICT,
+ state_dict_config=ShardedStateDictConfig(),
+ )
+ params = self._get_actor_params()
+ ret = []
+ for key, tensor in params.items():
+ ret.append((key, tensor.size(), tensor.dtype))
+ self._weights_info = ret
+ return ret
+
+
+class RolloutWorker(ActorRolloutRefWorker):
+ def __init__(self, config: DictConfig, role: str):
+ Worker.__init__(self)
+ assert role == "rollout"
+ self.config = config
+ import torch.distributed
+
+ if not torch.distributed.is_initialized():
+ rank = int(os.environ.get("RANK", 0))
+ world_size = int(os.environ.get("WORLD_SIZE", 1))
+ torch.distributed.init_process_group(
+ backend=f"cpu:gloo,{get_device_name()}:{get_nccl_backend()}",
+ rank=rank,
+ world_size=world_size,
+ init_method=os.environ.get("DIST_INIT_METHOD", None),
+ )
+ # TODO(haibin.lin):
+ # As of now the type of config is DictConfig, if we assign config.profiler with ProfilerConfig,
+ # it will actually convert the ProfilerConfig dataclass back to a DictConfig.
+ # We can still use ProfilerConfig for testing purpose (tests/utils/test_nvtx_profile.py)
+ # as they provides DictConfig-like interface
+ # The benefit of creating the dataclass config is to perform validation during __post_init__
+ profiler_config = omega_conf_to_dataclass(config.rollout.get("profiler", {}))
+ DistProfilerExtension.__init__(self, DistProfiler(rank=self.rank, config=profiler_config))
+ self._is_rollout = True
+ self._is_actor = False
+
+ @register(dispatch_mode=Dispatch.ONE_TO_ALL)
+ def init_model(self):
+ # This is used to import external_lib into the huggingface systems
+ import_external_libs(self.config.model.get("external_lib", None))
+ override_model_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create()))
+
+ use_shm = self.config.model.get("use_shm", False)
+ local_path = copy_to_local(self.config.model.path, use_shm=use_shm)
+ trust_remote_code = self.config.model.get("trust_remote_code", False)
+
+ self.tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
+ self.processor = hf_processor(local_path, trust_remote_code=trust_remote_code)
+
+ if self.config.model.get("custom_chat_template", None) is not None:
+ if self.processor is not None:
+ self.processor.chat_template = self.config.model.custom_chat_template
+ else:
+ self.tokenizer.chat_template = self.config.model.custom_chat_template
+
+ # override model kwargs
+ actor_model_config = AutoConfig.from_pretrained(
+ local_path, trust_remote_code=trust_remote_code, attn_implementation="flash_attention_2"
+ )
+
+ # patch for kimi-vl
+ if getattr(actor_model_config, "model_type", None) == "kimi_vl":
+ actor_model_config.text_config.topk_method = "greedy"
+
+ self.generation_config = get_generation_config(local_path, trust_remote_code=trust_remote_code)
+
+ override_config_kwargs = {
+ "bos_token_id": self.tokenizer.bos_token_id,
+ "eos_token_id": self.tokenizer.eos_token_id,
+ "pad_token_id": self.tokenizer.pad_token_id,
+ }
+ override_config_kwargs.update(override_model_config)
+ update_model_config(actor_model_config, override_config_kwargs=override_config_kwargs)
+ if self.rank == 0:
+ print(f"Model config after override: {actor_model_config}")
+
+ infer_tp = self.config.rollout.tensor_model_parallel_size
+ dp = self.world_size // infer_tp
+ assert self.world_size % infer_tp == 0, (
+ f"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}"
+ )
+ rollout_device_mesh = init_device_mesh(
+ device_name, mesh_shape=(dp, infer_tp), mesh_dim_names=["dp", "infer_tp"]
+ )
+ rollout_name = self.config.rollout.name
+ assert rollout_name == "vllm"
+
+ from verl.workers.rollout.vllm_rollout import vLLMRollout
+
+ log_gpu_memory_usage(f"Before building {rollout_name} rollout", logger=logger)
+
+ from verl.workers.rollout.vllm_rollout import vLLMAsyncRollout
+
+ vllm_rollout_cls = vLLMRollout if self.config.rollout.mode == "sync" else vLLMAsyncRollout
+ rollout = vllm_rollout_cls(
+ model_path=local_path,
+ config=self.config.rollout,
+ tokenizer=self.tokenizer,
+ model_hf_config=actor_model_config,
+ device_mesh=rollout_device_mesh,
+ trust_remote_code=trust_remote_code,
+ )
+ log_gpu_memory_usage(f"After building {rollout_name} rollout", logger=logger)
+ from .vllm_sharding_manager import VLLMShardingManager
+
+ rollout_sharding_manager = VLLMShardingManager(
+ inference_engine=rollout.inference_engine, device_mesh=rollout_device_mesh
+ )
+
+ log_gpu_memory_usage("After building sharding manager", logger=logger)
+
+ self.rollout = rollout
+ self.rollout_sharding_manager = rollout_sharding_manager
+
+ @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO, blocking=False)
+ def async_generate_sequences(self, *args, **kwargs):
+ return super().generate_sequences(*args, **kwargs)
+
+ @register(dispatch_mode=Dispatch.ONE_TO_ALL)
+ def set_actor_weights_info(self, weights_info):
+ assert self._is_rollout
+ self._weights_info = weights_info
+
+
+class AsyncActorRolloutRefWorker(ActorRolloutRefWorker):
+ def __init__(self, *args, **kwargs):
+ raise NotImplementedError
diff --git a/recipe/one_step_off_policy/grpo_0.6b_gsm8k_fsdp2_2_6.sh b/recipe/one_step_off_policy/grpo_0.6b_gsm8k_fsdp2_2_6.sh
new file mode 100644
index 00000000000..09048fd0340
--- /dev/null
+++ b/recipe/one_step_off_policy/grpo_0.6b_gsm8k_fsdp2_2_6.sh
@@ -0,0 +1,65 @@
+set -x
+
+project_name='GRPO'
+exp_name='GRPO-Qwen3-0.6b-gsm8k-fsdp2-one-step-off-2-6'
+
+# Paths
+RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
+MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen3-0.6B"}
+CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
+TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/gsm8k/train.parquet"}
+TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/gsm8k/test.parquet"}
+
+NNODES=${NNODES:-1}
+NGPUS_PER_NODE=${NGPUS_PER_NODE:-8}
+
+n_gpus_rollout=2
+n_gpus_training=$((NGPUS_PER_NODE - n_gpus_rollout))
+
+
+python3 -m recipe.one_step_off_policy.main_ppo \
+ algorithm.adv_estimator=grpo \
+ data.train_files="${TRAIN_FILE}" \
+ data.val_files="${TEST_FILE}" \
+ data.train_batch_size=1152 \
+ data.max_prompt_length=512 \
+ data.max_response_length=1024 \
+ data.filter_overlong_prompts=True \
+ data.truncation='error' \
+ actor_rollout_ref.actor.strategy=fsdp2 \
+ critic.strategy=fsdp2 \
+ actor_rollout_ref.model.path="${MODEL_PATH}" \
+ actor_rollout_ref.actor.optim.lr=1e-6 \
+ actor_rollout_ref.hybrid_engine=False \
+ actor_rollout_ref.model.use_remove_padding=True \
+ actor_rollout_ref.actor.ppo_mini_batch_size=192 \
+ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \
+ actor_rollout_ref.actor.use_kl_loss=True \
+ actor_rollout_ref.actor.kl_loss_coef=0.001 \
+ actor_rollout_ref.actor.kl_loss_type=low_var_kl \
+ actor_rollout_ref.actor.entropy_coeff=0 \
+ actor_rollout_ref.model.enable_gradient_checkpointing=True \
+ actor_rollout_ref.actor.fsdp_config.param_offload=False \
+ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
+ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \
+ actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
+ actor_rollout_ref.rollout.name=vllm \
+ actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
+ actor_rollout_ref.rollout.n=5 \
+ actor_rollout_ref.rollout.load_format=safetensors \
+ actor_rollout_ref.rollout.layered_summon=True \
+ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \
+ actor_rollout_ref.ref.fsdp_config.param_offload=True \
+ algorithm.use_kl_in_reward=False \
+ trainer.critic_warmup=0 \
+ trainer.val_before_train=True \
+ trainer.logger=['console','tensorboard'] \
+ trainer.project_name="${project_name}" \
+ trainer.experiment_name="${exp_name}" \
+ trainer.save_freq=-1 \
+ trainer.test_freq=5 \
+ trainer.total_epochs=2 \
+ trainer.nnodes="${NNODES}" \
+ trainer.n_gpus_per_node="${n_gpus_training}" \
+ rollout.nnodes="${NNODES}" \
+ rollout.n_gpus_per_node="${n_gpus_rollout}" $@
\ No newline at end of file
diff --git a/recipe/one_step_off_policy/grpo_3b_gsm8k_fsdp2_2_6.sh b/recipe/one_step_off_policy/grpo_3b_gsm8k_fsdp2_2_6.sh
new file mode 100644
index 00000000000..a0d3bdb8ce8
--- /dev/null
+++ b/recipe/one_step_off_policy/grpo_3b_gsm8k_fsdp2_2_6.sh
@@ -0,0 +1,64 @@
+set -x
+
+project_name='GRPO'
+exp_name='GRPO-Qwen3-0.6b-gsm8k-fsdp2-one-step-off-2-6'
+
+# Paths
+RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
+MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen/Qwen2.5-3B-Instruct"}
+CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
+TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/gsm8k/train.parquet"}
+TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/gsm8k/test.parquet"}
+
+NNODES=${NNODES:-1}
+NGPUS_PER_NODE=${NGPUS_PER_NODE:-8}
+
+n_gpus_rollout=2
+n_gpus_training=$((NGPUS_PER_NODE - n_gpus_rollout))
+
+python3 -m recipe.one_step_off_policy.main_ppo \
+ algorithm.adv_estimator=grpo \
+ data.train_files="${TRAIN_FILE}" \
+ data.val_files="${TEST_FILE}" \
+ data.train_batch_size=1152 \
+ data.max_prompt_length=512 \
+ data.max_response_length=1024 \
+ data.filter_overlong_prompts=True \
+ data.truncation='error' \
+ actor_rollout_ref.actor.strategy=fsdp2 \
+ critic.strategy=fsdp2 \
+ actor_rollout_ref.model.path="${MODEL_PATH}" \
+ actor_rollout_ref.actor.optim.lr=1e-6 \
+ actor_rollout_ref.hybrid_engine=False \
+ actor_rollout_ref.model.use_remove_padding=True \
+ actor_rollout_ref.actor.ppo_mini_batch_size=192 \
+ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \
+ actor_rollout_ref.actor.use_kl_loss=True \
+ actor_rollout_ref.actor.kl_loss_coef=0.001 \
+ actor_rollout_ref.actor.kl_loss_type=low_var_kl \
+ actor_rollout_ref.actor.entropy_coeff=0 \
+ actor_rollout_ref.model.enable_gradient_checkpointing=True \
+ actor_rollout_ref.actor.fsdp_config.param_offload=False \
+ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
+ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \
+ actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
+ actor_rollout_ref.rollout.name=vllm \
+ actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
+ actor_rollout_ref.rollout.n=5 \
+ actor_rollout_ref.rollout.load_format=safetensors \
+ actor_rollout_ref.rollout.layered_summon=True \
+ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \
+ actor_rollout_ref.ref.fsdp_config.param_offload=True \
+ algorithm.use_kl_in_reward=False \
+ trainer.critic_warmup=0 \
+ trainer.val_before_train=True \
+ trainer.logger=['console','tensorboard'] \
+ trainer.project_name="${project_name}" \
+ trainer.experiment_name="${exp_name}" \
+ trainer.save_freq=-1 \
+ trainer.test_freq=5 \
+ trainer.total_epochs=2 \
+ trainer.nnodes="${NNODES}" \
+ trainer.n_gpus_per_node="${n_gpus_training}" \
+ rollout.nnodes="${NNODES}" \
+ rollout.n_gpus_per_node="${n_gpus_rollout}" $@
\ No newline at end of file
diff --git a/recipe/one_step_off_policy/main_ppo.py b/recipe/one_step_off_policy/main_ppo.py
new file mode 100644
index 00000000000..cd7d1468da7
--- /dev/null
+++ b/recipe/one_step_off_policy/main_ppo.py
@@ -0,0 +1,228 @@
+# Copyright 2024 Bytedance Ltd. and/or its affiliates
+# Copyright 2025 Meituan Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
+"""
+
+import os
+import socket
+
+import hydra
+import ray
+from omegaconf import OmegaConf
+
+from verl.trainer.constants_ppo import PPO_RAY_RUNTIME_ENV
+from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler
+from verl.trainer.ppo.reward import load_reward_manager
+
+from .ray_trainer import OneStepOffRayTrainer
+
+
+@hydra.main(config_path="config", config_name="one_step_off_ppo_trainer", version_base=None)
+def main(config):
+ run_ppo(config)
+
+
+# Define a function to run the PPO-like training process
+def run_ppo(config) -> None:
+ # Check if Ray is not initialized
+ if not ray.is_initialized():
+ # Initialize Ray with a local cluster configuration
+ # Set environment variables in the runtime environment to control tokenizer parallelism,
+ # NCCL debug level, VLLM logging level, and allow runtime LoRA updating
+ # `num_cpus` specifies the number of CPU cores Ray can use, obtained from the configuration
+ ray.init(
+ runtime_env=PPO_RAY_RUNTIME_ENV,
+ num_cpus=config.ray_init.num_cpus,
+ )
+
+ # Create a remote instance of the TaskRunner class, and
+ # Execute the `run` method of the TaskRunner instance remotely and wait for it to complete
+ if (
+ OmegaConf.select(config.trainer, "profile_steps") is not None
+ and len(OmegaConf.select(config.trainer, "profile_steps")) > 0
+ ):
+ nsight_options = OmegaConf.to_container(config.trainer.controller_nsight_options)
+ runner = TaskRunner.options(runtime_env={"nsight": nsight_options}).remote()
+ else:
+ runner = TaskRunner.remote()
+ ray.get(runner.run.remote(config))
+
+ # [Optional] get the path of the timeline trace file from the configuration, default to None
+ # This file is used for performance analysis
+ timeline_json_file = config.ray_init.get("timeline_json_file", None)
+ if timeline_json_file:
+ ray.timeline(filename=timeline_json_file)
+
+
+@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head
+class TaskRunner:
+ def run(self, config):
+ # Print the initial configuration. `resolve=True` will evaluate symbolic values.
+ from pprint import pprint
+
+ from omegaconf import OmegaConf
+
+ from verl.utils.fs import copy_to_local
+
+ print(f"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}")
+
+ pprint(OmegaConf.to_container(config, resolve=True))
+
+ OmegaConf.resolve(config)
+
+ # Download the checkpoint from HDFS to the local machine.
+ # `use_shm` determines whether to use shared memory, which could lead to faster model loading if turned on
+ local_path = copy_to_local(
+ config.actor_rollout_ref.model.path, use_shm=config.actor_rollout_ref.model.get("use_shm", False)
+ )
+
+ # Instantiate the tokenizer and processor.
+ from verl.utils import hf_processor, hf_tokenizer
+
+ trust_remote_code = config.data.get("trust_remote_code", False)
+ tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
+ # Used for multimodal LLM, could be None
+ processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True)
+
+ # Define worker classes based on the actor strategy.
+ if config.actor_rollout_ref.actor.strategy == "fsdp2":
+ assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
+ from verl.single_controller.ray import RayWorkerGroup
+
+ from .fsdp_workers import (
+ ActorRolloutRefWorker,
+ AsyncActorRolloutRefWorker,
+ CriticWorker,
+ RolloutWorker,
+ )
+
+ actor_rollout_cls = (
+ AsyncActorRolloutRefWorker
+ if config.actor_rollout_ref.rollout.mode == "async"
+ else ActorRolloutRefWorker
+ )
+ ray_worker_group_cls = RayWorkerGroup
+
+ elif config.actor_rollout_ref.actor.strategy == "megatron":
+ assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
+ from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup
+
+ from .megatron_workers import (
+ ActorRolloutRefWorker,
+ AsyncActorRolloutRefWorker,
+ CriticWorker,
+ RolloutWorker,
+ )
+
+ actor_rollout_cls = (
+ AsyncActorRolloutRefWorker
+ if config.actor_rollout_ref.rollout.mode == "async"
+ else ActorRolloutRefWorker
+ )
+ ray_worker_group_cls = NVMegatronRayWorkerGroup
+
+ else:
+ raise NotImplementedError
+
+ from .ray_trainer import ResourcePoolManager, Role
+
+ role_worker_mapping = {
+ Role.Actor: ray.remote(actor_rollout_cls),
+ Role.Rollout: ray.remote(RolloutWorker),
+ Role.Critic: ray.remote(CriticWorker),
+ }
+
+ global_pool_id = "actor_pool"
+
+ assert config.trainer.n_gpus_per_node > 0, "config.trainer.n_gpus_per_node must be greater than 0"
+ assert config.trainer.nnodes > 0, "config.trainer.nnodes must be greater than 0"
+ assert config.rollout.n_gpus_per_node > 0, "config.rollout.n_gpus_per_node must be greater than 0"
+ assert config.rollout.nnodes > 0, "config.rollout.nnodes must be greater than 0"
+
+ actor_pool = [config.trainer.n_gpus_per_node] * config.trainer.nnodes
+ rollout_pool = [config.rollout.n_gpus_per_node] * config.rollout.nnodes
+
+ resource_pool_spec = {
+ "actor_pool": actor_pool,
+ "rollout_pool": rollout_pool,
+ }
+ mapping = {
+ Role.Actor: "actor_pool",
+ Role.Rollout: "rollout_pool",
+ Role.Critic: "actor_pool",
+ }
+ print(f"resource_pool_spec: {resource_pool_spec}")
+ # We should adopt a multi-source reward function here:
+ # - for rule-based rm, we directly call a reward score
+ # - for model-based rm, we call a model
+ # - for code related prompt, we send to a sandbox if there are test cases
+ # finally, we combine all the rewards together
+ # The reward type depends on the tag of the data
+ if config.reward_model.enable:
+ if config.reward_model.strategy in ["fsdp2"]:
+ from verl.workers.fsdp_workers import RewardModelWorker
+ elif config.reward_model.strategy == "megatron":
+ from verl.workers.megatron_workers import RewardModelWorker
+ else:
+ raise NotImplementedError
+ role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)
+ mapping[Role.RewardModel] = global_pool_id
+
+ # Add a reference policy worker if KL loss or KL reward is used.
+ if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:
+ role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker)
+ mapping[Role.RefPolicy] = global_pool_id
+
+ # Load the reward manager for training and validation.
+ reward_fn = load_reward_manager(
+ config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {})
+ )
+ val_reward_fn = load_reward_manager(
+ config, tokenizer, num_examine=1, **config.reward_model.get("reward_kwargs", {})
+ )
+ resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
+
+ from verl.utils.dataset.rl_dataset import collate_fn
+
+ # Create training and validation datasets.
+ train_dataset = create_rl_dataset(config.data.train_files, config.data, tokenizer, processor)
+ val_dataset = create_rl_dataset(config.data.val_files, config.data, tokenizer, processor)
+ train_sampler = create_rl_sampler(config.data, train_dataset)
+
+ # Initialize the PPO trainer.
+ trainer = OneStepOffRayTrainer(
+ config=config,
+ tokenizer=tokenizer,
+ processor=processor,
+ role_worker_mapping=role_worker_mapping,
+ resource_pool_manager=resource_pool_manager,
+ ray_worker_group_cls=ray_worker_group_cls,
+ reward_fn=reward_fn,
+ val_reward_fn=val_reward_fn,
+ train_dataset=train_dataset,
+ val_dataset=val_dataset,
+ collate_fn=collate_fn,
+ train_sampler=train_sampler,
+ device_name=config.trainer.device,
+ )
+ # Initialize the workers of the trainer.
+ trainer.init_workers()
+ # Start the training process.
+ trainer.fit()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/recipe/one_step_off_policy/megatron_workers.py b/recipe/one_step_off_policy/megatron_workers.py
new file mode 100644
index 00000000000..f81fbd94074
--- /dev/null
+++ b/recipe/one_step_off_policy/megatron_workers.py
@@ -0,0 +1,201 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+# Copyright 2025 Meituan Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+
+import torch
+import torch.distributed
+from omegaconf import DictConfig
+
+from verl.single_controller.base.decorator import Dispatch, register
+from verl.utils.debug import (
+ log_gpu_memory_usage,
+)
+from verl.utils.device import get_device_name, get_torch_device
+from verl.utils.fs import copy_to_local
+from verl.utils.vllm_utils import patch_vllm_moe_model_weight_loader
+from verl.workers.megatron_workers import ActorRolloutRefWorker as ARRWorker
+from verl.workers.megatron_workers import CriticWorker, RewardModelWorker
+
+logger = logging.getLogger(__file__)
+logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
+
+__all__ = ["ActorRolloutRefWorker", "AsyncActorRolloutRefWorker", "CriticWorker", "RewardModelWorker", "RolloutWorker"]
+
+
+class ActorRolloutRefWorker(ARRWorker):
+ def __init__(self, config: DictConfig, role: str):
+ assert role in ["actor", "ref"]
+ tmp_role = "ref" if role == "ref" else "actor_rollout"
+ super().__init__(config, tmp_role)
+ if role == "actor":
+ self._is_rollout = False
+ self.role = role
+
+ def _get_actor_params_generator(self):
+ assert self._is_actor
+ from verl.models.mcore import get_mcore_weight_converter
+ from verl.utils.megatron_utils import per_tensor_generator
+
+ layer_name_mapping = {
+ "qkv_layer_name": "self_attention.linear_qkv.",
+ "gate_proj_layer_name": "linear_fc1.",
+ }
+ weight_converter = get_mcore_weight_converter(self.actor_model_config, self.dtype)
+ generator = per_tensor_generator(
+ self.actor.actor_module,
+ self.actor_model_config,
+ weight_converter,
+ self.tf_config,
+ layer_name_mapping,
+ )
+ return generator
+
+ @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
+ def sync_rollout_weights(self):
+ assert (self._is_actor or self._is_rollout) and not self.config.hybrid_engine
+ assert hasattr(self, "_weights_info") and self._weights_info is not None
+
+ params_generator = self._get_actor_params_generator() if self._is_actor else None
+ if self._is_rollout:
+ inference_model = (
+ self.rollout.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model
+ )
+ patch_vllm_moe_model_weight_loader(inference_model)
+ for key, shape, dtype in self._weights_info:
+ if self._is_actor:
+ weight_key, weight = next(params_generator)
+ assert key == weight_key
+ assert shape == weight.size()
+ assert dtype == weight.dtype
+
+ tensor = torch.empty(shape, dtype=dtype, device=get_torch_device().current_device())
+ if self._is_actor and torch.distributed.get_rank() == 0:
+ tensor.copy_(weight)
+ from ray.util.collective import collective
+
+ collective.broadcast(tensor, src_rank=0, group_name="actor_rollout")
+ if self._is_rollout:
+ inference_model.load_weights([(key, tensor)])
+
+ @register(dispatch_mode=Dispatch.ONE_TO_ALL)
+ def get_actor_weights_info(self):
+ assert self._is_actor
+ if hasattr(self, "_weights_info"):
+ return self._weights_info
+
+ params_generator = self._get_actor_params_generator()
+ ret = []
+ for key, tensor in params_generator:
+ ret.append((key, tensor.size(), tensor.dtype))
+
+ self._weights_info = ret
+ return ret
+
+
+class RolloutWorker(ActorRolloutRefWorker):
+ def __init__(self, config: DictConfig, role: str):
+ assert role == "rollout"
+ ARRWorker.__init__(self, config, role)
+
+ @register(dispatch_mode=Dispatch.ONE_TO_ALL)
+ def init_model(self):
+ if self.config.model.get("external_lib", None) is not None:
+ # This is used to import external_lib into the huggingface systems
+ import importlib
+
+ importlib.import_module(self.config.model.external_lib)
+
+ from omegaconf import OmegaConf
+
+ from verl.utils.torch_dtypes import PrecisionType
+
+ override_model_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create()))
+ override_transformer_config = {}
+ self.param_dtype = torch.bfloat16
+ self.dtype = PrecisionType.to_dtype(self.param_dtype)
+ trust_remote_code = self.config.model.get("trust_remote_code", False)
+
+ from verl.utils.model import get_generation_config
+
+ self._init_hf_config_and_tf_config(
+ self.config.model.path,
+ self.config.model.path,
+ self.dtype,
+ override_model_config,
+ override_transformer_config,
+ trust_remote_code,
+ )
+ self.generation_config = get_generation_config(self.local_path)
+
+ from torch.distributed.device_mesh import init_device_mesh
+
+ assert self.config.rollout.name == "vllm"
+ assert self.config.rollout.mode == "sync"
+
+ from verl.workers.rollout.vllm_rollout import vLLMRollout
+
+ from .vllm_sharding_manager import VLLMShardingManager
+
+ # NOTE(sgm): If the QKV and gate_up projection layer are concate together in actor,
+ # we will reorganize their weight format when resharding from actor to rollout.
+
+ infer_tp = self.config.rollout.tensor_model_parallel_size
+ dp = self.world_size // infer_tp
+ assert self.world_size % infer_tp == 0, (
+ f"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}"
+ )
+ rollout_device_mesh = init_device_mesh(
+ get_device_name(), mesh_shape=(dp, infer_tp), mesh_dim_names=["dp", "infer_tp"]
+ )
+ log_gpu_memory_usage("Before building vllm rollout", logger=None)
+
+ local_path = copy_to_local(self.config.model.path, use_shm=self.config.model.get("use_shm", False))
+ from verl.workers.rollout.vllm_rollout import vLLMAsyncRollout
+
+ vllm_rollout_cls = vLLMRollout if self.config.rollout.mode == "sync" else vLLMAsyncRollout
+ rollout = vllm_rollout_cls(
+ model_path=local_path,
+ config=self.config.rollout,
+ tokenizer=self.tokenizer,
+ model_hf_config=self.hf_config,
+ device_mesh=rollout_device_mesh,
+ trust_remote_code=trust_remote_code,
+ )
+ log_gpu_memory_usage("After building vllm rollout", logger=logger)
+
+ sharding_manager = VLLMShardingManager(
+ inference_engine=rollout.inference_engine,
+ device_mesh=rollout_device_mesh,
+ )
+ log_gpu_memory_usage("After building sharding manager", logger=logger)
+
+ self.rollout, self.sharding_manager = rollout, sharding_manager
+ self.rollout.sharding_manager = sharding_manager
+
+ @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO, blocking=False)
+ def async_generate_sequences(self, *args, **kwargs):
+ return super().generate_sequences(*args, **kwargs)
+
+ @register(dispatch_mode=Dispatch.ONE_TO_ALL)
+ def set_actor_weights_info(self, weights_info):
+ assert self._is_rollout
+ self._weights_info = weights_info
+
+
+class AsyncActorRolloutRefWorker(ActorRolloutRefWorker):
+ def __init__(self, *args, **kwargs):
+ raise NotImplementedError
diff --git a/recipe/one_step_off_policy/ray_trainer.py b/recipe/one_step_off_policy/ray_trainer.py
new file mode 100644
index 00000000000..1f7011bdf54
--- /dev/null
+++ b/recipe/one_step_off_policy/ray_trainer.py
@@ -0,0 +1,624 @@
+# Copyright 2024 Bytedance Ltd. and/or its affiliates
+# Copyright 2023-2024 SGLang Team
+# Copyright 2025 ModelBest Inc. and/or its affiliates
+# Copyright 2025 Meituan Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This trainer supports model-agonistic model initialization with huggingface
+"""
+
+import uuid
+from pprint import pprint
+
+import numpy as np
+import ray
+import torch
+from omegaconf import OmegaConf
+from torch.utils.data import Dataset, Sampler
+from tqdm import tqdm
+
+from verl import DataProto
+from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup
+from verl.single_controller.ray.base import create_colocated_worker_cls
+from verl.trainer.ppo import core_algos
+from verl.trainer.ppo.core_algos import AdvantageEstimator, agg_loss
+from verl.trainer.ppo.metric_utils import (
+ compute_data_metrics,
+ compute_throughout_metrics,
+ compute_timing_metrics,
+)
+from verl.trainer.ppo.ray_trainer import (
+ RayPPOTrainer,
+ ResourcePoolManager,
+ Role,
+ WorkerType,
+ apply_kl_penalty,
+ compute_advantage,
+ compute_response_mask,
+)
+from verl.trainer.ppo.reward import compute_reward, compute_reward_async
+from verl.utils.debug import marked_timer
+from verl.utils.metric import (
+ reduce_metrics,
+)
+from verl.utils.tracking import ValidationGenerationsLogger
+
+
+class GenerationBatchFuture:
+ """
+ Wrapper class for encapsulating batch generation results
+ """
+
+ def __init__(self, epoch, batch, gen_batch_output):
+ """
+ :param epoch: current epoch
+ :param batch: Input batch data
+ :param gen_batch_output: Generated sequences from the main model (DataProtoFuture)
+ """
+ self.epoch = epoch
+ self.batch = batch
+ self.gen_batch_output = gen_batch_output
+
+ def get(self):
+ """
+ Get the actual results by calling get() method on gen_batch_output
+
+ Returns:
+ tuple: (batch, gen_batch_result)
+ - batch: Original input batch data
+ - gen_batch_result: Result from gen_batch_output.get() or gen_batch_output itself
+ """
+ # Call get() method on gen_batch_output if available
+ if hasattr(self.gen_batch_output, "get"):
+ gen_batch_result = self.gen_batch_output.get()
+ else:
+ gen_batch_result = self.gen_batch_output
+
+ return self.epoch, self.batch, gen_batch_result
+
+
+class OneStepOffRayTrainer(RayPPOTrainer):
+ # TODO: support each role have individual ray_worker_group_cls,
+ # i.e., support different backend of different role
+ def __init__(
+ self,
+ config,
+ tokenizer,
+ role_worker_mapping: dict[Role, WorkerType],
+ resource_pool_manager: ResourcePoolManager,
+ ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup,
+ processor=None,
+ reward_fn=None,
+ val_reward_fn=None,
+ train_dataset: Dataset | None = None,
+ val_dataset: Dataset | None = None,
+ collate_fn=None,
+ train_sampler: Sampler | None = None,
+ device_name="cuda",
+ ):
+ """
+ Initialize distributed PPO trainer with Ray backend.
+ Note that this trainer runs on the driver process on a single CPU/GPU node.
+
+ Args:
+ config: Configuration object containing training parameters.
+ tokenizer: Tokenizer used for encoding and decoding text.
+ role_worker_mapping (dict[Role, WorkerType]): Mapping from roles to worker classes.
+ resource_pool_manager (ResourcePoolManager): Manager for Ray resource pools.
+ ray_worker_group_cls (RayWorkerGroup, optional): Class for Ray worker groups. Defaults to RayWorkerGroup.
+ processor: Optional data processor, used for multimodal data
+ reward_fn: Function for computing rewards during training.
+ val_reward_fn: Function for computing rewards during validation.
+ train_dataset (Optional[Dataset], optional): Training dataset. Defaults to None.
+ val_dataset (Optional[Dataset], optional): Validation dataset. Defaults to None.
+ collate_fn: Function to collate data samples into batches.
+ train_sampler (Optional[Sampler], optional): Sampler for the training dataset. Defaults to None.
+ device_name (str, optional): Device name for training (e.g., "cuda", "cpu"). Defaults to "cuda".
+ """
+
+ # Store the tokenizer for text processing
+ self.tokenizer = tokenizer
+ self.processor = processor
+ self.config = config
+ self.reward_fn = reward_fn
+ self.val_reward_fn = val_reward_fn
+
+ self.hybrid_engine = config.actor_rollout_ref.hybrid_engine
+
+ assert not self.hybrid_engine
+
+ self.role_worker_mapping = role_worker_mapping
+ self.resource_pool_manager = resource_pool_manager
+ self.use_reference_policy = Role.RefPolicy in role_worker_mapping
+ self.use_rm = Role.RewardModel in role_worker_mapping
+ self.ray_worker_group_cls = ray_worker_group_cls
+ self.device_name = device_name
+ self.validation_generations_logger = ValidationGenerationsLogger()
+
+ # if ref_in_actor is True, the reference policy will be actor without lora applied
+ self.ref_in_actor = config.actor_rollout_ref.model.get("lora_rank", 0) > 0
+
+ # define in-reward KL control
+ # kl loss control currently not suppoorted
+ if config.algorithm.use_kl_in_reward:
+ self.kl_ctrl_in_reward = core_algos.get_kl_controller(config.algorithm.kl_ctrl)
+
+ if self.config.algorithm.adv_estimator == AdvantageEstimator.GAE:
+ self.use_critic = True
+ elif self.config.algorithm.adv_estimator in [
+ AdvantageEstimator.GRPO,
+ AdvantageEstimator.GRPO_PASSK,
+ AdvantageEstimator.REINFORCE_PLUS_PLUS,
+ # AdvantageEstimator.REMAX, # TODO:REMAX advantage estimator is not yet supported in one_step_off_policy
+ AdvantageEstimator.RLOO,
+ AdvantageEstimator.OPO,
+ AdvantageEstimator.REINFORCE_PLUS_PLUS_BASELINE,
+ AdvantageEstimator.GPG,
+ ]:
+ self.use_critic = False
+ else:
+ raise NotImplementedError
+
+ self._validate_config()
+ self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler)
+
+ def _validate(self):
+ self.actor_rollout_wg = self.rollout_wg
+ ret = super()._validate()
+ self.actor_rollout_wg = self.actor_wg
+ return ret
+
+ def init_workers(self):
+ """Initialize distributed training workers using Ray backend.
+
+ Creates:
+ 1. Ray resource pools from configuration
+ 2. Worker groups for each role (actor, critic, etc.)
+ """
+ self.resource_pool_manager.create_resource_pool()
+
+ self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()}
+
+ # create actor and rollout
+ for role, role_name in [(Role.Actor, "actor"), (Role.Rollout, "rollout")]:
+ resource_pool = self.resource_pool_manager.get_resource_pool(role)
+ role_cls = RayClassWithInitArgs(
+ cls=self.role_worker_mapping[role],
+ config=self.config.actor_rollout_ref,
+ role=role_name,
+ )
+ self.resource_pool_to_cls[resource_pool][role_name] = role_cls
+
+ # create critic
+ if self.use_critic:
+ resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic)
+ critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=self.config.critic)
+ self.resource_pool_to_cls[resource_pool]["critic"] = critic_cls
+
+ # create reference policy if needed
+ if self.use_reference_policy:
+ resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy)
+ ref_policy_cls = RayClassWithInitArgs(
+ self.role_worker_mapping[Role.RefPolicy],
+ config=self.config.actor_rollout_ref,
+ role="ref",
+ profile_option=self.config.trainer.npu_profile.options,
+ )
+ self.resource_pool_to_cls[resource_pool]["ref"] = ref_policy_cls
+
+ # create a reward model if reward_fn is None
+ if self.use_rm:
+ # we create a RM here
+ resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel)
+ rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model)
+ self.resource_pool_to_cls[resource_pool]["rm"] = rm_cls
+
+ # initialize WorkerGroup
+ # NOTE: if you want to use a different resource pool for each role, which can support different parallel size,
+ # you should not use `create_colocated_worker_cls`.
+ # Instead, directly pass different resource pool to different worker groups.
+ # See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information.
+ all_wg = {}
+ wg_kwargs = {} # Setting up kwargs for RayWorkerGroup
+ if OmegaConf.select(self.config.trainer, "ray_wait_register_center_timeout") is not None:
+ wg_kwargs["ray_wait_register_center_timeout"] = self.config.trainer.ray_wait_register_center_timeout
+ if OmegaConf.select(self.config.trainer, "profile_steps") is not None:
+ wg_kwargs["profile_steps"] = OmegaConf.select(self.config.trainer, "profile_steps")
+ assert OmegaConf.select(self.config.trainer, "worker_nsight_options") is not None, (
+ "worker_nsight_options must be set when profile_steps is set"
+ )
+ wg_kwargs["worker_nsight_options"] = OmegaConf.to_container(
+ OmegaConf.select(self.config.trainer, "worker_nsight_options")
+ )
+
+ for resource_pool, class_dict in self.resource_pool_to_cls.items():
+ worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)
+ wg_dict = self.ray_worker_group_cls(
+ resource_pool=resource_pool,
+ ray_cls_with_init=worker_dict_cls,
+ device_name=self.device_name,
+ **wg_kwargs,
+ )
+ spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())
+ all_wg.update(spawn_wg)
+
+ if self.use_critic:
+ self.critic_wg = all_wg["critic"]
+ self.critic_wg.init_model()
+
+ if self.use_reference_policy and not self.ref_in_actor:
+ self.ref_policy_wg = all_wg["ref"]
+ self.ref_policy_wg.init_model()
+
+ if self.use_rm:
+ self.rm_wg = all_wg["rm"]
+ self.rm_wg.init_model()
+
+ self.actor_wg = all_wg["actor"]
+ self.rollout_wg = all_wg["rollout"]
+ self.actor_wg.init_model()
+ self.rollout_wg.init_model()
+ self.actor_rollout_wg = self.actor_wg # to be compatible with the functions that not be modified
+ weights_info = self.actor_wg.get_actor_weights_info()[0]
+ self.rollout_wg.set_actor_weights_info(weights_info)
+ from ray.util.collective import collective
+
+ actor_rollout_workers = self.actor_wg.workers + self.rollout_wg.workers
+ collective.create_collective_group(
+ actor_rollout_workers,
+ len(actor_rollout_workers),
+ list(range(0, len(actor_rollout_workers))),
+ backend="nccl",
+ group_name="actor_rollout",
+ )
+ self.sync_rollout_weights()
+
+ # create async rollout manager and request scheduler
+ self.async_rollout_mode = False
+ if self.config.actor_rollout_ref.rollout.mode == "async" and self._is_rollout:
+ from verl.workers.rollout.async_server import AsyncLLMServerManager
+
+ self.async_rollout_mode = True
+ self.async_rollout_manager = AsyncLLMServerManager(
+ config=self.config,
+ worker_group=self.rollout_wg,
+ )
+
+ def sync_rollout_weights(self):
+ if not self.hybrid_engine:
+ self.actor_wg.sync_rollout_weights()
+ ray.get(self.rollout_wg.sync_rollout_weights())
+
+ def _create_continuous_iterator(self):
+ """
+ Create a continuous data iterator across epoch
+ """
+ for epoch in range(self.config.trainer.total_epochs):
+ iterator = iter(self.train_dataloader)
+ for batch_dict in iterator:
+ yield epoch, batch_dict
+
+ def _async_gen_next_batch(self, continuous_iterator):
+ """
+ Call parameter synchronization and asynchronous sequence generation.
+ """
+ try:
+ epoch, batch_dict = next(continuous_iterator)
+ except StopIteration:
+ return None
+ except Exception as e:
+ print(f"Error in async_gen_next_batch: {e}")
+ return None
+ batch = DataProto.from_single_dict(batch_dict)
+ # pop those keys for generation
+ batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"]
+ non_tensor_batch_keys_to_pop = ["raw_prompt_ids"]
+ if "multi_modal_data" in batch.non_tensor_batch:
+ non_tensor_batch_keys_to_pop.append("multi_modal_data")
+ if "raw_prompt" in batch.non_tensor_batch:
+ non_tensor_batch_keys_to_pop.append("raw_prompt")
+ if "tools_kwargs" in batch.non_tensor_batch:
+ non_tensor_batch_keys_to_pop.append("tools_kwargs")
+ if "interaction_kwargs" in batch.non_tensor_batch:
+ non_tensor_batch_keys_to_pop.append("interaction_kwargs")
+ gen_batch = batch.pop(
+ batch_keys=batch_keys_to_pop,
+ non_tensor_batch_keys=non_tensor_batch_keys_to_pop,
+ )
+ gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
+ # sync weights from actor to rollout
+ self.sync_rollout_weights()
+ # async generation
+ gen_batch_output = self.rollout_wg.async_generate_sequences(gen_batch)
+ return GenerationBatchFuture(epoch, batch, gen_batch_output)
+
+ def fit(self):
+ """
+ The training loop of PPO.
+ The driver process only need to call the compute functions of the worker group through RPC
+ to construct the PPO dataflow.
+ The light-weight advantage computation is done on the driver process.
+ """
+ from omegaconf import OmegaConf
+
+ from verl.utils.tracking import Tracking
+
+ logger = Tracking(
+ project_name=self.config.trainer.project_name,
+ experiment_name=self.config.trainer.experiment_name,
+ default_backend=self.config.trainer.logger,
+ config=OmegaConf.to_container(self.config, resolve=True),
+ )
+
+ self.global_steps = 0
+
+ # load checkpoint before doing anything
+ self._load_checkpoint()
+
+ # perform validation before training
+ # currently, we only support validation using the reward_function.
+ if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
+ val_metrics = self._validate()
+ assert val_metrics, f"{val_metrics=}"
+ pprint(f"Initial validation metrics: {val_metrics}")
+ logger.log(data=val_metrics, step=self.global_steps)
+ if self.config.trainer.get("val_only", False):
+ return
+
+ # add tqdm
+ progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress")
+
+ # we start from step 1
+ self.global_steps += 1
+ last_val_metrics = None
+
+ # across epoch iterator
+ continuous_iterator = self._create_continuous_iterator()
+
+ # Start the first asynchronous generation task.
+ batch_data_future = self._async_gen_next_batch(continuous_iterator)
+
+ while batch_data_future is not None:
+ do_profile = (
+ self.global_steps in self.config.trainer.profile_steps
+ if self.config.trainer.profile_steps is not None
+ else False
+ )
+ if do_profile:
+ self.actor_wg.start_profile()
+ if not self.hybrid_engine:
+ self.rollout_wg.start_profile()
+ if self.use_reference_policy:
+ self.ref_policy_wg.start_profile()
+ if self.use_critic:
+ self.critic_wg.start_profile()
+ if self.use_rm:
+ self.rm_wg.start_profile()
+
+ metrics = {}
+ timing_raw = {}
+ is_last_step = self.global_steps >= self.total_training_steps
+
+ with marked_timer("step", timing_raw):
+ # wait for the previous batch
+ with marked_timer("wait_prev_gen", timing_raw, color="red"):
+ epoch, batch, gen_batch_output = batch_data_future.get()
+ timing_raw.update(gen_batch_output.meta_info["timing"])
+ gen_batch_output.meta_info.pop("timing", None)
+
+ # asys next generation (with syns weights from actor to rollout)
+ with marked_timer("sync_rollout_weights", timing_raw, color="purple"):
+ if not is_last_step:
+ batch_data_future = self._async_gen_next_batch(continuous_iterator)
+
+ batch.non_tensor_batch["uid"] = np.array(
+ [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object
+ )
+ # repeat to align with repeated responses in rollout
+ batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
+ batch = batch.union(gen_batch_output)
+
+ batch.batch["response_mask"] = compute_response_mask(batch)
+ # Balance the number of valid tokens across DP ranks.
+ # NOTE: This usually changes the order of data in the `batch`,
+ # which won't affect the advantage calculation (since it's based on uid),
+ # but might affect the loss calculation (due to the change of mini-batching).
+ # TODO: Decouple the DP balancing and mini-batching.
+ if self.config.trainer.balance_batch:
+ self._balance_batch(batch, metrics=metrics)
+
+ # compute global_valid tokens
+ batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()
+
+ with marked_timer("reward", timing_raw, color="yellow"):
+ # compute reward model score
+ if self.use_rm:
+ reward_tensor = self.rm_wg.compute_rm_score(batch)
+ batch = batch.union(reward_tensor)
+
+ if self.config.reward_model.launch_reward_fn_async:
+ future_reward = compute_reward_async.remote(batch, self.config, self.tokenizer)
+ else:
+ reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn)
+
+ # recompute old_log_probs
+ with marked_timer("old_log_prob", timing_raw, color="blue"):
+ old_log_prob = self.actor_wg.compute_log_prob(batch)
+ entropys = old_log_prob.batch["entropys"]
+ response_masks = batch.batch["response_mask"]
+ loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode
+ entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode)
+ old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()}
+ metrics.update(old_log_prob_metrics)
+ old_log_prob.batch.pop("entropys")
+ batch = batch.union(old_log_prob)
+
+ if "rollout_log_probs" in batch.batch.keys():
+ # TODO: we may want to add diff of probs too.
+ rollout_old_log_probs = batch.batch["rollout_log_probs"]
+ actor_old_log_probs = batch.batch["old_log_probs"]
+ attention_mask = batch.batch["attention_mask"]
+ responses = batch.batch["responses"]
+ response_length = responses.size(1)
+ response_mask = attention_mask[:, -response_length:]
+
+ rollout_probs = torch.exp(rollout_old_log_probs)
+ actor_probs = torch.exp(actor_old_log_probs)
+ rollout_probs_diff = torch.abs(rollout_probs - actor_probs)
+ rollout_probs_diff = torch.masked_select(rollout_probs_diff, response_mask.bool())
+ rollout_probs_diff_max = torch.max(rollout_probs_diff)
+ rollout_probs_diff_mean = torch.mean(rollout_probs_diff)
+ rollout_probs_diff_std = torch.std(rollout_probs_diff)
+ metrics.update(
+ {
+ "training/rollout_probs_diff_max": rollout_probs_diff_max.detach().item(),
+ "training/rollout_probs_diff_mean": rollout_probs_diff_mean.detach().item(),
+ "training/rollout_probs_diff_std": rollout_probs_diff_std.detach().item(),
+ }
+ )
+
+ if self.use_reference_policy:
+ # compute reference log_prob
+ with marked_timer("ref", timing_raw, color="olive"):
+ if not self.ref_in_actor:
+ ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
+ else:
+ ref_log_prob = self.actor_wg.compute_ref_log_prob(batch)
+ batch = batch.union(ref_log_prob)
+
+ # compute values
+ if self.use_critic:
+ with marked_timer("values", timing_raw, color="cyan"):
+ values = self.critic_wg.compute_values(batch)
+ batch = batch.union(values)
+
+ with marked_timer("adv", timing_raw, color="brown"):
+ # we combine with rule-based rm
+ reward_extra_infos_dict: dict[str, list]
+ if self.config.reward_model.launch_reward_fn_async:
+ reward_tensor, reward_extra_infos_dict = ray.get(future_reward)
+ batch.batch["token_level_scores"] = reward_tensor
+
+ if reward_extra_infos_dict:
+ batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()})
+
+ # compute rewards. apply_kl_penalty if available
+ if self.config.algorithm.use_kl_in_reward:
+ batch, kl_metrics = apply_kl_penalty(
+ batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty
+ )
+ metrics.update(kl_metrics)
+ else:
+ batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]
+
+ # compute advantages, executed on the driver process
+
+ norm_adv_by_std_in_grpo = self.config.algorithm.get(
+ "norm_adv_by_std_in_grpo", True
+ ) # GRPO adv normalization factor
+
+ batch = compute_advantage(
+ batch,
+ adv_estimator=self.config.algorithm.adv_estimator,
+ gamma=self.config.algorithm.gamma,
+ lam=self.config.algorithm.lam,
+ num_repeat=self.config.actor_rollout_ref.rollout.n,
+ norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
+ config=self.config.algorithm,
+ )
+
+ # update critic
+ if self.use_critic:
+ with marked_timer("update_critic", timing_raw, color="pink"):
+ critic_output = self.critic_wg.update_critic(batch)
+ critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"])
+ metrics.update(critic_output_metrics)
+
+ # implement critic warmup
+ if self.config.trainer.critic_warmup <= self.global_steps:
+ # update actor
+ with marked_timer("update_actor", timing_raw, color="red"):
+ batch.meta_info["multi_turn"] = self.config.actor_rollout_ref.rollout.multi_turn.enable
+ actor_output = self.actor_wg.update_actor(batch)
+ actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
+ metrics.update(actor_output_metrics)
+
+ # Log rollout generations if enabled
+ rollout_data_dir = self.config.trainer.get("rollout_data_dir", None)
+ if rollout_data_dir:
+ with marked_timer("dump_rollout_generations", timing_raw, color="green"):
+ inputs = self.tokenizer.batch_decode(batch.batch["prompts"], skip_special_tokens=True)
+ outputs = self.tokenizer.batch_decode(batch.batch["responses"], skip_special_tokens=True)
+ scores = batch.batch["token_level_scores"].sum(-1).cpu().tolist()
+ self._dump_generations(
+ inputs=inputs,
+ outputs=outputs,
+ scores=scores,
+ reward_extra_infos_dict=reward_extra_infos_dict,
+ dump_path=rollout_data_dir,
+ )
+
+ # validate
+ if (
+ self.val_reward_fn is not None
+ and self.config.trainer.test_freq > 0
+ and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0)
+ ):
+ with marked_timer("testing", timing_raw, color="green"):
+ val_metrics: dict = self._validate()
+ if is_last_step:
+ last_val_metrics = val_metrics
+ metrics.update(val_metrics)
+
+ if self.config.trainer.save_freq > 0 and (
+ is_last_step or self.global_steps % self.config.trainer.save_freq == 0
+ ):
+ with marked_timer("save_checkpoint", timing_raw, color="green"):
+ self._save_checkpoint()
+
+ # training metrics
+ metrics.update(
+ {
+ "training/global_step": self.global_steps,
+ "training/epoch": epoch,
+ }
+ )
+ # collect metrics
+ metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
+ metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
+ # TODO: implement actual tflpo and theoretical tflpo
+ n_gpus = self.resource_pool_manager.get_n_gpus()
+ metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))
+
+ # TODO: make a canonical logger that supports various backend
+ logger.log(data=metrics, step=self.global_steps)
+
+ progress_bar.update(1)
+ self.global_steps += 1
+
+ if do_profile:
+ self.actor_wg.stop_profile()
+ if not self.hybrid_engine:
+ self.rollout_wg.stop_profile()
+ if self.use_reference_policy:
+ self.ref_policy_wg.stop_profile()
+ if self.use_critic:
+ self.critic_wg.stop_profile()
+ if self.use_rm:
+ self.rm_wg.stop_profile()
+
+ if is_last_step:
+ pprint(f"Final validation metrics: {last_val_metrics}")
+ progress_bar.close()
+ return
diff --git a/recipe/one_step_off_policy/vllm_sharding_manager.py b/recipe/one_step_off_policy/vllm_sharding_manager.py
new file mode 100644
index 00000000000..c33ba585470
--- /dev/null
+++ b/recipe/one_step_off_policy/vllm_sharding_manager.py
@@ -0,0 +1,74 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+# Copyright 2025 Meituan Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+
+from torch.distributed.device_mesh import DeviceMesh
+
+from verl import DataProto
+from verl.protocol import all_gather_data_proto
+from verl.third_party.vllm import parallel_state as vllm_ps
+from verl.utils.debug import GPUMemoryLogger
+from verl.utils.device import get_torch_device
+from verl.utils.torch_functional import check_device_is_available
+from verl.workers.sharding_manager.base import BaseShardingManager
+
+logger = logging.getLogger(__file__)
+logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
+
+
+class VLLMShardingManager(BaseShardingManager):
+ @check_device_is_available()
+ def __init__(self, inference_engine, device_mesh: DeviceMesh):
+ self.device_mesh = device_mesh
+ self.inference_engine = inference_engine
+ inference_engine.wake_up()
+ assert device_mesh is not None
+ assert inference_engine is not None
+ self.tp_size = self.device_mesh["infer_tp"].size()
+ self.tp_rank = self.device_mesh["infer_tp"].get_local_rank()
+ self.timing = {}
+ gen_dp_rank = self.device_mesh["dp"].get_local_rank()
+ get_torch_device().manual_seed(gen_dp_rank + 1000)
+ self.gen_random_states = get_torch_device().get_rng_state()
+
+ @GPUMemoryLogger(role="vllm sharding_manager", logger=logger)
+ def __enter__(self):
+ get_torch_device().set_rng_state(self.gen_random_states)
+
+ @GPUMemoryLogger(role="vllm sharding_manager", logger=logger)
+ def __exit__(self, exc_type, exc_value, traceback):
+ self.gen_random_states = get_torch_device().get_rng_state()
+ self.inference_engine.reset_prefix_cache()
+
+ @GPUMemoryLogger(role="vllm sharding_manager", logger=logger)
+ def preprocess_data(self, data: DataProto) -> DataProto:
+ """All gather across tp group to make each rank has identical input."""
+ if self.tp_size == 1:
+ return data
+
+ group = vllm_ps.get_tensor_model_parallel_group().device_group
+
+ all_gather_data_proto(data=data, process_group=group)
+ return data
+
+ @GPUMemoryLogger(role="vllm sharding_manager", logger=logger)
+ def postprocess_data(self, data: DataProto) -> DataProto:
+ """Get chunk data of this tp rank since we do all gather in preprocess."""
+ if self.tp_size == 1:
+ return data
+
+ return data.chunk(chunks=self.tp_size)[self.tp_rank]
diff --git a/tests/special_e2e/run_one_step_off_policy.sh b/tests/special_e2e/run_one_step_off_policy.sh
new file mode 100755
index 00000000000..84fdd1d8113
--- /dev/null
+++ b/tests/special_e2e/run_one_step_off_policy.sh
@@ -0,0 +1,173 @@
+#!/usr/bin/env bash
+set -xeuo pipefail
+
+# Test script for one_step_off_policy E2E regression testing
+# This script runs one_step_off_policy with both FSDP2 and Megatron backends
+# to ensure the asynchronous training mechanism works correctly
+
+NUM_GPUS=${NUM_GPUS:-8}
+ACTOR_STRATEGY=${ACTOR_STRATEGY:-"fsdp2"} # fsdp2 or megatron
+
+# Download model if not exists
+MODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B-Instruct}
+MODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}}
+huggingface-cli download "${MODEL_ID}" --local-dir "${MODEL_PATH}"
+
+# Algorithm parameters
+adv_estimator=grpo
+
+use_kl_in_reward=False
+kl_coef=0.0
+use_kl_loss=False
+kl_loss_coef=0.0
+
+clip_ratio_low=0.2
+clip_ratio_high=0.28
+
+# Response length parameters
+max_prompt_length=1024
+max_response_length=2048
+enable_overlong_buffer=True
+overlong_buffer_len=128
+overlong_penalty_factor=1.0
+
+# Training parameters
+loss_agg_mode="token-mean"
+train_prompt_bsz=8
+n_resp_per_prompt=3
+train_prompt_mini_bsz=4
+
+# Temperature parameters
+temperature=1.0
+top_p=1.0
+top_k=-1
+val_top_p=0.7
+
+# One-step-off-policy specific parameters
+# Allocate 2 GPUs for rollout, remaining for training
+n_gpus_rollout=2
+n_gpus_training=$((NUM_GPUS - n_gpus_rollout))
+
+exp_name="$(basename "${MODEL_ID,,}")-one-step-off-policy-${ACTOR_STRATEGY}-minimal"
+
+echo "Running one_step_off_policy with ${ACTOR_STRATEGY} strategy"
+echo "Total GPUs: ${NUM_GPUS}, Rollout GPUs: ${n_gpus_rollout}, Training GPUs: ${n_gpus_training}"
+
+# Common parameters for both FSDP2 and Megatron
+common_params=(
+ data.train_files="${HOME}/data/gsm8k/train.parquet"
+ data.val_files="${HOME}/data/gsm8k/test.parquet"
+ data.prompt_key=prompt
+ data.truncation='left'
+ data.max_prompt_length=${max_prompt_length}
+ data.max_response_length=${max_response_length}
+ data.train_batch_size=${train_prompt_bsz}
+ actor_rollout_ref.rollout.n=${n_resp_per_prompt}
+ algorithm.adv_estimator=${adv_estimator}
+ algorithm.use_kl_in_reward=${use_kl_in_reward}
+ algorithm.kl_ctrl.kl_coef=${kl_coef}
+ actor_rollout_ref.hybrid_engine=False \
+ actor_rollout_ref.actor.use_kl_loss=${use_kl_loss}
+ actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef}
+ actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low}
+ actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high}
+ actor_rollout_ref.actor.clip_ratio_c=10.0
+ actor_rollout_ref.model.path="${MODEL_PATH}"
+ actor_rollout_ref.model.enable_gradient_checkpointing=True
+ actor_rollout_ref.actor.optim.lr=1e-6
+ actor_rollout_ref.actor.optim.lr_warmup_steps=-1
+ actor_rollout_ref.actor.optim.weight_decay=0.1
+ actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz}
+ actor_rollout_ref.actor.entropy_coeff=0
+ actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode}
+ actor_rollout_ref.rollout.gpu_memory_utilization=0.80
+ actor_rollout_ref.rollout.temperature=${temperature}
+ actor_rollout_ref.rollout.top_p=${top_p}
+ actor_rollout_ref.rollout.top_k=${top_k}
+ actor_rollout_ref.rollout.val_kwargs.temperature=${temperature}
+ actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p}
+ actor_rollout_ref.rollout.val_kwargs.top_k=${top_k}
+ actor_rollout_ref.rollout.val_kwargs.do_sample=True
+ actor_rollout_ref.rollout.val_kwargs.n=1
+ actor_rollout_ref.rollout.enable_chunked_prefill=True \
+ reward_model.reward_manager=dapo
+ +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer}
+ +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len}
+ +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor}
+ +reward_model.reward_kwargs.overlong_buffer_cfg.log=False
+ +reward_model.reward_kwargs.max_resp_len=${max_response_length}
+ trainer.logger=['console']
+ trainer.project_name='verl-test'
+ trainer.experiment_name="${exp_name}"
+ trainer.val_before_train=False
+ trainer.test_freq=-1
+ trainer.save_freq=-1
+ trainer.total_epochs=2
+ trainer.total_training_steps=2
+ trainer.resume_mode=disable
+ trainer.nnodes=1
+ trainer.n_gpus_per_node=${n_gpus_training}
+ rollout.nnodes=1
+ rollout.n_gpus_per_node=${n_gpus_rollout}
+
+)
+
+if [ "${ACTOR_STRATEGY}" == "fsdp2" ]; then
+ echo "Running with FSDP2 strategy..."
+ # FSDP2 specific parameters
+ gen_tp=2
+ sp_size=2
+ fsdp_size=2
+ ref_offload=True
+ actor_offload=False
+
+ python3 -m recipe.one_step_off_policy.main_ppo \
+ "${common_params[@]}" \
+ actor_rollout_ref.actor.strategy=fsdp2 \
+ critic.strategy=fsdp2 \
+ actor_rollout_ref.actor.grad_clip=1.0 \
+ actor_rollout_ref.model.use_remove_padding=True \
+ actor_rollout_ref.actor.use_dynamic_bsz=True \
+ actor_rollout_ref.ref.log_prob_use_dynamic_bsz=True \
+ actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=True \
+ actor_rollout_ref.actor.fsdp_config.param_offload=${actor_offload} \
+ actor_rollout_ref.actor.fsdp_config.optimizer_offload=${actor_offload} \
+ actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
+ actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
+ actor_rollout_ref.ref.fsdp_config.param_offload=${ref_offload} \
+ actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
+ actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} $@
+
+elif [ "${ACTOR_STRATEGY}" == "megatron" ]; then
+ echo "Running with Megatron strategy..."
+ # Megatron specific parameters
+ gen_tp=2
+ train_tp=1
+ train_pp=2
+ ref_offload=True
+ actor_offload=False
+
+ python3 -m recipe.one_step_off_policy.main_ppo \
+ --config-path=config \
+ --config-name='one_step_off_ppo_megatron_trainer.yaml' \
+ "${common_params[@]}" \
+ actor_rollout_ref.actor.strategy=megatron \
+ critic.strategy=megatron \
+ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \
+ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \
+ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \
+ actor_rollout_ref.actor.megatron.param_offload=${actor_offload} \
+ actor_rollout_ref.actor.megatron.optimizer_offload=${actor_offload} \
+ actor_rollout_ref.actor.megatron.grad_offload=${actor_offload} \
+ actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} \
+ actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} \
+ actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
+ actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${train_pp} \
+ actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${train_tp} \
+ actor_rollout_ref.ref.megatron.param_offload=${ref_offload} $@
+else
+ echo "Error: Unknown strategy ${ACTOR_STRATEGY}. Please use 'fsdp2' or 'megatron'"
+ exit 1
+fi
+
+echo "One-step-off-policy E2E test completed successfully with ${ACTOR_STRATEGY} strategy"
\ No newline at end of file
diff --git a/tests/special_sanity/check_device_api_usage.py b/tests/special_sanity/check_device_api_usage.py
index c8988db55a5..cb5c05a4ee1 100644
--- a/tests/special_sanity/check_device_api_usage.py
+++ b/tests/special_sanity/check_device_api_usage.py
@@ -28,6 +28,7 @@
"recipe/prime/prime_ray_trainer.py", # appear in default device_name
"recipe/spin/spin_trainer.py", # appear in default device_name
"recipe/sppo/sppo_ray_trainer.py", # appear in default device_name
+ "recipe/one_step_off_policy/ray_trainer.py", # appear in default device_name
"verl/utils/profiler/nvtx_profile.py", # appear in NsightSystemsProfiler
"verl/utils/kernel/linear_cross_entropy.py", # appear in nvidia nvtx
"verl/utils/rendezvous/ray_backend.py", # appear in cupy importance
diff --git a/verl/workers/megatron_workers.py b/verl/workers/megatron_workers.py
index c94edfd72d9..475b0a51783 100644
--- a/verl/workers/megatron_workers.py
+++ b/verl/workers/megatron_workers.py
@@ -213,7 +213,7 @@ def megatron_actor_model_provider(pre_process, post_process):
override_ddp_config=override_ddp_config,
)
- if self._is_actor and self._is_rollout:
+ if self._is_actor or self._is_rollout:
actor_module = make_model(wrap_with_ddp=True)
print(f"actor_module: {len(actor_module)}")
if self.config.actor.load_weight:
@@ -402,7 +402,7 @@ def init_model(self):
self.config.ref.megatron.get("override_transformer_config", OmegaConf.create()), resolve=True
)
else:
- override_transformer_config = None
+ override_transformer_config = {}
self.param_dtype = torch.bfloat16
log_gpu_memory_usage("Before init actor model and optimizer", logger=logger)
self.dtype = PrecisionType.to_dtype(self.param_dtype)