Skip to content
Next Next commit
use mbridge
  • Loading branch information
ISEEKYAN committed Jun 17, 2025
commit 99b41fedd9fe8a8f2c05b463e32be636ff30f74c
21 changes: 21 additions & 0 deletions verl/models/mcore/mbridge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
try:
from mbridge import AutoBridge
from mbridge.utils.post_creation_callbacks import freeze_moe_router, make_value_model
except ImportError:
import subprocess
import sys

print("mbridge package not found. This package is required for model bridging functionality.")
print("Install mbridge with `pip install git+https://github.com/ISEEKYAN/mbridge.git --no-deps`")

def install_mbridge():
try:
subprocess.check_call([sys.executable, "-m", "pip", "install", "git+https://github.com/ISEEKYAN/mbridge.git", "--no-deps"])
except subprocess.CalledProcessError:
print("Failed to install mbridge")
raise

install_mbridge()
from mbridge import *

__all__ = ["AutoBridge", "make_value_model", "freeze_moe_router"]
10 changes: 10 additions & 0 deletions verl/single_controller/base/megatron/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def _init_hf_config_and_tf_config(
override_model_config,
override_transformer_config,
trust_remote_code=False,
use_mbridge=False,
):
from transformers import AutoConfig

Expand Down Expand Up @@ -94,6 +95,15 @@ def add_optimization_config_to_tf_config(tf_config):
setattr(tf_config, k, v)

add_optimization_config_to_tf_config(tf_config)
if use_mbridge:
from verl.models.mcore.mbridge import AutoBridge

bridge = AutoBridge.from_config(hf_config)
bridge.set_extra_args(**override_transformer_config)
tf_config = bridge.config
self.bridge = bridge
else:
self.bridge = None

print(f"TF config: {tf_config}")
self.hf_config = hf_config
Expand Down
4 changes: 4 additions & 0 deletions verl/trainer/config/ppo_megatron_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ actor_rollout_ref:
dist_checkpointing_path: null
seed: 42
override_transformer_config: {} # additional transformer config like: num_layers_in_first(/last)_pipeline_stage
use_mbridge: False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually use dist_checkpointing and mbridge should be an either-or relation? Maybe we shall use some naming like io_methods.loading_backend/saving_backend to choose between huggingface/dist_checkpointing/mbridge?

Also, we may need to consider how this combined with checkpoint configuration. Maybe directly merge these into checkpoint?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ccclyu @dataproblems , could you give some advice on the API design?

How use_dist_checkpointing and use_mbridge work to better integrate? My original thinking:

checkpoint:
    pre_load:    # first time load
        format: [hf, dist_ckpt].   # hf default use_mbridge
    load:
        format: [hf, dist_ckpt]
    save:
        format: [hf, dist_ckpt]

But maybe this will break some APIs.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the current way is ok in the config, since it's possible to have some relationship between load and save operations ( actor saves the model, rollout loads it - in the case where the two are not colocated ). However, we would need a validation when the config is read to make sure the load and save options are compatible with each other.

Implementation wise, I would add an abstraction that captures the checkpoint saving logic away from the checkpoint manager and the workers, that way the code base for checkpoint manager and workers relies on a stable interface and allows you to provide more options while modifying less code. Is that something that you were looking for, or am I missing the point here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, Your latter part makes sense to me, it's a refactor point, here I hope to focus on API design.

So use_mbridge is a more functional option including model initialization, so it shall work as @ISEEKYAN 's implementation, so the question is whether use_dist_checkpointing should migrate into checkpoint config to work as first time loading option? Since API migration shall not involve this PR's changes, we will separate the feature development and the interface refactor, is it OK?

cc @ISEEKYAN @dataproblems @ccclyu

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks good to me.
More detail about mbridge, it will include model init, parameter reshard, save/load HF format, forward with seq_pack/fused kernel (to be added), and other potential improvement on megatron side as a solution from NV to use megatron in RL frameworks

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

current config LGTM. Long-term wise, if we migrate to mbridge, will use_dist_checkpointing be deprecated and it only loads hf format?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally I prefer use HF format in all lifetime of training.
But supporting dist_checkpointing or other formats like bytecheckpoint would make it more flexible if user is using a private pre-trained model. So the config might be like:

checkpoint:
    pre_load:    # first time load
        format: [hf, dist_ckpt, bytecheckpoint]   # hf default use_mbridge
    load_save:
        format: [hf, dist_ckpt, bytecheckpoint]

We would deprecate use_dist_checkpointing but keep it for a while and remind the user to use the new way. And we would update the example scripts to the new way.

