Skip to content

Conversation

@imh966
Copy link
Contributor

@imh966 imh966 commented Jun 27, 2025

What does this PR do?

This PR provides a simple implementation of one step off async training with fsdp and vllm backend.

We conducted three different experiments with qwen2.5_3b model on 8 A100 GPUs:

  1. baseline: all models are colocated
  2. standalone rollout: rollout model runs on 4 GPUs and other models run on remaining 4GPUs
  3. one step off: the same model placement as the second experiment, but with one step off async training

The pictures below demonstrate the results of these experiments:

In these experiments, baseline has the highest throughput, but we think it is just because we didn't find the best configure for one step off async training.

The exciting point is that our nccl based weights updating for rollout model has great performance. The latency is showed below:

At most of time, the latency is under 300ms, which is negligible for RLHF. Although it is only implemented with fsdp and vllm now, we think it is not complex to extend it to the other backend.

Checklist Before Starting

  • Search for similar PRs. Paste at least one query link here: ...
  • Format the PR title as [{modules}] {type}: {description} (This will be checked by the CI)
    • {modules} include fsdp, megatron, sglang, vllm, rollout, trainer, ci, training_utils, recipe, hardware, deployment, ray, worker, single_controller, misc, perf, model, algo, env, tool, ckpt, doc, data
    • If this PR involves multiple modules, separate them with , like [megatron, fsdp, doc]
    • {type} is in feat, fix, refactor, chore, test
    • If this PR breaks any API (CLI arguments, config, function signature, etc.), add [BREAKING] to the beginning of the title.
    • Example: [BREAKING][fsdp, megatron] feat: dynamic batching

Test

For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc.

API and Usage Example

Demonstrate how the API changes if any, and provide usage example(s) if possible.

To use this feature, hybrid_engine option must be disabled to separate actor model and rollout model into difference GPU cluster. rollout.n_gpus option has been added to configure file to indicate how many GPUs rollout model would be occupied. The script below is an example to train qwen2.5_3b with 8 GPUs.

python3 -m recipe.async.async_main_ppo \
    algorithm.adv_estimator=grpo \
    data.train_files=$HOME/data/gsm8k/train.parquet \
    data.val_files=$HOME/data/gsm8k/test.parquet \
    data.train_batch_size=1024 \
    data.max_prompt_length=512 \
    data.max_response_length=1024 \
    data.filter_overlong_prompts=True \
    data.truncation='error' \
    data.shuffle=False \
    actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \
    actor_rollout_ref.actor.optim.lr=3e-6 \
    actor_rollout_ref.hybrid_engine=False \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.actor.ppo_mini_batch_size=256 \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=40 \
    actor_rollout_ref.actor.use_kl_loss=True \
    actor_rollout_ref.actor.kl_loss_coef=0.001 \
    actor_rollout_ref.actor.kl_loss_type=low_var_kl \
    actor_rollout_ref.actor.entropy_coeff=0 \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.actor.fsdp_config.param_offload=False \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=40 \
    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
    actor_rollout_ref.rollout.name=vllm \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
    actor_rollout_ref.rollout.n=5 \
    actor_rollout_ref.rollout.n_gpus=4 \
    actor_rollout_ref.rollout.load_format=safetensors \
    actor_rollout_ref.rollout.layered_summon=True \
    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=40 \
    actor_rollout_ref.ref.fsdp_config.param_offload=True \
    algorithm.use_kl_in_reward=False \
    trainer.critic_warmup=0 \
    trainer.val_before_train=True \
    trainer.logger=['console','wandb'] \
    trainer.project_name='verl_grpo_example_gsm8k' \
    trainer.experiment_name='qwen2.5_3b_grpo_async_one_step_off' \
    trainer.n_gpus_per_node=8 \
    trainer.nnodes=1 \
    trainer.save_freq=-1 \
    trainer.test_freq=-1 \
    trainer.total_epochs=15 $@

High-Level Design

Demonstrate the high-level design if this PR is complex.

Specific Changes

  1. nccl based weights updating for rollout model.
  2. one step off async trainer.

Checklist Before Submitting

Important

Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.

Copy link
Collaborator

@PeterSH6 PeterSH6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you simplify the code in this file? There's too much redundancy

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you simplify the code in this file? There's too much redundancy

OK, I will try.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@PeterSH6 I've removed some redundant code, but I'm not sure whether it's enought.

@imh966 imh966 force-pushed the recipe/async_training branch from 7a6c847 to 071ddc2 Compare July 1, 2025 09:21
Copy link
Collaborator

