Skip to content
Next Next commit
support fsdp2
  • Loading branch information
lxg2015 authored and lixiaoguang12 committed May 1, 2025
commit d6297a6e0539cb5114275c81574b6d3fdcd698de
36 changes: 24 additions & 12 deletions tests/checkpoint/test_fsdp_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@

from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager
from verl.utils.distributed import initialize_global_process_group
from verl.utils.fsdp_utils import fully_shard, MixedPrecisionPolicy, apply_fsdp2


def test_fsdp_ckpt():
def test_fsdp_ckpt(strategy='fsdp'):
assert torch.cuda.device_count() >= 2, "need at least 2 gpus for test"
local_rank, rank, world_size = initialize_global_process_group()
device_mesh = init_device_mesh("cuda", mesh_shape=(world_size,), mesh_dim_names=("dp",))
Expand All @@ -39,16 +40,23 @@ def test_fsdp_ckpt():
model = model.to(device="cuda")

# Wrap model with FSDP
mixed_precision = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32)

model = FSDP(
model,
use_orig_params=False,
device_id=torch.cuda.current_device(),
sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=mixed_precision,
device_mesh=device_mesh,
)
if strategy == 'fsdp':
mixed_precision = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32)

model = FSDP(
model,
use_orig_params=False,
device_id=torch.cuda.current_device(),
sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=mixed_precision,
device_mesh=device_mesh)
else:
mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, cast_forward_inputs=True)
fsdp_kwargs = {
'mesh': device_mesh,
'mp_policy': mp_policy,
}
apply_fsdp2(model, fsdp_kwargs)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)
Expand Down Expand Up @@ -116,7 +124,11 @@ def test_fsdp_ckpt():
# Cleanup
shutil.rmtree(temp_dir)
torch.distributed.barrier()

torch.distributed.destroy_process_group()

if __name__ == "__main__":
test_fsdp_ckpt()
if fully_shard is not None:
print('begin to test fsdp2')
test_fsdp_ckpt(strategy='fsdp2')
print('test fsdp2 passed!')
2 changes: 1 addition & 1 deletion verl/trainer/config/ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ actor_rollout_ref:
use_remove_padding: False
use_liger: False
actor:
strategy: fsdp # This is for backward-compatibility
strategy: fsdp # [fsdp, fsdp2], This is for backward-compatibility
ppo_mini_batch_size: 256
ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu
ppo_micro_batch_size_per_gpu: null
Expand Down
6 changes: 3 additions & 3 deletions verl/trainer/main_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ def run(self, config):
processor = hf_processor(local_path, use_fast=True) # used for multimodal LLM, could be none

# define worker classes
if config.actor_rollout_ref.actor.strategy == "fsdp":
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
if config.actor_rollout_ref.actor.strategy in ['fsdp', 'fsdp2']:
assert config.critic.strategy in ['fsdp', 'fsdp2']
from verl.single_controller.ray import RayWorkerGroup
from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, CriticWorker

Expand Down Expand Up @@ -145,7 +145,7 @@ def run(self, config):
# - finally, we combine all the rewards together
# - The reward type depends on the tag of the data
if config.reward_model.enable:
if config.reward_model.strategy == "fsdp":
if config.reward_model.strategy in ['fsdp', 'fsdp2']:
from verl.workers.fsdp_workers import RewardModelWorker
elif config.reward_model.strategy == "megatron":
from verl.workers.megatron_workers import RewardModelWorker
Expand Down
17 changes: 10 additions & 7 deletions verl/utils/checkpoint/fsdp_checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from verl.utils.fs import copy_to_local, is_non_local

from .checkpoint_manager import BaseCheckpointManager

from verl.utils.fsdp_utils import get_fsdp_state_ctx, fsdp_version

