Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
391a1fc
one step off async training recipe
imh966 Jun 27, 2025
338c2a9
simplify trainer
imh966 Jun 30, 2025
071ddc2
fix resource pool config and simplify the trainer yaml file
imh966 Jul 1, 2025
71569f5
separate actor and rollout class
imh966 Jul 3, 2025
78ef6f2
update name of recipe and add license
imh966 Jul 3, 2025
e274747
Merge branch 'volcengine:main' into recipe/async_training
ArronHZG Jul 7, 2025
b9a9618
one_step_off_policy megatron
ArronHZG Jul 7, 2025
8dc0034
use fsdp2 and clear useless code
lalala-2 Jul 7, 2025
5ea1c00
fix config
lalala-2 Jul 7, 2025
69d58c4
fix
lalala-2 Jul 7, 2025
6cdaf2e
one_step_off_policy dapo_7b 2 node
ArronHZG Jul 7, 2025
36ed4f6
recipe/one_step_off_policy
ArronHZG Jul 8, 2025
a1966ef
opt gen_next_batch
lalala-2 Jul 8, 2025
40df88f
Merge branch 'recipe/async_training_megatron' of https://github.com/i…
lalala-2 Jul 8, 2025
5d52efa
4_12_megatron
ArronHZG Jul 8, 2025
59f6be9
4_12_megatron
ArronHZG Jul 8, 2025
dfabe15
megatron config
ArronHZG Jul 8, 2025
40e8816
megatron config
ArronHZG Jul 8, 2025
fc76d4f
fix megatron
lalala-2 Jul 8, 2025
dedc436
Merge branch 'recipe/async_training_megatron' of https://github.com/i…
lalala-2 Jul 8, 2025
344581f
megatron config
ArronHZG Jul 8, 2025
0091f52
megatron config
ArronHZG Jul 9, 2025
283f7fd
megatron config
ArronHZG Jul 9, 2025
6871a29
cross epoch
ArronHZG Jul 9, 2025
1b96322
ruff format
ArronHZG Jul 9, 2025
652f91f
# Copyright 2025 Meituan Ltd. and/or its affiliates
ArronHZG Jul 9, 2025
b36918c
add Copyright
ArronHZG Jul 9, 2025
84b712d
optim sh
ArronHZG Jul 9, 2025
4685463
python3
ArronHZG Jul 9, 2025
7f3d1db
update recipe
ArronHZG Jul 10, 2025
592f393
add doc
ArronHZG Jul 10, 2025
2fb1cd9
Merge branch 'volcengine:main' into recipe/async_training
ArronHZG Jul 10, 2025
dff8f56
update date
ArronHZG Jul 11, 2025
648cb44
update date
ArronHZG Jul 11, 2025
c2395f7
config
ArronHZG Jul 11, 2025
165c1b2
Revert "fix config"
lalala-2 Jul 11, 2025
aaa356e
fix error
lalala-2 Jul 11, 2025
03f1dec
update is_last_step
ArronHZG Jul 11, 2025
e2007ef
one_step_off_policy
ArronHZG Jul 11, 2025
204d624
update readme
ArronHZG Jul 14, 2025
19fac39
e2e_one_step_off_policy
ArronHZG Jul 14, 2025
c1b86ec
add e2e test for one_step_off_policy
ArronHZG Jul 14, 2025
492ff98
add e2e test for one_step_off_policy
ArronHZG Jul 14, 2025
1e7aa47
add e2e test for one_step_off_policy
ArronHZG Jul 14, 2025
8ab0834
add e2e test for one_step_off_policy
ArronHZG Jul 14, 2025
22dc212
format
ArronHZG Jul 14, 2025
dcbfb0c
ruff check
ArronHZG Jul 14, 2025
1e8cee3
add megatron test
ArronHZG Jul 14, 2025
27c9816
Merge pull request #2 from imh966/recipe/async_training_e2e_test
ArronHZG Jul 14, 2025
727320b
Merge branch 'volcengine:main' into recipe/async_training
ArronHZG Jul 14, 2025
8727916
rm spmd
ArronHZG Jul 14, 2025
42ddeed
CI check fix some error
ArronHZG Jul 14, 2025
5ffd8b4
merge main
ArronHZG Jul 15, 2025
1c9b6eb
change author
ArronHZG Jul 15, 2025
8772b14
update e2e_one_step_off_policy CI rule
ArronHZG Jul 15, 2025
c8468e6
update comments
ArronHZG Jul 15, 2025
d8dd8b0
Merge branch 'volcengine:main' into recipe/async_training
ArronHZG Jul 15, 2025
659b108
update ruff
ArronHZG Jul 15, 2025
9b5646a
Fix pre-commit error: sort imports in async_main_ppo.py
openhands-agent Jul 15, 2025
1ed49c7
rollout.nnodes
ArronHZG Jul 16, 2025
754cfae
update code and doc by comments
ArronHZG Jul 16, 2025
8df1c1b
ruff
ArronHZG Jul 16, 2025
1837fc7
update code and doc by comments
ArronHZG Jul 16, 2025
c56467f
update docs
ArronHZG Jul 16, 2025
174d94a
Merge branch 'recipe/async_training' of https://github.com/imh966/ver…
ArronHZG Jul 16, 2025
e3db358
Merge branch 'recipe/async_training' into recipe/async_training_rollo…
ArronHZG Jul 16, 2025
8e5b714
Merge pull request #3 from imh966/recipe/async_training_rollout_nodes
ArronHZG Jul 16, 2025
40b2ebe
Merge branch 'volcengine:main' into recipe/async_training
ArronHZG Jul 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
update name of recipe and add license
  • Loading branch information
