Skip to content

Commit 4ee6721

Browse files
SuperCBoseyosey
authored andcommitted
[sglang] fix: Bug in megatron+sglang TP16 update_weights. (verl-project#2336)
### What does this PR do? > We observe the following when using Megatron + Sglang + TP16: <img width="1236" alt="image" src="https://github.com/user-attachments/assets/875d83e6-325a-41c4-b778-81b457b508a1" /> After investigation, we found that this was caused by the **cudaipc** mechanism not supporting cross-machine access. We have resolved and fixed this bug. ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [ ] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### High-Level Design > Demonstrate the high-level design if this PR is complex. ### Specific Changes > List the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
1 parent ae4f634 commit 4ee6721

File tree

1 file changed

+22
-6
lines changed

1 file changed

+22
-6
lines changed

verl/workers/sharding_manager/megatron_sglang.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,21 @@
2121
import logging
2222
import os
2323

24+
import torch.distributed as dist
2425
from omegaconf import DictConfig
2526
from sglang.srt.entrypoints.engine import Engine
27+
from sglang.srt.model_executor.model_runner import LocalSerializedTensor
28+
from sglang.srt.utils import MultiprocessingSerializer
2629
from torch import nn
2730
from torch.distributed.device_mesh import DeviceMesh
2831

2932
from verl.protocol import DataProto, all_gather_data_proto
3033
from verl.utils.device import get_torch_device
31-
from verl.utils.megatron_utils import load_megatron_model_to_gpu, offload_megatron_model_to_cpu, per_tensor_generator
34+
from verl.utils.megatron_utils import (
35+
load_megatron_model_to_gpu,
36+
offload_megatron_model_to_cpu,
37+
per_tensor_generator,
38+
)
3239
from verl.utils.profiler import GPUMemoryLogger, log_gpu_memory_usage, simple_timer
3340

3441
from .base import BaseShardingManager
@@ -125,24 +132,33 @@ def __exit__(self, exc_type, exc_value, traceback):
125132
async def update_weights(self, params):
126133
if self.device_mesh["tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine:
127134
await self.inference_engine.resume_memory_occupation()
128-
129-
# Most naive implementation, can optimize a lot if it is bottleneck from sglang Engine weight update
130-
# named_tensors = [(k, v) for k, v in params.items()]
131135
named_tensors = params
132136
load_format = None
133137
for tensor_index, (name, tensor) in enumerate(named_tensors):
138+
serialized_tensor = MultiprocessingSerializer.serialize(tensor.detach())
139+
140+
if self.device_mesh["tp"].get_local_rank() == 0:
141+
gathered_serialized_tensors = [None for _ in range(self.device_mesh["tp"].mesh.size()[0])]
142+
else:
143+
gathered_serialized_tensors = None
144+
dist.gather_object(
145+
obj=serialized_tensor,
146+
object_gather_list=gathered_serialized_tensors,
147+
dst=self.device_mesh["tp"].mesh.tolist()[0],
148+
group=self.device_mesh["tp"].get_group(),
149+
)
150+
134151
if self.device_mesh["tp"].get_local_rank() == 0:
135152
await self.inference_engine.update_weights_from_tensor(
136153
named_tensors=[
137154
(
138155
name,
139-
tensor.detach(),
156+
LocalSerializedTensor(values=gathered_serialized_tensors),
140157
)
141158
],
142159
load_format=load_format,
143160
flush_cache=False,
144161
)
145-
146162
if self.device_mesh["tp"].get_local_rank() == 0:
147163
await self.inference_engine.flush_cache()
148164

0 commit comments

Comments
 (0)