Skip to content

Commit 433544f

Browse files
authored
[megatron] feat: use mbridge as megatron adaptor (#2064)
### What does this PR do? MBridge provides a seamless bridge between Hugging Face models and Megatron-Core's optimized implementation for efficient distributed training and inference. It also offers necessary tools and processes for integrating Reinforcement Learning (RL) with Megatron. see https://github.com/ISEEKYAN/mbridge mbridge is developed and maintained by NVIDIA, providing functions for: - modeling HF models with megatron - loading/saving HF format weights with no memory overhead - online export parameter to rollout engine with per-tensor-generator - RL specific optimization and friendly APIs on Megatron side. Some early access features for megatron. with mbridge, the direct improvement is: - a clean interface for megatron - no offline dist_ckpt conversion needed - no offline model merger needed ### Test tested with GSM8k qwen2-7B-instruct <img width="486" alt="image" src="https://github.com/user-attachments/assets/dd271e8a-9167-470f-8b0c-dde2bcfe1800" /> ### High-Level Design add an option `actor_rollout_ref.actor.megatron.use_mbridge`, default is False. Set it to true for enable. when enabled, the model_instantiate/model_init_load/checkpoint_save/checkpoint_load/per_tensor_generator will be taken over by mbridge ### Specific Changes > List the specific changes. ### API > Demonstrate how the API changes if any. ### Usage Example add this line to the script: ``` actor_rollout_ref.actor.megatron.use_mbridge=True \ ``` ### Checklist Before Submitting - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting). - [ ] Add `[BREAKING]` to the PR title `description` if it breaks any API. - [ ] Update the documentation about your changes in the [docs](https://github.com/volcengine/verl/tree/main/docs). - [ ] New CI unit test(s) are added to cover the code path. - [ ] Rely on existing unit tests on CI that covers the code path.
1 parent 0ea96a2 commit 433544f

File tree

11 files changed

+263
-133
lines changed

11 files changed

+263
-133
lines changed

.github/workflows/e2e_ppo_trainer_megatron.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ jobs:
320320
ADV_ESTIMATOR=grpo USE_DUMMY_MODEL=True DUMMY_MODEL_CONFIG_PATH=tests/special_e2e/ppo_trainer/expert_parallel/qwen2moe_minimal.json \
321321
PPO_MAX_TOKEN_LEN=512 FWD_MAX_TOKEN_LEN=512 \
322322
MAX_PROMPT_LENGTH=256 MAX_RESPONSE_LENGTH=256 \
323-
MODEL_ID=Qwen/Qwen1.5-MoE-A2.7B-Chat \
323+
MODEL_ID=Qwen/Qwen1.5-MoE-A2.7B-Chat USE_MBRIDGE=True \
324324
COMMON_PP=2 COMMON_VPP=null COMMON_CP=1 COMMON_TP=4 COMMON_EP=4 COMMON_ETP=1 INFER_TP=8 \
325325
USE_DIST_CKPT=True ALL_OFFLOAD=True SKIP_SAVE_HF_MODEL=1 bash tests/special_e2e/run_ppo_trainer_megatron.sh
326326
- name: clean up

setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
"torch==2.6.0",
5757
]
5858
TRL_REQUIRES = ["trl<=0.9.6"]
59+
MCORE_REQUIRES = ["mbridge"]
5960

6061
extras_require = {
6162
"test": TEST_REQUIRES,
@@ -66,6 +67,7 @@
6667
"vllm": VLLM_REQUIRES,
6768
"sglang": SGLANG_REQUIRES,
6869
"trl": TRL_REQUIRES,
70+
"mcore": MCORE_REQUIRES,
6971
}
7072

7173

tests/special_e2e/run_ppo_trainer_megatron.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ CRITIC_PARAM_OFFLOAD=${CRITIC_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD}
102102
CRITIC_GRAD_OFFLOAD=${CRITIC_GRAD_OFFLOAD:-$COMMON_GRAD_OFFLOAD}
103103
CRITIC_OPTIMIZER_OFFLOAD=${CRITIC_OPTIMIZER_OFFLOAD:-$COMMON_OPTIMIZER_OFFLOAD}
104104
RM_PARAM_OFFLOAD=${RM_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD}
105+
USE_MBRIDGE=${USE_MBRIDGE:-False}
105106

106107
LR_WARMUP_STEPS=${LR_WARMUP_STEPS:-null}
107108

@@ -182,6 +183,7 @@ for ENGINE in "${ENGINES[@]}"; do
182183
actor_rollout_ref.ref.megatron.param_offload=${REF_PARAM_OFFLOAD} \
183184
actor_rollout_ref.ref.megatron.use_dist_checkpointing=${USE_DIST_CKPT} \
184185
actor_rollout_ref.ref.megatron.dist_checkpointing_path=${DIST_CKPT_PATH} \
186+
actor_rollout_ref.ref.megatron.use_mbridge=${USE_MBRIDGE} \
185187
critic.optim.lr=2e-5 \
186188
critic.optim.lr_warmup_steps=$LR_WARMUP_STEPS \
187189
critic.model.path="${MODEL_PATH}" \