imh966 committed Jul 3, 2025
commit 78ef6f2aa40a7ac2b445266054905e7233a86527
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright 2025 Bytedance Ltd. and/or its affiliates
# Copyright 2025 Meituan Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2025 Meituan Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023-2024 SGLang Team
# Copyright 2025 ModelBest Inc. and/or its affiliates
# Copyright 2025 Meituan Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -40,7 +41,15 @@
compute_throughout_metrics,
compute_timing_metrics,
)
from verl.trainer.ppo.ray_trainer import RayPPOTrainer, ResourcePoolManager, Role, WorkerType, apply_kl_penalty, compute_advantage, compute_response_mask
from verl.trainer.ppo.ray_trainer import (
RayPPOTrainer,
ResourcePoolManager,
Role,
WorkerType,
apply_kl_penalty,
compute_advantage,
compute_response_mask,
)
from verl.trainer.ppo.reward import compute_reward, compute_reward_async
from verl.utils.debug import marked_timer
from verl.utils.metric import (
Expand Down Expand Up @@ -137,20 +146,32 @@ def _validate_config(self):
config = self.config
# number of GPUs total
n_gpus = config.trainer.n_gpus_per_node * config.trainer.nnodes
n_gpus_actor = n_gpus if config.actor_rollout_ref.hybrid_engine else n_gpus - config.actor_rollout_ref.rollout.n_gpus
n_gpus_actor = (
n_gpus if config.actor_rollout_ref.hybrid_engine else n_gpus - config.actor_rollout_ref.rollout.n_gpus
)
if config.actor_rollout_ref.actor.strategy == "megatron":
model_parallel_size = config.actor_rollout_ref.actor.megatron.tensor_model_parallel_size * config.actor_rollout_ref.actor.megatron.pipeline_model_parallel_size
assert n_gpus_actor % (model_parallel_size * config.actor_rollout_ref.actor.megatron.context_parallel_size) == 0, (
model_parallel_size = (
config.actor_rollout_ref.actor.megatron.tensor_model_parallel_size
* config.actor_rollout_ref.actor.megatron.pipeline_model_parallel_size
)
assert (
n_gpus_actor % (model_parallel_size * config.actor_rollout_ref.actor.megatron.context_parallel_size)
== 0
), (
f"n_gpus_actor ({n_gpus_actor}) must be divisible by model_parallel_size ({model_parallel_size}) times context_parallel_size ({config.actor_rollout_ref.actor.megatron.context_parallel_size})"
)
megatron_dp = n_gpus_actor // (model_parallel_size * config.actor_rollout_ref.actor.megatron.context_parallel_size)
megatron_dp = n_gpus_actor // (
model_parallel_size * config.actor_rollout_ref.actor.megatron.context_parallel_size
)
minimal_bsz = megatron_dp * config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu
else:
minimal_bsz = n_gpus_actor

# 1. Check total batch size for data correctness
real_train_batch_size = config.data.train_batch_size * config.actor_rollout_ref.rollout.n
assert real_train_batch_size % minimal_bsz == 0, f"real_train_batch_size ({real_train_batch_size}) must be divisible by minimal possible batch size ({minimal_bsz})"
assert real_train_batch_size % minimal_bsz == 0, (
f"real_train_batch_size ({real_train_batch_size}) must be divisible by minimal possible batch size ({minimal_bsz})"
)

# A helper function to check "micro_batch_size" vs "micro_batch_size_per_gpu"
# We throw an error if the user sets both. The new convention is "..._micro_batch_size_per_gpu".
Expand All @@ -168,10 +189,15 @@ def check_mutually_exclusive(mbs, mbs_per_gpu, name: str):
param_per_gpu = f"{param}_per_gpu"

if mbs is None and mbs_per_gpu is None:
raise ValueError(f"[{name}] Please set at least one of '{name}.{param}' or '{name}.{param_per_gpu}'.")
raise ValueError(
f"[{name}] Please set at least one of '{name}.{param}' or '{name}.{param_per_gpu}'."
)

if mbs is not None and mbs_per_gpu is not None:
raise ValueError(f"[{name}] You have set both '{name}.{param}' AND '{name}.{param_per_gpu}'. Please remove '{name}.{param}' because only '*_{param_per_gpu}'" + "is supported (the former is deprecated).")
raise ValueError(
f"[{name}] You have set both '{name}.{param}' AND '{name}.{param_per_gpu}'. Please remove '{name}.{param}' because only '*_{param_per_gpu}'"
+ "is supported (the former is deprecated)."
)

if not config.actor_rollout_ref.actor.use_dynamic_bsz:
# actor: ppo_micro_batch_size vs. ppo_micro_batch_size_per_gpu
Expand All @@ -198,11 +224,15 @@ def check_mutually_exclusive(mbs, mbs_per_gpu, name: str):

if self.use_critic and not config.critic.use_dynamic_bsz:
# Check for critic micro-batch size conflicts
check_mutually_exclusive(config.critic.ppo_micro_batch_size, config.critic.ppo_micro_batch_size_per_gpu, "critic")
check_mutually_exclusive(
config.critic.ppo_micro_batch_size, config.critic.ppo_micro_batch_size_per_gpu, "critic"
)

# Check for reward model micro-batch size conflicts
if config.reward_model.enable and not config.reward_model.use_dynamic_bsz:
check_mutually_exclusive(config.reward_model.micro_batch_size, config.reward_model.micro_batch_size_per_gpu, "reward_model")
check_mutually_exclusive(
config.reward_model.micro_batch_size, config.reward_model.micro_batch_size_per_gpu, "reward_model"
)

# Actor
# check if train_batch_size is larger than ppo_mini_batch_size
Expand All @@ -213,7 +243,11 @@ def check_mutually_exclusive(mbs, mbs_per_gpu, name: str):
assert config.data.train_batch_size >= config.actor_rollout_ref.actor.ppo_mini_batch_size
sp_size = config.actor_rollout_ref.actor.get("ulysses_sequence_parallel_size", 1)
if config.actor_rollout_ref.actor.ppo_micro_batch_size is not None:
assert config.actor_rollout_ref.actor.ppo_mini_batch_size % config.actor_rollout_ref.actor.ppo_micro_batch_size == 0
assert (
config.actor_rollout_ref.actor.ppo_mini_batch_size
% config.actor_rollout_ref.actor.ppo_micro_batch_size
== 0
)
assert config.actor_rollout_ref.actor.ppo_micro_batch_size * sp_size >= n_gpus_actor

assert config.actor_rollout_ref.actor.loss_agg_mode in [
Expand All @@ -235,24 +269,44 @@ def check_mutually_exclusive(mbs, mbs_per_gpu, name: str):
assert config.critic.ppo_micro_batch_size * sp_size >= n_gpus_actor

# Check if use_remove_padding is enabled when using sequence parallelism for fsdp
if config.actor_rollout_ref.actor.strategy == "fsdp" and (config.actor_rollout_ref.actor.get("ulysses_sequence_parallel_size", 1) > 1 or config.actor_rollout_ref.ref.get("ulysses_sequence_parallel_size", 1) > 1):
assert config.actor_rollout_ref.model.use_remove_padding, "When using sequence parallelism for actor/ref policy, you must enable `use_remove_padding`."
if config.actor_rollout_ref.actor.strategy == "fsdp" and (
config.actor_rollout_ref.actor.get("ulysses_sequence_parallel_size", 1) > 1
or config.actor_rollout_ref.ref.get("ulysses_sequence_parallel_size", 1) > 1
):
assert config.actor_rollout_ref.model.use_remove_padding, (
"When using sequence parallelism for actor/ref policy, you must enable `use_remove_padding`."
)

if self.use_critic and config.critic.strategy == "fsdp":
if config.critic.get("ulysses_sequence_parallel_size", 1) > 1:
assert config.critic.model.use_remove_padding, "When using sequence parallelism for critic, you must enable `use_remove_padding`."
assert config.critic.model.use_remove_padding, (
"When using sequence parallelism for critic, you must enable `use_remove_padding`."
)

if config.data.get("val_batch_size", None) is not None:
print("WARNING: val_batch_size is deprecated." + " Validation datasets are sent to inference engines as a whole batch," + " which will schedule the memory themselves.")
print(
"WARNING: val_batch_size is deprecated."
+ " Validation datasets are sent to inference engines as a whole batch,"
+ " which will schedule the memory themselves."
)

# check eval config
if config.actor_rollout_ref.rollout.val_kwargs.do_sample:
assert config.actor_rollout_ref.rollout.temperature > 0, "validation gen temperature should be greater than 0 when enabling do_sample"
assert config.actor_rollout_ref.rollout.temperature > 0, (
"validation gen temperature should be greater than 0 when enabling do_sample"
)

# check multi_turn with tool config
if config.actor_rollout_ref.rollout.multi_turn.enable:
assert config.actor_rollout_ref.rollout.multi_turn.tool_config_path is not None or config.actor_rollout_ref.rollout.multi_turn.interaction_config_path is not None, "tool_config_path or interaction_config_path must be set when enabling multi_turn with tool, due to no role-playing support"
assert config.algorithm.adv_estimator in [AdvantageEstimator.GRPO], "only GRPO is tested for multi-turn with tool"
assert (
config.actor_rollout_ref.rollout.multi_turn.tool_config_path is not None
or config.actor_rollout_ref.rollout.multi_turn.interaction_config_path is not None
), (
"tool_config_path or interaction_config_path must be set when enabling multi_turn with tool, due to no role-playing support"
)
assert config.algorithm.adv_estimator in [AdvantageEstimator.GRPO], (
"only GRPO is tested for multi-turn with tool"
)

print("[validate_config] All configuration checks passed successfully!")

Expand Down Expand Up @@ -292,7 +346,9 @@ def init_workers(self):
# create reference policy if needed
if self.use_reference_policy:
resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy)
ref_policy_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RefPolicy], config=self.config.actor_rollout_ref, role="ref")
ref_policy_cls = RayClassWithInitArgs(
self.role_worker_mapping[Role.RefPolicy], config=self.config.actor_rollout_ref, role="ref"
)
self.resource_pool_to_cls[resource_pool]["ref"] = ref_policy_cls

