Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions verl/utils/checkpoint/megatron_checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,11 @@ def get_checkpoint_name(
return os.path.join(common_path, basename)

def generate_state_dict(
self, generate_model: bool = True, generate_optimizer: bool = True, generate_extra: bool = True
self,
generate_model: bool = True,
generate_optimizer: bool = True,
generate_extra: bool = True,
is_loading: bool = False,
):
# For save dist checkpointing
state_dict = {}
Expand All @@ -252,7 +256,7 @@ def generate_state_dict(
# Optimizer State Dict
if generate_optimizer:
torch.distributed.barrier()
optimizer_sharded_states = self.optimizer.sharded_state_dict(state_dict)
optimizer_sharded_states = self.optimizer.sharded_state_dict(state_dict, is_loading=is_loading)
state_dict["optimizer"] = optimizer_sharded_states

if self.lr_scheduler is not None:
Expand Down Expand Up @@ -292,11 +296,20 @@ def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_afte
if local_path is not None:
assert os.path.exists(local_path), f"Checkpoint path {local_path} does not exist."

# For load optimizer dist_ckpt
import transformer_engine

torch.serialization.add_safe_globals([torch.optim.AdamW])
torch.serialization.add_safe_globals([transformer_engine.pytorch.optimizers.fused_adam.FusedAdam])

dist_checkpoint_path = get_dist_checkpoint_path(local_path)

# Get State Dict for loading
sharded_state_dict = self.generate_state_dict(
self.should_load_model and self.use_dist_checkpointing, self.should_load_optimizer, self.should_load_extra
self.should_load_model and self.use_dist_checkpointing,
self.should_load_optimizer,
self.should_load_extra,
is_loading=True,
)
log_with_rank(f"Generated state dict for loading: {sharded_state_dict.keys()}", rank=self.rank, logger=logger)

Expand Down
Loading