Skip to content

Commit 72219f7

Browse files
[sglang, megatron, perf] feat: speed up megatron sglang weight update by 10x (verl-project#2418)
### What does this PR do? optimize the performance of sglang+megatron weight update refer to the bucketing implementation of [`THUDM/slime`](https://github.com/THUDM/slime/blob/fb7605cc5fb09af0f9369d37f7192f12bddee577/slime/ray/ppo_actor.py#L452). |model| bucket size MB |boost | | ---- | ----- | ---- | | Moonlight16B @ 8xH20 | 512MB | 175s -> 18s | |DeepseekV3 671B @ 512xH20| 512MB | ONGOING | releated to issues verl-project#2419 , sgl-project/sglang#6762 zhaochenyang20/Awesome-ML-SYS-Tutorial#169 similar fixes for FSDP: verl-project#2499 > We are from the Large Model Post-Training Team of 📕 Xiaohongshu's AI Platform Technology Department , dedicated to developing high-performance, easily-scalable distributed post-training engines. ### Checklist Before Starting - [ ] Search for similar PRs. Paste at least one query link here: ... - [ ] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). --------- Co-authored-by: Stefan He <hebiaobuaa@gmail.com>
1 parent cd0f039 commit 72219f7

File tree

6 files changed

+165
-11
lines changed

6 files changed

+165
-11
lines changed

tests/special_e2e/run_ppo_trainer_megatron.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ python3 -m verl.trainer.main_ppo --config-path=config \
175175
actor_rollout_ref.rollout.tensor_model_parallel_size=$ROLLOUT_TP \
176176
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
177177
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
178+
actor_rollout_ref.rollout.update_weights_bucket_megabytes=128 \
178179
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \
179180
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \
180181
actor_rollout_ref.ref.megatron.use_mbridge=${USE_MBRIDGE} \
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright 2023-2024 SGLang Team
2+
# Copyright 2025 ModelBest Inc. and/or its affiliates
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import pytest
17+
import torch
18+
19+
from verl.workers.rollout.sglang_rollout.utils import get_named_tensor_buckets
20+
21+
_TENSOR_1MB = torch.zeros(512, 512)
22+
_BYTES_1MB = 1 << 20
23+
24+
25+
@pytest.mark.parametrize(
26+
"named_tensors, bucket_size_mb, gt_groups",
27+
[
28+
(
29+
[("a", _TENSOR_1MB), ("b", _TENSOR_1MB)],
30+
0.5 * _BYTES_1MB,
31+
[["a"], ["b"]],
32+
),
33+
(
34+
[("a", _TENSOR_1MB), ("b", _TENSOR_1MB)],
35+
1 * _BYTES_1MB,
36+
[["a"], ["b"]],
37+
),
38+
(
39+
[("a", _TENSOR_1MB), ("b", _TENSOR_1MB)],
40+
1.5 * _BYTES_1MB,
41+
[["a"], ["b"]],
42+
),
43+
(
44+
[("a", _TENSOR_1MB), ("b", _TENSOR_1MB)],
45+
2 * _BYTES_1MB,
46+
[["a", "b"]],
47+
),
48+
],
49+
)
50+
def test_get_named_tensor_buckets(named_tensors, bucket_size_mb, gt_groups: list[list[str]]):
51+
named_tensors_iter = iter(named_tensors)
52+
groups = list(get_named_tensor_buckets(named_tensors_iter, bucket_size_mb))
53+
assert len(groups) == len(gt_groups)
54+
for group, gt_group in zip(groups, gt_groups, strict=True):
55+
assert len(group) == len(gt_group)
56+
for (name, _), (gt_name) in zip(group, gt_group, strict=True):
57+
assert name == gt_name

verl/trainer/config/_generated_ppo_trainer.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ actor_rollout_ref:
130130
custom_async_server:
131131
path: null
132132
name: null
133+
update_weights_bucket_megabytes: 2048
133134
trace:
134135
backend: null
135136
token2text: false

verl/trainer/config/rollout/rollout.yaml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,21 @@ agent:
179179
# Class name of the custom async server class (e.g. AsyncvLLMServer)
180180
name: null
181181

182+
# Specifies the tensor bucket size (in megabytes) for batch weight updates during rollout operations.
183+
# This parameter controls the maximum payload size for a single weight update request.
184+
#
185+
# https://github.com/volcengine/verl/pull/2281
186+
#
187+
# Note:
188+
# - Currently only supported in SGLang rollout implementations
189+
# - Larger values may improve throughput but increase memory overhead
190+
# - Default value (2GB) is optimized for typical GPU memory configurations
191+
# - For the best performance of `rebuild_cuda_tensor`, it is recommended to:
192+
# 1. Enable `RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES`.
193+
# 2. Manually set `CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7`
194+
# when using Tensor Parallelism (TP) >= 8.
195+
update_weights_bucket_megabytes: 2048
196+
182197
# trace rollout data
183198
trace:
184199

verl/workers/rollout/sglang_rollout/utils.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515

1616
import pickle
17-
from typing import Any, Optional
17+
from typing import Any, Iterator, Optional
1818

1919
import numpy as np
2020
import torch
@@ -66,3 +66,43 @@ def broadcast_pyobj(
6666
serialized_data = bytes(tensor_data.cpu().numpy())
6767
data = pickle.loads(serialized_data)
6868
return data
69+
70+
71+
def get_named_tensor_buckets(
72+
iterable: Iterator[tuple[str, torch.Tensor]], bucket_bytes: int
73+
) -> Iterator[list[tuple[str, torch.Tensor]]]:
74+
"""
75+
Group tensors into buckets based on a specified size in megabytes.
76+
77+
Args:
78+
iterable: An iterator of tuples containing tensor names and tensors.
79+
bucket_bytes: The maximum size of each bucket in bytes.
80+
81+
Yields:
82+
Lists of tuples, where each tuple contains a tensor name and its corresponding tensor.
83+
84+
Example:
85+
>>> tensors = [('tensor1', torch.randn(1000, 1000)), ('tensor2', torch.randn(2000, 2000))]
86+
>>> for bucket in get_named_tensor_buckets(tensors, bucket_size_mb=10):
87+
... print(bucket)
88+
[('tensor1', tensor(...)), ('tensor2', tensor(...))]
89+
90+
"""
91+
if bucket_bytes <= 0:
92+
raise ValueError(f"bucket_bytes must be greater than 0, got {bucket_bytes}")
93+
94+
current_bucket = []
95+
current_size = 0
96+
for name, tensor in iterable:
97+
tensor_size = tensor.element_size() * tensor.numel()
98+
if current_size + tensor_size > bucket_bytes:
99+
if current_bucket:
100+
yield current_bucket
101+
current_bucket = [(name, tensor)]
102+
current_size = tensor_size
103+
else:
104+
current_bucket.append((name, tensor))
105+
current_size += tensor_size
106+
107+
if current_bucket:
108+
yield current_bucket

verl/workers/sharding_manager/megatron_sglang.py

Lines changed: 50 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
per_tensor_generator,
3838
)
3939
from verl.utils.profiler import GPUMemoryLogger, log_gpu_memory_usage, simple_timer
40+
from verl.workers.rollout.sglang_rollout.utils import get_named_tensor_buckets
4041

4142
from .base import BaseShardingManager
4243

@@ -130,37 +131,76 @@ def __exit__(self, exc_type, exc_value, traceback):
130131
loop.run_until_complete(self.sleep())
131132

132133
async def update_weights(self, params):
134+
"""
135+
Update model weights using tensor buckets, similar to THUDM/slime's implementation.
136+
137+
Notes:
138+
- For the best performance of `rebuild_cuda_tensor`, it is recommended to:
139+
1. Enable `RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES`.
140+
2. Manually set `CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7`
141+
when using Tensor Parallelism (TP >= 8).
142+
- See reference implementations in SLIME:
143+
- Main logic: https://github.com/THUDM/slime/blob/fb7605cc5fb09af0f9369d37f7192f12bddee577/slime/ray/ppo_actor.py#L452
144+
- runtime envs: https://github.com/THUDM/slime/blob/fb7605cc5fb09af0f9369d37f7192f12bddee577/slime/ray/ppo_actor.py#L39
145+
"""
133146
if self.device_mesh["tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine:
134147
await self.inference_engine.resume_memory_occupation()
135148
named_tensors = params
136149
load_format = None
137-
for tensor_index, (name, tensor) in enumerate(named_tensors):
138-
serialized_tensor = MultiprocessingSerializer.serialize(tensor.detach())
150+
151+
update_weights_bucket_bytes = int(self.rollout_config.update_weights_bucket_megabytes) << 20
152+
for batch in get_named_tensor_buckets(named_tensors, update_weights_bucket_bytes):
153+
# On each rank, serialize a batch of (name, tensor) tuples.
154+
# named_tensors_batch will be a list like:
155+
# [(name0, serialized_tensor0_tp0), (name1, serialized_tensor1_tp0), ...]
156+
named_tensors_batch = [
157+
(name, MultiprocessingSerializer.serialize(tensor.detach())) for name, tensor in batch
158+
]
139159

140160
if self.device_mesh["tp"].get_local_rank() == 0:
141-
gathered_serialized_tensors = [None for _ in range(self.device_mesh["tp"].mesh.size()[0])]
161+
# On rank 0, prepare a list to hold the gathered batches from all ranks.
162+
gathered_serialized_batches = [None for _ in range(self.device_mesh["tp"].mesh.size()[0])]
142163
else:
143-
gathered_serialized_tensors = None
164+
gathered_serialized_batches = None
165+
166+
# Gather the named_tensors_batch from all ranks to rank 0.
167+
# After this, on rank 0, gathered_serialized_batches will be a list of lists:
168+
# [ [ (name0, s_t0_tp0), (name1, s_t1_tp0), ... ], # batch from TP rank 0
169+
# [ (name0, s_t0_tp1), (name1, s_t1_tp1), ... ], # batch from TP rank 1
170+
# ... ]
171+
# On other ranks, gathered_serialized_batches will be None.
144172
dist.gather_object(
145-
obj=serialized_tensor,
146-
object_gather_list=gathered_serialized_tensors,
173+
obj=named_tensors_batch,
174+
object_gather_list=gathered_serialized_batches,
147175
dst=self.device_mesh["tp"].mesh.tolist()[0],
148176
group=self.device_mesh["tp"].get_group(),
149177
)
150178

151179
if self.device_mesh["tp"].get_local_rank() == 0:
180+
# Use zip(*) to "transpose" the data structure.
181+
# This groups the serialized parts for each individual tensor across all TP ranks.
182+
# Example: from [[(n0, t0_tp0), (n1, t1_tp0)], [(n0, t0_tp1), (n1, t1_tp1)]]
183+
# to [ ( (n0, t0_tp0), (n0, t0_tp1) ), ( (n1, t1_tp0), (n1, t1_tp1) ) ]
184+
logical_tensors = zip(*gathered_serialized_batches, strict=False)
152185
await self.inference_engine.update_weights_from_tensor(
153186
named_tensors=[
187+
# 'tensor_group' represents a single logical tensor's data from all ranks.
154188
(
155-
name,
156-
LocalSerializedTensor(values=gathered_serialized_tensors),
189+
tensor_group[0][0], # Get the name from the first rank's data.
190+
LocalSerializedTensor(
191+
# 'rank_part' is the (name, serialized_tensor) tuple from one specific rank.
192+
values=[rank_part[1] for rank_part in tensor_group]
193+
),
157194
)
195+
for tensor_group in logical_tensors
196+
# each tensor_group is like ( (n0, t0_tp0), (n0, t0_tp1) )
158197
],
159198
load_format=load_format,
160199
flush_cache=False,
161200
)
162-
if self.device_mesh["tp"].get_local_rank() == 0:
163-
await self.inference_engine.flush_cache()
201+
202+
if self.device_mesh["tp"].get_local_rank() == 0:
203+
await self.inference_engine.flush_cache()
164204

165205
async def release_memory(self):
166206
if self.device_mesh["tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine:

0 commit comments

Comments
 (0)