Skip to content
Merged
Changes from 1 commit
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
391a1fc
one step off async training recipe
imh966 Jun 27, 2025
338c2a9
simplify trainer
imh966 Jun 30, 2025
071ddc2
fix resource pool config and simplify the trainer yaml file
imh966 Jul 1, 2025
71569f5
separate actor and rollout class
imh966 Jul 3, 2025
78ef6f2
update name of recipe and add license
imh966 Jul 3, 2025
e274747
Merge branch 'volcengine:main' into recipe/async_training
ArronHZG Jul 7, 2025
b9a9618
one_step_off_policy megatron
ArronHZG Jul 7, 2025
8dc0034
use fsdp2 and clear useless code
lalala-2 Jul 7, 2025
5ea1c00
fix config
lalala-2 Jul 7, 2025
69d58c4
fix
lalala-2 Jul 7, 2025
6cdaf2e
one_step_off_policy dapo_7b 2 node
ArronHZG Jul 7, 2025
36ed4f6
recipe/one_step_off_policy
ArronHZG Jul 8, 2025
a1966ef
opt gen_next_batch
lalala-2 Jul 8, 2025
40df88f
Merge branch 'recipe/async_training_megatron' of https://github.com/i…
lalala-2 Jul 8, 2025
5d52efa
4_12_megatron
ArronHZG Jul 8, 2025
59f6be9
4_12_megatron
ArronHZG Jul 8, 2025
dfabe15
megatron config
ArronHZG Jul 8, 2025
40e8816
megatron config
ArronHZG Jul 8, 2025
fc76d4f
fix megatron
lalala-2 Jul 8, 2025
dedc436
Merge branch 'recipe/async_training_megatron' of https://github.com/i…
lalala-2 Jul 8, 2025
344581f
megatron config
ArronHZG Jul 8, 2025
0091f52
megatron config
ArronHZG Jul 9, 2025
283f7fd
megatron config
ArronHZG Jul 9, 2025
6871a29
cross epoch
ArronHZG Jul 9, 2025
1b96322
ruff format
ArronHZG Jul 9, 2025
652f91f
# Copyright 2025 Meituan Ltd. and/or its affiliates
ArronHZG Jul 9, 2025
b36918c
add Copyright
ArronHZG Jul 9, 2025
84b712d
optim sh
ArronHZG Jul 9, 2025
4685463
python3
ArronHZG Jul 9, 2025
7f3d1db
update recipe
ArronHZG Jul 10, 2025
592f393
add doc
ArronHZG Jul 10, 2025
2fb1cd9
Merge branch 'volcengine:main' into recipe/async_training
ArronHZG Jul 10, 2025
dff8f56
update date
ArronHZG Jul 11, 2025
648cb44
update date
ArronHZG Jul 11, 2025
c2395f7
config
ArronHZG Jul 11, 2025
165c1b2
Revert "fix config"
lalala-2 Jul 11, 2025
aaa356e
fix error
lalala-2 Jul 11, 2025
03f1dec
update is_last_step
ArronHZG Jul 11, 2025
e2007ef
one_step_off_policy
ArronHZG Jul 11, 2025
204d624
update readme
ArronHZG Jul 14, 2025
19fac39
e2e_one_step_off_policy
ArronHZG Jul 14, 2025
c1b86ec
add e2e test for one_step_off_policy
ArronHZG Jul 14, 2025
492ff98
add e2e test for one_step_off_policy
ArronHZG Jul 14, 2025
1e7aa47
add e2e test for one_step_off_policy
ArronHZG Jul 14, 2025
8ab0834
add e2e test for one_step_off_policy
ArronHZG Jul 14, 2025
22dc212
format
ArronHZG Jul 14, 2025
dcbfb0c
ruff check
ArronHZG Jul 14, 2025
1e8cee3
add megatron test
ArronHZG Jul 14, 2025
27c9816
Merge pull request #2 from imh966/recipe/async_training_e2e_test
ArronHZG Jul 14, 2025
727320b
Merge branch 'volcengine:main' into recipe/async_training
ArronHZG Jul 14, 2025
8727916
rm spmd
ArronHZG Jul 14, 2025
42ddeed
CI check fix some error
ArronHZG Jul 14, 2025
5ffd8b4
merge main
ArronHZG Jul 15, 2025
1c9b6eb
change author
ArronHZG Jul 15, 2025
8772b14
update e2e_one_step_off_policy CI rule
ArronHZG Jul 15, 2025
c8468e6
update comments
ArronHZG Jul 15, 2025
d8dd8b0
Merge branch 'volcengine:main' into recipe/async_training
ArronHZG Jul 15, 2025
659b108
update ruff
ArronHZG Jul 15, 2025
9b5646a
Fix pre-commit error: sort imports in async_main_ppo.py
openhands-agent Jul 15, 2025
1ed49c7
rollout.nnodes
ArronHZG Jul 16, 2025
754cfae
update code and doc by comments
ArronHZG Jul 16, 2025
8df1c1b
ruff
ArronHZG Jul 16, 2025
1837fc7
update code and doc by comments
ArronHZG Jul 16, 2025
c56467f
update docs
ArronHZG Jul 16, 2025
174d94a
Merge branch 'recipe/async_training' of https://github.com/imh966/ver…
ArronHZG Jul 16, 2025
e3db358
Merge branch 'recipe/async_training' into recipe/async_training_rollo…
ArronHZG Jul 16, 2025
8e5b714
Merge pull request #3 from imh966/recipe/async_training_rollout_nodes
ArronHZG Jul 16, 2025
40b2ebe
Merge branch 'volcengine:main' into recipe/async_training
ArronHZG Jul 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
opt gen_next_batch
  • Loading branch information
lalala-2 committed Jul 8, 2025
commit a1966ef4eff278521a65ff2648fe8bb366e6b7d5
90 changes: 49 additions & 41 deletions recipe/one_step_off_policy/async_ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,31 @@
from verl.utils.tracking import ValidationGenerationsLogger


class GenerationBatchFuture:
"""
Wrapper class for encapsulating batch generation results
"""
def __init__(self, batch, gen_batch_output):
self.batch = batch # Input batch data
self.gen_batch_output = gen_batch_output # Generated sequences from the main model (DataProtoFuture)

def get(self):
"""
Get the actual results by calling get() method on gen_batch_output

Returns:
tuple: (batch, gen_batch_result)
- batch: Original input batch data
- gen_batch_result: Result from gen_batch_output.get() or gen_batch_output itself
"""
# Call get() method on gen_batch_output if available
if hasattr(self.gen_batch_output, 'get'):
gen_batch_result = self.gen_batch_output.get()
else:
gen_batch_result = self.gen_batch_output

return self.batch, gen_batch_result

class AsyncRayPPOTrainer(RayPPOTrainer):
# TODO: support each role have individual ray_worker_group_cls,
# i.e., support different backend of different role
Expand Down Expand Up @@ -121,7 +146,7 @@ def __init__(
AdvantageEstimator.GRPO,
AdvantageEstimator.GRPO_PASSK,
AdvantageEstimator.REINFORCE_PLUS_PLUS,
AdvantageEstimator.REMAX,
# AdvantageEstimator.REMAX, # TODO:REMAX advantage estimator is not yet supported in one_step_off_policy
AdvantageEstimator.RLOO,
AdvantageEstimator.OPO,
AdvantageEstimator.REINFORCE_PLUS_PLUS_BASELINE,
Expand Down Expand Up @@ -404,20 +429,14 @@ def fit(self):
last_val_metrics = None

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we try to avoid using nested function definitions? For instance, move this to
def _create_continuous_iterator(self) and def _async_gen_next_batch(self, continuous_iterator)

for epoch in range(self.config.trainer.total_epochs):
batch_data = None
batch_data_future = None
iterator = iter(self.train_dataloader)

def gen_next_batch(iterator, weight_sync_func, weight_sync_ctx=None):
nonlocal batch_data
ret = (None, None, None)
# waiting for the output of previous rollout step
if batch_data is not None:
ret = (batch_data[0], batch_data[1].get(), batch_data[2])

def asys_gen_next_batch(iterator):
try:
batch_dict = next(iterator)
except StopIteration:
return ret
return None

batch = DataProto.from_single_dict(batch_dict)
# pop those keys for generation
Expand All @@ -436,22 +455,23 @@ def gen_next_batch(iterator, weight_sync_func, weight_sync_ctx=None):
non_tensor_batch_keys=non_tensor_batch_keys_to_pop,
)

if weight_sync_ctx is None:
from contextlib import nullcontext

weight_sync_ctx = nullcontext()

with weight_sync_ctx:
weight_sync_func()
# sync weights from actor to rollout
self.sync_rollout_weights()

# async generation
gen_batch_output = self.rollout_wg.async_generate_sequences(gen_batch)
batch_data = (batch, gen_batch_output, gen_batch)
return ret

gen_next_batch(iterator, self.sync_rollout_weights)
while batch_data is not None:
do_profile = self.global_steps in self.config.trainer.profile_steps if self.config.trainer.profile_steps is not None else False
return GenerationBatchFuture(batch, gen_batch_output)

# first call asys_gen_next_batch before train
batch_data_future = asys_gen_next_batch(iterator)

while batch_data_future is not None:
do_profile = (
self.global_steps in self.config.trainer.profile_steps
if self.config.trainer.profile_steps is not None
else False
)
if do_profile:
self.actor_wg.start_profile()
if not self.hybrid_engine:
Expand All @@ -468,27 +488,15 @@ def gen_next_batch(iterator, weight_sync_func, weight_sync_ctx=None):
is_last_step = self.global_steps >= self.total_training_steps

with marked_timer("step", timing_raw):
# wait for the previous batch and generate next batch
with marked_timer("gen", timing_raw, color="red"):
batch, gen_batch_output, gen_batch = gen_next_batch(iterator, self.sync_rollout_weights, marked_timer("sync_rollout_weights", timing_raw))
# wait for the previous batch
with marked_timer("wait_prev_gen", timing_raw, color="red"):
batch, gen_batch_output = batch_data_future.get()
timing_raw.update(gen_batch_output.meta_info["timing"])
gen_batch_output.meta_info.pop("timing", None)

if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
with marked_timer("gen_max", timing_raw, color="purple"):
gen_baseline_batch = deepcopy(gen_batch)
gen_baseline_batch.meta_info["do_sample"] = False
gen_baseline_output = self.rollout_wg.generate_sequences(gen_baseline_batch)

batch = batch.union(gen_baseline_output)
reward_baseline_tensor = self.reward_fn(batch)
reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)

batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))

batch.batch["reward_baselines"] = reward_baseline_tensor

del gen_baseline_batch, gen_baseline_output

# asys next generation (with syns weights from actor to rollout)
with marked_timer("sync_rollout_weights", timing_raw, color="purple"):
batch_data_future = asys_gen_next_batch(iterator)

batch.non_tensor_batch["uid"] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object)
# repeat to align with repeated responses in rollout
Expand Down