profile: # profile the actor model in `update_policy`
use_profile: False # open it when you want to profile the actor model
profile_ranks: null # list, you can specify the ranks to profile
Expand Down Expand Up @@ -124,6 +125,7 @@ actor_rollout_ref:
dist_checkpointing_path: null
seed: ${actor_rollout_ref.actor.megatron.seed}
override_transformer_config: ${actor_rollout_ref.actor.megatron.override_transformer_config}
use_mbridge: ${actor_rollout_ref.actor.megatron.use_mbridge}
profile:
use_profile: False
profile_ranks: null
Expand Down Expand Up @@ -245,6 +247,7 @@ critic:
dist_checkpointing_path: null
seed: ${actor_rollout_ref.actor.megatron.seed}
override_transformer_config: ${actor_rollout_ref.actor.megatron.override_transformer_config}
use_mbridge: ${actor_rollout_ref.actor.megatron.use_mbridge}
load_weight: True
ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}
ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu
Expand Down Expand Up @@ -284,6 +287,7 @@ reward_model:
dist_checkpointing_path: null
seed: ${actor_rollout_ref.actor.megatron.seed}
override_transformer_config: {}
use_mbridge: ${actor_rollout_ref.actor.megatron.use_mbridge}
model:
input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical
path: ~/models/FsfairX-LLaMA3-RM-v0.1
Expand Down
53 changes: 33 additions & 20 deletions verl/utils/checkpoint/megatron_checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(
use_distributed_optimizer: bool,
use_checkpoint_opt_param_scheduler: bool = False,
checkpoint_contents: DictConfig = None,
bridge=None,
**kwargs,
):
super().__init__(
Expand All @@ -97,7 +98,7 @@ def __init__(
self.model_path = self.config.model.path
self.use_distributed_optimizer = use_distributed_optimizer
self.use_checkpoint_opt_param_scheduler = use_checkpoint_opt_param_scheduler

self.bridge = bridge
self.rank = torch.distributed.get_rank()

self.weight_saver = get_weight_saver(self.arch)
Expand Down Expand Up @@ -217,7 +218,11 @@ def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_afte
if local_path is None:
return

if self.should_load_model:
if self.should_load_model and self.bridge is not None and not self.is_value_model:
model_path = get_model_checkpoint_path(local_path)
self.bridge.load_weights(self.model, model_path)
log_with_rank(f"Loaded HF model checkpoint from {model_path} with bridge", rank=self.rank, logger=logger)
elif self.should_load_model:
model_path = get_model_checkpoint_path(local_path)
ckpt_name = self.get_checkpoint_name(model_path, return_base_dir=False)
state_dicts = torch.load(os.path.join(ckpt_name), weights_only=False)
Expand Down Expand Up @@ -260,7 +265,14 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i
local_path = self.local_mkdir(local_path)

# Save Model
if self.should_save_model and mpu.get_data_parallel_rank() == 0:
saved = False
if self.should_save_model and self.bridge is not None and not self.is_value_model:
log_with_rank(f"Saving HF model checkpoint to {local_path} with bridge", rank=self.rank, logger=logger)
model_ckpt_path = get_model_checkpoint_path(local_path)
self.bridge.save_weights(self.model, model_ckpt_path)
log_with_rank(f"Saved bridge checkpoint to {model_ckpt_path}", rank=self.rank, logger=logger)
saved = True
elif self.should_save_model and mpu.get_data_parallel_rank() == 0:
state_dicts = []

for vpp_rank, model in enumerate(self.model):
Expand All @@ -274,23 +286,24 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i
torch.save(state_dicts, os.path.join(ckpt_name))

log_with_rank(f"Saved checkpoint to {model_ckpt_path}", rank=self.rank, logger=logger)
if self.rank == 0:
self.processing_class.save_pretrained(hf_config_and_tokenizer_path)
self.hf_config.save_pretrained(hf_config_and_tokenizer_path)
if hasattr(self.hf_config, "name_or_path") and self.hf_config.name_or_path:
try:
generation_config = GenerationConfig.from_pretrained(self.hf_config.name_or_path)
generation_config.save_pretrained(hf_config_and_tokenizer_path)
except Exception:
# if the generation config isn't available, we don't save it
pass
if hdfs_path is not None:
log_with_rank(f"Uploading checkpoint to {hdfs_path}", rank=self.rank, logger=logger)
from verl.utils import hdfs_io

hdfs_io.makedirs(hdfs_path, exist_ok=True)
hdfs_io.copy(src=model_ckpt_path, dst=hdfs_path, dirs_exist_ok=True)
hdfs_io.copy(src=hf_config_and_tokenizer_path, dst=hdfs_path, dirs_exist_ok=True)
saved = True
if self.rank == 0 and saved:
self.processing_class.save_pretrained(hf_config_and_tokenizer_path)
self.hf_config.save_pretrained(hf_config_and_tokenizer_path)
if hasattr(self.hf_config, "name_or_path") and self.hf_config.name_or_path:
try:
generation_config = GenerationConfig.from_pretrained(self.hf_config.name_or_path)
generation_config.save_pretrained(hf_config_and_tokenizer_path)
except Exception:
# if the generation config isn't available, we don't save it
pass
if hdfs_path is not None:
log_with_rank(f"Uploading checkpoint to {hdfs_path}", rank=self.rank, logger=logger)
from verl.utils import hdfs_io

hdfs_io.makedirs(hdfs_path, exist_ok=True)
hdfs_io.copy(src=model_ckpt_path, dst=hdfs_path, dirs_exist_ok=True)
hdfs_io.copy(src=hf_config_and_tokenizer_path, dst=hdfs_path, dirs_exist_ok=True)

if self.should_save_hf_model:
# wait for everyone to dump to local
Expand Down
11 changes: 11 additions & 0 deletions verl/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,17 @@ def _load_hf_model(config, model_config, is_value_model, local_cache_path):
return architectures, model, state_dict, is_value_model


def get_hf_model_path(config, local_cache_path="~/.cache/verl/rlhf"):
local_cache_path = os.path.expanduser(local_cache_path)
if config.model.path.startswith("hdfs:"):
from verl.utils.fs import copy_to_local

local_model_path = copy_to_local(src=config.model.path, cache_dir=local_cache_path, use_shm=config.model.get("use_shm", False))
else:
local_model_path = config.model.path
return local_model_path


def load_megatron_model_weights(config, model_config, parallel_model, params_dtype, is_value_model=False, local_cache_path="~/.cache/verl/rlhf"):
"""Load weights for verl customized model."""
architectures, model, state_dict, is_value_model = _load_hf_model(config, model_config, is_value_model, local_cache_path)
Expand Down
Loading