@ccclyu ccclyu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great work! Have you tried the testing on multiple nodes and observed some throughput delta?

Copy link
Collaborator

@eric-haibin-lin eric-haibin-lin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the contribution! please add a README.md for the scope of this recipe, for instance, indicating the support status for features available in the original ray trainer such as vlm/multi-turn.
Please also make a copy of the doc to section to docs/advance/ for documentation. Please include the convergence curve in these docs


# Define worker classes based on the actor strategy.
if config.actor_rollout_ref.actor.strategy in ["fsdp", "fsdp2"]:
assert config.critic.strategy in ["fsdp", "fsdp2"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since we're deprecating fsdp, we can limit this recipe to fsdp2 only, and make sure it is tested with fsdp2

if not self.hybrid_engine:
self.actor_wg.sync_rollout_weights()
ray.get(self.rollout_wg.sync_rollout_weights())
# param_ref = self.actor_wg.sync_rollout_weights_v2(None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pls remove unused code

@eric-haibin-lin eric-haibin-lin self-assigned this Jul 4, 2025
@CLAassistant
Copy link

CLAassistant commented Jul 7, 2025

CLA assistant check
All committers have signed the CLA.

@lzxdjb
Copy link

lzxdjb commented Jul 8, 2025

Hello! Thank you so much for implementing asynchronous RLHF framework. I have little questions:
1, Have you tested the time consumption between your awesome implementation with the standard verl?
2, I find that every time before generation, it needs synchronize the weight from actor. I think it is a little redundant. It actually just needs updating parameter once updating the actor? (Although it might mix some staleness batch, but the AReaL prove that using decoupled PPO algorithm can almost keep the training quality)
Looking forward to your reply~

@BounharAbdelaziz
Copy link
Contributor

@lzxdjb I agree. Mistral AI does the same without even recomputing kv cache (in Magistral paper https://arxiv.org/pdf/2506.10910).

if self.config.trainer.profile_steps is not None
else False
)
if do_profile:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the next PR, could you reuse the function _start_profiling and _stop_profiling from the parent class? thanks.
https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/ray_trainer.py#L1042-L1063

@eric-haibin-lin eric-haibin-lin merged commit 503ea75 into volcengine:main Jul 17, 2025
61 of 64 checks passed
@eric-haibin-lin
Copy link
Collaborator

The snapshot of this recipe development branch is pushed to https://github.com/volcengine/verl/tree/recipe/one_step_off_async. Thanks team for the great work!

HelloWorld686 pushed a commit to HelloWorld686/verl that referenced this pull request Jul 17, 2025
volcengine#2231)

### What does this PR do?

This PR provides a simple implementation of one step off async training
with fsdp and vllm backend.

We conducted three different experiments with qwen2.5_3b model on 8 A100
GPUs:
1. baseline:  all models are colocated
2. standalone rollout: rollout model runs on 4 GPUs and other models run
on remaining 4GPUs
3. one step off: the same model placement as the second experiment, but
with one step off async training

The pictures below demonstrate the results of these experiments:
<img
src="https://github.com/user-attachments/assets/1df6af46-2242-48e7-a937-a817b278e644"
width="30%" height="auto"><img
src="https://github.com/user-attachments/assets/bd5c1345-466a-478f-b0d3-95d9a8706496"
width="30%" height="auto"><img
src="https://github.com/user-attachments/assets/4cf76800-6763-4468-8b1f-b8be9d0fef51"
width="30%" height="auto">


In these experiments, baseline has the highest throughput, but we think
it is just because we didn't find the best configure for one step off
async training.

The exciting point is that our nccl based weights updating for rollout
model has great performance. The latency is showed below:
<img
src="https://github.com/user-attachments/assets/388e5736-ef84-4cf0-a586-6543cefb91be"
width="30%" height="auto">
At most of time, the latency is under 300ms, which is negligible for
RLHF. Although it is only implemented with fsdp and vllm now, we think
it is not complex to extend it to the other backend.


### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

To use this feature, `hybrid_engine` option must be disabled to separate
actor model and rollout model into difference GPU cluster.
`rollout.n_gpus` option has been added to configure file to indicate how
many GPUs rollout model would be occupied. The script below is an
example to train `qwen2.5_3b` with 8 GPUs.

```shell

python3 -m recipe.async.async_main_ppo \
    algorithm.adv_estimator=grpo \
    data.train_files=$HOME/data/gsm8k/train.parquet \
    data.val_files=$HOME/data/gsm8k/test.parquet \
    data.train_batch_size=1024 \
    data.max_prompt_length=512 \
    data.max_response_length=1024 \
    data.filter_overlong_prompts=True \
    data.truncation='error' \
    data.shuffle=False \
    actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \
    actor_rollout_ref.actor.optim.lr=3e-6 \
    actor_rollout_ref.hybrid_engine=False \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.actor.ppo_mini_batch_size=256 \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=40 \
    actor_rollout_ref.actor.use_kl_loss=True \
    actor_rollout_ref.actor.kl_loss_coef=0.001 \
    actor_rollout_ref.actor.kl_loss_type=low_var_kl \
    actor_rollout_ref.actor.entropy_coeff=0 \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.actor.fsdp_config.param_offload=False \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=40 \
    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
    actor_rollout_ref.rollout.name=vllm \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
    actor_rollout_ref.rollout.n=5 \
    actor_rollout_ref.rollout.n_gpus=4 \
    actor_rollout_ref.rollout.load_format=safetensors \
    actor_rollout_ref.rollout.layered_summon=True \
    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=40 \
    actor_rollout_ref.ref.fsdp_config.param_offload=True \
    algorithm.use_kl_in_reward=False \
    trainer.critic_warmup=0 \
    trainer.val_before_train=True \
    trainer.logger=['console','wandb'] \
    trainer.project_name='verl_grpo_example_gsm8k' \
    trainer.experiment_name='qwen2.5_3b_grpo_async_one_step_off' \
    trainer.n_gpus_per_node=8 \
    trainer.nnodes=1 \
    trainer.save_freq=-1 \
    trainer.test_freq=-1 \
    trainer.total_epochs=15 $@
```

### High-Level Design

> Demonstrate the high-level design if this PR is complex.

### Specific Changes

1. nccl based weights updating for rollout model.
5. one step off async trainer.


### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).

---------

Co-authored-by: arron <[email protected]>
Co-authored-by: lalala-2 <[email protected]>
Co-authored-by: openhands <[email protected]>
yellowbee686 pushed a commit to yellowbee686/verl that referenced this pull request Jul 25, 2025
volcengine#2231)

### What does this PR do?

This PR provides a simple implementation of one step off async training
with fsdp and vllm backend.

We conducted three different experiments with qwen2.5_3b model on 8 A100
GPUs:
1. baseline:  all models are colocated
2. standalone rollout: rollout model runs on 4 GPUs and other models run
on remaining 4GPUs
3. one step off: the same model placement as the second experiment, but
with one step off async training

The pictures below demonstrate the results of these experiments:
<img
src="https://github.com/user-attachments/assets/1df6af46-2242-48e7-a937-a817b278e644"
width="30%" height="auto"><img
src="https://github.com/user-attachments/assets/bd5c1345-466a-478f-b0d3-95d9a8706496"
width="30%" height="auto"><img
src="https://github.com/user-attachments/assets/4cf76800-6763-4468-8b1f-b8be9d0fef51"
width="30%" height="auto">


In these experiments, baseline has the highest throughput, but we think
it is just because we didn't find the best configure for one step off
async training.

The exciting point is that our nccl based weights updating for rollout
model has great performance. The latency is showed below:
<img
src="https://github.com/user-attachments/assets/388e5736-ef84-4cf0-a586-6543cefb91be"
width="30%" height="auto">
At most of time, the latency is under 300ms, which is negligible for
RLHF. Although it is only implemented with fsdp and vllm now, we think
it is not complex to extend it to the other backend.


### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

To use this feature, `hybrid_engine` option must be disabled to separate
actor model and rollout model into difference GPU cluster.
`rollout.n_gpus` option has been added to configure file to indicate how
many GPUs rollout model would be occupied. The script below is an
example to train `qwen2.5_3b` with 8 GPUs.

```shell

python3 -m recipe.async.async_main_ppo \
    algorithm.adv_estimator=grpo \
    data.train_files=$HOME/data/gsm8k/train.parquet \
    data.val_files=$HOME/data/gsm8k/test.parquet \
    data.train_batch_size=1024 \
    data.max_prompt_length=512 \
    data.max_response_length=1024 \
    data.filter_overlong_prompts=True \
    data.truncation='error' \
    data.shuffle=False \
    actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \
    actor_rollout_ref.actor.optim.lr=3e-6 \
    actor_rollout_ref.hybrid_engine=False \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.actor.ppo_mini_batch_size=256 \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=40 \
    actor_rollout_ref.actor.use_kl_loss=True \
    actor_rollout_ref.actor.kl_loss_coef=0.001 \
    actor_rollout_ref.actor.kl_loss_type=low_var_kl \
    actor_rollout_ref.actor.entropy_coeff=0 \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.actor.fsdp_config.param_offload=False \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=40 \
    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
    actor_rollout_ref.rollout.name=vllm \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
    actor_rollout_ref.rollout.n=5 \
    actor_rollout_ref.rollout.n_gpus=4 \
    actor_rollout_ref.rollout.load_format=safetensors \
    actor_rollout_ref.rollout.layered_summon=True \
    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=40 \
    actor_rollout_ref.ref.fsdp_config.param_offload=True \
    algorithm.use_kl_in_reward=False \
    trainer.critic_warmup=0 \
    trainer.val_before_train=True \
    trainer.logger=['console','wandb'] \
    trainer.project_name='verl_grpo_example_gsm8k' \
    trainer.experiment_name='qwen2.5_3b_grpo_async_one_step_off' \
    trainer.n_gpus_per_node=8 \
    trainer.nnodes=1 \
    trainer.save_freq=-1 \
    trainer.test_freq=-1 \
    trainer.total_epochs=15 $@
```

### High-Level Design

> Demonstrate the high-level design if this PR is complex.

### Specific Changes

1. nccl based weights updating for rollout model.
5. one step off async trainer.


### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).

---------

Co-authored-by: arron <[email protected]>
Co-authored-by: lalala-2 <[email protected]>
Co-authored-by: openhands <[email protected]>
oseyosey pushed a commit to oseyosey/verl that referenced this pull request Jul 28, 2025
volcengine#2231)

### What does this PR do?

This PR provides a simple implementation of one step off async training
with fsdp and vllm backend.

We conducted three different experiments with qwen2.5_3b model on 8 A100
GPUs:
1. baseline:  all models are colocated
2. standalone rollout: rollout model runs on 4 GPUs and other models run
on remaining 4GPUs
3. one step off: the same model placement as the second experiment, but
with one step off async training

The pictures below demonstrate the results of these experiments:
<img
src="https://github.com/user-attachments/assets/1df6af46-2242-48e7-a937-a817b278e644"
width="30%" height="auto"><img
src="https://github.com/user-attachments/assets/bd5c1345-466a-478f-b0d3-95d9a8706496"
width="30%" height="auto"><img
src="https://github.com/user-attachments/assets/4cf76800-6763-4468-8b1f-b8be9d0fef51"
width="30%" height="auto">


In these experiments, baseline has the highest throughput, but we think
it is just because we didn't find the best configure for one step off
async training.

The exciting point is that our nccl based weights updating for rollout
model has great performance. The latency is showed below:
<img
src="https://github.com/user-attachments/assets/388e5736-ef84-4cf0-a586-6543cefb91be"
width="30%" height="auto">
At most of time, the latency is under 300ms, which is negligible for
RLHF. Although it is only implemented with fsdp and vllm now, we think
it is not complex to extend it to the other backend.


### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

To use this feature, `hybrid_engine` option must be disabled to separate
actor model and rollout model into difference GPU cluster.
`rollout.n_gpus` option has been added to configure file to indicate how
many GPUs rollout model would be occupied. The script below is an
example to train `qwen2.5_3b` with 8 GPUs.

```shell

python3 -m recipe.async.async_main_ppo \
    algorithm.adv_estimator=grpo \
    data.train_files=$HOME/data/gsm8k/train.parquet \
    data.val_files=$HOME/data/gsm8k/test.parquet \
    data.train_batch_size=1024 \
    data.max_prompt_length=512 \
    data.max_response_length=1024 \
    data.filter_overlong_prompts=True \
    data.truncation='error' \
    data.shuffle=False \
    actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \
    actor_rollout_ref.actor.optim.lr=3e-6 \
    actor_rollout_ref.hybrid_engine=False \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.actor.ppo_mini_batch_size=256 \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=40 \
    actor_rollout_ref.actor.use_kl_loss=True \
    actor_rollout_ref.actor.kl_loss_coef=0.001 \
    actor_rollout_ref.actor.kl_loss_type=low_var_kl \
    actor_rollout_ref.actor.entropy_coeff=0 \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.actor.fsdp_config.param_offload=False \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=40 \
    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
    actor_rollout_ref.rollout.name=vllm \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
    actor_rollout_ref.rollout.n=5 \
    actor_rollout_ref.rollout.n_gpus=4 \
    actor_rollout_ref.rollout.load_format=safetensors \
    actor_rollout_ref.rollout.layered_summon=True \
    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=40 \
    actor_rollout_ref.ref.fsdp_config.param_offload=True \
    algorithm.use_kl_in_reward=False \
    trainer.critic_warmup=0 \
    trainer.val_before_train=True \
    trainer.logger=['console','wandb'] \
    trainer.project_name='verl_grpo_example_gsm8k' \
    trainer.experiment_name='qwen2.5_3b_grpo_async_one_step_off' \
    trainer.n_gpus_per_node=8 \
    trainer.nnodes=1 \
    trainer.save_freq=-1 \
    trainer.test_freq=-1 \
    trainer.total_epochs=15 $@
```

### High-Level Design

> Demonstrate the high-level design if this PR is complex.

### Specific Changes

1. nccl based weights updating for rollout model.
5. one step off async trainer.


### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).

---------

Co-authored-by: arron <[email protected]>
Co-authored-by: lalala-2 <[email protected]>
Co-authored-by: openhands <[email protected]>
Juniper1021 pushed a commit to Juniper1021/verl that referenced this pull request Aug 7, 2025
volcengine#2231)

### What does this PR do?

This PR provides a simple implementation of one step off async training
with fsdp and vllm backend.

We conducted three different experiments with qwen2.5_3b model on 8 A100
GPUs:
1. baseline:  all models are colocated
2. standalone rollout: rollout model runs on 4 GPUs and other models run
on remaining 4GPUs
3. one step off: the same model placement as the second experiment, but
with one step off async training

The pictures below demonstrate the results of these experiments:
<img
src="https://github.com/user-attachments/assets/1df6af46-2242-48e7-a937-a817b278e644"
width="30%" height="auto"><img
src="https://github.com/user-attachments/assets/bd5c1345-466a-478f-b0d3-95d9a8706496"
width="30%" height="auto"><img
src="https://github.com/user-attachments/assets/4cf76800-6763-4468-8b1f-b8be9d0fef51"
width="30%" height="auto">


In these experiments, baseline has the highest throughput, but we think
it is just because we didn't find the best configure for one step off
async training.

The exciting point is that our nccl based weights updating for rollout
model has great performance. The latency is showed below:
<img
src="https://github.com/user-attachments/assets/388e5736-ef84-4cf0-a586-6543cefb91be"
width="30%" height="auto">
At most of time, the latency is under 300ms, which is negligible for
RLHF. Although it is only implemented with fsdp and vllm now, we think
it is not complex to extend it to the other backend.


### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

To use this feature, `hybrid_engine` option must be disabled to separate
actor model and rollout model into difference GPU cluster.
`rollout.n_gpus` option has been added to configure file to indicate how
many GPUs rollout model would be occupied. The script below is an
example to train `qwen2.5_3b` with 8 GPUs.

```shell

python3 -m recipe.async.async_main_ppo \
    algorithm.adv_estimator=grpo \
    data.train_files=$HOME/data/gsm8k/train.parquet \
    data.val_files=$HOME/data/gsm8k/test.parquet \
    data.train_batch_size=1024 \
    data.max_prompt_length=512 \
    data.max_response_length=1024 \
    data.filter_overlong_prompts=True \
    data.truncation='error' \
    data.shuffle=False \
    actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \
    actor_rollout_ref.actor.optim.lr=3e-6 \
    actor_rollout_ref.hybrid_engine=False \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.actor.ppo_mini_batch_size=256 \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=40 \
    actor_rollout_ref.actor.use_kl_loss=True \
    actor_rollout_ref.actor.kl_loss_coef=0.001 \
    actor_rollout_ref.actor.kl_loss_type=low_var_kl \
    actor_rollout_ref.actor.entropy_coeff=0 \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.actor.fsdp_config.param_offload=False \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=40 \
    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
    actor_rollout_ref.rollout.name=vllm \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
    actor_rollout_ref.rollout.n=5 \
    actor_rollout_ref.rollout.n_gpus=4 \
    actor_rollout_ref.rollout.load_format=safetensors \
    actor_rollout_ref.rollout.layered_summon=True \
    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=40 \
    actor_rollout_ref.ref.fsdp_config.param_offload=True \
    algorithm.use_kl_in_reward=False \
    trainer.critic_warmup=0 \
    trainer.val_before_train=True \
    trainer.logger=['console','wandb'] \
    trainer.project_name='verl_grpo_example_gsm8k' \
    trainer.experiment_name='qwen2.5_3b_grpo_async_one_step_off' \
    trainer.n_gpus_per_node=8 \
    trainer.nnodes=1 \
    trainer.save_freq=-1 \
    trainer.test_freq=-1 \
    trainer.total_epochs=15 $@
```

### High-Level Design

> Demonstrate the high-level design if this PR is complex.

### Specific Changes

1. nccl based weights updating for rollout model.
5. one step off async trainer.


### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).

---------

Co-authored-by: arron <[email protected]>
Co-authored-by: lalala-2 <[email protected]>
Co-authored-by: openhands <[email protected]>
sguo35 pushed a commit to safety-research/verl that referenced this pull request Aug 11, 2025
volcengine#2231)

This PR provides a simple implementation of one step off async training
with fsdp and vllm backend.

We conducted three different experiments with qwen2.5_3b model on 8 A100
GPUs:
1. baseline:  all models are colocated
2. standalone rollout: rollout model runs on 4 GPUs and other models run
on remaining 4GPUs
3. one step off: the same model placement as the second experiment, but
with one step off async training

The pictures below demonstrate the results of these experiments:
<img
src="https://github.com/user-attachments/assets/1df6af46-2242-48e7-a937-a817b278e644"
width="30%" height="auto"><img
src="https://github.com/user-attachments/assets/bd5c1345-466a-478f-b0d3-95d9a8706496"
width="30%" height="auto"><img
src="https://github.com/user-attachments/assets/4cf76800-6763-4468-8b1f-b8be9d0fef51"
width="30%" height="auto">

In these experiments, baseline has the highest throughput, but we think
it is just because we didn't find the best configure for one step off
async training.

The exciting point is that our nccl based weights updating for rollout
model has great performance. The latency is showed below:
<img
src="https://github.com/user-attachments/assets/388e5736-ef84-4cf0-a586-6543cefb91be"
width="30%" height="auto">
At most of time, the latency is under 300ms, which is negligible for
RLHF. Although it is only implemented with fsdp and vllm now, we think
it is not complex to extend it to the other backend.

- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

To use this feature, `hybrid_engine` option must be disabled to separate
actor model and rollout model into difference GPU cluster.
`rollout.n_gpus` option has been added to configure file to indicate how
many GPUs rollout model would be occupied. The script below is an
example to train `qwen2.5_3b` with 8 GPUs.

```shell

python3 -m recipe.async.async_main_ppo \
    algorithm.adv_estimator=grpo \
    data.train_files=$HOME/data/gsm8k/train.parquet \
    data.val_files=$HOME/data/gsm8k/test.parquet \
    data.train_batch_size=1024 \
    data.max_prompt_length=512 \
    data.max_response_length=1024 \
    data.filter_overlong_prompts=True \
    data.truncation='error' \
    data.shuffle=False \
    actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \
    actor_rollout_ref.actor.optim.lr=3e-6 \
    actor_rollout_ref.hybrid_engine=False \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.actor.ppo_mini_batch_size=256 \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=40 \
    actor_rollout_ref.actor.use_kl_loss=True \
    actor_rollout_ref.actor.kl_loss_coef=0.001 \
    actor_rollout_ref.actor.kl_loss_type=low_var_kl \
    actor_rollout_ref.actor.entropy_coeff=0 \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.actor.fsdp_config.param_offload=False \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=40 \
    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
    actor_rollout_ref.rollout.name=vllm \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
    actor_rollout_ref.rollout.n=5 \
    actor_rollout_ref.rollout.n_gpus=4 \
    actor_rollout_ref.rollout.load_format=safetensors \
    actor_rollout_ref.rollout.layered_summon=True \
    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=40 \
    actor_rollout_ref.ref.fsdp_config.param_offload=True \
    algorithm.use_kl_in_reward=False \
    trainer.critic_warmup=0 \
    trainer.val_before_train=True \
    trainer.logger=['console','wandb'] \
    trainer.project_name='verl_grpo_example_gsm8k' \
    trainer.experiment_name='qwen2.5_3b_grpo_async_one_step_off' \
    trainer.n_gpus_per_node=8 \
    trainer.nnodes=1 \
    trainer.save_freq=-1 \
    trainer.test_freq=-1 \
    trainer.total_epochs=15 $@
```

> Demonstrate the high-level design if this PR is complex.

1. nccl based weights updating for rollout model.
5. one step off async trainer.

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).