class FSDPCheckpointManager(BaseCheckpointManager):
"""
Expand Down Expand Up @@ -96,7 +96,7 @@ def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_afte

state_dict_cfg = ShardedStateDictConfig(offload_to_cpu=True)
optim_cfg = ShardedOptimStateDictConfig(offload_to_cpu=True)
with FSDP.state_dict_type(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg):
with get_fsdp_state_ctx(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg):
self.model.load_state_dict(model_state_dict)
if self.optimizer is not None:
self.optimizer.load_state_dict(optimizer_state_dict)
Expand Down Expand Up @@ -129,7 +129,7 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i
optim_cfg = ShardedOptimStateDictConfig(offload_to_cpu=True)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
with FSDP.state_dict_type(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg):
with get_fsdp_state_ctx(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg):
model_state_dict = self.model.state_dict()
optimizer_state_dict = self.optimizer.state_dict() if self.optimizer is not None else None
lr_scheduler_state_dict = self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None
Expand All @@ -153,11 +153,14 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i
# wait for everyone to dump to local
torch.distributed.barrier()

if self.rank == 0:
hf_local_path = os.path.join(local_path, "huggingface")
os.makedirs(hf_local_path, exist_ok=True)
if self.rank == 0:
hf_local_path = os.path.join(local_path, 'huggingface')
os.makedirs(hf_local_path, exist_ok=True)
if fsdp_version(self.model) == 1:
self.model._fsdp_wrapped_module.config.save_pretrained(hf_local_path)
self.processing_class.save_pretrained(hf_local_path)
else:
self.model.config.save_pretrained(hf_local_path)
self.processing_class.save_pretrained(hf_local_path)

torch.distributed.barrier()

Expand Down
70 changes: 68 additions & 2 deletions verl/utils/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import json
import math
import os
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
from typing import Dict

import torch
Expand All @@ -28,7 +28,13 @@
from torch.distributed.fsdp._runtime_utils import _lazy_init
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy
from transformers.trainer_pt_utils import get_module_class_from_name

from packaging import version
if version.parse(torch.__version__) >= version.parse('2.6'):
from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy, FSDPModule, CPUOffloadPolicy
elif version.parse(torch.__version__) >= version.parse('2.4'):
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy, FSDPModule, CPUOffloadPolicy
else:
fully_shard, MixedPrecisionPolicy, FSDPModule, CPUOffloadPolicy = None, None, None, None

def init_fn(x: torch.nn.Module):
if torch.distributed.get_rank() != 0:
Expand Down Expand Up @@ -111,6 +117,10 @@ def lambda_policy_fn(module):

@torch.no_grad()
def offload_fsdp_model_to_cpu(model: FSDP, empty_cache: bool = True):
if fsdp_version(model) == 2:
offload_fsdp2_model_to_cpu(model, empty_cache)
return

assert isinstance(model, FSDP)
# lazy init FSDP model
_lazy_init(model, model)
Expand All @@ -127,9 +137,18 @@ def offload_fsdp_model_to_cpu(model: FSDP, empty_cache: bool = True):
if empty_cache:
torch.cuda.empty_cache()

@torch.no_grad()
def offload_fsdp2_model_to_cpu(model, empty_cache: bool = True):
model.to('cpu', non_blocking=True)
if empty_cache:
torch.cuda.empty_cache()

@torch.no_grad()
def load_fsdp_model_to_gpu(model: FSDP):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the callsite for load_fsdp_model_to_gpu? when applying FSDP or fully_shard, we move model to cuda

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This runs on the actor model. When we disable the offload_policy of fsdp2 while enabling param_offload, it will call load_fsdp_model_to_gpu before actor.update_policy and will call offload_fsdp_model_to_cpu after actor.update_policy, and this acts like fsdp1

if fsdp_version(model) == 2:
load_fsdp2_model_to_gpu(model)
return

assert isinstance(model, FSDP)
# lazy init FSDP model
_lazy_init(model, model)
Expand All @@ -143,6 +162,10 @@ def load_fsdp_model_to_gpu(model: FSDP):
# the following still keeps id(._local_shard) != id(.data)
flat_param._local_shard = flat_param.data

@torch.no_grad()
def load_fsdp2_model_to_gpu(model):
device_id = torch.cuda.current_device()
model.to(f"cuda:{device_id}", non_blocking=True)

@torch.no_grad()
def offload_fsdp_optimizer(optimizer):
Expand Down Expand Up @@ -333,3 +356,46 @@ def init_fn(sub_mod: torch.nn.Module, recurse: bool = True):
return sub_mod

return init_fn


def fsdp_version(model):
if isinstance(model, FSDP):
return 1
elif isinstance(model, FSDPModule):
return 2
else:
return 0


def get_fsdp_state_ctx(model, state_type, state_cfg, optim_cfg):
if fsdp_version(model) == 1:
return FSDP.state_dict_type(model, state_type, state_cfg, optim_cfg)
else:
return nullcontext()


def fsdp2_sharding_strategy(device_mesh):
sharding_strategy = False # zero1,2
if device_mesh.ndim == 1:
sharding_strategy = True # zero3
elif device_mesh.ndim == 2:
sharding_strategy = torch.cuda.device_count() # hsdp
Copy link
Contributor

@weifengpy weifengpy Apr 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why reshard_after_forward=N is decided by device_mesh.ndim? I might miss some context here

for hsdp, fully_shard(device_mesh=(replicate, shard)), we can do reshard_after_forward=True/False/Int. device_mesh and reshard_after_forward are orthogonal to me

Copy link
Contributor Author

@lxg2015 lxg2015 Apr 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

verl will initialize the device_mesh as described in the init with 2 dim for hsdp , so a specific judgement related to the device_mesh is made here, same as fsdp1

or should it be changed like this

sharding_strategy = device_mesh.get_group(-1).size()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @weifengpy
Let me explain what I understand. reshard_after_forward only controls the number of shards for the parameters, and in device_mesh=(replicate, shard), the size of shard determines the number of shards for the optimizer and the gradient. so they are orthogonal.
Did I understand correctly? If I did, should I expose the reshard_after_forward parameter to user regardless of whether it is HSDP or not. Thanks

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did I understand correctly? If I did, should I expose the reshard_after_forward parameter to user regardless of whether it is HSDP or not. Thanks

your understanding is correct, they are orthogonal

should I expose the reshard_after_forward parameter to user regardless of whether it is HSDP or not. Thanks

yes, it should be exposed to user. practically we set reshard_after_forward=False for FSDP2 + pipeline parallel

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I have exposed this parameter in ppo_trainer.yaml. since Verl does not currently support pipeline parallel, I have not set this

else:
raise NotImplementedError(f"Get device mesh ndim={device_mesh.ndim}, but only support 1 or 2")
return sharding_strategy


def apply_fsdp2(model, fsdp_kwargs):
'''model: AutoModelForCausalLM
'''
assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)"

fsdp_mesh = fsdp_kwargs.get('mesh')

# refer torchtitan
for idx, layer in enumerate(model.model.layers):
reshard_after_forward = fsdp2_sharding_strategy(fsdp_mesh)
if idx == len(model.model.layers) - 1:
reshard_after_forward = False
fully_shard(layer, **fsdp_kwargs, reshard_after_forward=reshard_after_forward)
fully_shard(model, **fsdp_kwargs, reshard_after_forward=False)
Loading