# create a reward model if reward_fn is None
Expand All @@ -313,12 +369,21 @@ def init_workers(self):
wg_kwargs["ray_wait_register_center_timeout"] = self.config.trainer.ray_wait_register_center_timeout
if OmegaConf.select(self.config.trainer, "profile_steps") is not None:
wg_kwargs["profile_steps"] = OmegaConf.select(self.config.trainer, "profile_steps")
assert OmegaConf.select(self.config.trainer, "worker_nsight_options") is not None, "worker_nsight_options must be set when profile_steps is set"
wg_kwargs["worker_nsight_options"] = OmegaConf.to_container(OmegaConf.select(self.config.trainer, "worker_nsight_options"))
assert OmegaConf.select(self.config.trainer, "worker_nsight_options") is not None, (
"worker_nsight_options must be set when profile_steps is set"
)
wg_kwargs["worker_nsight_options"] = OmegaConf.to_container(
OmegaConf.select(self.config.trainer, "worker_nsight_options")
)

for resource_pool, class_dict in self.resource_pool_to_cls.items():
worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)
wg_dict = self.ray_worker_group_cls(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls, device_name=self.device_name, **wg_kwargs)
wg_dict = self.ray_worker_group_cls(
resource_pool=resource_pool,
ray_cls_with_init=worker_dict_cls,
device_name=self.device_name,
**wg_kwargs,
)
spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())
all_wg.update(spawn_wg)