---------

Co-authored-by: arron <[email protected]>
Co-authored-by: lalala-2 <[email protected]>
Co-authored-by: openhands <[email protected]>
whatadayG pushed a commit to whatadayG/verl that referenced this pull request Sep 5, 2025
volcengine#2231)

### What does this PR do?

This PR provides a simple implementation of one step off async training
with fsdp and vllm backend.

We conducted three different experiments with qwen2.5_3b model on 8 A100
GPUs:
1. baseline:  all models are colocated
2. standalone rollout: rollout model runs on 4 GPUs and other models run
on remaining 4GPUs
3. one step off: the same model placement as the second experiment, but
with one step off async training

The pictures below demonstrate the results of these experiments:
<img
src="https://github.com/user-attachments/assets/1df6af46-2242-48e7-a937-a817b278e644"
width="30%" height="auto"><img
src="https://github.com/user-attachments/assets/bd5c1345-466a-478f-b0d3-95d9a8706496"
width="30%" height="auto"><img
src="https://github.com/user-attachments/assets/4cf76800-6763-4468-8b1f-b8be9d0fef51"
width="30%" height="auto">


In these experiments, baseline has the highest throughput, but we think
it is just because we didn't find the best configure for one step off
async training.

The exciting point is that our nccl based weights updating for rollout
model has great performance. The latency is showed below:
<img
src="https://github.com/user-attachments/assets/388e5736-ef84-4cf0-a586-6543cefb91be"
width="30%" height="auto">
At most of time, the latency is under 300ms, which is negligible for
RLHF. Although it is only implemented with fsdp and vllm now, we think
it is not complex to extend it to the other backend.


### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

To use this feature, `hybrid_engine` option must be disabled to separate
actor model and rollout model into difference GPU cluster.
`rollout.n_gpus` option has been added to configure file to indicate how
many GPUs rollout model would be occupied. The script below is an
example to train `qwen2.5_3b` with 8 GPUs.

```shell

python3 -m recipe.async.async_main_ppo \
    algorithm.adv_estimator=grpo \
    data.train_files=$HOME/data/gsm8k/train.parquet \
    data.val_files=$HOME/data/gsm8k/test.parquet \
    data.train_batch_size=1024 \
    data.max_prompt_length=512 \
    data.max_response_length=1024 \
    data.filter_overlong_prompts=True \
    data.truncation='error' \
    data.shuffle=False \
    actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \
    actor_rollout_ref.actor.optim.lr=3e-6 \
    actor_rollout_ref.hybrid_engine=False \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.actor.ppo_mini_batch_size=256 \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=40 \
    actor_rollout_ref.actor.use_kl_loss=True \
    actor_rollout_ref.actor.kl_loss_coef=0.001 \
    actor_rollout_ref.actor.kl_loss_type=low_var_kl \
    actor_rollout_ref.actor.entropy_coeff=0 \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.actor.fsdp_config.param_offload=False \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=40 \
    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
    actor_rollout_ref.rollout.name=vllm \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
    actor_rollout_ref.rollout.n=5 \
    actor_rollout_ref.rollout.n_gpus=4 \
    actor_rollout_ref.rollout.load_format=safetensors \
    actor_rollout_ref.rollout.layered_summon=True \
    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=40 \
    actor_rollout_ref.ref.fsdp_config.param_offload=True \
    algorithm.use_kl_in_reward=False \
    trainer.critic_warmup=0 \
    trainer.val_before_train=True \
    trainer.logger=['console','wandb'] \
    trainer.project_name='verl_grpo_example_gsm8k' \
    trainer.experiment_name='qwen2.5_3b_grpo_async_one_step_off' \
    trainer.n_gpus_per_node=8 \
    trainer.nnodes=1 \
    trainer.save_freq=-1 \
    trainer.test_freq=-1 \
    trainer.total_epochs=15 $@
```

