diff --git a/.github/workflows/e2e_one_step_off_policy.yml b/.github/workflows/e2e_one_step_off_policy.yml new file mode 100644 index 00000000000..2ac76f869c7 --- /dev/null +++ b/.github/workflows/e2e_one_step_off_policy.yml @@ -0,0 +1,144 @@ +# # Tests layout + +# Each folder under tests/ corresponds to a test category for a sub-namespace in verl. For instance: +# - `tests/trainer` for testing functionality related to `verl/trainer` +# - `tests/models` for testing functionality related to `verl/models` +# - ... + +# There are a few folders with `special_` prefix, created for special purposes: +# - `special_distributed`: unit tests that must run with multiple GPUs +# - `special_e2e`: end-to-end tests with training/generation scripts +# - `special_npu`: tests for NPUs +# - `special_sanity`: a suite of quick sanity tests +# - `special_standalone`: a set of test that are designed to run in dedicated environments + +# Accelerators for tests +# - By default tests are run with GPU available, except for the ones under `special_npu`, and any test script whose name ends with `on_cpu.py`. +# - For test scripts with `on_cpu.py` name suffix would be tested on CPU resources in linux environment. + +# # Workflow layout + +# All CI tests are configured by yaml files in `.github/workflows/`. Here's an overview of all test configs: +# 1. A list of always triggered CPU sanity tests: `check-pr-title.yml`, `secrets_scan.yml`, `check-pr-title,yml`, `pre-commit.yml`, `doc.yml` +# 2. Some heavy multi-GPU unit tests, such as `model.yml`, `vllm.yml`, `sgl.yml` +# 3. End-to-end tests: `e2e_*.yml` +# 4. Unit tests +# - `cpu_unit_tests.yml`, run pytest on all scripts with file name pattern `tests/**/test_*_on_cpu.py` +# - `gpu_unit_tests.yml`, run pytest on all scripts with file without the `on_cpu.py` suffix. +# - Since cpu/gpu unit tests by default runs all tests under `tests`, please make sure tests are manually excluded in them when +# - new workflow yaml is added to `.github/workflows` +# - new tests are added to workflow mentioned in 2. + + +name: e2e_one_step_off_policy + +on: + # Trigger the workflow on push or pull request, + # but only for the main branch + # For push, for now only anti-patterns are specified so it is more conservative + # and achieves higher coverage. + push: + branches: + - main + - v0.* + paths: + - "**/*.py" + - "!**/*.md" + - "!**/*.sh" + # Other entrypoints + - "!examples/*trainer*" + - "!tests/**" + - "!verl/trainer/main_*.py" + - "!verl/trainer/fsdp_sft_trainer.py" + - "!recipe/**" + - "recipe/one_step_off_policy" + pull_request: + branches: + - main + - v0.* + paths: + - "**/*.py" + - "!**/*.md" + - "!**/*.sh" + # Other entrypoints + - "!examples/**" + - "!tests/**" + - "!verl/trainer/main_*.py" + - "!verl/trainer/fsdp_sft_trainer.py" + # Other recipes + - "!recipe/**" + # Home + - "recipe/one_step_off_policy" + # Entrypoints + - ".github/workflows/e2e_one_step_off_policy.yml" + - "examples/data_preprocess/gsm8k.py" + - "tests/special_e2e/run_one_step_off_policy.sh" + +# Cancel jobs on the same ref if a new one is triggered +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} + +# Declare permissions just read content. +permissions: + contents: read + +jobs: + # Test FSDP2 strategy + e2e_one_step_off_policy_fsdp2: + runs-on: [L20x8] + timeout-minutes: 50 # Increase timeout for async training + env: + HTTP_PROXY: ${{ secrets.PROXY_HTTP }} + HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} + NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" + HF_ENDPOINT: "https://hf-mirror.com" + HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable + ACTOR_STRATEGY: "fsdp2" + container: + image: verlai/verl:app-verl0.4-vllm0.8.5-mcore0.12.1 + options: --gpus all --shm-size=10g + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + fetch-depth: 0 + - name: Install the current repository + run: | + pip3 install --no-deps -e .[test,gpu] + - name: Prepare GSM8K dataset + run: | + python3 examples/data_preprocess/gsm8k.py + - name: Running the E2E test with one_step_off_policy algorithm (FSDP2) + run: | + ray stop --force + bash tests/special_e2e/run_one_step_off_policy.sh + + # Test Megatron strategy + e2e_one_step_off_policy_megatron: + runs-on: [L20x8] + timeout-minutes: 50 # Increase timeout for async training + env: + HTTP_PROXY: ${{ secrets.PROXY_HTTP }} + HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} + NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" + HF_ENDPOINT: "https://hf-mirror.com" + HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable + ACTOR_STRATEGY: "megatron" + container: + image: verlai/verl:app-verl0.4-vllm0.8.5-mcore0.12.1 + options: --gpus all --shm-size=10g + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + fetch-depth: 0 + - name: Install the current repository + run: | + pip3 install --no-deps -e .[test,gpu] + - name: Prepare GSM8K dataset + run: | + python3 examples/data_preprocess/gsm8k.py + - name: Running the E2E test with one_step_off_policy algorithm (Megatron) + run: | + ray stop --force + bash tests/special_e2e/run_one_step_off_policy.sh + diff --git a/docs/advance/one_step_off.md b/docs/advance/one_step_off.md new file mode 100644 index 00000000000..43edb199947 --- /dev/null +++ b/docs/advance/one_step_off.md @@ -0,0 +1,297 @@ +# Recipe: One Step Off Policy Async Trainer + +**Author:** `https://github.com/meituan-search>` + +Last updated: 07/16/2025. + +## Introduction + +### Background + +The current reinforcement learning training process implemented by verl is synchronous, adhering to the algorithmic +workflows of established methods like PPO, GRPO, and DAPO. In each step, training samples are generated by the latest +model, and the model is updated after training completes. While this approach aligns with off-policy reinforcement +learning and stabilizes RL training, but it suffers from severe efficiency issues. +Model updates must wait for the longest output in the generation phase to complete. +During the generation of long-tail samples, GPUs remain idle, resulting in significant underutilization. +The more severe the long-tail problem in sample generation, the lower the overall training efficiency. +For example, in DAPO 32B training, the Rollout phase accounts for approximately 70% of the total time, +and increasing resources does not reduce the Rollout duration. + +![DAPO 32B Math Performance]( +https://raw.githubusercontent.com/eric-haibin-lin/verl-community/refs/heads/main/docs/dapo_32b_math.png) +> source data: https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/workspace?nw=nwusertongyuxuan361 + +### Solution + +We have implemented the **One Step Off Async Trainer** to help alleviate this issue. This approach parallelizes the +generation and training processes, utilizing samples generated in the previous step for current training. +It also involves appropriately partitioning resources, allocating dedicated resources for generation while automatically +assigning the remainder to training. By reducing resources allocated to the generation phase, we mitigate GPU idle time +during long-tail sample generation. Throughout this process, generation and training parameters maintain a one-step off +policy. + +![One Step Off Policy Diagram]( +https://raw.githubusercontent.com/eric-haibin-lin/verl-community/refs/heads/main/docs/one_step_off_policy.png) +> reference: [AReaL: A Large-Scale Asynchronous Reinforcement Learning System for Language Reasoning]( +> https://arxiv.org/abs/2505.24298) + +Our core contributions include: + +1. **Parallel Generation and Training**: + Samples for the next batch are asynchronously generated while the current batch is being trained. + +2. **Resource Isolation**: + Unlike `hybrid_engine`, this method requires explicit resource allocation for rollout, with remaining resources + automatically assigned to training. + +3. **NCCL Parameter Synchronization**: + Employs NCCL communication primitives for seamless parameter transfer between generation and training modules. + +### Experimental Results + +- **Machine Configuration**: 2 nodes with 16 H20 GPUs each + - Generation: 4 GPUs + - Training: 12 GPUs +- **Model**: Qwen2.5-Math-7B +- **Rollout Configuration**: +- **Max Response Length**: FSDP2: 20,480 tokens; Megatron: 8,192 tokens +- **Algorithm**: DAPO +- **Rollout Engine**: vLLM + +| training mode | engine | step | gen | wait_prev_gen | generate_sequences | old_log_prob | update_actor | total time | acc/best@32/mean | acc/maj@32/mean | +|------------------------|---------------|------|-----|---------------|--------------------|--------------|--------------|---------------|------------------|-----------------| +| colocate sync | VLLM+FSDP2 | 749 | 321 | - | 247 | 88 | 286 | 19h18m | 0.5948 | 0.417 | +| one-step-overlap async | VLLM+FSDP2 | 520 | - | 45 | 458 | 108 | 337 | 15h34m(+23%) | 0.6165 | 0.494 | +| colocate sync | VLLM+Megatron | 699 | 207 | - | 162 | 119 | 344 | 18h21m | 0.605 | 0.4217 | +| one-step-overlap async | VLLM+Megatron | 566 | - | 59 | 501 | 120 | 347 | 13h06m (+40%) | 0.6569 | 0.4038 | + +* colocate sync: step = gen + old_log_prob + update_actor +* one-step-overlap async: step = max(wait_prev_gen + generate_sequences, old_log_prob + update_actor) + +![One Step Off Megatron Performance]( +https://raw.githubusercontent.com/eric-haibin-lin/verl-community/refs/heads/main/docs/one_step_off_megatron.png) + +> source data: https://wandb.ai/hou-zg-meituan/one-step-off-policy?nw=nwuserhouzg + +## Implementation + +### One Step Off Policy Async Pipline + +Our implemented **One Step Off Policy Async Pipeline** integrates seamlessly into existing training logic at minimal +cost, +eliminating the need for additional sample storage management. The core mechanism uses `async_gen_next_batch` +for asynchronous rollout generation while maintaining continuous operation during epoch transitions +via `create_continuous_iterator`. + +```python +# iterator generator, simplify one-step integration of the training process +def _create_continuous_iterator(self): + for epoch in range(self.config.trainer.total_epochs): + iterator = iter(self.train_dataloader) + for batch_dict in iterator: + yield epoch, batch_dict + + +# read next batch samples, parameters sync and launch asyn gen_seq +def _async_gen_next_batch(self, continuous_iterator): + # read train_data + try: + epoch, batch_dict = next(continuous_iterator) + except StopIteration: + return None + batch = DataProto.from_single_dict(batch_dict) + gen_batch = batch_pocess(batch) + # sync weights from actor to rollout + self.sync_rollout_weights() + # async generation + gen_batch_output = self.rollout_wg.async_generate_sequences(gen_batch) + # future encapsulated + return GenerationBatchFuture(epoch, batch, gen_batch_output) + + +continuous_iterator = self._create_continuous_iterator() +# run rollout first to achieve one-step-off +batch_data_future = self._async_gen_next_batch(continuous_iterator) + +while batch_data_future is not None: + # wait for the gen_seq result from the previous step + batch = batch_data_future.get() + # launch the next async call to generate sequences + batch_data_future = self._async_gen_next_batch(continuous_iterator) + + # compute advantages + batch = critic.compute_values(batch) + batch = reference.compute_log_prob(batch) + batch = reward.compute_reward(batch) + batch = compute_advantages(batch) + + # model update + critic_metrics = critic.update_critic(batch) + actor_metrics = actor.update_actor(batch) +``` + +### Parameter Synchronization + +The exciting point is that our nccl based weights updating for rollout model has great performance. +At most of time, the latency is under 300ms, which is negligible for RLHF. +Although it is only implemented with fsdp and vllm now, we think it is not complex to extend it to the other backend. + +> **sync_rollout_weights**:The time for synchronizing parameters from actor to rollout is extremely fast and can almost +> be ignored because it is implemented with nccl. + +```python +class ActorRolloutRefWorker: + # actor acquires the meta-info of model parameters for parameter sync + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def get_actor_weights_info(self): + params = self._get_actor_params() + ret = [] + for key, tensor in params.items(): + ret.append((key, tensor.size(), tensor.dtype)) + self._weights_info = ret + return ret + + # rollout sets the meta-info of model parameters for parameter sync + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def set_actor_weights_info(self, weights_info): + self._weights_info = weights_info + + +class AsyncRayPPOTrainer(RayPPOTrainer): + def init_workers(self): + + +... +# rollout obtains the meta-info of model parameters from the actor for parameter sync +weights_info = self.actor_wg.get_actor_weights_info()[0] +self.rollout_wg.set_actor_weights_info(weights_info) + +# Create an actor-rollout communication group for parameter sync +actor_rollout_workers = self.actor_wg.workers + self.rollout_wg.workers +collective.create_collective_group( + actor_rollout_workers, + len(actor_rollout_workers), + list(range(0, len(actor_rollout_workers))), + backend="nccl", + group_name="actor_rollout" +) +``` + +```python +# drive process call the actor and rollout respectively to sync parameters by nccl +def sync_rollout_weights(self): + self.actor_wg.sync_rollout_weights() + ray.get(self.rollout_wg.sync_rollout_weights()) + + +# fsdp model parameter sync +@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False) +def sync_rollout_weights(self): + params = self._get_actor_params() if self._is_actor else None + if self._is_rollout: + inference_model = ( + self.rollout.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model + ) + patch_vllm_moe_model_weight_loader(inference_model) + # Model parameters are broadcast tensor-by-tensor from actor to rollout + for key, shape, dtype in self._weights_info: + tensor = torch.empty(shape, dtype=dtype, device=get_torch_device().current_device()) + if self._is_actor: + assert key in params + origin_data = params[key] + if hasattr(origin_data, "full_tensor"): + origin_data = origin_data.full_tensor() + if torch.distributed.get_rank() == 0: + tensor.copy_(origin_data) + from ray.util.collective import collective + + collective.broadcast(tensor, src_rank=0, group_name="actor_rollout") + if self._is_rollout: + inference_model.load_weights([(key, tensor)]) +``` + +## Usage + +### FSDP2 Configuration Example + +```shell +python3 -m recipe.one_step_off_policy.async_main_ppo \ + --config-path=config \ + --config-name='one_step_off_ppo_trainer.yaml' \ + actor_rollout_ref.actor.strategy=fsdp2 \ + # actor and rollout are placed separately + actor_rollout_ref.hybrid_engine=False \ + # actor and rollout resource + trainer.nnodes=1 \ + trainer.n_gpus_per_node=6 \ + rollout.nnodes=1 \ + rollout.n_gpus_per_node=2 +``` + +### Megatron Configuration Example + +```shell +python3 -m recipe.one_step_off_policy.async_main_ppo \ + --config-path=config \ + --config-name='one_step_off_ppo_megatron_trainer.yaml' \ + actor_rollout_ref.actor.strategy=megatron \ + # actor and rollout are placed separately + actor_rollout_ref.hybrid_engine=False \ + # actor and rollout resource + trainer.nnodes=1 \ + trainer.n_gpus_per_node=6 \ + rollout.nnodes=1 \ + rollout.n_gpus_per_node=2 +``` + +### Configuration Guidelines + +1. **Card Number Relationships** + Maintain either of these relationships for optimal batch distribution: + - `actor_rollout_ref.rollout.n` should be an integer divisor of: + `trainer.n_gpus_per_node * trainer.nnodes` + - `actor_rollout_ref.rollout.n * data.train_batch_size` should be evenly divisible by: + `trainer.n_gpus_per_node * trainer.nnodes` + + > Rationale: Ensures training samples can be evenly distributed across training GPUs when using partial resources for + generation. + +2. **Dynamic Resource Tuning** + Adjust `trainer.nnodes` `trainer.n_gpus_per_node` `rollout.nnodes` `rollout.n_gpus_per_node` based on phase + durations: + - **Ideal state**: Rollout and training phases have comparable durations + - **Diagnostic metrics**: + - Monitor `wait_prev_gen` duration + - Analyze `sequence_length` distribution + - **Adjustment strategy**: + - High `wait_prev_gen` + uniform sequence lengths → Increase rollout resources + - High `wait_prev_gen` + long-tail sequences → Optimize stopping criteria (resource increase won't help) + > **wait_prev_gen**:The time consumed waiting for the previous rollout to end (the part that is not fully + overlapped). + **Resource Configuration Strategies:** + - **Resource-constrained scenario**: Optimize resource utilization by adjusting GPU allocation ratios, + keeping the number of nodes equal to allow training and rollout to share nodes; + - Configure `trainer.nnodes = rollout.nnodes` with + `trainer.n_gpus_per_node + rollout.n_gpus_per_node = physical_gpus_per_node`. Control rollout resource + allocation by adjusting `n_gpus_per_node`. + - **Resource-abundant scenario**: Optimize performance by adjusting the number of nodes, + keeping the number of GPUs per node equal to enable independent scaling of training and rollout + parallelism. + - Configure `trainer.n_gpus_per_node = rollout.n_gpus_per_node` and control rollout resource allocation by + adjusting `trainer.nnodes` and `rollout.nnodes`to achieve optimal performance. + > **Note**: The total number of nodes required by the system is not simply `trainer.nnodes + rollout.nnodes`. The + > actual calculation depends on GPU capacity: + > - When `trainer.n_gpus_per_node + rollout.n_gpus_per_node <= physical_gpus_per_node`, + > the required node count is `max(trainer.nnodes, rollout.nnodes)` + > - When `trainer.n_gpus_per_node + rollout.n_gpus_per_node > physical_gpus_per_node`, + > the required node count is `trainer.nnodes + rollout.nnodes` + +## Functional Support + +| Category | Support Situation | +|--------------------|-----------------------------------------------------------------------------------------------------------------| +| train engine | FSDP2
Megatron | +| rollout engine | vLLM | +| AdvantageEstimator | GRPO
GRPO_PASSK
REINFORCE_PLUS_PLUS
RLOO
OPO
REINFORCE_PLUS_PLUS_BASELINE
GPG | +| Reward | all | diff --git a/docs/index.rst b/docs/index.rst index 980066a7fbc..a486586945c 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -116,6 +116,7 @@ verl is fast with: advance/dpo_extension examples/sandbox_fusion_example advance/rollout_trace.rst + advance/one_step_off .. toctree:: :maxdepth: 1 diff --git a/recipe/one_step_off_policy/README.md b/recipe/one_step_off_policy/README.md new file mode 100644 index 00000000000..2eef6997f8c --- /dev/null +++ b/recipe/one_step_off_policy/README.md @@ -0,0 +1,297 @@ +# Recipe: One Step Off Policy Async Trainer + +**Author:** `https://github.com/meituan-search>` + +Last updated: 07/16/2025. + +## Introduction + +### Background + +The current reinforcement learning training process implemented by verl is synchronous, adhering to the algorithmic +workflows of established methods like PPO, GRPO, and DAPO. In each step, training samples are generated by the latest +model, and the model is updated after training completes. While this approach aligns with off-policy reinforcement +learning and stabilizes RL training, but it suffers from severe efficiency issues. +Model updates must wait for the longest output in the generation phase to complete. +During the generation of long-tail samples, GPUs remain idle, resulting in significant underutilization. +The more severe the long-tail problem in sample generation, the lower the overall training efficiency. +For example, in DAPO 32B training, the Rollout phase accounts for approximately 70% of the total time, +and increasing resources does not reduce the Rollout duration. + +![DAPO 32B Math Performance]( +https://raw.githubusercontent.com/eric-haibin-lin/verl-community/refs/heads/main/docs/dapo_32b_math.png) +> source data: https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/workspace?nw=nwusertongyuxuan361 + +### Solution + +We have implemented the **One Step Off Async Trainer** to help alleviate this issue. This approach parallelizes the +generation and training processes, utilizing samples generated in the previous step for current training. +It also involves appropriately partitioning resources, allocating dedicated resources for generation while automatically +assigning the remainder to training. By reducing resources allocated to the generation phase, we mitigate GPU idle time +during long-tail sample generation. Throughout this process, generation and training parameters maintain a one-step off +policy. + +![One Step Off Policy Diagram]( +https://raw.githubusercontent.com/eric-haibin-lin/verl-community/refs/heads/main/docs/one_step_off_policy.png) +> reference: [AReaL: A Large-Scale Asynchronous Reinforcement Learning System for Language Reasoning]( +> https://arxiv.org/abs/2505.24298) + +Our core contributions include: + +1. **Parallel Generation and Training**: + Samples for the next batch are asynchronously generated while the current batch is being trained. + +2. **Resource Isolation**: + Unlike `hybrid_engine`, this method requires explicit resource allocation for rollout, with remaining resources + automatically assigned to training. + +3. **NCCL Parameter Synchronization**: + Employs NCCL communication primitives for seamless parameter transfer between generation and training modules. + +### Experimental Results + +- **Machine Configuration**: 2 nodes with 16 H20 GPUs each + - Generation: 4 GPUs + - Training: 12 GPUs +- **Model**: Qwen2.5-Math-7B +- **Rollout Configuration**: +- **Max Response Length**: FSDP2: 20,480 tokens; Megatron: 8,192 tokens +- **Algorithm**: DAPO +- **Rollout Engine**: vLLM + +| training mode | engine | step | gen | wait_prev_gen | generate_sequences | old_log_prob | update_actor | total time | acc/best@32/mean | acc/maj@32/mean | +|------------------------|---------------|------|-----|---------------|--------------------|--------------|--------------|---------------|------------------|-----------------| +| colocate sync | VLLM+FSDP2 | 749 | 321 | - | 247 | 88 | 286 | 19h18m | 0.5948 | 0.417 | +| one-step-overlap async | VLLM+FSDP2 | 520 | - | 45 | 458 | 108 | 337 | 15h34m(+23%) | 0.6165 | 0.494 | +| colocate sync | VLLM+Megatron | 699 | 207 | - | 162 | 119 | 344 | 18h21m | 0.605 | 0.4217 | +| one-step-overlap async | VLLM+Megatron | 566 | - | 59 | 501 | 120 | 347 | 13h06m (+40%) | 0.6569 | 0.4038 | + +* colocate sync: step = gen + old_log_prob + update_actor +* one-step-overlap async: step = max(wait_prev_gen + generate_sequences, old_log_prob + update_actor) + +![One Step Off Megatron Performance]( +https://raw.githubusercontent.com/eric-haibin-lin/verl-community/refs/heads/main/docs/one_step_off_megatron.png) + +> source data: https://wandb.ai/hou-zg-meituan/one-step-off-policy?nw=nwuserhouzg + +## Implementation + +### One Step Off Policy Async Pipline + +Our implemented **One Step Off Policy Async Pipeline** integrates seamlessly into existing training logic at minimal +cost, +eliminating the need for additional sample storage management. The core mechanism uses `async_gen_next_batch` +for asynchronous rollout generation while maintaining continuous operation during epoch transitions +via `create_continuous_iterator`. + +```python +# iterator generator, simplify one-step integration of the training process +def _create_continuous_iterator(self): + for epoch in range(self.config.trainer.total_epochs): + iterator = iter(self.train_dataloader) + for batch_dict in iterator: + yield epoch, batch_dict + + +# read next batch samples, parameters sync and launch asyn gen_seq +def _async_gen_next_batch(self, continuous_iterator): + # read train_data + try: + epoch, batch_dict = next(continuous_iterator) + except StopIteration: + return None + batch = DataProto.from_single_dict(batch_dict) + gen_batch = batch_pocess(batch) + # sync weights from actor to rollout + self.sync_rollout_weights() + # async generation + gen_batch_output = self.rollout_wg.async_generate_sequences(gen_batch) + # future encapsulated + return GenerationBatchFuture(epoch, batch, gen_batch_output) + + +continuous_iterator = self._create_continuous_iterator() +# run rollout first to achieve one-step-off +batch_data_future = self._async_gen_next_batch(continuous_iterator) + +while batch_data_future is not None: + # wait for the gen_seq result from the previous step + batch = batch_data_future.get() + # launch the next async call to generate sequences + batch_data_future = self._async_gen_next_batch(continuous_iterator) + + # compute advantages + batch = critic.compute_values(batch) + batch = reference.compute_log_prob(batch) + batch = reward.compute_reward(batch) + batch = compute_advantages(batch) + + # model update + critic_metrics = critic.update_critic(batch) + actor_metrics = actor.update_actor(batch) +``` + +### Parameter Synchronization + +The exciting point is that our nccl based weights updating for rollout model has great performance. +At most of time, the latency is under 300ms, which is negligible for RLHF. +Although it is only implemented with fsdp and vllm now, we think it is not complex to extend it to the other backend. + +> **sync_rollout_weights**:The time for synchronizing parameters from actor to rollout is extremely fast and can almost +> be ignored because it is implemented with nccl. + +```python +class ActorRolloutRefWorker: + # actor acquires the meta-info of model parameters for parameter sync + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def get_actor_weights_info(self): + params = self._get_actor_params() + ret = [] + for key, tensor in params.items(): + ret.append((key, tensor.size(), tensor.dtype)) + self._weights_info = ret + return ret + + # rollout sets the meta-info of model parameters for parameter sync + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def set_actor_weights_info(self, weights_info): + self._weights_info = weights_info + + +class AsyncRayPPOTrainer(RayPPOTrainer): + def init_workers(self): + + +... +# rollout obtains the meta-info of model parameters from the actor for parameter sync +weights_info = self.actor_wg.get_actor_weights_info()[0] +self.rollout_wg.set_actor_weights_info(weights_info) + +# Create an actor-rollout communication group for parameter sync +actor_rollout_workers = self.actor_wg.workers + self.rollout_wg.workers +collective.create_collective_group( + actor_rollout_workers, + len(actor_rollout_workers), + list(range(0, len(actor_rollout_workers))), + backend="nccl", + group_name="actor_rollout" +) +``` + +```python +# drive process call the actor and rollout respectively to sync parameters by nccl +def sync_rollout_weights(self): + self.actor_wg.sync_rollout_weights() + ray.get(self.rollout_wg.sync_rollout_weights()) + + +# fsdp model parameter sync +@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False) +def sync_rollout_weights(self): + params = self._get_actor_params() if self._is_actor else None + if self._is_rollout: + inference_model = ( + self.rollout.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model + ) + patch_vllm_moe_model_weight_loader(inference_model) + # Model parameters are broadcast tensor-by-tensor from actor to rollout + for key, shape, dtype in self._weights_info: + tensor = torch.empty(shape, dtype=dtype, device=get_torch_device().current_device()) + if self._is_actor: + assert key in params + origin_data = params[key] + if hasattr(origin_data, "full_tensor"): + origin_data = origin_data.full_tensor() + if torch.distributed.get_rank() == 0: + tensor.copy_(origin_data) + from ray.util.collective import collective + + collective.broadcast(tensor, src_rank=0, group_name="actor_rollout") + if self._is_rollout: + inference_model.load_weights([(key, tensor)]) +``` + +## Usage + +### FSDP2 Configuration Example + +```shell +python3 -m recipe.one_step_off_policy.async_main_ppo \ + --config-path=config \ + --config-name='one_step_off_ppo_trainer.yaml' \ + actor_rollout_ref.actor.strategy=fsdp2 \ + # actor and rollout are placed separately + actor_rollout_ref.hybrid_engine=False \ + # actor and rollout resource + trainer.nnodes=1 \ + trainer.n_gpus_per_node=6 \ + rollout.nnodes=1 \ + rollout.n_gpus_per_node=2 +``` + +### Megatron Configuration Example + +```shell +python3 -m recipe.one_step_off_policy.async_main_ppo \ + --config-path=config \ + --config-name='one_step_off_ppo_megatron_trainer.yaml' \ + actor_rollout_ref.actor.strategy=megatron \ + # actor and rollout are placed separately + actor_rollout_ref.hybrid_engine=False \ + # actor and rollout resource + trainer.nnodes=1 \ + trainer.n_gpus_per_node=6 \ + rollout.nnodes=1 \ + rollout.n_gpus_per_node=2 +``` + +### Configuration Guidelines + +1. **Card Number Relationships** + Maintain either of these relationships for optimal batch distribution: + - `actor_rollout_ref.rollout.n` should be an integer divisor of: + `trainer.n_gpus_per_node * trainer.nnodes` + - `actor_rollout_ref.rollout.n * data.train_batch_size` should be evenly divisible by: + `trainer.n_gpus_per_node * trainer.nnodes` + + > Rationale: Ensures training samples can be evenly distributed across training GPUs when using partial resources for + generation. + +2. **Dynamic Resource Tuning** + Adjust `trainer.nnodes` `trainer.n_gpus_per_node` `rollout.nnodes` `rollout.n_gpus_per_node` based on phase + durations: + - **Ideal state**: Rollout and training phases have comparable durations + - **Diagnostic metrics**: + - Monitor `wait_prev_gen` duration + - Analyze `sequence_length` distribution + - **Adjustment strategy**: + - High `wait_prev_gen` + uniform sequence lengths → Increase rollout resources + - High `wait_prev_gen` + long-tail sequences → Optimize stopping criteria (resource increase won't help) + > **wait_prev_gen**:The time consumed waiting for the previous rollout to end (the part that is not fully + overlapped). + **Resource Configuration Strategies:** + - **Resource-constrained scenario**: Optimize resource utilization by adjusting GPU allocation ratios, + keeping the number of nodes equal to allow training and rollout to share nodes; + - Configure `trainer.nnodes = rollout.nnodes` with + `trainer.n_gpus_per_node + rollout.n_gpus_per_node = physical_gpus_per_node`. Control rollout resource + allocation by adjusting `n_gpus_per_node`. + - **Resource-abundant scenario**: Optimize performance by adjusting the number of nodes, + keeping the number of GPUs per node equal to enable independent scaling of training and rollout + parallelism. + - Configure `trainer.n_gpus_per_node = rollout.n_gpus_per_node` and control rollout resource allocation by + adjusting `trainer.nnodes` and `rollout.nnodes`to achieve optimal performance. + > **Note**: The total number of nodes required by the system is not simply `trainer.nnodes + rollout.nnodes`. The + > actual calculation depends on GPU capacity: + > - When `trainer.n_gpus_per_node + rollout.n_gpus_per_node <= physical_gpus_per_node`, + > the required node count is `max(trainer.nnodes, rollout.nnodes)` + > - When `trainer.n_gpus_per_node + rollout.n_gpus_per_node > physical_gpus_per_node`, + > the required node count is `trainer.nnodes + rollout.nnodes` + +## Functional Support + +| Category | Support Situation | +|--------------------|-----------------------------------------------------------------------------------------------------------------| +| train engine | FSDP2
Megatron | +| rollout engine | vLLM | +| AdvantageEstimator | GRPO
GRPO_PASSK
REINFORCE_PLUS_PLUS
RLOO
OPO
REINFORCE_PLUS_PLUS_BASELINE
GPG | +| Reward | all | diff --git a/recipe/one_step_off_policy/config/one_step_off_ppo_megatron_trainer.yaml b/recipe/one_step_off_policy/config/one_step_off_ppo_megatron_trainer.yaml new file mode 100644 index 00000000000..f5a3c6b58e6 --- /dev/null +++ b/recipe/one_step_off_policy/config/one_step_off_ppo_megatron_trainer.yaml @@ -0,0 +1,14 @@ +hydra: + searchpath: + - file://verl/trainer/config + +defaults: + - ppo_megatron_trainer + - _self_ + +# config for the rollout (only for resource isolation) +rollout: + # Number of nodes used in the rollout + nnodes: 1 + # Number of GPUs per node + n_gpus_per_node: 8 \ No newline at end of file diff --git a/recipe/one_step_off_policy/config/one_step_off_ppo_trainer.yaml b/recipe/one_step_off_policy/config/one_step_off_ppo_trainer.yaml new file mode 100644 index 00000000000..05890e9b5a2 --- /dev/null +++ b/recipe/one_step_off_policy/config/one_step_off_ppo_trainer.yaml @@ -0,0 +1,14 @@ +hydra: + searchpath: + - file://verl/trainer/config + +defaults: + - ppo_trainer + - _self_ + +# config for the rollout (only for resource isolation) +rollout: + # Number of nodes used in the rollout + nnodes: 1 + # Number of GPUs per node + n_gpus_per_node: 8 \ No newline at end of file diff --git a/recipe/one_step_off_policy/dapo_7b_math_fsdp2_4_12.sh b/recipe/one_step_off_policy/dapo_7b_math_fsdp2_4_12.sh new file mode 100644 index 00000000000..7f4ca919a4c --- /dev/null +++ b/recipe/one_step_off_policy/dapo_7b_math_fsdp2_4_12.sh @@ -0,0 +1,139 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +project_name='DAPO' +exp_name='DAPO-Qwen2.5-7b-MATH-0527a1-fsdp2-one-step-off-4-12' + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 8)) +enable_overlong_buffer=True +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +train_prompt_bsz=512 +n_resp_per_prompt=12 +train_prompt_mini_bsz=32 + +# Ray +# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +# WORKING_DIR=${WORKING_DIR:-"${PWD}"} +# RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} +NNODES=${NNODES:-2} +NGPUS_PER_NODE=${NGPUS_PER_NODE:-8} + +n_gpus_rollout=2 +n_gpus_training=$((NGPUS_PER_NODE - n_gpus_rollout)) + +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} + + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=0.7 + +# Performance Related Parameter +use_dynamic_bsz=True +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2)) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3)) +ref_offload=True +actor_offload=False +gen_tp=2 +sp_size=4 +fsdp_size=2 + +python3 -m recipe.one_step_off_policy.main_ppo \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.actor.strategy=fsdp2 \ + critic.strategy=fsdp2 \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.hybrid_engine=False \ + +actor_rollout_ref.model.override_config.max_position_embeddings=32768 \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${actor_offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${actor_offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.ref.fsdp_config.param_offload=${ref_offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \ + reward_model.reward_manager=dapo \ + +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \ + +reward_model.reward_kwargs.max_resp_len=${max_response_length} \ + trainer.logger=['console','tensorboard'] \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.val_before_train=True \ + trainer.test_freq=10 \ + trainer.save_freq=-1 \ + trainer.total_epochs=10 \ + trainer.total_training_steps=100 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto \ + trainer.log_val_generations=10 \ + trainer.nnodes="${NNODES}" \ + trainer.n_gpus_per_node="${n_gpus_training}" \ + rollout.nnodes="${NNODES}" \ + rollout.n_gpus_per_node="${n_gpus_rollout}" \ No newline at end of file diff --git a/recipe/one_step_off_policy/dapo_7b_math_fsdp2_colocate.sh b/recipe/one_step_off_policy/dapo_7b_math_fsdp2_colocate.sh new file mode 100644 index 00000000000..33b65b06080 --- /dev/null +++ b/recipe/one_step_off_policy/dapo_7b_math_fsdp2_colocate.sh @@ -0,0 +1,131 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +project_name='DAPO' +exp_name='DAPO-Qwen2.5-7b-MATH-0527a1-fsdp2-colocate' + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 8)) +enable_overlong_buffer=True +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +train_prompt_bsz=512 +n_resp_per_prompt=12 +train_prompt_mini_bsz=32 + +# Ray +# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +# WORKING_DIR=${WORKING_DIR:-"${PWD}"} +# RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} +NNODES=${NNODES:-2} +NGPUS_PER_NODE=${NGPUS_PER_NODE:-8} +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=0.7 + +# Performance Related Parameter +use_dynamic_bsz=True +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2)) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3)) +offload=True +gen_tp=2 +sp_size=4 +fsdp_size=2 + +# reference run wandb: https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/runs/ow47vvon?nw=nwusertongyuxuan361 + +python3 -m verl.trainer.main_ppo \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.actor.strategy=fsdp2 \ + critic.strategy=fsdp2 \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.model.use_remove_padding=True \ + +actor_rollout_ref.model.override_config.max_position_embeddings=32768 \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \ + reward_model.reward_manager=dapo \ + +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \ + +reward_model.reward_kwargs.max_resp_len=${max_response_length} \ + trainer.logger=['console','tensorboard'] \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \ + trainer.nnodes="${NNODES}" \ + trainer.val_before_train=True \ + trainer.test_freq=10 \ + trainer.save_freq=-1 \ + trainer.total_epochs=10 \ + trainer.total_training_steps=100 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto \ + trainer.log_val_generations=10 diff --git a/recipe/one_step_off_policy/dapo_7b_math_megatron_4_12.sh b/recipe/one_step_off_policy/dapo_7b_math_megatron_4_12.sh new file mode 100644 index 00000000000..e246f01fbf5 --- /dev/null +++ b/recipe/one_step_off_policy/dapo_7b_math_megatron_4_12.sh @@ -0,0 +1,146 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +project_name='DAPO' +exp_name='DAPO-Qwen2.5-7b-MATH-0527a1-megatron-one-step-off-4-12' + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 8)) +enable_overlong_buffer=True +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +train_prompt_bsz=512 +n_resp_per_prompt=12 +train_prompt_mini_bsz=32 + + +# Ray +# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +# WORKING_DIR=${WORKING_DIR:-"${PWD}"} +# RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} +NNODES=${NNODES:-2} +NGPUS_PER_NODE=${NGPUS_PER_NODE:-8} + +n_gpus_rollout=2 +n_gpus_training=$((NGPUS_PER_NODE - n_gpus_rollout)) + +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=0.7 + +# Performance Related Parameter +use_dynamic_bsz=True +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2)) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3)) +ref_offload=True +actor_offload=False +gen_tp=2 +train_tp=2 +train_pp=2 + +# TODO: support dynamic_bsz for megatron +# actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ +# actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ +# actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ +# actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ +# actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ +# actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + +python3 -m recipe.one_step_off_policy.main_ppo \ + --config-path=config \ + --config-name='one_step_off_ppo_megatron_trainer.yaml' \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.actor.strategy=megatron \ + critic.strategy=megatron \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.hybrid_engine=False \ + +actor_rollout_ref.model.override_config.max_position_embeddings=32768 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.megatron.param_offload=${actor_offload} \ + actor_rollout_ref.actor.megatron.optimizer_offload=${actor_offload} \ + actor_rollout_ref.actor.megatron.grad_offload=${actor_offload} \ + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} \ + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.optim.clip_grad=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${train_pp} \ + actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${train_tp} \ + actor_rollout_ref.ref.megatron.param_offload=${ref_offload} \ + reward_model.reward_manager=dapo \ + +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \ + +reward_model.reward_kwargs.max_resp_len=${max_response_length} \ + trainer.logger=['console','tensorboard'] \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.val_before_train=True \ + trainer.test_freq=10 \ + trainer.save_freq=-1 \ + trainer.total_epochs=10 \ + trainer.total_training_steps=100 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto \ + trainer.log_val_generations=10 \ + trainer.nnodes="${NNODES}" \ + trainer.n_gpus_per_node="${n_gpus_training}" \ + rollout.nnodes="${NNODES}" \ + rollout.n_gpus_per_node="${n_gpus_rollout}" diff --git a/recipe/one_step_off_policy/dapo_7b_math_megatron_colocate.sh b/recipe/one_step_off_policy/dapo_7b_math_megatron_colocate.sh new file mode 100644 index 00000000000..c22da0fd440 --- /dev/null +++ b/recipe/one_step_off_policy/dapo_7b_math_megatron_colocate.sh @@ -0,0 +1,138 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +project_name='DAPO' +exp_name='DAPO-Qwen2.5-7b-MATH-0519a1-megatron-colocate' + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 8)) +enable_overlong_buffer=True +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +train_prompt_bsz=512 +n_resp_per_prompt=16 +train_prompt_mini_bsz=32 + +# Ray +# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +# WORKING_DIR=${WORKING_DIR:-"${PWD}"} +# RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} +NNODES=${NNODES:-2} +NGPUS_PER_NODE=${NGPUS_PER_NODE:-8} +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=0.7 + +# Performance Related Parameter +use_dynamic_bsz=True +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2)) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3)) +offload=True +gen_tp=2 +train_tp=2 +train_pp=2 + +# TODO: support dynamic_bsz for megatron +# actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ +# actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ +# actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ +# actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ +# actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ +# actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + +python3 -m verl.trainer.main_ppo \ + --config-path=config \ + --config-name='ppo_megatron_trainer.yaml' \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.actor.strategy=megatron \ + critic.strategy=megatron \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + +actor_rollout_ref.model.override_config.max_position_embeddings=32768 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.megatron.param_offload=${offload} \ + actor_rollout_ref.actor.megatron.optimizer_offload=${offload} \ + actor_rollout_ref.actor.megatron.grad_offload=${offload} \ + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} \ + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.optim.clip_grad=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${train_pp} \ + actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${train_tp} \ + actor_rollout_ref.ref.megatron.param_offload=${offload} \ + reward_model.reward_manager=dapo \ + +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \ + +reward_model.reward_kwargs.max_resp_len=${max_response_length} \ + trainer.logger=['console','tensorboard'] \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes="${NNODES}" \ + trainer.val_before_train=True \ + trainer.test_freq=10 \ + trainer.save_freq=-1 \ + trainer.total_epochs=10 \ + trainer.total_training_steps=100 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto \ + trainer.log_val_generations=10 diff --git a/recipe/one_step_off_policy/fsdp_workers.py b/recipe/one_step_off_policy/fsdp_workers.py new file mode 100644 index 00000000000..c2600f7ebd7 --- /dev/null +++ b/recipe/one_step_off_policy/fsdp_workers.py @@ -0,0 +1,228 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright 2025 Meituan Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os + +import torch +import torch.distributed +from omegaconf import DictConfig, OmegaConf +from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from transformers import AutoConfig + +from verl.single_controller.base import Worker +from verl.single_controller.base.decorator import Dispatch, register +from verl.utils import hf_processor, hf_tokenizer, omega_conf_to_dataclass +from verl.utils.debug import DistProfiler, DistProfilerExtension, log_gpu_memory_usage +from verl.utils.device import ( + get_device_name, + get_nccl_backend, + get_torch_device, +) +from verl.utils.fs import copy_to_local +from verl.utils.fsdp_utils import ( + fsdp_version, +) +from verl.utils.import_utils import import_external_libs +from verl.utils.model import get_generation_config, update_model_config +from verl.utils.vllm_utils import patch_vllm_moe_model_weight_loader +from verl.workers.fsdp_workers import ActorRolloutRefWorker as ARRWorker +from verl.workers.fsdp_workers import CriticWorker + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + +device_name = get_device_name() + +__all__ = ["ActorRolloutRefWorker", "AsyncActorRolloutRefWorker", "CriticWorker", "RolloutWorker"] + + +class ActorRolloutRefWorker(ARRWorker): + def _get_actor_params(self): + assert self._is_actor + params = self.actor_module_fsdp.state_dict() + from verl.utils.model import convert_weight_keys + + params = convert_weight_keys( + params, getattr(self.actor_module_fsdp, "_fsdp_wrapped_module", self.actor_module_fsdp) + ) + return params + + @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False) + def sync_rollout_weights(self): + assert (self._is_actor or self._is_rollout) and not self.config.hybrid_engine + assert hasattr(self, "_weights_info") and self._weights_info is not None + + params = self._get_actor_params() if self._is_actor else None + if self._is_rollout: + inference_model = ( + self.rollout.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model + ) + patch_vllm_moe_model_weight_loader(inference_model) + for key, shape, dtype in self._weights_info: + tensor = torch.empty(shape, dtype=dtype, device=get_torch_device().current_device()) + if self._is_actor: + assert key in params + origin_data = params[key] + if hasattr(origin_data, "full_tensor"): + origin_data = origin_data.full_tensor() + if torch.distributed.get_rank() == 0: + tensor.copy_(origin_data) + from ray.util.collective import collective + + collective.broadcast(tensor, src_rank=0, group_name="actor_rollout") + if self._is_rollout: + inference_model.load_weights([(key, tensor)]) + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def get_actor_weights_info(self): + assert self._is_actor + if hasattr(self, "_weights_info"): + return self._weights_info + if fsdp_version(self.actor_module_fsdp) == 1: + from torch.distributed.fsdp.api import ShardedStateDictConfig, StateDictType + + FSDP.set_state_dict_type( + self.actor_module_fsdp, + state_dict_type=StateDictType.SHARDED_STATE_DICT, + state_dict_config=ShardedStateDictConfig(), + ) + params = self._get_actor_params() + ret = [] + for key, tensor in params.items(): + ret.append((key, tensor.size(), tensor.dtype)) + self._weights_info = ret + return ret + + +class RolloutWorker(ActorRolloutRefWorker): + def __init__(self, config: DictConfig, role: str): + Worker.__init__(self) + assert role == "rollout" + self.config = config + import torch.distributed + + if not torch.distributed.is_initialized(): + rank = int(os.environ.get("RANK", 0)) + world_size = int(os.environ.get("WORLD_SIZE", 1)) + torch.distributed.init_process_group( + backend=f"cpu:gloo,{get_device_name()}:{get_nccl_backend()}", + rank=rank, + world_size=world_size, + init_method=os.environ.get("DIST_INIT_METHOD", None), + ) + # TODO(haibin.lin): + # As of now the type of config is DictConfig, if we assign config.profiler with ProfilerConfig, + # it will actually convert the ProfilerConfig dataclass back to a DictConfig. + # We can still use ProfilerConfig for testing purpose (tests/utils/test_nvtx_profile.py) + # as they provides DictConfig-like interface + # The benefit of creating the dataclass config is to perform validation during __post_init__ + profiler_config = omega_conf_to_dataclass(config.rollout.get("profiler", {})) + DistProfilerExtension.__init__(self, DistProfiler(rank=self.rank, config=profiler_config)) + self._is_rollout = True + self._is_actor = False + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def init_model(self): + # This is used to import external_lib into the huggingface systems + import_external_libs(self.config.model.get("external_lib", None)) + override_model_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create())) + + use_shm = self.config.model.get("use_shm", False) + local_path = copy_to_local(self.config.model.path, use_shm=use_shm) + trust_remote_code = self.config.model.get("trust_remote_code", False) + + self.tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) + self.processor = hf_processor(local_path, trust_remote_code=trust_remote_code) + + if self.config.model.get("custom_chat_template", None) is not None: + if self.processor is not None: + self.processor.chat_template = self.config.model.custom_chat_template + else: + self.tokenizer.chat_template = self.config.model.custom_chat_template + + # override model kwargs + actor_model_config = AutoConfig.from_pretrained( + local_path, trust_remote_code=trust_remote_code, attn_implementation="flash_attention_2" + ) + + # patch for kimi-vl + if getattr(actor_model_config, "model_type", None) == "kimi_vl": + actor_model_config.text_config.topk_method = "greedy" + + self.generation_config = get_generation_config(local_path, trust_remote_code=trust_remote_code) + + override_config_kwargs = { + "bos_token_id": self.tokenizer.bos_token_id, + "eos_token_id": self.tokenizer.eos_token_id, + "pad_token_id": self.tokenizer.pad_token_id, + } + override_config_kwargs.update(override_model_config) + update_model_config(actor_model_config, override_config_kwargs=override_config_kwargs) + if self.rank == 0: + print(f"Model config after override: {actor_model_config}") + + infer_tp = self.config.rollout.tensor_model_parallel_size + dp = self.world_size // infer_tp + assert self.world_size % infer_tp == 0, ( + f"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}" + ) + rollout_device_mesh = init_device_mesh( + device_name, mesh_shape=(dp, infer_tp), mesh_dim_names=["dp", "infer_tp"] + ) + rollout_name = self.config.rollout.name + assert rollout_name == "vllm" + + from verl.workers.rollout.vllm_rollout import vLLMRollout + + log_gpu_memory_usage(f"Before building {rollout_name} rollout", logger=logger) + + from verl.workers.rollout.vllm_rollout import vLLMAsyncRollout + + vllm_rollout_cls = vLLMRollout if self.config.rollout.mode == "sync" else vLLMAsyncRollout + rollout = vllm_rollout_cls( + model_path=local_path, + config=self.config.rollout, + tokenizer=self.tokenizer, + model_hf_config=actor_model_config, + device_mesh=rollout_device_mesh, + trust_remote_code=trust_remote_code, + ) + log_gpu_memory_usage(f"After building {rollout_name} rollout", logger=logger) + from .vllm_sharding_manager import VLLMShardingManager + + rollout_sharding_manager = VLLMShardingManager( + inference_engine=rollout.inference_engine, device_mesh=rollout_device_mesh + ) + + log_gpu_memory_usage("After building sharding manager", logger=logger) + + self.rollout = rollout + self.rollout_sharding_manager = rollout_sharding_manager + + @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO, blocking=False) + def async_generate_sequences(self, *args, **kwargs): + return super().generate_sequences(*args, **kwargs) + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def set_actor_weights_info(self, weights_info): + assert self._is_rollout + self._weights_info = weights_info + + +class AsyncActorRolloutRefWorker(ActorRolloutRefWorker): + def __init__(self, *args, **kwargs): + raise NotImplementedError diff --git a/recipe/one_step_off_policy/grpo_0.6b_gsm8k_fsdp2_2_6.sh b/recipe/one_step_off_policy/grpo_0.6b_gsm8k_fsdp2_2_6.sh new file mode 100644 index 00000000000..09048fd0340 --- /dev/null +++ b/recipe/one_step_off_policy/grpo_0.6b_gsm8k_fsdp2_2_6.sh @@ -0,0 +1,65 @@ +set -x + +project_name='GRPO' +exp_name='GRPO-Qwen3-0.6b-gsm8k-fsdp2-one-step-off-2-6' + +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen3-0.6B"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/gsm8k/train.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/gsm8k/test.parquet"} + +NNODES=${NNODES:-1} +NGPUS_PER_NODE=${NGPUS_PER_NODE:-8} + +n_gpus_rollout=2 +n_gpus_training=$((NGPUS_PER_NODE - n_gpus_rollout)) + + +python3 -m recipe.one_step_off_policy.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.train_batch_size=1152 \ + data.max_prompt_length=512 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.actor.strategy=fsdp2 \ + critic.strategy=fsdp2 \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.hybrid_engine=False \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=192 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.rollout.load_format=safetensors \ + actor_rollout_ref.rollout.layered_summon=True \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.val_before_train=True \ + trainer.logger=['console','tensorboard'] \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.save_freq=-1 \ + trainer.test_freq=5 \ + trainer.total_epochs=2 \ + trainer.nnodes="${NNODES}" \ + trainer.n_gpus_per_node="${n_gpus_training}" \ + rollout.nnodes="${NNODES}" \ + rollout.n_gpus_per_node="${n_gpus_rollout}" $@ \ No newline at end of file diff --git a/recipe/one_step_off_policy/grpo_3b_gsm8k_fsdp2_2_6.sh b/recipe/one_step_off_policy/grpo_3b_gsm8k_fsdp2_2_6.sh new file mode 100644 index 00000000000..a0d3bdb8ce8 --- /dev/null +++ b/recipe/one_step_off_policy/grpo_3b_gsm8k_fsdp2_2_6.sh @@ -0,0 +1,64 @@ +set -x + +project_name='GRPO' +exp_name='GRPO-Qwen3-0.6b-gsm8k-fsdp2-one-step-off-2-6' + +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen/Qwen2.5-3B-Instruct"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/gsm8k/train.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/gsm8k/test.parquet"} + +NNODES=${NNODES:-1} +NGPUS_PER_NODE=${NGPUS_PER_NODE:-8} + +n_gpus_rollout=2 +n_gpus_training=$((NGPUS_PER_NODE - n_gpus_rollout)) + +python3 -m recipe.one_step_off_policy.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.train_batch_size=1152 \ + data.max_prompt_length=512 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.actor.strategy=fsdp2 \ + critic.strategy=fsdp2 \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.hybrid_engine=False \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=192 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.rollout.load_format=safetensors \ + actor_rollout_ref.rollout.layered_summon=True \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.val_before_train=True \ + trainer.logger=['console','tensorboard'] \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.save_freq=-1 \ + trainer.test_freq=5 \ + trainer.total_epochs=2 \ + trainer.nnodes="${NNODES}" \ + trainer.n_gpus_per_node="${n_gpus_training}" \ + rollout.nnodes="${NNODES}" \ + rollout.n_gpus_per_node="${n_gpus_rollout}" $@ \ No newline at end of file diff --git a/recipe/one_step_off_policy/main_ppo.py b/recipe/one_step_off_policy/main_ppo.py new file mode 100644 index 00000000000..cd7d1468da7 --- /dev/null +++ b/recipe/one_step_off_policy/main_ppo.py @@ -0,0 +1,228 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2025 Meituan Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Note that we don't combine the main with ray_trainer as ray_trainer is used by other main. +""" + +import os +import socket + +import hydra +import ray +from omegaconf import OmegaConf + +from verl.trainer.constants_ppo import PPO_RAY_RUNTIME_ENV +from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler +from verl.trainer.ppo.reward import load_reward_manager + +from .ray_trainer import OneStepOffRayTrainer + + +@hydra.main(config_path="config", config_name="one_step_off_ppo_trainer", version_base=None) +def main(config): + run_ppo(config) + + +# Define a function to run the PPO-like training process +def run_ppo(config) -> None: + # Check if Ray is not initialized + if not ray.is_initialized(): + # Initialize Ray with a local cluster configuration + # Set environment variables in the runtime environment to control tokenizer parallelism, + # NCCL debug level, VLLM logging level, and allow runtime LoRA updating + # `num_cpus` specifies the number of CPU cores Ray can use, obtained from the configuration + ray.init( + runtime_env=PPO_RAY_RUNTIME_ENV, + num_cpus=config.ray_init.num_cpus, + ) + + # Create a remote instance of the TaskRunner class, and + # Execute the `run` method of the TaskRunner instance remotely and wait for it to complete + if ( + OmegaConf.select(config.trainer, "profile_steps") is not None + and len(OmegaConf.select(config.trainer, "profile_steps")) > 0 + ): + nsight_options = OmegaConf.to_container(config.trainer.controller_nsight_options) + runner = TaskRunner.options(runtime_env={"nsight": nsight_options}).remote() + else: + runner = TaskRunner.remote() + ray.get(runner.run.remote(config)) + + # [Optional] get the path of the timeline trace file from the configuration, default to None + # This file is used for performance analysis + timeline_json_file = config.ray_init.get("timeline_json_file", None) + if timeline_json_file: + ray.timeline(filename=timeline_json_file) + + +@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head +class TaskRunner: + def run(self, config): + # Print the initial configuration. `resolve=True` will evaluate symbolic values. + from pprint import pprint + + from omegaconf import OmegaConf + + from verl.utils.fs import copy_to_local + + print(f"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}") + + pprint(OmegaConf.to_container(config, resolve=True)) + + OmegaConf.resolve(config) + + # Download the checkpoint from HDFS to the local machine. + # `use_shm` determines whether to use shared memory, which could lead to faster model loading if turned on + local_path = copy_to_local( + config.actor_rollout_ref.model.path, use_shm=config.actor_rollout_ref.model.get("use_shm", False) + ) + + # Instantiate the tokenizer and processor. + from verl.utils import hf_processor, hf_tokenizer + + trust_remote_code = config.data.get("trust_remote_code", False) + tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) + # Used for multimodal LLM, could be None + processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True) + + # Define worker classes based on the actor strategy. + if config.actor_rollout_ref.actor.strategy == "fsdp2": + assert config.actor_rollout_ref.actor.strategy == config.critic.strategy + from verl.single_controller.ray import RayWorkerGroup + + from .fsdp_workers import ( + ActorRolloutRefWorker, + AsyncActorRolloutRefWorker, + CriticWorker, + RolloutWorker, + ) + + actor_rollout_cls = ( + AsyncActorRolloutRefWorker + if config.actor_rollout_ref.rollout.mode == "async" + else ActorRolloutRefWorker + ) + ray_worker_group_cls = RayWorkerGroup + + elif config.actor_rollout_ref.actor.strategy == "megatron": + assert config.actor_rollout_ref.actor.strategy == config.critic.strategy + from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup + + from .megatron_workers import ( + ActorRolloutRefWorker, + AsyncActorRolloutRefWorker, + CriticWorker, + RolloutWorker, + ) + + actor_rollout_cls = ( + AsyncActorRolloutRefWorker + if config.actor_rollout_ref.rollout.mode == "async" + else ActorRolloutRefWorker + ) + ray_worker_group_cls = NVMegatronRayWorkerGroup + + else: + raise NotImplementedError + + from .ray_trainer import ResourcePoolManager, Role + + role_worker_mapping = { + Role.Actor: ray.remote(actor_rollout_cls), + Role.Rollout: ray.remote(RolloutWorker), + Role.Critic: ray.remote(CriticWorker), + } + + global_pool_id = "actor_pool" + + assert config.trainer.n_gpus_per_node > 0, "config.trainer.n_gpus_per_node must be greater than 0" + assert config.trainer.nnodes > 0, "config.trainer.nnodes must be greater than 0" + assert config.rollout.n_gpus_per_node > 0, "config.rollout.n_gpus_per_node must be greater than 0" + assert config.rollout.nnodes > 0, "config.rollout.nnodes must be greater than 0" + + actor_pool = [config.trainer.n_gpus_per_node] * config.trainer.nnodes + rollout_pool = [config.rollout.n_gpus_per_node] * config.rollout.nnodes + + resource_pool_spec = { + "actor_pool": actor_pool, + "rollout_pool": rollout_pool, + } + mapping = { + Role.Actor: "actor_pool", + Role.Rollout: "rollout_pool", + Role.Critic: "actor_pool", + } + print(f"resource_pool_spec: {resource_pool_spec}") + # We should adopt a multi-source reward function here: + # - for rule-based rm, we directly call a reward score + # - for model-based rm, we call a model + # - for code related prompt, we send to a sandbox if there are test cases + # finally, we combine all the rewards together + # The reward type depends on the tag of the data + if config.reward_model.enable: + if config.reward_model.strategy in ["fsdp2"]: + from verl.workers.fsdp_workers import RewardModelWorker + elif config.reward_model.strategy == "megatron": + from verl.workers.megatron_workers import RewardModelWorker + else: + raise NotImplementedError + role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker) + mapping[Role.RewardModel] = global_pool_id + + # Add a reference policy worker if KL loss or KL reward is used. + if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss: + role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker) + mapping[Role.RefPolicy] = global_pool_id + + # Load the reward manager for training and validation. + reward_fn = load_reward_manager( + config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {}) + ) + val_reward_fn = load_reward_manager( + config, tokenizer, num_examine=1, **config.reward_model.get("reward_kwargs", {}) + ) + resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) + + from verl.utils.dataset.rl_dataset import collate_fn + + # Create training and validation datasets. + train_dataset = create_rl_dataset(config.data.train_files, config.data, tokenizer, processor) + val_dataset = create_rl_dataset(config.data.val_files, config.data, tokenizer, processor) + train_sampler = create_rl_sampler(config.data, train_dataset) + + # Initialize the PPO trainer. + trainer = OneStepOffRayTrainer( + config=config, + tokenizer=tokenizer, + processor=processor, + role_worker_mapping=role_worker_mapping, + resource_pool_manager=resource_pool_manager, + ray_worker_group_cls=ray_worker_group_cls, + reward_fn=reward_fn, + val_reward_fn=val_reward_fn, + train_dataset=train_dataset, + val_dataset=val_dataset, + collate_fn=collate_fn, + train_sampler=train_sampler, + device_name=config.trainer.device, + ) + # Initialize the workers of the trainer. + trainer.init_workers() + # Start the training process. + trainer.fit() + + +if __name__ == "__main__": + main() diff --git a/recipe/one_step_off_policy/megatron_workers.py b/recipe/one_step_off_policy/megatron_workers.py new file mode 100644 index 00000000000..f81fbd94074 --- /dev/null +++ b/recipe/one_step_off_policy/megatron_workers.py @@ -0,0 +1,201 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright 2025 Meituan Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os + +import torch +import torch.distributed +from omegaconf import DictConfig + +from verl.single_controller.base.decorator import Dispatch, register +from verl.utils.debug import ( + log_gpu_memory_usage, +) +from verl.utils.device import get_device_name, get_torch_device +from verl.utils.fs import copy_to_local +from verl.utils.vllm_utils import patch_vllm_moe_model_weight_loader +from verl.workers.megatron_workers import ActorRolloutRefWorker as ARRWorker +from verl.workers.megatron_workers import CriticWorker, RewardModelWorker + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + +__all__ = ["ActorRolloutRefWorker", "AsyncActorRolloutRefWorker", "CriticWorker", "RewardModelWorker", "RolloutWorker"] + + +class ActorRolloutRefWorker(ARRWorker): + def __init__(self, config: DictConfig, role: str): + assert role in ["actor", "ref"] + tmp_role = "ref" if role == "ref" else "actor_rollout" + super().__init__(config, tmp_role) + if role == "actor": + self._is_rollout = False + self.role = role + + def _get_actor_params_generator(self): + assert self._is_actor + from verl.models.mcore import get_mcore_weight_converter + from verl.utils.megatron_utils import per_tensor_generator + + layer_name_mapping = { + "qkv_layer_name": "self_attention.linear_qkv.", + "gate_proj_layer_name": "linear_fc1.", + } + weight_converter = get_mcore_weight_converter(self.actor_model_config, self.dtype) + generator = per_tensor_generator( + self.actor.actor_module, + self.actor_model_config, + weight_converter, + self.tf_config, + layer_name_mapping, + ) + return generator + + @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False) + def sync_rollout_weights(self): + assert (self._is_actor or self._is_rollout) and not self.config.hybrid_engine + assert hasattr(self, "_weights_info") and self._weights_info is not None + + params_generator = self._get_actor_params_generator() if self._is_actor else None + if self._is_rollout: + inference_model = ( + self.rollout.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model + ) + patch_vllm_moe_model_weight_loader(inference_model) + for key, shape, dtype in self._weights_info: + if self._is_actor: + weight_key, weight = next(params_generator) + assert key == weight_key + assert shape == weight.size() + assert dtype == weight.dtype + + tensor = torch.empty(shape, dtype=dtype, device=get_torch_device().current_device()) + if self._is_actor and torch.distributed.get_rank() == 0: + tensor.copy_(weight) + from ray.util.collective import collective + + collective.broadcast(tensor, src_rank=0, group_name="actor_rollout") + if self._is_rollout: + inference_model.load_weights([(key, tensor)]) + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def get_actor_weights_info(self): + assert self._is_actor + if hasattr(self, "_weights_info"): + return self._weights_info + + params_generator = self._get_actor_params_generator() + ret = [] + for key, tensor in params_generator: + ret.append((key, tensor.size(), tensor.dtype)) + + self._weights_info = ret + return ret + + +class RolloutWorker(ActorRolloutRefWorker): + def __init__(self, config: DictConfig, role: str): + assert role == "rollout" + ARRWorker.__init__(self, config, role) + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def init_model(self): + if self.config.model.get("external_lib", None) is not None: + # This is used to import external_lib into the huggingface systems + import importlib + + importlib.import_module(self.config.model.external_lib) + + from omegaconf import OmegaConf + + from verl.utils.torch_dtypes import PrecisionType + + override_model_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create())) + override_transformer_config = {} + self.param_dtype = torch.bfloat16 + self.dtype = PrecisionType.to_dtype(self.param_dtype) + trust_remote_code = self.config.model.get("trust_remote_code", False) + + from verl.utils.model import get_generation_config + + self._init_hf_config_and_tf_config( + self.config.model.path, + self.config.model.path, + self.dtype, + override_model_config, + override_transformer_config, + trust_remote_code, + ) + self.generation_config = get_generation_config(self.local_path) + + from torch.distributed.device_mesh import init_device_mesh + + assert self.config.rollout.name == "vllm" + assert self.config.rollout.mode == "sync" + + from verl.workers.rollout.vllm_rollout import vLLMRollout + + from .vllm_sharding_manager import VLLMShardingManager + + # NOTE(sgm): If the QKV and gate_up projection layer are concate together in actor, + # we will reorganize their weight format when resharding from actor to rollout. + + infer_tp = self.config.rollout.tensor_model_parallel_size + dp = self.world_size // infer_tp + assert self.world_size % infer_tp == 0, ( + f"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}" + ) + rollout_device_mesh = init_device_mesh( + get_device_name(), mesh_shape=(dp, infer_tp), mesh_dim_names=["dp", "infer_tp"] + ) + log_gpu_memory_usage("Before building vllm rollout", logger=None) + + local_path = copy_to_local(self.config.model.path, use_shm=self.config.model.get("use_shm", False)) + from verl.workers.rollout.vllm_rollout import vLLMAsyncRollout + + vllm_rollout_cls = vLLMRollout if self.config.rollout.mode == "sync" else vLLMAsyncRollout + rollout = vllm_rollout_cls( + model_path=local_path, + config=self.config.rollout, + tokenizer=self.tokenizer, + model_hf_config=self.hf_config, + device_mesh=rollout_device_mesh, + trust_remote_code=trust_remote_code, + ) + log_gpu_memory_usage("After building vllm rollout", logger=logger) + + sharding_manager = VLLMShardingManager( + inference_engine=rollout.inference_engine, + device_mesh=rollout_device_mesh, + ) + log_gpu_memory_usage("After building sharding manager", logger=logger) + + self.rollout, self.sharding_manager = rollout, sharding_manager + self.rollout.sharding_manager = sharding_manager + + @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO, blocking=False) + def async_generate_sequences(self, *args, **kwargs): + return super().generate_sequences(*args, **kwargs) + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def set_actor_weights_info(self, weights_info): + assert self._is_rollout + self._weights_info = weights_info + + +class AsyncActorRolloutRefWorker(ActorRolloutRefWorker): + def __init__(self, *args, **kwargs): + raise NotImplementedError diff --git a/recipe/one_step_off_policy/ray_trainer.py b/recipe/one_step_off_policy/ray_trainer.py new file mode 100644 index 00000000000..1f7011bdf54 --- /dev/null +++ b/recipe/one_step_off_policy/ray_trainer.py @@ -0,0 +1,624 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# Copyright 2025 Meituan Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This trainer supports model-agonistic model initialization with huggingface +""" + +import uuid +from pprint import pprint + +import numpy as np +import ray +import torch +from omegaconf import OmegaConf +from torch.utils.data import Dataset, Sampler +from tqdm import tqdm + +from verl import DataProto +from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup +from verl.single_controller.ray.base import create_colocated_worker_cls +from verl.trainer.ppo import core_algos +from verl.trainer.ppo.core_algos import AdvantageEstimator, agg_loss +from verl.trainer.ppo.metric_utils import ( + compute_data_metrics, + compute_throughout_metrics, + compute_timing_metrics, +) +from verl.trainer.ppo.ray_trainer import ( + RayPPOTrainer, + ResourcePoolManager, + Role, + WorkerType, + apply_kl_penalty, + compute_advantage, + compute_response_mask, +) +from verl.trainer.ppo.reward import compute_reward, compute_reward_async +from verl.utils.debug import marked_timer +from verl.utils.metric import ( + reduce_metrics, +) +from verl.utils.tracking import ValidationGenerationsLogger + + +class GenerationBatchFuture: + """ + Wrapper class for encapsulating batch generation results + """ + + def __init__(self, epoch, batch, gen_batch_output): + """ + :param epoch: current epoch + :param batch: Input batch data + :param gen_batch_output: Generated sequences from the main model (DataProtoFuture) + """ + self.epoch = epoch + self.batch = batch + self.gen_batch_output = gen_batch_output + + def get(self): + """ + Get the actual results by calling get() method on gen_batch_output + + Returns: + tuple: (batch, gen_batch_result) + - batch: Original input batch data + - gen_batch_result: Result from gen_batch_output.get() or gen_batch_output itself + """ + # Call get() method on gen_batch_output if available + if hasattr(self.gen_batch_output, "get"): + gen_batch_result = self.gen_batch_output.get() + else: + gen_batch_result = self.gen_batch_output + + return self.epoch, self.batch, gen_batch_result + + +class OneStepOffRayTrainer(RayPPOTrainer): + # TODO: support each role have individual ray_worker_group_cls, + # i.e., support different backend of different role + def __init__( + self, + config, + tokenizer, + role_worker_mapping: dict[Role, WorkerType], + resource_pool_manager: ResourcePoolManager, + ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup, + processor=None, + reward_fn=None, + val_reward_fn=None, + train_dataset: Dataset | None = None, + val_dataset: Dataset | None = None, + collate_fn=None, + train_sampler: Sampler | None = None, + device_name="cuda", + ): + """ + Initialize distributed PPO trainer with Ray backend. + Note that this trainer runs on the driver process on a single CPU/GPU node. + + Args: + config: Configuration object containing training parameters. + tokenizer: Tokenizer used for encoding and decoding text. + role_worker_mapping (dict[Role, WorkerType]): Mapping from roles to worker classes. + resource_pool_manager (ResourcePoolManager): Manager for Ray resource pools. + ray_worker_group_cls (RayWorkerGroup, optional): Class for Ray worker groups. Defaults to RayWorkerGroup. + processor: Optional data processor, used for multimodal data + reward_fn: Function for computing rewards during training. + val_reward_fn: Function for computing rewards during validation. + train_dataset (Optional[Dataset], optional): Training dataset. Defaults to None. + val_dataset (Optional[Dataset], optional): Validation dataset. Defaults to None. + collate_fn: Function to collate data samples into batches. + train_sampler (Optional[Sampler], optional): Sampler for the training dataset. Defaults to None. + device_name (str, optional): Device name for training (e.g., "cuda", "cpu"). Defaults to "cuda". + """ + + # Store the tokenizer for text processing + self.tokenizer = tokenizer + self.processor = processor + self.config = config + self.reward_fn = reward_fn + self.val_reward_fn = val_reward_fn + + self.hybrid_engine = config.actor_rollout_ref.hybrid_engine + + assert not self.hybrid_engine + + self.role_worker_mapping = role_worker_mapping + self.resource_pool_manager = resource_pool_manager + self.use_reference_policy = Role.RefPolicy in role_worker_mapping + self.use_rm = Role.RewardModel in role_worker_mapping + self.ray_worker_group_cls = ray_worker_group_cls + self.device_name = device_name + self.validation_generations_logger = ValidationGenerationsLogger() + + # if ref_in_actor is True, the reference policy will be actor without lora applied + self.ref_in_actor = config.actor_rollout_ref.model.get("lora_rank", 0) > 0 + + # define in-reward KL control + # kl loss control currently not suppoorted + if config.algorithm.use_kl_in_reward: + self.kl_ctrl_in_reward = core_algos.get_kl_controller(config.algorithm.kl_ctrl) + + if self.config.algorithm.adv_estimator == AdvantageEstimator.GAE: + self.use_critic = True + elif self.config.algorithm.adv_estimator in [ + AdvantageEstimator.GRPO, + AdvantageEstimator.GRPO_PASSK, + AdvantageEstimator.REINFORCE_PLUS_PLUS, + # AdvantageEstimator.REMAX, # TODO:REMAX advantage estimator is not yet supported in one_step_off_policy + AdvantageEstimator.RLOO, + AdvantageEstimator.OPO, + AdvantageEstimator.REINFORCE_PLUS_PLUS_BASELINE, + AdvantageEstimator.GPG, + ]: + self.use_critic = False + else: + raise NotImplementedError + + self._validate_config() + self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler) + + def _validate(self): + self.actor_rollout_wg = self.rollout_wg + ret = super()._validate() + self.actor_rollout_wg = self.actor_wg + return ret + + def init_workers(self): + """Initialize distributed training workers using Ray backend. + + Creates: + 1. Ray resource pools from configuration + 2. Worker groups for each role (actor, critic, etc.) + """ + self.resource_pool_manager.create_resource_pool() + + self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()} + + # create actor and rollout + for role, role_name in [(Role.Actor, "actor"), (Role.Rollout, "rollout")]: + resource_pool = self.resource_pool_manager.get_resource_pool(role) + role_cls = RayClassWithInitArgs( + cls=self.role_worker_mapping[role], + config=self.config.actor_rollout_ref, + role=role_name, + ) + self.resource_pool_to_cls[resource_pool][role_name] = role_cls + + # create critic + if self.use_critic: + resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic) + critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=self.config.critic) + self.resource_pool_to_cls[resource_pool]["critic"] = critic_cls + + # create reference policy if needed + if self.use_reference_policy: + resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy) + ref_policy_cls = RayClassWithInitArgs( + self.role_worker_mapping[Role.RefPolicy], + config=self.config.actor_rollout_ref, + role="ref", + profile_option=self.config.trainer.npu_profile.options, + ) + self.resource_pool_to_cls[resource_pool]["ref"] = ref_policy_cls + + # create a reward model if reward_fn is None + if self.use_rm: + # we create a RM here + resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel) + rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model) + self.resource_pool_to_cls[resource_pool]["rm"] = rm_cls + + # initialize WorkerGroup + # NOTE: if you want to use a different resource pool for each role, which can support different parallel size, + # you should not use `create_colocated_worker_cls`. + # Instead, directly pass different resource pool to different worker groups. + # See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information. + all_wg = {} + wg_kwargs = {} # Setting up kwargs for RayWorkerGroup + if OmegaConf.select(self.config.trainer, "ray_wait_register_center_timeout") is not None: + wg_kwargs["ray_wait_register_center_timeout"] = self.config.trainer.ray_wait_register_center_timeout + if OmegaConf.select(self.config.trainer, "profile_steps") is not None: + wg_kwargs["profile_steps"] = OmegaConf.select(self.config.trainer, "profile_steps") + assert OmegaConf.select(self.config.trainer, "worker_nsight_options") is not None, ( + "worker_nsight_options must be set when profile_steps is set" + ) + wg_kwargs["worker_nsight_options"] = OmegaConf.to_container( + OmegaConf.select(self.config.trainer, "worker_nsight_options") + ) + + for resource_pool, class_dict in self.resource_pool_to_cls.items(): + worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) + wg_dict = self.ray_worker_group_cls( + resource_pool=resource_pool, + ray_cls_with_init=worker_dict_cls, + device_name=self.device_name, + **wg_kwargs, + ) + spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) + all_wg.update(spawn_wg) + + if self.use_critic: + self.critic_wg = all_wg["critic"] + self.critic_wg.init_model() + + if self.use_reference_policy and not self.ref_in_actor: + self.ref_policy_wg = all_wg["ref"] + self.ref_policy_wg.init_model() + + if self.use_rm: + self.rm_wg = all_wg["rm"] + self.rm_wg.init_model() + + self.actor_wg = all_wg["actor"] + self.rollout_wg = all_wg["rollout"] + self.actor_wg.init_model() + self.rollout_wg.init_model() + self.actor_rollout_wg = self.actor_wg # to be compatible with the functions that not be modified + weights_info = self.actor_wg.get_actor_weights_info()[0] + self.rollout_wg.set_actor_weights_info(weights_info) + from ray.util.collective import collective + + actor_rollout_workers = self.actor_wg.workers + self.rollout_wg.workers + collective.create_collective_group( + actor_rollout_workers, + len(actor_rollout_workers), + list(range(0, len(actor_rollout_workers))), + backend="nccl", + group_name="actor_rollout", + ) + self.sync_rollout_weights() + + # create async rollout manager and request scheduler + self.async_rollout_mode = False + if self.config.actor_rollout_ref.rollout.mode == "async" and self._is_rollout: + from verl.workers.rollout.async_server import AsyncLLMServerManager + + self.async_rollout_mode = True + self.async_rollout_manager = AsyncLLMServerManager( + config=self.config, + worker_group=self.rollout_wg, + ) + + def sync_rollout_weights(self): + if not self.hybrid_engine: + self.actor_wg.sync_rollout_weights() + ray.get(self.rollout_wg.sync_rollout_weights()) + + def _create_continuous_iterator(self): + """ + Create a continuous data iterator across epoch + """ + for epoch in range(self.config.trainer.total_epochs): + iterator = iter(self.train_dataloader) + for batch_dict in iterator: + yield epoch, batch_dict + + def _async_gen_next_batch(self, continuous_iterator): + """ + Call parameter synchronization and asynchronous sequence generation. + """ + try: + epoch, batch_dict = next(continuous_iterator) + except StopIteration: + return None + except Exception as e: + print(f"Error in async_gen_next_batch: {e}") + return None + batch = DataProto.from_single_dict(batch_dict) + # pop those keys for generation + batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"] + non_tensor_batch_keys_to_pop = ["raw_prompt_ids"] + if "multi_modal_data" in batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.append("multi_modal_data") + if "raw_prompt" in batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.append("raw_prompt") + if "tools_kwargs" in batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.append("tools_kwargs") + if "interaction_kwargs" in batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.append("interaction_kwargs") + gen_batch = batch.pop( + batch_keys=batch_keys_to_pop, + non_tensor_batch_keys=non_tensor_batch_keys_to_pop, + ) + gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) + # sync weights from actor to rollout + self.sync_rollout_weights() + # async generation + gen_batch_output = self.rollout_wg.async_generate_sequences(gen_batch) + return GenerationBatchFuture(epoch, batch, gen_batch_output) + + def fit(self): + """ + The training loop of PPO. + The driver process only need to call the compute functions of the worker group through RPC + to construct the PPO dataflow. + The light-weight advantage computation is done on the driver process. + """ + from omegaconf import OmegaConf + + from verl.utils.tracking import Tracking + + logger = Tracking( + project_name=self.config.trainer.project_name, + experiment_name=self.config.trainer.experiment_name, + default_backend=self.config.trainer.logger, + config=OmegaConf.to_container(self.config, resolve=True), + ) + + self.global_steps = 0 + + # load checkpoint before doing anything + self._load_checkpoint() + + # perform validation before training + # currently, we only support validation using the reward_function. + if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): + val_metrics = self._validate() + assert val_metrics, f"{val_metrics=}" + pprint(f"Initial validation metrics: {val_metrics}") + logger.log(data=val_metrics, step=self.global_steps) + if self.config.trainer.get("val_only", False): + return + + # add tqdm + progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress") + + # we start from step 1 + self.global_steps += 1 + last_val_metrics = None + + # across epoch iterator + continuous_iterator = self._create_continuous_iterator() + + # Start the first asynchronous generation task. + batch_data_future = self._async_gen_next_batch(continuous_iterator) + + while batch_data_future is not None: + do_profile = ( + self.global_steps in self.config.trainer.profile_steps + if self.config.trainer.profile_steps is not None + else False + ) + if do_profile: + self.actor_wg.start_profile() + if not self.hybrid_engine: + self.rollout_wg.start_profile() + if self.use_reference_policy: + self.ref_policy_wg.start_profile() + if self.use_critic: + self.critic_wg.start_profile() + if self.use_rm: + self.rm_wg.start_profile() + + metrics = {} + timing_raw = {} + is_last_step = self.global_steps >= self.total_training_steps + + with marked_timer("step", timing_raw): + # wait for the previous batch + with marked_timer("wait_prev_gen", timing_raw, color="red"): + epoch, batch, gen_batch_output = batch_data_future.get() + timing_raw.update(gen_batch_output.meta_info["timing"]) + gen_batch_output.meta_info.pop("timing", None) + + # asys next generation (with syns weights from actor to rollout) + with marked_timer("sync_rollout_weights", timing_raw, color="purple"): + if not is_last_step: + batch_data_future = self._async_gen_next_batch(continuous_iterator) + + batch.non_tensor_batch["uid"] = np.array( + [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object + ) + # repeat to align with repeated responses in rollout + batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) + batch = batch.union(gen_batch_output) + + batch.batch["response_mask"] = compute_response_mask(batch) + # Balance the number of valid tokens across DP ranks. + # NOTE: This usually changes the order of data in the `batch`, + # which won't affect the advantage calculation (since it's based on uid), + # but might affect the loss calculation (due to the change of mini-batching). + # TODO: Decouple the DP balancing and mini-batching. + if self.config.trainer.balance_batch: + self._balance_batch(batch, metrics=metrics) + + # compute global_valid tokens + batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() + + with marked_timer("reward", timing_raw, color="yellow"): + # compute reward model score + if self.use_rm: + reward_tensor = self.rm_wg.compute_rm_score(batch) + batch = batch.union(reward_tensor) + + if self.config.reward_model.launch_reward_fn_async: + future_reward = compute_reward_async.remote(batch, self.config, self.tokenizer) + else: + reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn) + + # recompute old_log_probs + with marked_timer("old_log_prob", timing_raw, color="blue"): + old_log_prob = self.actor_wg.compute_log_prob(batch) + entropys = old_log_prob.batch["entropys"] + response_masks = batch.batch["response_mask"] + loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode + entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode) + old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()} + metrics.update(old_log_prob_metrics) + old_log_prob.batch.pop("entropys") + batch = batch.union(old_log_prob) + + if "rollout_log_probs" in batch.batch.keys(): + # TODO: we may want to add diff of probs too. + rollout_old_log_probs = batch.batch["rollout_log_probs"] + actor_old_log_probs = batch.batch["old_log_probs"] + attention_mask = batch.batch["attention_mask"] + responses = batch.batch["responses"] + response_length = responses.size(1) + response_mask = attention_mask[:, -response_length:] + + rollout_probs = torch.exp(rollout_old_log_probs) + actor_probs = torch.exp(actor_old_log_probs) + rollout_probs_diff = torch.abs(rollout_probs - actor_probs) + rollout_probs_diff = torch.masked_select(rollout_probs_diff, response_mask.bool()) + rollout_probs_diff_max = torch.max(rollout_probs_diff) + rollout_probs_diff_mean = torch.mean(rollout_probs_diff) + rollout_probs_diff_std = torch.std(rollout_probs_diff) + metrics.update( + { + "training/rollout_probs_diff_max": rollout_probs_diff_max.detach().item(), + "training/rollout_probs_diff_mean": rollout_probs_diff_mean.detach().item(), + "training/rollout_probs_diff_std": rollout_probs_diff_std.detach().item(), + } + ) + + if self.use_reference_policy: + # compute reference log_prob + with marked_timer("ref", timing_raw, color="olive"): + if not self.ref_in_actor: + ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) + else: + ref_log_prob = self.actor_wg.compute_ref_log_prob(batch) + batch = batch.union(ref_log_prob) + + # compute values + if self.use_critic: + with marked_timer("values", timing_raw, color="cyan"): + values = self.critic_wg.compute_values(batch) + batch = batch.union(values) + + with marked_timer("adv", timing_raw, color="brown"): + # we combine with rule-based rm + reward_extra_infos_dict: dict[str, list] + if self.config.reward_model.launch_reward_fn_async: + reward_tensor, reward_extra_infos_dict = ray.get(future_reward) + batch.batch["token_level_scores"] = reward_tensor + + if reward_extra_infos_dict: + batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()}) + + # compute rewards. apply_kl_penalty if available + if self.config.algorithm.use_kl_in_reward: + batch, kl_metrics = apply_kl_penalty( + batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty + ) + metrics.update(kl_metrics) + else: + batch.batch["token_level_rewards"] = batch.batch["token_level_scores"] + + # compute advantages, executed on the driver process + + norm_adv_by_std_in_grpo = self.config.algorithm.get( + "norm_adv_by_std_in_grpo", True + ) # GRPO adv normalization factor + + batch = compute_advantage( + batch, + adv_estimator=self.config.algorithm.adv_estimator, + gamma=self.config.algorithm.gamma, + lam=self.config.algorithm.lam, + num_repeat=self.config.actor_rollout_ref.rollout.n, + norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, + config=self.config.algorithm, + ) + + # update critic + if self.use_critic: + with marked_timer("update_critic", timing_raw, color="pink"): + critic_output = self.critic_wg.update_critic(batch) + critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"]) + metrics.update(critic_output_metrics) + + # implement critic warmup + if self.config.trainer.critic_warmup <= self.global_steps: + # update actor + with marked_timer("update_actor", timing_raw, color="red"): + batch.meta_info["multi_turn"] = self.config.actor_rollout_ref.rollout.multi_turn.enable + actor_output = self.actor_wg.update_actor(batch) + actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) + metrics.update(actor_output_metrics) + + # Log rollout generations if enabled + rollout_data_dir = self.config.trainer.get("rollout_data_dir", None) + if rollout_data_dir: + with marked_timer("dump_rollout_generations", timing_raw, color="green"): + inputs = self.tokenizer.batch_decode(batch.batch["prompts"], skip_special_tokens=True) + outputs = self.tokenizer.batch_decode(batch.batch["responses"], skip_special_tokens=True) + scores = batch.batch["token_level_scores"].sum(-1).cpu().tolist() + self._dump_generations( + inputs=inputs, + outputs=outputs, + scores=scores, + reward_extra_infos_dict=reward_extra_infos_dict, + dump_path=rollout_data_dir, + ) + + # validate + if ( + self.val_reward_fn is not None + and self.config.trainer.test_freq > 0 + and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0) + ): + with marked_timer("testing", timing_raw, color="green"): + val_metrics: dict = self._validate() + if is_last_step: + last_val_metrics = val_metrics + metrics.update(val_metrics) + + if self.config.trainer.save_freq > 0 and ( + is_last_step or self.global_steps % self.config.trainer.save_freq == 0 + ): + with marked_timer("save_checkpoint", timing_raw, color="green"): + self._save_checkpoint() + + # training metrics + metrics.update( + { + "training/global_step": self.global_steps, + "training/epoch": epoch, + } + ) + # collect metrics + metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) + metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) + # TODO: implement actual tflpo and theoretical tflpo + n_gpus = self.resource_pool_manager.get_n_gpus() + metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus)) + + # TODO: make a canonical logger that supports various backend + logger.log(data=metrics, step=self.global_steps) + + progress_bar.update(1) + self.global_steps += 1 + + if do_profile: + self.actor_wg.stop_profile() + if not self.hybrid_engine: + self.rollout_wg.stop_profile() + if self.use_reference_policy: + self.ref_policy_wg.stop_profile() + if self.use_critic: + self.critic_wg.stop_profile() + if self.use_rm: + self.rm_wg.stop_profile() + + if is_last_step: + pprint(f"Final validation metrics: {last_val_metrics}") + progress_bar.close() + return diff --git a/recipe/one_step_off_policy/vllm_sharding_manager.py b/recipe/one_step_off_policy/vllm_sharding_manager.py new file mode 100644 index 00000000000..c33ba585470 --- /dev/null +++ b/recipe/one_step_off_policy/vllm_sharding_manager.py @@ -0,0 +1,74 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright 2025 Meituan Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os + +from torch.distributed.device_mesh import DeviceMesh + +from verl import DataProto +from verl.protocol import all_gather_data_proto +from verl.third_party.vllm import parallel_state as vllm_ps +from verl.utils.debug import GPUMemoryLogger +from verl.utils.device import get_torch_device +from verl.utils.torch_functional import check_device_is_available +from verl.workers.sharding_manager.base import BaseShardingManager + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +class VLLMShardingManager(BaseShardingManager): + @check_device_is_available() + def __init__(self, inference_engine, device_mesh: DeviceMesh): + self.device_mesh = device_mesh + self.inference_engine = inference_engine + inference_engine.wake_up() + assert device_mesh is not None + assert inference_engine is not None + self.tp_size = self.device_mesh["infer_tp"].size() + self.tp_rank = self.device_mesh["infer_tp"].get_local_rank() + self.timing = {} + gen_dp_rank = self.device_mesh["dp"].get_local_rank() + get_torch_device().manual_seed(gen_dp_rank + 1000) + self.gen_random_states = get_torch_device().get_rng_state() + + @GPUMemoryLogger(role="vllm sharding_manager", logger=logger) + def __enter__(self): + get_torch_device().set_rng_state(self.gen_random_states) + + @GPUMemoryLogger(role="vllm sharding_manager", logger=logger) + def __exit__(self, exc_type, exc_value, traceback): + self.gen_random_states = get_torch_device().get_rng_state() + self.inference_engine.reset_prefix_cache() + + @GPUMemoryLogger(role="vllm sharding_manager", logger=logger) + def preprocess_data(self, data: DataProto) -> DataProto: + """All gather across tp group to make each rank has identical input.""" + if self.tp_size == 1: + return data + + group = vllm_ps.get_tensor_model_parallel_group().device_group + + all_gather_data_proto(data=data, process_group=group) + return data + + @GPUMemoryLogger(role="vllm sharding_manager", logger=logger) + def postprocess_data(self, data: DataProto) -> DataProto: + """Get chunk data of this tp rank since we do all gather in preprocess.""" + if self.tp_size == 1: + return data + + return data.chunk(chunks=self.tp_size)[self.tp_rank] diff --git a/tests/special_e2e/run_one_step_off_policy.sh b/tests/special_e2e/run_one_step_off_policy.sh new file mode 100755 index 00000000000..84fdd1d8113 --- /dev/null +++ b/tests/special_e2e/run_one_step_off_policy.sh @@ -0,0 +1,173 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +# Test script for one_step_off_policy E2E regression testing +# This script runs one_step_off_policy with both FSDP2 and Megatron backends +# to ensure the asynchronous training mechanism works correctly + +NUM_GPUS=${NUM_GPUS:-8} +ACTOR_STRATEGY=${ACTOR_STRATEGY:-"fsdp2"} # fsdp2 or megatron + +# Download model if not exists +MODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B-Instruct} +MODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}} +huggingface-cli download "${MODEL_ID}" --local-dir "${MODEL_PATH}" + +# Algorithm parameters +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +# Response length parameters +max_prompt_length=1024 +max_response_length=2048 +enable_overlong_buffer=True +overlong_buffer_len=128 +overlong_penalty_factor=1.0 + +# Training parameters +loss_agg_mode="token-mean" +train_prompt_bsz=8 +n_resp_per_prompt=3 +train_prompt_mini_bsz=4 + +# Temperature parameters +temperature=1.0 +top_p=1.0 +top_k=-1 +val_top_p=0.7 + +# One-step-off-policy specific parameters +# Allocate 2 GPUs for rollout, remaining for training +n_gpus_rollout=2 +n_gpus_training=$((NUM_GPUS - n_gpus_rollout)) + +exp_name="$(basename "${MODEL_ID,,}")-one-step-off-policy-${ACTOR_STRATEGY}-minimal" + +echo "Running one_step_off_policy with ${ACTOR_STRATEGY} strategy" +echo "Total GPUs: ${NUM_GPUS}, Rollout GPUs: ${n_gpus_rollout}, Training GPUs: ${n_gpus_training}" + +# Common parameters for both FSDP2 and Megatron +common_params=( + data.train_files="${HOME}/data/gsm8k/train.parquet" + data.val_files="${HOME}/data/gsm8k/test.parquet" + data.prompt_key=prompt + data.truncation='left' + data.max_prompt_length=${max_prompt_length} + data.max_response_length=${max_response_length} + data.train_batch_size=${train_prompt_bsz} + actor_rollout_ref.rollout.n=${n_resp_per_prompt} + algorithm.adv_estimator=${adv_estimator} + algorithm.use_kl_in_reward=${use_kl_in_reward} + algorithm.kl_ctrl.kl_coef=${kl_coef} + actor_rollout_ref.hybrid_engine=False \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} + actor_rollout_ref.actor.clip_ratio_c=10.0 + actor_rollout_ref.model.path="${MODEL_PATH}" + actor_rollout_ref.model.enable_gradient_checkpointing=True + actor_rollout_ref.actor.optim.lr=1e-6 + actor_rollout_ref.actor.optim.lr_warmup_steps=-1 + actor_rollout_ref.actor.optim.weight_decay=0.1 + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} + actor_rollout_ref.actor.entropy_coeff=0 + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} + actor_rollout_ref.rollout.gpu_memory_utilization=0.80 + actor_rollout_ref.rollout.temperature=${temperature} + actor_rollout_ref.rollout.top_p=${top_p} + actor_rollout_ref.rollout.top_k=${top_k} + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} + actor_rollout_ref.rollout.val_kwargs.do_sample=True + actor_rollout_ref.rollout.val_kwargs.n=1 + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + reward_model.reward_manager=dapo + +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} + +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} + +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} + +reward_model.reward_kwargs.overlong_buffer_cfg.log=False + +reward_model.reward_kwargs.max_resp_len=${max_response_length} + trainer.logger=['console'] + trainer.project_name='verl-test' + trainer.experiment_name="${exp_name}" + trainer.val_before_train=False + trainer.test_freq=-1 + trainer.save_freq=-1 + trainer.total_epochs=2 + trainer.total_training_steps=2 + trainer.resume_mode=disable + trainer.nnodes=1 + trainer.n_gpus_per_node=${n_gpus_training} + rollout.nnodes=1 + rollout.n_gpus_per_node=${n_gpus_rollout} + +) + +if [ "${ACTOR_STRATEGY}" == "fsdp2" ]; then + echo "Running with FSDP2 strategy..." + # FSDP2 specific parameters + gen_tp=2 + sp_size=2 + fsdp_size=2 + ref_offload=True + actor_offload=False + + python3 -m recipe.one_step_off_policy.main_ppo \ + "${common_params[@]}" \ + actor_rollout_ref.actor.strategy=fsdp2 \ + critic.strategy=fsdp2 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.use_dynamic_bsz=True \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=True \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=${actor_offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${actor_offload} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.ref.fsdp_config.param_offload=${ref_offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} $@ + +elif [ "${ACTOR_STRATEGY}" == "megatron" ]; then + echo "Running with Megatron strategy..." + # Megatron specific parameters + gen_tp=2 + train_tp=1 + train_pp=2 + ref_offload=True + actor_offload=False + + python3 -m recipe.one_step_off_policy.main_ppo \ + --config-path=config \ + --config-name='one_step_off_ppo_megatron_trainer.yaml' \ + "${common_params[@]}" \ + actor_rollout_ref.actor.strategy=megatron \ + critic.strategy=megatron \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.actor.megatron.param_offload=${actor_offload} \ + actor_rollout_ref.actor.megatron.optimizer_offload=${actor_offload} \ + actor_rollout_ref.actor.megatron.grad_offload=${actor_offload} \ + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} \ + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${train_pp} \ + actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${train_tp} \ + actor_rollout_ref.ref.megatron.param_offload=${ref_offload} $@ +else + echo "Error: Unknown strategy ${ACTOR_STRATEGY}. Please use 'fsdp2' or 'megatron'" + exit 1 +fi + +echo "One-step-off-policy E2E test completed successfully with ${ACTOR_STRATEGY} strategy" \ No newline at end of file diff --git a/tests/special_sanity/check_device_api_usage.py b/tests/special_sanity/check_device_api_usage.py index c8988db55a5..cb5c05a4ee1 100644 --- a/tests/special_sanity/check_device_api_usage.py +++ b/tests/special_sanity/check_device_api_usage.py @@ -28,6 +28,7 @@ "recipe/prime/prime_ray_trainer.py", # appear in default device_name "recipe/spin/spin_trainer.py", # appear in default device_name "recipe/sppo/sppo_ray_trainer.py", # appear in default device_name + "recipe/one_step_off_policy/ray_trainer.py", # appear in default device_name "verl/utils/profiler/nvtx_profile.py", # appear in NsightSystemsProfiler "verl/utils/kernel/linear_cross_entropy.py", # appear in nvidia nvtx "verl/utils/rendezvous/ray_backend.py", # appear in cupy importance diff --git a/verl/workers/megatron_workers.py b/verl/workers/megatron_workers.py index c94edfd72d9..475b0a51783 100644 --- a/verl/workers/megatron_workers.py +++ b/verl/workers/megatron_workers.py @@ -213,7 +213,7 @@ def megatron_actor_model_provider(pre_process, post_process): override_ddp_config=override_ddp_config, ) - if self._is_actor and self._is_rollout: + if self._is_actor or self._is_rollout: actor_module = make_model(wrap_with_ddp=True) print(f"actor_module: {len(actor_module)}") if self.config.actor.load_weight: @@ -402,7 +402,7 @@ def init_model(self): self.config.ref.megatron.get("override_transformer_config", OmegaConf.create()), resolve=True ) else: - override_transformer_config = None + override_transformer_config = {} self.param_dtype = torch.bfloat16 log_gpu_memory_usage("Before init actor model and optimizer", logger=logger) self.dtype = PrecisionType.to_dtype(self.param_dtype)