Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions verl/utils/checkpoint/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,34 @@ def __init__(
self.rank = torch.distributed.get_rank()
self.world_size = torch.distributed.get_world_size()

@property
def save_model(self) -> bool:
"""
Returns True if 'model' is in checkpoint_contents, indicating the model state should be loaded and saved.
"""
return "model" in self.checkpoint_contents

@property
def save_optimizer(self) -> bool:
"""
Returns True if 'optimizer' is in checkpoint_contents, indicating the optimizer state should be loaded and saved.
"""
return "optimizer" in self.checkpoint_contents

@property
def save_extra(self) -> bool:
"""
Returns True if 'extra' is in checkpoint_contents, indicating the extra state should be loaded and saved.
"""
return "extra" in self.checkpoint_contents

@property
def save_hf_model(self) -> bool:
"""
Returns True if 'hf_model' is in checkpoint_contents, indicating the model should be converted to hf model and saved.
"""
return "hf_model" in self.checkpoint_contents

def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_after_load: bool = False):
raise NotImplementedError

Expand Down
116 changes: 68 additions & 48 deletions verl/utils/checkpoint/fsdp_checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ class FSDPCheckpointManager(BaseCheckpointManager):
def __init__(
self,
model: FSDP,
optimizer: torch.optim.Optimizer,
lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
optimizer: Optional[torch.optim.Optimizer] = None,
lr_scheduler: Optional[torch.optim.lr_scheduler.LRScheduler] = None,
processing_class: Union[PreTrainedTokenizer, ProcessorMixin] = None,
checkpoint_contents: Optional[list] = None,
**kwargs,
Expand All @@ -63,7 +63,6 @@ def __init__(
assert "tokenizer" in kwargs, "tokenizer or processor must be provided"
warnings.warn("`tokenizer` is deprecated. use `processing_class` instead.", DeprecationWarning, stacklevel=2)
processing_class = kwargs.pop("tokenizer")
assert "model" in checkpoint_contents and "optimizer" in checkpoint_contents and "extra" in checkpoint_contents, f"FSDPCheckpointManager must include ['model', 'optimizer', 'extra'], got {checkpoint_contents}"

super().__init__(
model,
Expand All @@ -73,6 +72,16 @@ def __init__(
checkpoint_contents=checkpoint_contents,
)

assert self.save_model, f"FSDPCheckpointManager must include ['model'], got {self.checkpoint_contents}"
if self.save_optimizer:
assert optimizer is not None, "optimizer must be provided when checkpoint_contents includes ['optimizer']"

if self.optimizer is not None and not self.save_optimizer:
print("Warning: optimizer is managed by FSDPCheckpointManager, but 'optimizer' not in checkpoint_contents. optimizer state will not be saved or loaded.")

if self.lr_scheduler is not None and not self.save_extra:
print("Warning: lr_scheduler is managed by FSDPCheckpointManager, but 'extra' not in checkpoint_contents. lr_scheduler state will not be saved or loaded.")

def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_after_load=False):
"""
Load an FSDP checkpoint for this rank.
Expand All @@ -90,41 +99,48 @@ def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_afte
return

# every rank download its own checkpoint
remote_model_path = os.path.join(local_path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt")
remote_optim_path = os.path.join(local_path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt")
remote_extra_state_path = os.path.join(local_path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt")
print(f"[rank-{self.rank}]: Loading from {remote_model_path} and {remote_optim_path} and {remote_extra_state_path}")
local_model_path = copy_to_local(remote_model_path)
local_optim_path = copy_to_local(remote_optim_path)
local_extra_state_path = copy_to_local(remote_extra_state_path)

model_state_dict = torch.load(local_model_path, weights_only=False)
optimizer_state_dict = torch.load(local_optim_path, weights_only=False)
extra_state_dict = torch.load(local_extra_state_path, weights_only=False)

if del_local_after_load:
state_dict_cfg = ShardedStateDictConfig(offload_to_cpu=True if is_cuda_available else False) if self.save_model else None
optim_cfg = ShardedOptimStateDictConfig(offload_to_cpu=True if is_cuda_available else False) if self.save_optimizer else None
with get_fsdp_state_ctx(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg):
if self.save_model:
remote_model_path = os.path.join(local_path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt")
local_model_path = copy_to_local(remote_model_path)
model_state_dict = torch.load(local_model_path, weights_only=False)
self.model.load_state_dict(model_state_dict)
print(f"[rank-{self.rank}]: Loading model from {remote_model_path}")

if self.save_optimizer:
remote_optim_path = os.path.join(local_path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt")
local_optim_path = copy_to_local(remote_optim_path)
optimizer_state_dict = torch.load(local_optim_path, weights_only=False)
self.optimizer.load_state_dict(optimizer_state_dict)
print(f"[rank-{self.rank}]: Loading optimizer from {remote_optim_path}")

if self.save_extra:
remote_extra_state_path = os.path.join(local_path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt")
local_extra_state_path = copy_to_local(remote_extra_state_path)
extra_state_dict = torch.load(local_extra_state_path, weights_only=False)
# recover random state
if "rng" in extra_state_dict:
# 'rng' may not exist for backward compatibility
self.load_rng_state(extra_state_dict["rng"])
print(f"[rank-{self.rank}]: Loading rng from {remote_extra_state_path}")

lr_scheduler_state_dict = extra_state_dict["lr_scheduler"]
if lr_scheduler_state_dict is not None and self.lr_scheduler is not None:
self.lr_scheduler.load_state_dict(lr_scheduler_state_dict)
print(f"[rank-{self.rank}]: Loading lr_scheduler from {remote_extra_state_path}")

if self.rank == 0 and del_local_after_load:
try:
os.remove(local_model_path) if is_non_local(local_model_path) else None
os.remove(local_optim_path) if is_non_local(local_optim_path) else None
os.remove(local_extra_state_path) if is_non_local(local_extra_state_path) else None
except Exception as e:
print(f"[rank-{self.rank}]: remove local resume ckpt file after loading failed, exception {e} will be ignored")

lr_scheduler_state_dict = extra_state_dict["lr_scheduler"]

state_dict_cfg = ShardedStateDictConfig(offload_to_cpu=True if is_cuda_available else False)
optim_cfg = ShardedOptimStateDictConfig(offload_to_cpu=True if is_cuda_available else False)
with get_fsdp_state_ctx(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg):
self.model.load_state_dict(model_state_dict)
if self.optimizer is not None:
self.optimizer.load_state_dict(optimizer_state_dict)
# recover random state
if "rng" in extra_state_dict:
# 'rng' may not exist for backward compatibility
self.load_rng_state(extra_state_dict["rng"])

if self.lr_scheduler is not None:
self.lr_scheduler.load_state_dict(lr_scheduler_state_dict)
# wait for everyone to load checkpoints
torch.distributed.barrier()

def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: int = 0, max_ckpt_to_keep=None):
"""
Expand All @@ -150,8 +166,8 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i
# record the previous global step
self.previous_global_step = global_step

# remove previous local_path
if max_ckpt_to_keep and isinstance(max_ckpt_to_keep, int) and max_ckpt_to_keep > 0 and len(self.previous_saved_paths) >= max_ckpt_to_keep:
# remove previous local_path, only rank 0 should do this
if self.rank == 0 and max_ckpt_to_keep and isinstance(max_ckpt_to_keep, int) and max_ckpt_to_keep > 0 and len(self.previous_saved_paths) >= max_ckpt_to_keep:
keep_start = len(self.previous_saved_paths) - max_ckpt_to_keep + 1
self.remove_previous_save_local_path(self.previous_saved_paths[:keep_start])
self.previous_saved_paths = self.previous_saved_paths[keep_start:]
Expand All @@ -165,24 +181,28 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i
with warnings.catch_warnings():
warnings.simplefilter("ignore")
with get_fsdp_state_ctx(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg):
model_state_dict = self.model.state_dict()
optimizer_state_dict = self.optimizer.state_dict() if self.optimizer is not None else None
lr_scheduler_state_dict = self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None

extra_state_dict = {
"lr_scheduler": lr_scheduler_state_dict,
"rng": self.get_rng_state(),
}
model_path = os.path.join(local_path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt")
optim_path = os.path.join(local_path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt")
extra_path = os.path.join(local_path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt")

print(f"[rank-{self.rank}]: Saving model to {os.path.abspath(model_path)}")
print(f"[rank-{self.rank}]: Saving optim to {os.path.abspath(optim_path)}")
print(f"[rank-{self.rank}]: Saving extra_state to {os.path.abspath(extra_path)}")
torch.save(model_state_dict, model_path)
torch.save(optimizer_state_dict, optim_path) # TODO: address optimizer is None
torch.save(extra_state_dict, extra_path)
if self.save_model:
model_state_dict = self.model.state_dict()
torch.save(model_state_dict, model_path)
print(f"[rank-{self.rank}]: Saving model to {os.path.abspath(model_path)}")

if self.save_optimizer:
optimizer_state_dict = self.optimizer.state_dict()
torch.save(optimizer_state_dict, optim_path)
print(f"[rank-{self.rank}]: Saving optim to {os.path.abspath(optim_path)}")

if self.save_extra:
lr_scheduler_state_dict = self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None
extra_state_dict = {
"lr_scheduler": lr_scheduler_state_dict,
"rng": self.get_rng_state(),
}
torch.save(extra_state_dict, extra_path)
print(f"[rank-{self.rank}]: Saving extra_state to {os.path.abspath(extra_path)}")

if self.rank == 0:
if fsdp_version(self.model) == 1:
Expand All @@ -205,7 +225,7 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i
# wait for everyone to dump to local
torch.distributed.barrier()

if "hf_model" in self.checkpoint_contents:
if self.save_hf_model:
hf_local_path = os.path.join(local_path, "huggingface")
os.makedirs(hf_local_path, exist_ok=True)

Expand Down
14 changes: 7 additions & 7 deletions verl/utils/checkpoint/megatron_checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_afte
if local_path is None:
return

if "model" in self.checkpoint_contents:
if self.save_model:
Copy link
Collaborator

@ETOgaosion ETOgaosion May 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find it a little strange to use save_xxx to control the behavior of loading checkpoints, megatron has a use_checkpoint_opt_param_scheduler or override_opt_param_scheduler to control optimizer scheduler loading process, can you design a new mechanism?

What about divide the self.checkpoint_contents to load/save?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @ETOgaosion, sorry for the late update, just go through a busy week. Did you mean use two list to control which content to load/save? like remove the current checkpoint_contents and introduce like checkpoint_load_contents and checkpoint_save_contents?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, maybe it's better to achieve finer-grained and flexible checkpoint choice

Copy link
Collaborator

@ETOgaosion ETOgaosion Jun 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe our default API like this:

checkpoint_contents:
    save: [...]
    load: ${(...).checkpoint_contents.save}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay, I understand

model_path = get_model_checkpoint_path(local_path)
ckpt_name = self.get_checkpoint_name(model_path, return_base_dir=False)
state_dicts = torch.load(os.path.join(ckpt_name), weights_only=False)
Expand All @@ -208,10 +208,10 @@ def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_afte
model.load_state_dict(state_dict)
print(f"Loaded sharded model checkpoint from {model_path}")

if "optimizer" in self.checkpoint_contents:
if self.save_optimizer:
self.load_optimizer(local_path)

if "extra" in self.checkpoint_contents:
if self.save_extra:
self.load_rng_states(local_path)

if del_local_after_load:
Expand All @@ -233,7 +233,7 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i
local_path = self.local_mkdir(local_path)

# Save Model
if "model" in self.checkpoint_contents and mpu.get_data_parallel_rank() == 0:
if self.save_model and mpu.get_data_parallel_rank() == 0:
state_dicts = []

for vpp_rank, model in enumerate(self.model):
Expand Down Expand Up @@ -265,7 +265,7 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i
hdfs_io.copy(src=model_ckpt_path, dst=hdfs_path, dirs_exist_ok=True)
hdfs_io.copy(src=hf_config_and_tokenizer_path, dst=hdfs_path, dirs_exist_ok=True)

if "hf_model" in self.checkpoint_contents:
if self.save_hf_model:
# wait for everyone to dump to local
state_dict = self.weight_saver(
self.model,
Expand Down Expand Up @@ -307,7 +307,7 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i
hdfs_io.copy(src=hf_model_ckpt_path, dst=hdfs_path, dirs_exist_ok=True)

# Save Optimizer
if "optimizer" in self.checkpoint_contents:
if self.save_optimizer:
torch.distributed.barrier()

optimizer_path = get_optimizer_checkpoint_path(local_path)
Expand All @@ -316,7 +316,7 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i
print(f"saving optimizer state to {optimizer_path}")

# Save RNG States
if "extra" in self.checkpoint_contents:
if self.save_extra:
torch.distributed.barrier()

rng_state_path = get_rng_states_checkpoint_path(local_path, only_rank0_save=False)
Expand Down
15 changes: 14 additions & 1 deletion verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,7 @@ def init_model(self):
if self._is_offload_optimizer:
offload_fsdp_optimizer(optimizer=self.actor_optimizer)
log_gpu_memory_usage("After offload actor optimizer during init", logger=logger)
# load from checkpoint

if self._is_actor:
OmegaConf.set_struct(self.config.actor, True)
with open_dict(self.config.actor):
Expand Down Expand Up @@ -637,6 +637,17 @@ def init_model(self):
checkpoint_contents=self.config.actor.checkpoint.contents,
)

if not self._is_actor and self._is_rollout:
# If ActorRolloutRefWorker is initialized as a standalone rollout,
# create a checkpoint manager for FSDP model to allow loading FSDP checkpoints for rollout.
self.checkpoint_manager = FSDPCheckpointManager(
model=self.actor_module_fsdp,
optimizer=None,
lr_scheduler=None,
processing_class=self.processor if self.processor is not None else self.tokenizer,
checkpoint_contents=["model"],
)

@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def update_actor(self, data: DataProto):
# Support all hardwares
Expand Down Expand Up @@ -835,6 +846,8 @@ def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to

@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=False):
assert self._is_actor or (not self._is_actor and self._is_rollout)

if self._is_offload_param:
load_fsdp_model_to_gpu(self.actor_module_fsdp)

Expand Down