### High-Level Design

> Demonstrate the high-level design if this PR is complex.

### Specific Changes

1. nccl based weights updating for rollout model.
5. one step off async trainer.


### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).

---------

Co-authored-by: arron <[email protected]>
Co-authored-by: lalala-2 <[email protected]>
Co-authored-by: openhands <[email protected]>
vermouth1992 pushed a commit that referenced this pull request Oct 17, 2025
### What does this PR do?

To implement a purely asynchronous training workflow, we further split
the training process into a Trainer and a Rollouter based on the
existing one-step-off policy code, with samples transmitted via a
message queue.

We will continue to integrate partial rollout to mitigate the impact of
long-tail training.

> Add **concise** overview of what this PR aims to achieve or
accomplish. Reference related GitHub issues and PRs that help with the
review. #2231
#2200

### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [x] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [x] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [x] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)

---------

Co-authored-by: meituan-search <[email protected]>
Co-authored-by: wangshulin02 <[email protected]>
Co-authored-by: arron <[email protected]>
Co-authored-by: wangshulin02 <[email protected]>
Co-authored-by: hadoop-ai-search <[email protected]>
Co-authored-by: sl-1314 <[email protected]>
Co-authored-by: arron <[email protected]>
Co-authored-by: arron <[email protected]>
masoudhashemi pushed a commit to masoudhashemi/verl that referenced this pull request Oct 19, 2025
### What does this PR do?