Expand All @@ -344,7 +409,13 @@ def init_workers(self):
from ray.util.collective import collective

actor_rollout_workers = self.actor_wg.workers + self.rollout_wg.workers
collective.create_collective_group(actor_rollout_workers, len(actor_rollout_workers), list(range(0, len(actor_rollout_workers))), backend="nccl", group_name="actor_rollout")
collective.create_collective_group(
actor_rollout_workers,
len(actor_rollout_workers),
list(range(0, len(actor_rollout_workers))),
backend="nccl",
group_name="actor_rollout",
)
self.sync_rollout_weights()

# create async rollout manager and request scheduler
Expand Down Expand Up @@ -453,7 +524,11 @@ def gen_next_batch(iterator, weight_sync_func, weight_sync_ctx=None):

gen_next_batch(iterator, self.sync_rollout_weights)
while batch_data is not None:
do_profile = self.global_steps in self.config.trainer.profile_steps if self.config.trainer.profile_steps is not None else False
do_profile = (
self.global_steps in self.config.trainer.profile_steps
if self.config.trainer.profile_steps is not None
else False
)
if do_profile:
self.actor_wg.start_profile()
if not self.hybrid_engine:
Expand All @@ -472,7 +547,9 @@ def gen_next_batch(iterator, weight_sync_func, weight_sync_ctx=None):
with marked_timer("step", timing_raw):
# wait for the previous batch and generate next batch
with marked_timer("gen", timing_raw, color="red"):
batch, gen_batch_output, gen_batch = gen_next_batch(iterator, self.sync_rollout_weights, marked_timer("sync_rollout_weights", timing_raw))
batch, gen_batch_output, gen_batch = gen_next_batch(
iterator, self.sync_rollout_weights, marked_timer("sync_rollout_weights", timing_raw)
)
timing_raw.update(gen_batch_output.meta_info["timing"])
gen_batch_output.meta_info.pop("timing", None)

