Skip to content
Prev Previous commit
Next Next commit
Merge branch 'main' into use_mbridge
  • Loading branch information
ISEEKYAN committed Jun 25, 2025
commit 8527ee1247a4b5ed9ddc42710fd8786a8d548e81
134 changes: 82 additions & 52 deletions verl/utils/checkpoint/megatron_checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def __init__(
optimizer_scheduler,
use_distributed_optimizer: bool,
use_checkpoint_opt_param_scheduler: bool = False,
checkpoint_contents: DictConfig = None,
use_dist_checkpointing: bool = True,
bridge=None,
**kwargs,
):
Expand All @@ -143,6 +143,7 @@ def __init__(
self.use_checkpoint_opt_param_scheduler = use_checkpoint_opt_param_scheduler
self.bridge = bridge
self.rank = torch.distributed.get_rank()
self.use_dist_checkpointing = use_dist_checkpointing

self.weight_saver = get_weight_saver(self.arch)

Expand Down Expand Up @@ -301,18 +302,21 @@ def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_afte
ckpt_dir=dist_checkpoint_path,
)

if self.should_load_model and self.bridge is not None and not self.is_value_model:
model_path = get_model_checkpoint_path(local_path)
self.bridge.load_weights(self.model, model_path)
log_with_rank(f"Loaded HF model checkpoint from {model_path} with bridge", rank=self.rank, logger=logger)
elif self.should_load_model:
model_path = get_model_checkpoint_path(local_path)
ckpt_name = self.get_checkpoint_name(model_path, return_base_dir=False)
state_dicts = torch.load(os.path.join(ckpt_name), weights_only=False)
assert len(state_dicts) == len(self.model), f"state_dicts length: {len(state_dicts)} mismatch with model length: {len(self.model)}"
for vpp_rank, (state_dict, model) in enumerate(zip(state_dicts, self.model)):
model.load_state_dict(state_dict)
log_with_rank(f"Loaded sharded model checkpoint from {model_path}", rank=self.rank, logger=logger)
if self.should_load_model and (not self.bridge or self.is_value_model or self.use_dist_checkpointing):
assert "model" in state_dict or any(f"model{vpp_rank}" in state_dict for vpp_rank in range(len(self.model))), f"Model state dict not found in {state_dict.keys()}. Please check the checkpoint file {local_path}."
for vpp_rank, model in enumerate(self.model):
if len(self.model) == 1:
model_state_dict = state_dict["model"]
else:
assert f"model{vpp_rank}" in state_dict, f"model{vpp_rank} not found in state_dict"
model_state_dict = state_dict[f"model{vpp_rank}"]
mpu.set_virtual_pipeline_model_parallel_rank(vpp_rank)
self.model[vpp_rank].load_state_dict(model_state_dict)
log_with_rank(f"Loaded sharded model checkpoint from {local_path}", rank=self.rank, logger=logger)
else:
hf_model_path = get_hf_model_checkpoint_path(local_path)
self.bridge.load_weights(self.model, hf_model_path)
log_with_rank(f"Loaded HF model checkpoint from {hf_model_path} with bridge", rank=self.rank, logger=logger)

if self.should_load_optimizer:
assert "optimizer" in state_dict, f"Optimizer state dict not found in {state_dict.keys()}. Please check the checkpoint file {local_path}."
Expand Down Expand Up @@ -351,48 +355,74 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i
local_path = local_mkdir_safe(local_path)
dist_checkpoint_path = get_dist_checkpoint_path(local_path)

# Save Model
saved = False
if self.should_save_model and self.bridge is not None and not self.is_value_model:

if not self.bridge or self.is_value_model or self.use_dist_checkpointing:
# Generate state dict for saving
state_dict = self.generate_state_dict()
log_with_rank(f"Generated state dict for saving: {state_dict.keys()}", rank=self.rank, logger=logger)
for vpp_rank, model in enumerate(self.model):
if len(self.model) > 1:
model_i_keys = state_dict[f"model{vpp_rank}"].keys()
log_with_rank(f"Generated state dict for saving: {model_i_keys}", rank=self.rank, logger=logger)
else:
log_with_rank(f"Generated state dict for saving: {state_dict['model'].keys()}", rank=self.rank, logger=logger)

# Start Async save if enabled
async_save_request = save_dist_checkpointing(
sharded_state_dict=state_dict,
ckpt_path=dist_checkpoint_path,
async_save=self.checkpoint_config.async_save,
)

# Synchronize all async save requests
if not self.checkpoint_config.async_save:
assert async_save_request is None, "Async save request should be None when not using async save."
torch.distributed.barrier()
else:
log_with_rank(f"Saving HF model checkpoint to {local_path} with bridge", rank=self.rank, logger=logger)
model_ckpt_path = get_model_checkpoint_path(local_path)
self.bridge.save_weights(self.model, model_ckpt_path)
log_with_rank(f"Saved bridge checkpoint to {model_ckpt_path}", rank=self.rank, logger=logger)
saved = True
elif self.should_save_model and mpu.get_data_parallel_rank() == 0:
state_dicts = []

