diff --git a/docs/experiment/ppo.rst b/docs/experiment/ppo.rst index 679c663e033..942cca54b62 100644 --- a/docs/experiment/ppo.rst +++ b/docs/experiment/ppo.rst @@ -27,6 +27,7 @@ NVIDIA GPUs .. _Qwen0.5b PRIME Script: https://github.com/volcengine/verl/blob/main/recipe/prime/run_prime_qwen.sh .. _Qwen0.5b PRIME Wandb: https://api.wandb.ai/links/zefan-wang-thu-tsinghua-university/rxd1btvb .. _Megatron Qwen2 7b GRPO Script with Math and GSM8k: https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/qwen2-7b_math_megatron.log +.. _Qwen7b GRPO FSDP2 Script and Logs: https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/qwen2-7b-fsdp2.log +----------------------------------+------------------------+------------+-----------------------------------------------------------------------------------------------+ | Model | Method | Test score | Details | @@ -47,6 +48,8 @@ NVIDIA GPUs +----------------------------------+------------------------+------------+-----------------------------------------------------------------------------------------------+ | Qwen/Qwen2-7B-Instruct | GRPO | 89 | `Qwen7b GRPO Script`_ | +----------------------------------+------------------------+------------+-----------------------------------------------------------------------------------------------+ +| Qwen/Qwen2-7B-Instruct | GRPO (FSDP2) | 89.8 | `_Qwen7b GRPO FSDP2 Script and Logs`_ | ++----------------------------------+------------------------+------------+-----------------------------------------------------------------------------------------------+ | Qwen/Qwen2-7B-Instruct | GRPO (Megatron) | 89.6 | `Megatron Qwen2 7b GRPO Script with Math and GSM8k`_ | +----------------------------------+------------------------+------------+-----------------------------------------------------------------------------------------------+ | Qwen/Qwen2.5-7B-Instruct | ReMax | 97 | `Qwen7b ReMax Script`_, `Qwen7b ReMax Wandb`_ | diff --git a/tests/checkpoint/test_fsdp_ckpt.py b/tests/checkpoint/test_fsdp_ckpt.py index e1d6deeee21..3560da3a4f6 100644 --- a/tests/checkpoint/test_fsdp_ckpt.py +++ b/tests/checkpoint/test_fsdp_ckpt.py @@ -24,9 +24,10 @@ from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager from verl.utils.distributed import initialize_global_process_group +from verl.utils.fsdp_utils import MixedPrecisionPolicy, apply_fsdp2, fully_shard -def test_fsdp_ckpt(): +def test_fsdp_ckpt(strategy="fsdp"): assert torch.cuda.device_count() >= 2, "need at least 2 gpus for test" local_rank, rank, world_size = initialize_global_process_group() device_mesh = init_device_mesh("cuda", mesh_shape=(world_size,), mesh_dim_names=("dp",)) @@ -39,16 +40,24 @@ def test_fsdp_ckpt(): model = model.to(device="cuda") # Wrap model with FSDP - mixed_precision = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32) - - model = FSDP( - model, - use_orig_params=False, - device_id=torch.cuda.current_device(), - sharding_strategy=ShardingStrategy.FULL_SHARD, - mixed_precision=mixed_precision, - device_mesh=device_mesh, - ) + if strategy == "fsdp": + mixed_precision = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32) + + model = FSDP( + model, + use_orig_params=False, + device_id=torch.cuda.current_device(), + sharding_strategy=ShardingStrategy.FULL_SHARD, + mixed_precision=mixed_precision, + device_mesh=device_mesh, + ) + else: + mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, cast_forward_inputs=True) + fsdp_kwargs = { + "mesh": device_mesh, + "mp_policy": mp_policy, + } + apply_fsdp2(model, fsdp_kwargs, {}) optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9) @@ -116,7 +125,12 @@ def test_fsdp_ckpt(): # Cleanup shutil.rmtree(temp_dir) torch.distributed.barrier() + torch.distributed.destroy_process_group() if __name__ == "__main__": test_fsdp_ckpt() + if fully_shard is not None: + print("begin to test fsdp2") + test_fsdp_ckpt(strategy="fsdp2") + print("test fsdp2 passed!") diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index 1b6668dce8a..4af2490a2e9 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -30,7 +30,7 @@ actor_rollout_ref: use_remove_padding: False use_liger: False actor: - strategy: fsdp # This is for backward-compatibility + strategy: fsdp # [fsdp, fsdp2], This is for backward-compatibility ppo_mini_batch_size: 256 ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu ppo_micro_batch_size_per_gpu: null @@ -67,11 +67,14 @@ actor_rollout_ref: min_num_params: 0 param_offload: False optimizer_offload: False + offload_policy: False # only for fsdp2, offload param\grad\optimizer during train + reshard_after_forward: True # only for fsdp2, [True, False, int between 1 and fsdp_size] fsdp_size: -1 ref: strategy: fsdp fsdp_config: param_offload: False + reshard_after_forward: True # only for fsdp2, [True, False, int between 1 and fsdp_size] wrap_policy: # transformer_layer_cls_to_wrap: None min_num_params: 0 @@ -129,7 +132,7 @@ actor_rollout_ref: critic: rollout_n: ${actor_rollout_ref.rollout.n} - strategy: fsdp + strategy: fsdp # [fsdp, fsdp2] optim: lr: 1e-5 lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime @@ -147,6 +150,8 @@ critic: fsdp_config: param_offload: False optimizer_offload: False + offload_policy: False # only for fsdp2, offload param\grad\optimizer during train + reshard_after_forward: True # only for fsdp2, [True, False, int between 1 and fsdp_size] wrap_policy: # transformer_layer_cls_to_wrap: None min_num_params: 0 @@ -179,6 +184,7 @@ reward_model: wrap_policy: min_num_params: 0 param_offload: False + reshard_after_forward: True # only for fsdp2, [True, False, int between 1 and fsdp_size] fsdp_size: -1 micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu micro_batch_size_per_gpu: null # set a number diff --git a/verl/trainer/main_ppo.py b/verl/trainer/main_ppo.py index fab382a37bb..72c9846beab 100644 --- a/verl/trainer/main_ppo.py +++ b/verl/trainer/main_ppo.py @@ -103,8 +103,8 @@ def run(self, config): processor = hf_processor(local_path, use_fast=True) # used for multimodal LLM, could be none # define worker classes - if config.actor_rollout_ref.actor.strategy == "fsdp": - assert config.actor_rollout_ref.actor.strategy == config.critic.strategy + if config.actor_rollout_ref.actor.strategy in ["fsdp", "fsdp2"]: + assert config.critic.strategy in ["fsdp", "fsdp2"] from verl.single_controller.ray import RayWorkerGroup from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, CriticWorker @@ -145,7 +145,7 @@ def run(self, config): # - finally, we combine all the rewards together # - The reward type depends on the tag of the data if config.reward_model.enable: - if config.reward_model.strategy == "fsdp": + if config.reward_model.strategy in ["fsdp", "fsdp2"]: from verl.workers.fsdp_workers import RewardModelWorker elif config.reward_model.strategy == "megatron": from verl.workers.megatron_workers import RewardModelWorker diff --git a/verl/utils/checkpoint/fsdp_checkpoint_manager.py b/verl/utils/checkpoint/fsdp_checkpoint_manager.py index 9778390ed72..a3e4303e6c7 100644 --- a/verl/utils/checkpoint/fsdp_checkpoint_manager.py +++ b/verl/utils/checkpoint/fsdp_checkpoint_manager.py @@ -23,6 +23,7 @@ from transformers import PreTrainedTokenizer, ProcessorMixin from verl.utils.fs import copy_to_local, is_non_local +from verl.utils.fsdp_utils import fsdp_version, get_fsdp_state_ctx from .checkpoint_manager import BaseCheckpointManager @@ -96,7 +97,7 @@ def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_afte state_dict_cfg = ShardedStateDictConfig(offload_to_cpu=True) optim_cfg = ShardedOptimStateDictConfig(offload_to_cpu=True) - with FSDP.state_dict_type(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg): + with get_fsdp_state_ctx(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg): self.model.load_state_dict(model_state_dict) if self.optimizer is not None: self.optimizer.load_state_dict(optimizer_state_dict) @@ -129,7 +130,7 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i optim_cfg = ShardedOptimStateDictConfig(offload_to_cpu=True) with warnings.catch_warnings(): warnings.simplefilter("ignore") - with FSDP.state_dict_type(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg): + with get_fsdp_state_ctx(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg): model_state_dict = self.model.state_dict() optimizer_state_dict = self.optimizer.state_dict() if self.optimizer is not None else None lr_scheduler_state_dict = self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None @@ -153,11 +154,14 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i # wait for everyone to dump to local torch.distributed.barrier() - if self.rank == 0: - hf_local_path = os.path.join(local_path, "huggingface") - os.makedirs(hf_local_path, exist_ok=True) + if self.rank == 0: + hf_local_path = os.path.join(local_path, "huggingface") + os.makedirs(hf_local_path, exist_ok=True) + if fsdp_version(self.model) == 1: self.model._fsdp_wrapped_module.config.save_pretrained(hf_local_path) - self.processing_class.save_pretrained(hf_local_path) + else: + self.model.config.save_pretrained(hf_local_path) + self.processing_class.save_pretrained(hf_local_path) torch.distributed.barrier() diff --git a/verl/utils/fsdp_utils.py b/verl/utils/fsdp_utils.py index 310b5527dad..6d241611b0b 100644 --- a/verl/utils/fsdp_utils.py +++ b/verl/utils/fsdp_utils.py @@ -17,18 +17,26 @@ import json import math import os -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext from typing import Dict import torch import torch.distributed as dist import torch.nn as nn +from packaging import version from torch.distributed import DeviceMesh from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp._runtime_utils import _lazy_init from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy from transformers.trainer_pt_utils import get_module_class_from_name +if version.parse(torch.__version__) >= version.parse("2.6"): + from torch.distributed.fsdp import CPUOffloadPolicy, FSDPModule, MixedPrecisionPolicy, fully_shard +elif version.parse(torch.__version__) >= version.parse("2.4"): + from torch.distributed._composable.fsdp import CPUOffloadPolicy, FSDPModule, MixedPrecisionPolicy, fully_shard +else: + fully_shard, MixedPrecisionPolicy, FSDPModule, CPUOffloadPolicy = None, None, None, None + def init_fn(x: torch.nn.Module): if torch.distributed.get_rank() != 0: @@ -111,6 +119,10 @@ def lambda_policy_fn(module): @torch.no_grad() def offload_fsdp_model_to_cpu(model: FSDP, empty_cache: bool = True): + if fsdp_version(model) == 2: + offload_fsdp2_model_to_cpu(model, empty_cache) + return + assert isinstance(model, FSDP) # lazy init FSDP model _lazy_init(model, model) @@ -128,8 +140,20 @@ def offload_fsdp_model_to_cpu(model: FSDP, empty_cache: bool = True): torch.cuda.empty_cache() +@torch.no_grad() +def offload_fsdp2_model_to_cpu(model, empty_cache: bool = True): + for param in model.parameters(): + param.data = param.data.to(torch.device("cpu"), non_blocking=True) + if empty_cache: + torch.cuda.empty_cache() + + @torch.no_grad() def load_fsdp_model_to_gpu(model: FSDP): + if fsdp_version(model) == 2: + load_fsdp2_model_to_gpu(model) + return + assert isinstance(model, FSDP) # lazy init FSDP model _lazy_init(model, model) @@ -144,6 +168,13 @@ def load_fsdp_model_to_gpu(model: FSDP): flat_param._local_shard = flat_param.data +@torch.no_grad() +def load_fsdp2_model_to_gpu(model): + device = torch.cuda.current_device() + for param in model.parameters(): + param.data = param.data.to(device, non_blocking=True) + + @torch.no_grad() def offload_fsdp_optimizer(optimizer): if not optimizer.state: @@ -333,3 +364,88 @@ def init_fn(sub_mod: torch.nn.Module, recurse: bool = True): return sub_mod return init_fn + + +def fsdp_version(model): + if isinstance(model, FSDP): + return 1 + elif isinstance(model, FSDPModule): + return 2 + else: + return 0 + + +def get_fsdp_state_ctx(model, state_type, state_cfg, optim_cfg): + if fsdp_version(model) == 1: + return FSDP.state_dict_type(model, state_type, state_cfg, optim_cfg) + else: + return nullcontext() + + +def fsdp2_load_full_state_dict(model: torch.nn.Module, full_state: dict, device_mesh=None, cpu_offload=None): + """ + Loads the full state dict (could be only on rank 0) into the sharded model. This is done by broadcasting the + parameters from rank 0 to all other ranks. This function modifies the model in-place. + + Args: + model (`torch.nn.Module`): The model to load the state dict into + full_state (`dict`): The full state dict to load, can only be on rank 0 + """ + from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict + + # To broadcast, it needs to be instantiated in the GPU. + if dist.get_rank() == 0: + model = model.to(device=torch.cuda.current_device(), non_blocking=True) + else: + model = model.to_empty(device=torch.cuda.current_device()) + + cpu_offload = cpu_offload is not None + options = StateDictOptions(full_state_dict=True, cpu_offload=cpu_offload, broadcast_from_rank0=True) + set_model_state_dict(model, full_state, options=options) + + # rotary_emb is not in state_dict, so we need to broadcast it manually + for name, buf in model.named_buffers(): + dist.broadcast(buf, src=0) + + if cpu_offload: + model.to("cpu", non_blocking=True) + for buf in model.buffers(): + buf.data = buf.data.to(torch.cuda.current_device()) + + +def apply_fsdp2(model, fsdp_kwargs, config): + """model: AutoModelForCausalLM""" + assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)" + + default_transformer_cls_names_to_wrap = getattr(model, "_no_split_modules", None) + fsdp_transformer_layer_cls_to_wrap = config.get("wrap_policy", {}).get("transformer_layer_cls_to_wrap", default_transformer_cls_names_to_wrap) + + if isinstance(fsdp_transformer_layer_cls_to_wrap, str): + fsdp_transformer_layer_cls_to_wrap = [fsdp_transformer_layer_cls_to_wrap] + + assert len(fsdp_transformer_layer_cls_to_wrap) > 0 and fsdp_transformer_layer_cls_to_wrap[0] is not None + + modules = [] + for name, module in model.named_modules(): + if module.__class__.__name__ in fsdp_transformer_layer_cls_to_wrap or (isinstance(module, nn.Embedding) and not model.config.tie_word_embeddings): + modules.append(module) + + for idx, module in enumerate(modules): + fully_shard(module, **fsdp_kwargs) + fully_shard(model, **fsdp_kwargs) # fsdp2 will not reshard_after_forward for root module + + +def fsdp2_clip_grad_norm_(parameters, max_norm, norm_type=2.0, error_if_nonfinite=False, foreach=None): + """torch.nn.utils.clip_grad_norm_ cann't run on cpu parameter DTensor""" + from torch.nn.utils.clip_grad import _clip_grads_with_norm_, _get_total_norm + + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + else: + # prevent generators from being exhausted + parameters = list(parameters) + grads = [p.grad for p in parameters if p.grad is not None] + total_norm = _get_total_norm(grads, norm_type, error_if_nonfinite, foreach) + total_norm = total_norm.to(torch.cuda.current_device(), non_blocking=True) + _clip_grads_with_norm_(parameters, max_norm, total_norm, foreach) + return total_norm diff --git a/verl/workers/actor/dp_actor.py b/verl/workers/actor/dp_actor.py index 3b26d81fe6d..b6d970d3a48 100644 --- a/verl/workers/actor/dp_actor.py +++ b/verl/workers/actor/dp_actor.py @@ -31,6 +31,7 @@ from verl import DataProto from verl.trainer.ppo.core_algos import agg_loss, compute_policy_loss, kl_penalty from verl.utils.debug import GPUMemoryLogger +from verl.utils.fsdp_utils import FSDPModule, fsdp2_clip_grad_norm_ from verl.utils.py_functional import append_to_dict from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches from verl.utils.torch_functional import logprobs_from_logits @@ -161,6 +162,8 @@ def _optimizer_step(self): if isinstance(self.actor_module, FSDP): grad_norm = self.actor_module.clip_grad_norm_(max_norm=self.config.grad_clip) + elif isinstance(self.actor_module, FSDPModule): + grad_norm = fsdp2_clip_grad_norm_(self.actor_module.parameters(), max_norm=self.config.grad_clip) else: grad_norm = torch.nn.utils.clip_grad_norm_(self.actor_module.parameters(), max_norm=self.config.grad_clip) diff --git a/verl/workers/critic/dp_critic.py b/verl/workers/critic/dp_critic.py index f87dc3a9cf4..3e63deb0850 100644 --- a/verl/workers/critic/dp_critic.py +++ b/verl/workers/critic/dp_critic.py @@ -28,6 +28,7 @@ from verl import DataProto from verl.trainer.ppo import core_algos from verl.utils.debug import GPUMemoryLogger +from verl.utils.fsdp_utils import FSDPModule, fsdp2_clip_grad_norm_ from verl.utils.py_functional import append_to_dict from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches from verl.utils.torch_functional import masked_mean @@ -114,6 +115,8 @@ def _optimizer_step(self): if isinstance(self.critic_module, FSDP): grad_norm = self.critic_module.clip_grad_norm_(self.config.grad_clip) + elif isinstance(self.critic_module, FSDPModule): + grad_norm = fsdp2_clip_grad_norm_(self.critic_module.parameters(), max_norm=self.config.grad_clip) else: grad_norm = torch.nn.utils.clip_grad_norm_(self.critic_module.parameters(), max_norm=self.config.grad_clip) diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index c1b501c0cb9..0ed28e6f57b 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -37,6 +37,11 @@ from verl.utils.flops_counter import FlopsCounter from verl.utils.fs import copy_to_local from verl.utils.fsdp_utils import ( + CPUOffloadPolicy, + MixedPrecisionPolicy, + apply_fsdp2, + fsdp2_load_full_state_dict, + fsdp_version, get_fsdp_wrap_policy, get_init_weight_context_manager, init_fn, @@ -260,19 +265,43 @@ def _build_model_optimizer( # We force reference policy to use CPUOffload to save memory. # We force turn off CPUOffload for actor because it causes incorrect results when using grad accumulation cpu_offload = None if role == "actor" else CPUOffload(offload_params=True) - actor_module_fsdp = FSDP( - actor_module, - cpu_offload=cpu_offload, - param_init_fn=init_fn, - use_orig_params=False, - auto_wrap_policy=auto_wrap_policy, - device_id=torch.cuda.current_device(), - sharding_strategy=sharding_strategy, # zero3 - mixed_precision=mixed_precision, - sync_module_states=True, - device_mesh=self.device_mesh, - forward_prefetch=False, - ) + fsdp_strategy = self.config.actor.strategy + if fsdp_strategy == "fsdp": + actor_module_fsdp = FSDP( + actor_module, + cpu_offload=cpu_offload, + param_init_fn=init_fn, + use_orig_params=False, + auto_wrap_policy=auto_wrap_policy, + device_id=torch.cuda.current_device(), + sharding_strategy=sharding_strategy, # zero3 + mixed_precision=mixed_precision, + sync_module_states=True, + device_mesh=self.device_mesh, + forward_prefetch=False, + ) + elif fsdp_strategy == "fsdp2": + assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)" + mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype, cast_forward_inputs=True) + if role == "actor" and fsdp_config.offload_policy: + cpu_offload = CPUOffloadPolicy(pin_memory=True) + self._is_offload_param = False + self._is_offload_optimizer = False + else: + cpu_offload = None if role == "actor" else CPUOffloadPolicy(pin_memory=True) + + fsdp_kwargs = { + "mesh": fsdp_mesh, + "mp_policy": mp_policy, + "offload_policy": cpu_offload, + "reshard_after_forward": fsdp_config.reshard_after_forward, + } + full_state = actor_module.state_dict() + apply_fsdp2(actor_module, fsdp_kwargs, fsdp_config) + fsdp2_load_full_state_dict(actor_module, full_state, fsdp_mesh, cpu_offload) + actor_module_fsdp = actor_module + else: + raise NotImplementedError(f"not implement {fsdp_strategy}") log_gpu_memory_usage(f"After {role} FSDP init", logger=logger) @@ -463,7 +492,8 @@ def init_model(self): ) # get the original unwrapped module - self.actor_module = self.actor_module_fsdp._fsdp_wrapped_module + if fsdp_version(self.actor_module_fsdp) == 1: + self.actor_module = self.actor_module_fsdp._fsdp_wrapped_module if self._is_offload_param: offload_fsdp_model_to_cpu(self.actor_module_fsdp) @@ -614,7 +644,7 @@ def compute_log_prob(self, data: DataProto): # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes # unshard the root FSDP module - if self.world_size > 1: + if self.world_size > 1 and fsdp_version(self.actor.actor_module) == 1: self.actor.actor_module._handle.reshard(True) if self._is_offload_param: @@ -645,7 +675,7 @@ def compute_ref_log_prob(self, data: DataProto): # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes # unshard the root FSDP module - if self.world_size > 1: + if self.world_size > 1 and fsdp_version(self.ref_policy.actor_module) == 1: self.ref_policy.actor_module._handle.reshard(True) return output @@ -809,19 +839,40 @@ def _build_critic_model_optimizer(self, config): sharding_strategy = get_sharding_strategy(fsdp_mesh) # Note: We force turn off CPUOffload for critic because it causes incorrect results when using grad accumulation - critic_module = FSDP( - critic_module, - param_init_fn=init_fn, - use_orig_params=False, - auto_wrap_policy=auto_wrap_policy, - device_id=torch.cuda.current_device(), - sharding_strategy=sharding_strategy, - mixed_precision=mixed_precision, - sync_module_states=True, - forward_prefetch=False, - device_mesh=self.device_mesh, - cpu_offload=None, - ) + if config.strategy == "fsdp": + critic_module = FSDP( + critic_module, + param_init_fn=init_fn, + use_orig_params=False, + auto_wrap_policy=auto_wrap_policy, + device_id=torch.cuda.current_device(), + sharding_strategy=sharding_strategy, + mixed_precision=mixed_precision, + sync_module_states=True, + forward_prefetch=False, + device_mesh=self.device_mesh, + cpu_offload=None, + ) + elif config.strategy == "fsdp2": + assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)" + mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype, cast_forward_inputs=True) + offload_policy = None + if fsdp_config.offload_policy: + self._is_offload_param = False + self._is_offload_optimizer = False + offload_policy = CPUOffloadPolicy(pin_memory=True) + + fsdp_kwargs = { + "mesh": fsdp_mesh, + "mp_policy": mp_policy, + "offload_policy": offload_policy, + "reshard_after_forward": fsdp_config.reshard_after_forward, + } + full_state = critic_module.state_dict() + apply_fsdp2(critic_module, fsdp_kwargs, fsdp_config) + fsdp2_load_full_state_dict(critic_module, full_state, fsdp_mesh, offload_policy) + else: + raise NotImplementedError(f"Unknown strategy {config.strategy}") log_gpu_memory_usage("After critic FSDP", logger=None) @@ -1051,19 +1102,32 @@ def _build_model(self, config): fsdp_mesh = self.device_mesh sharding_strategy = get_sharding_strategy(fsdp_mesh) - reward_module = FSDP( - reward_module, - param_init_fn=init_fn, - use_orig_params=False, - auto_wrap_policy=auto_wrap_policy, - device_id=torch.cuda.current_device(), - sharding_strategy=sharding_strategy, # zero3 - sync_module_states=True, - cpu_offload=CPUOffload(offload_params=True), - forward_prefetch=False, - device_mesh=self.device_mesh, - ) - + if config.strategy == "fsdp": + reward_module = FSDP( + reward_module, + param_init_fn=init_fn, + use_orig_params=False, + auto_wrap_policy=auto_wrap_policy, + device_id=torch.cuda.current_device(), + sharding_strategy=sharding_strategy, # zero3 + sync_module_states=True, + cpu_offload=CPUOffload(offload_params=True), + forward_prefetch=False, + device_mesh=self.device_mesh, + ) + elif config.strategy == "fsdp2": + assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)" + cpu_offload = CPUOffloadPolicy(pin_memory=True) + fsdp_kwargs = { + "mesh": fsdp_mesh, + "offload_policy": cpu_offload, + "reshard_after_forward": config.model.fsdp_config.reshard_after_forward, + } + full_state = reward_module.state_dict() + apply_fsdp2(reward_module, fsdp_kwargs, config.model.fsdp_config) + fsdp2_load_full_state_dict(reward_module, full_state, fsdp_mesh, cpu_offload) + else: + raise NotImplementedError(f"Unknown strategy: {config.strategy}") return reward_module @register(dispatch_mode=Dispatch.ONE_TO_ALL) diff --git a/verl/workers/sharding_manager/fsdp_sglang.py b/verl/workers/sharding_manager/fsdp_sglang.py index bc2ff6775d1..639a954a5c9 100644 --- a/verl/workers/sharding_manager/fsdp_sglang.py +++ b/verl/workers/sharding_manager/fsdp_sglang.py @@ -45,13 +45,12 @@ from verl import DataProto from verl.protocol import all_gather_data_proto from verl.utils.debug import log_gpu_memory_usage -from verl.utils.fsdp_utils import load_fsdp_model_to_gpu, offload_fsdp_model_to_cpu +from verl.utils.fsdp_utils import fsdp_version, load_fsdp_model_to_gpu, offload_fsdp_model_to_cpu from verl.utils.torch_functional import broadcast_dict_tensor, check_cuda_is_available from .base import BaseShardingManager # from vllm.distributed import parallel_state as sglang_ps - logger = logging.getLogger(__file__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) @@ -81,9 +80,9 @@ def __init__( # Full params self.full_params = full_params - if full_params: + if full_params and fsdp_version(self.module) == 1: FSDP.set_state_dict_type(self.module, state_dict_type=StateDictType.FULL_STATE_DICT, state_dict_config=FullStateDictConfig()) - else: + elif fsdp_version(self.module) == 1: FSDP.set_state_dict_type( self.module, state_dict_type=StateDictType.SHARDED_STATE_DICT, @@ -108,6 +107,8 @@ def __enter__(self): load_fsdp_model_to_gpu(self.module) params = self.module.state_dict() log_gpu_memory_usage("After state_dict() in sharding manager memory", logger=logger) + device = torch.cuda.current_device() # used when fsdp2 set cpu_offload_policy + params = {k: v.to(device, non_blocking=True) if fsdp_version(self.module) == 2 else v for k, v in params.items()} # Copy, not share memory self.update_weights(params) log_gpu_memory_usage("After sync model weights in sharding manager", logger=logger) diff --git a/verl/workers/sharding_manager/fsdp_vllm.py b/verl/workers/sharding_manager/fsdp_vllm.py index 8ed19882edf..655d769a56d 100644 --- a/verl/workers/sharding_manager/fsdp_vllm.py +++ b/verl/workers/sharding_manager/fsdp_vllm.py @@ -26,7 +26,7 @@ from verl.third_party.vllm import LLM, vllm_version from verl.third_party.vllm import parallel_state as vllm_ps from verl.utils.debug import GPUMemoryLogger, log_gpu_memory_usage -from verl.utils.fsdp_utils import load_fsdp_model_to_gpu, offload_fsdp_model_to_cpu +from verl.utils.fsdp_utils import fsdp_version, load_fsdp_model_to_gpu, offload_fsdp_model_to_cpu from verl.utils.torch_functional import check_cuda_is_available from verl.utils.vllm_utils import patch_vllm_moe_model_weight_loader @@ -57,9 +57,9 @@ def __init__( # Full params self.full_params = full_params - if full_params: + if full_params and fsdp_version(self.module) == 1: FSDP.set_state_dict_type(self.module, state_dict_type=StateDictType.FULL_STATE_DICT, state_dict_config=FullStateDictConfig()) - else: + elif fsdp_version(self.module) == 1: FSDP.set_state_dict_type( self.module, state_dict_type=StateDictType.SHARDED_STATE_DICT, @@ -181,5 +181,6 @@ def update_params(self, updated_params): model = self.model_runner.model patch_vllm_moe_model_weight_loader(model) world_size = torch.distributed.get_world_size() - loaded_params = model.load_weights(((name, param.full_tensor() if world_size != 1 and hasattr(param, "full_tensor") else param) for name, param in updated_params.items())) + device = torch.cuda.current_device() # used when fsdp2 set cpu_offload_policy + loaded_params = model.load_weights(((name, param.to(device, non_blocking=True).full_tensor() if world_size != 1 and hasattr(param, "full_tensor") else param) for name, param in updated_params.items())) logger.info("vLLM load weights, loaded_params: %d", len(loaded_params))