Expand All @@ -492,7 +569,9 @@ def gen_next_batch(iterator, weight_sync_func, weight_sync_ctx=None):

del gen_baseline_batch, gen_baseline_output

batch.non_tensor_batch["uid"] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object)
batch.non_tensor_batch["uid"] = np.array(
[str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object
)
# repeat to align with repeated responses in rollout
batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
batch = batch.union(gen_batch_output)
Expand Down Expand Up @@ -583,14 +662,18 @@ def gen_next_batch(iterator, weight_sync_func, weight_sync_ctx=None):

# compute rewards. apply_kl_penalty if available
if self.config.algorithm.use_kl_in_reward:
batch, kl_metrics = apply_kl_penalty(batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty)
batch, kl_metrics = apply_kl_penalty(
batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty
)
metrics.update(kl_metrics)
else:
batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]

# compute advantages, executed on the driver process

norm_adv_by_std_in_grpo = self.config.algorithm.get("norm_adv_by_std_in_grpo", True) # GRPO adv normalization factor
norm_adv_by_std_in_grpo = self.config.algorithm.get(
"norm_adv_by_std_in_grpo", True
) # GRPO adv normalization factor

batch = compute_advantage(
batch,
Expand Down Expand Up @@ -636,14 +719,20 @@ def gen_next_batch(iterator, weight_sync_func, weight_sync_ctx=None):
)

# validate
if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0):
if (
self.val_reward_fn is not None
and self.config.trainer.test_freq > 0
and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0)
):
with marked_timer("testing", timing_raw, color="green"):
val_metrics: dict = self._validate()
if is_last_step:
last_val_metrics = val_metrics
metrics.update(val_metrics)

if self.config.trainer.save_freq > 0 and (is_last_step or self.global_steps % self.config.trainer.save_freq == 0):
if self.config.trainer.save_freq > 0 and (
is_last_step or self.global_steps % self.config.trainer.save_freq == 0
):
with marked_timer("save_checkpoint", timing_raw, color="green"):
self._save_checkpoint()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ set -x
# If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs:
# export VLLM_ATTENTION_BACKEND=XFORMERS

python3 -m recipe.async.async_main_ppo \
python3 -m recipe.one_step_off_policy.async_main_ppo \
algorithm.adv_estimator=grpo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
# Copyright 2025 Bytedance Ltd. and/or its affiliates
# Copyright 2025 Meituan Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os

Expand Down