To implement a purely asynchronous training workflow, we further split
the training process into a Trainer and a Rollouter based on the
existing one-step-off policy code, with samples transmitted via a
message queue.

We will continue to integrate partial rollout to mitigate the impact of
long-tail training.

> Add **concise** overview of what this PR aims to achieve or
accomplish. Reference related GitHub issues and PRs that help with the
review. volcengine#2231
volcengine#2200

### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [x] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [x] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [x] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)

---------

Co-authored-by: meituan-search <[email protected]>
Co-authored-by: wangshulin02 <[email protected]>
Co-authored-by: arron <[email protected]>
Co-authored-by: wangshulin02 <[email protected]>
Co-authored-by: hadoop-ai-search <[email protected]>
Co-authored-by: sl-1314 <[email protected]>
Co-authored-by: arron <[email protected]>
Co-authored-by: arron <[email protected]>
wangboxiong320 pushed a commit to wangboxiong320/verl that referenced this pull request Nov 1, 2025
### What does this PR do?

To implement a purely asynchronous training workflow, we further split
the training process into a Trainer and a Rollouter based on the
existing one-step-off policy code, with samples transmitted via a
message queue.

We will continue to integrate partial rollout to mitigate the impact of
long-tail training.

> Add **concise** overview of what this PR aims to achieve or
accomplish. Reference related GitHub issues and PRs that help with the
review. volcengine#2231
volcengine#2200

### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [x] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [x] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [x] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)

---------

Co-authored-by: meituan-search <[email protected]>
Co-authored-by: wangshulin02 <[email protected]>
Co-authored-by: arron <[email protected]>
Co-authored-by: wangshulin02 <[email protected]>
Co-authored-by: hadoop-ai-search <[email protected]>
Co-authored-by: sl-1314 <[email protected]>
Co-authored-by: arron <[email protected]>
Co-authored-by: arron <[email protected]>
NenoL2001 pushed a commit to NenoL2001/verl that referenced this pull request Nov 3, 2025
### What does this PR do?

To implement a purely asynchronous training workflow, we further split
the training process into a Trainer and a Rollouter based on the
existing one-step-off policy code, with samples transmitted via a
message queue.

We will continue to integrate partial rollout to mitigate the impact of
long-tail training.

> Add **concise** overview of what this PR aims to achieve or
accomplish. Reference related GitHub issues and PRs that help with the
review. volcengine#2231
volcengine#2200

### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [x] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [x] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [x] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)

---------

Co-authored-by: meituan-search <[email protected]>
Co-authored-by: wangshulin02 <[email protected]>
Co-authored-by: arron <[email protected]>
Co-authored-by: wangshulin02 <[email protected]>
Co-authored-by: hadoop-ai-search <[email protected]>
Co-authored-by: sl-1314 <[email protected]>
Co-authored-by: arron <[email protected]>
Co-authored-by: arron <[email protected]>
AlexJJ009 pushed a commit to AlexJJ009/verl that referenced this pull request Nov 5, 2025
### What does this PR do?

To implement a purely asynchronous training workflow, we further split
the training process into a Trainer and a Rollouter based on the
existing one-step-off policy code, with samples transmitted via a
message queue.

We will continue to integrate partial rollout to mitigate the impact of
long-tail training.

> Add **concise** overview of what this PR aims to achieve or
accomplish. Reference related GitHub issues and PRs that help with the
review. volcengine#2231
volcengine#2200

### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [x] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [x] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [x] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)

---------

Co-authored-by: meituan-search <[email protected]>
Co-authored-by: wangshulin02 <[email protected]>
Co-authored-by: arron <[email protected]>
Co-authored-by: wangshulin02 <[email protected]>
Co-authored-by: hadoop-ai-search <[email protected]>
Co-authored-by: sl-1314 <[email protected]>
Co-authored-by: arron <[email protected]>
Co-authored-by: arron <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants