Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
try to refactor APIs
  • Loading branch information
ETOgaosion committed Jun 11, 2025
commit d36e018975099208ad5525de7c4dcda44015a896
256 changes: 111 additions & 145 deletions recipe/spin/fsdp_workers.py

Large diffs are not rendered by default.

3 changes: 1 addition & 2 deletions recipe/sppo/sppo_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,5 @@ def init_model(self):
optimizer=self.actor.actor_optimizer,
lr_scheduler=self.actor_lr_scheduler,
processing_class=self.processor if self.processor is not None else self.tokenizer,
checkpoint_load_contents=self.config.actor.checkpoint.load_contents,
checkpoint_save_contents=self.config.actor.checkpoint.save_contents,
checkpoint_contents=self.config.actor.checkpoint,
)
15 changes: 9 additions & 6 deletions verl/utils/checkpoint/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import random
import shutil
import tempfile
from typing import Optional, Union
from typing import Union

import numpy as np
import torch
import torch.distributed
from filelock import FileLock
from omegaconf import DictConfig
from transformers import PreTrainedTokenizer, ProcessorMixin

from verl.utils.device import is_cuda_available, is_npu_available
Expand All @@ -47,9 +49,10 @@ def __init__(
optimizer: torch.optim.Optimizer,
lr_scheduler: torch.optim.lr_scheduler.LRScheduler = None,
processing_class: Union[PreTrainedTokenizer, ProcessorMixin] = None,
checkpoint_load_contents: Optional[list] = None,
checkpoint_save_contents: Optional[list] = None,
checkpoint_contents: DictConfig = None,
):
checkpoint_load_contents = checkpoint_contents.get("load", None) if checkpoint_contents else None
checkpoint_save_contents = checkpoint_contents.get("save", None) if checkpoint_contents else None
if checkpoint_load_contents is None:
checkpoint_load_contents = ["model", "optimizer", "extra"]
if checkpoint_save_contents is None:
Expand Down Expand Up @@ -94,21 +97,21 @@ def should_save_hf_model(self) -> bool:
Returns True if 'hf_model' is in checkpoint_save_contents, indicating the model should be converted to hf model and saved.
"""
return "hf_model" in self.checkpoint_save_contents

@property
def should_load_model(self) -> bool:
"""
Returns True if 'model' is in checkpoint_load_contents, indicating the model state should be loaded.
"""
return "model" in self.checkpoint_load_contents

@property
def should_load_optimizer(self) -> bool:
"""
Returns True if 'optimizer' is in checkpoint_load_contents, indicating the optimizer state should be loaded.
"""
return "optimizer" in self.checkpoint_load_contents

@property
def should_load_extra(self) -> bool:
"""
Expand Down
14 changes: 6 additions & 8 deletions verl/utils/checkpoint/fsdp_checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import torch
import torch.distributed
from accelerate import init_empty_weights
from omegaconf import DictConfig
from torch.distributed.fsdp import FullStateDictConfig, ShardedOptimStateDictConfig, ShardedStateDictConfig, StateDictType
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from transformers import GenerationConfig, PreTrainedTokenizer, ProcessorMixin
Expand Down Expand Up @@ -50,10 +51,9 @@ class FSDPCheckpointManager(BaseCheckpointManager):
lr_scheduler (LRScheduler): Learning-rate scheduler.
processing_class (PreTrainedTokenizer or ProcessorMixin, optional):
Pre-/post-processing artifact handler.
checkpoint_load_contents (list[str], optional):
Components to load; must contain 'model'. Defaults to ['model', 'optimizer', 'extra'].
checkpoint_save_contents (list[str], optional):
Components to save; must contain 'model'. Defaults to ['model', 'optimizer', 'extra'].
checkpoint_contents DictConfig: Configuration for checkpoint contents.
- 'load': Components to load; must contain 'model'. Defaults to ['model', 'optimizer', 'extra'].
- 'save': Components to save; must contain 'model'. Defaults to ['model', 'optimizer', 'extra'].
"""

def __init__(
Expand All @@ -62,8 +62,7 @@ def __init__(
optimizer: Optional[torch.optim.Optimizer] = None,
lr_scheduler: Optional[torch.optim.lr_scheduler.LRScheduler] = None,
processing_class: Union[PreTrainedTokenizer, ProcessorMixin] = None,
checkpoint_load_contents: Optional[list] = None,
checkpoint_save_contents: Optional[list] = None,
checkpoint_contents: DictConfig = None,
**kwargs,
):
if processing_class is None:
Expand All @@ -76,8 +75,7 @@ def __init__(
optimizer,
lr_scheduler=lr_scheduler,
processing_class=processing_class,
checkpoint_load_contents=checkpoint_load_contents,
checkpoint_save_contents=checkpoint_save_contents,
checkpoint_contents=checkpoint_contents,
)

def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_after_load=False):
Expand Down
8 changes: 3 additions & 5 deletions verl/utils/checkpoint/megatron_checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
import logging
import os
import random
from typing import Optional

import numpy as np
import torch
import torch.distributed
from megatron.core import mpu, tensor_parallel
from megatron.core.dist_checkpointing.mapping import ShardedObject
from omegaconf import DictConfig
from transformers import GenerationConfig

from verl.models.weight_loader_registry import get_weight_saver
Expand Down Expand Up @@ -73,17 +73,15 @@ def __init__(
optimizer_scheduler,
use_distributed_optimizer: bool,
use_checkpoint_opt_param_scheduler: bool = False,
checkpoint_load_contents: Optional[list] = None,
checkpoint_save_contents: Optional[list] = None,
checkpoint_contents: DictConfig = None,
**kwargs,
):
super().__init__(
model,
optimizer=optimizer,
lr_scheduler=optimizer_scheduler,
processing_class=tokenizer,
checkpoint_load_contents=checkpoint_load_contents,
checkpoint_save_contents=checkpoint_save_contents,
checkpoint_contents=checkpoint_contents,
)
self.arch = arch
self.config = config
Expand Down
14 changes: 8 additions & 6 deletions verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,20 +563,23 @@ def init_model(self):
optimizer=self.actor.actor_optimizer,
lr_scheduler=self.actor_lr_scheduler,
processing_class=self.processor if self.processor is not None else self.tokenizer,
checkpoint_load_contents=self.config.actor.checkpoint.load_contents,
checkpoint_save_contents=self.config.actor.checkpoint.save_contents,
checkpoint_contents=self.config.actor.checkpoint,
)

if not self._is_actor and self._is_rollout:
# If ActorRolloutRefWorker is initialized as a standalone rollout,
# create a checkpoint manager for FSDP model to allow loading FSDP checkpoints for rollout.
from copy import deepcopy

checkpoint_contents = deepcopy(self.config.actor.checkpoint)
checkpoint_contents.load_contents = ["model"]
checkpoint_contents.save_contents = []
self.checkpoint_manager = FSDPCheckpointManager(
model=self.actor_module_fsdp,
optimizer=None,
lr_scheduler=None,
processing_class=self.processor if self.processor is not None else self.tokenizer,
checkpoint_load_contents=["model"],
checkpoint_save_contents=[],
checkpoint_contents=checkpoint_contents,
)

@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
Expand Down Expand Up @@ -1034,8 +1037,7 @@ def init_model(self):
optimizer=self.critic_optimizer,
lr_scheduler=self.critic_lr_scheduler,
processing_class=self.processor if self.processor is not None else self.tokenizer,
checkpoint_load_contents=self.config.checkpoint.load_contents,
checkpoint_save_contents=self.config.checkpoint.save_contents,
checkpoint_contents=self.config.checkpoint,
)

@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
Expand Down
3 changes: 1 addition & 2 deletions verl/workers/megatron_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,8 +406,7 @@ def init_model(self):
optimizer_scheduler=self.actor_optimizer_scheduler,
use_distributed_optimizer=self.config.actor.megatron.use_distributed_optimizer,
use_checkpoint_opt_param_scheduler=self.config.actor.optim.use_checkpoint_opt_param_scheduler,
checkpoint_load_contents=self.config.actor.checkpoint.load_contents,
checkpoint_save_contents=self.config.actor.checkpoint.save_contents,
checkpoint_contents=self.config.actor.checkpoint,
)
torch.cuda.empty_cache()
log_gpu_memory_usage("After init_model finish", logger=logger)
Expand Down
Loading