verl/models/mcore/mbridge.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
2+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
try:
17+
from mbridge import AutoBridge
18+
from mbridge.utils.post_creation_callbacks import freeze_moe_router, make_value_model
19+
except ImportError:
20+
print("mbridge package not found. Please install mbridge with `pip install verl[mcore]` or `pip install mbridge`")
21+
raise
22+
23+
__all__ = ["AutoBridge", "make_value_model", "freeze_moe_router"]

verl/single_controller/base/megatron/worker.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def _init_hf_config_and_tf_config(
4747
override_model_config,
4848
override_transformer_config,
4949
trust_remote_code=False,
50+
use_mbridge=False,
5051
):
5152
from transformers import AutoConfig
5253

@@ -105,6 +106,15 @@ def add_optimization_config_to_tf_config(tf_config):
105106
setattr(tf_config, k, v)
106107

107108
add_optimization_config_to_tf_config(tf_config)
109+
if use_mbridge:
110+
from verl.models.mcore.mbridge import AutoBridge
111+
112+
bridge = AutoBridge.from_config(hf_config)
113+
bridge.set_extra_args(**override_transformer_config)
114+
tf_config = bridge.config
115+
self.bridge = bridge
116+
else:
117+
self.bridge = None
108118

109119
print(f"TF config: {tf_config}")
110120
self.hf_config = hf_config

verl/trainer/config/ppo_megatron_trainer.yaml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,8 @@ actor_rollout_ref:
102102
dist_checkpointing_path: null
103103
seed: 42
104104
override_transformer_config: {} # additional transformer config like: num_layers_in_first(/last)_pipeline_stage
105-
profile: # profile the actor model in `update_policy`
105+
use_mbridge: False
106+
profile: # profile the actor model in `update_policy`
106107
use_profile: False # open it when you want to profile the actor model
107108
profile_ranks: null # list, you can specify the ranks to profile
108109
step_start: -1 # start step in update_policy
@@ -138,6 +139,7 @@ actor_rollout_ref:
138139
dist_checkpointing_path: null
139140
seed: ${actor_rollout_ref.actor.megatron.seed}
140141
override_transformer_config: ${actor_rollout_ref.actor.megatron.override_transformer_config}
142+
use_mbridge: ${actor_rollout_ref.actor.megatron.use_mbridge}
141143
profile:
142144
use_profile: False
143145
profile_ranks: null
@@ -311,6 +313,7 @@ critic:
311313
dist_checkpointing_path: null
312314
seed: ${actor_rollout_ref.actor.megatron.seed}
313315
override_transformer_config: ${actor_rollout_ref.actor.megatron.override_transformer_config}
316+
use_mbridge: ${actor_rollout_ref.actor.megatron.use_mbridge}
314317
load_weight: True
315318
ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}
316319
ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu
@@ -355,6 +358,7 @@ reward_model:
355358
dist_checkpointing_path: null
356359
seed: ${actor_rollout_ref.actor.megatron.seed}
357360
override_transformer_config: {}
361+
use_mbridge: ${actor_rollout_ref.actor.megatron.use_mbridge}
358362
model:
359363
input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical
360364
path: ~/models/FsfairX-LLaMA3-RM-v0.1

verl/utils/checkpoint/megatron_checkpoint_manager.py