# Start Async save if enabled
async_save_request = save_dist_checkpointing(
sharded_state_dict=state_dict,
ckpt_path=dist_checkpoint_path,
async_save=self.checkpoint_config.async_save,
)
hf_ckpt_path = get_hf_model_checkpoint_path(local_path)
self.bridge.save_weights(self.model, hf_ckpt_path)
log_with_rank(f"Saved bridge checkpoint to {hf_ckpt_path}", rank=self.rank, logger=logger)

# Synchronize all async save requests
if not self.checkpoint_config.async_save:
assert async_save_request is None, "Async save request should be None when not using async save."
torch.distributed.barrier()

log_with_rank(f"Saved checkpoint to {model_ckpt_path}", rank=self.rank, logger=logger)
saved = True
if self.rank == 0 and saved:
self.processing_class.save_pretrained(hf_config_and_tokenizer_path)
self.hf_config.save_pretrained(hf_config_and_tokenizer_path)
if hasattr(self.hf_config, "name_or_path") and self.hf_config.name_or_path:
try:
generation_config = GenerationConfig.from_pretrained(self.hf_config.name_or_path)
generation_config.save_pretrained(hf_config_and_tokenizer_path)
except Exception:
# if the generation config isn't available, we don't save it
pass
if hdfs_path is not None:
log_with_rank(f"Uploading checkpoint to {hdfs_path}", rank=self.rank, logger=logger)
from verl.utils import hdfs_io

hdfs_io.makedirs(hdfs_path, exist_ok=True)
hdfs_io.copy(src=model_ckpt_path, dst=hdfs_path, dirs_exist_ok=True)
hdfs_io.copy(src=hf_config_and_tokenizer_path, dst=hdfs_path, dirs_exist_ok=True)
if self.should_save_model:
# Only rank 0 saves the hf config and tokenizer to huggingface path
# No matter whether we save hf model or not
if self.rank == 0:
# Save tokenizer
hf_config_tokenizer_path = get_hf_model_checkpoint_path(local_path)
self.processing_class.save_pretrained(hf_config_tokenizer_path)
# Save huggingface config
self.hf_config.save_pretrained(hf_config_tokenizer_path)
if hasattr(self.hf_config, "name_or_path") and self.hf_config.name_or_path:
try:
generation_config = GenerationConfig.from_pretrained(self.hf_config.name_or_path)
generation_config.save_pretrained(hf_config_tokenizer_path)
except Exception:
# if the generation config isn't available, we don't save it
pass
log_with_rank(f"Saved Huggingface config and tokenizer to {hf_config_tokenizer_path}", rank=self.rank, logger=logger, log_only_rank_0=True)

if self.should_save_extra:
if self.rank == 0:
# Save transformer config
print(self.transformer_config)
transformer_config_dict = asdict(self.transformer_config)
to_convert_types = {torch.dtype: str, AttnBackend: str}
ignore_types = [Callable]
pop_keys = []
for key, value in transformer_config_dict.items():
if type(value) in to_convert_types:
transformer_config_dict[key] = to_convert_types[type(value)](value)
if type(value) in ignore_types:
pop_keys.append(key)
if callable(value):
pop_keys.append(key)
for key in pop_keys:
transformer_config_dict.pop(key)
transformer_config_path = get_transformer_config_checkpoint_path(local_path)
with open(transformer_config_path, "w") as f:
json.dump(transformer_config_dict, f, indent=2)

if self.should_save_hf_model:
# wait for everyone to dump to local
Expand Down
4 changes: 2 additions & 2 deletions verl/workers/megatron_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,8 +453,8 @@ def init_model(self):
optimizer_scheduler=self.actor_optimizer_scheduler,
use_distributed_optimizer=self.config.actor.megatron.use_distributed_optimizer,
use_checkpoint_opt_param_scheduler=self.config.actor.optim.use_checkpoint_opt_param_scheduler,
checkpoint_contents=self.config.actor.checkpoint,
bridge=self.bridge,
use_dist_checkpointing=self.config.megatron.use_dist_checkpointing,
)
get_torch_device().empty_cache()
log_gpu_memory_usage("After init_model finish", logger=logger)
Expand Down Expand Up @@ -804,8 +804,8 @@ def init_model(self):
optimizer_scheduler=self.critic_optimizer_scheduler,
use_distributed_optimizer=self.config.megatron.use_distributed_optimizer,
use_checkpoint_opt_param_scheduler=self.config.optim.use_checkpoint_opt_param_scheduler,
checkpoint_contents=self.config.checkpoint,
bridge=self.bridge,
use_dist_checkpointing=self.config.megatron.use_dist_checkpointing,
)

@register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO)
Expand Down
You are viewing a condensed version of this merge commit. You can view the full changes here.