Lines changed: 38 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@ def __init__(
116116
optimizer_scheduler,
117117
use_distributed_optimizer: bool,
118118
use_checkpoint_opt_param_scheduler: bool = False,
119+
use_dist_checkpointing: bool = True,
120+
bridge=None,
119121
**kwargs,
120122
):
121123
super().__init__(
@@ -139,8 +141,10 @@ def __init__(
139141
self.model_path = self.config.model.path
140142
self.use_distributed_optimizer = use_distributed_optimizer
141143
self.use_checkpoint_opt_param_scheduler = use_checkpoint_opt_param_scheduler
142-
144+
self.bridge = bridge
143145
self.rank = torch.distributed.get_rank()
146+
self.use_dist_checkpointing = use_dist_checkpointing or not self.bridge or self.is_value_model
147+
self.use_hf_checkpoint = not self.use_dist_checkpointing
144148

145149
self.weight_saver = get_weight_saver(self.arch)
146150

@@ -303,7 +307,7 @@ def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_afte
303307
ckpt_dir=dist_checkpoint_path,
304308
)
305309

306-
if self.should_load_model:
310+
if self.should_load_model and self.use_dist_checkpointing:
307311
assert "model" in state_dict or any(
308312
f"model{vpp_rank}" in state_dict for vpp_rank in range(len(self.model))
309313
), f"Model state dict not found in {state_dict.keys()}. Please check the checkpoint file {local_path}."
@@ -316,6 +320,10 @@ def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_afte
316320
mpu.set_virtual_pipeline_model_parallel_rank(vpp_rank)
317321
self.model[vpp_rank].load_state_dict(model_state_dict)
318322
log_with_rank(f"Loaded sharded model checkpoint from {local_path}", rank=self.rank, logger=logger)
323+
elif self.should_load_model and self.use_hf_checkpoint:
324+
hf_model_path = get_hf_model_checkpoint_path(local_path)
325+
self.bridge.load_weights(self.model, hf_model_path)
326+
log_with_rank(f"Loaded HF model checkpoint from {hf_model_path} with bridge", rank=self.rank, logger=logger)
319327

320328
if self.should_load_optimizer:
321329
assert "optimizer" in state_dict, (
@@ -370,29 +378,35 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i
370378
local_path = local_mkdir_safe(local_path)
371379
dist_checkpoint_path = get_dist_checkpoint_path(local_path)
372380

373-
# Generate state dict for saving
374-
state_dict = self.generate_state_dict()
375-
log_with_rank(f"Generated state dict for saving: {state_dict.keys()}", rank=self.rank, logger=logger)
376-
for vpp_rank, model in enumerate(self.model):
377-
if len(self.model) > 1:
378-
model_i_keys = state_dict[f"model{vpp_rank}"].keys()
379-
log_with_rank(f"Generated state dict for saving: {model_i_keys}", rank=self.rank, logger=logger)
380-
else:
381-
log_with_rank(
382-
f"Generated state dict for saving: {state_dict['model'].keys()}", rank=self.rank, logger=logger
383-
)
384-
385-
# Start Async save if enabled
386-
async_save_request = save_dist_checkpointing(
387-
sharded_state_dict=state_dict,
388-
ckpt_path=dist_checkpoint_path,
389-
async_save=self.checkpoint_config.async_save,
390-
)
381+
if self.use_dist_checkpointing:
382+
# Generate state dict for saving
383+
state_dict = self.generate_state_dict()
384+
log_with_rank(f"Generated state dict for saving: {state_dict.keys()}", rank=self.rank, logger=logger)
385+
for vpp_rank, model in enumerate(self.model):
386+
if len(self.model) > 1:
387+
model_i_keys = state_dict[f"model{vpp_rank}"].keys()
388+
log_with_rank(f"Generated state dict for saving: {model_i_keys}", rank=self.rank, logger=logger)
389+
else:
390+
log_with_rank(
391+
f"Generated state dict for saving: {state_dict['model'].keys()}", rank=self.rank, logger=logger
392+
)
393+
# Start Async save if enabled
394+
async_save_request = save_dist_checkpointing(
395+
sharded_state_dict=state_dict,
396+
ckpt_path=dist_checkpoint_path,
397+
async_save=self.checkpoint_config.async_save,
398+
)
391399

392-
# Synchronize all async save requests
393-
if not self.checkpoint_config.async_save:
394-
assert async_save_request is None, "Async save request should be None when not using async save."
395-
torch.distributed.barrier()
400+
# Synchronize all async save requests
401+
if not self.checkpoint_config.async_save:
402+
assert async_save_request is None, "Async save request should be None when not using async save."
403+
torch.distributed.barrier()
404+
else:
405+
assert self.use_hf_checkpoint, "use_hf_checkpoint should be True when not using dist checkpointing"
406+
log_with_rank(f"Saving HF model checkpoint to {local_path} with bridge", rank=self.rank, logger=logger)
407+
hf_ckpt_path = get_hf_model_checkpoint_path(local_path)
408+
self.bridge.save_weights(self.model, hf_ckpt_path)
409+
log_with_rank(f"Saved bridge checkpoint to {hf_ckpt_path}", rank=self.rank, logger=logger)
396410

397411
if self.should_save_model:
398412
# Only rank 0 saves the hf config and tokenizer to huggingface path

verl/utils/model.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,19 @@ def _load_hf_model(config, model_config, is_value_model, local_cache_path):
443443
return architectures, model, state_dict, is_value_model
444444

445445

446+
def get_hf_model_path(config, local_cache_path="~/.cache/verl/rlhf"):
447+
local_cache_path = os.path.expanduser(local_cache_path)
448+
if config.model.path.startswith("hdfs:"):
449+
from verl.utils.fs import copy_to_local
450+
451+
local_model_path = copy_to_local(
452+
src=config.model.path, cache_dir=local_cache_path, use_shm=config.model.get("use_shm", False)
453+
)
454+
else:
455+
local_model_path = config.model.path
456+
return local_model_path
457+
458+
446459
def load_megatron_model_weights(
447460
config, model_config, parallel_model, params_dtype, is_value_model=False, local_cache_path="~/.cache/verl/rlhf"
448461
):

0 commit comments

Comments
 (0)