-
Notifications
You must be signed in to change notification settings - Fork 3.2k
[fsdp] feat: support fsdp2 training and inference in fsdp_workers #1026
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
d6297a6
2ce75dd
3240532
b58a93d
82aaae4
95ce32e
87550dc
d361184
08a60c3
1efd9c2
97a0f23
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,7 +17,7 @@ | |
| import json | ||
| import math | ||
| import os | ||
| from contextlib import contextmanager | ||
| from contextlib import contextmanager, nullcontext | ||
| from typing import Dict | ||
|
|
||
| import torch | ||
|
|
@@ -28,7 +28,13 @@ | |
| from torch.distributed.fsdp._runtime_utils import _lazy_init | ||
| from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy | ||
| from transformers.trainer_pt_utils import get_module_class_from_name | ||
|
|
||
| from packaging import version | ||
| if version.parse(torch.__version__) >= version.parse('2.6'): | ||
| from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy, FSDPModule, CPUOffloadPolicy | ||
| elif version.parse(torch.__version__) >= version.parse('2.4'): | ||
| from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy, FSDPModule, CPUOffloadPolicy | ||
| else: | ||
| fully_shard, MixedPrecisionPolicy, FSDPModule, CPUOffloadPolicy = None, None, None, None | ||
|
|
||
| def init_fn(x: torch.nn.Module): | ||
| if torch.distributed.get_rank() != 0: | ||
|
|
@@ -111,6 +117,10 @@ def lambda_policy_fn(module): | |
|
|
||
| @torch.no_grad() | ||
| def offload_fsdp_model_to_cpu(model: FSDP, empty_cache: bool = True): | ||
| if fsdp_version(model) == 2: | ||
| offload_fsdp2_model_to_cpu(model, empty_cache) | ||
| return | ||
|
|
||
| assert isinstance(model, FSDP) | ||
| # lazy init FSDP model | ||
| _lazy_init(model, model) | ||
|
|
@@ -127,9 +137,18 @@ def offload_fsdp_model_to_cpu(model: FSDP, empty_cache: bool = True): | |
| if empty_cache: | ||
| torch.cuda.empty_cache() | ||
|
|
||
| @torch.no_grad() | ||
| def offload_fsdp2_model_to_cpu(model, empty_cache: bool = True): | ||
| model.to('cpu', non_blocking=True) | ||
| if empty_cache: | ||
| torch.cuda.empty_cache() | ||
|
|
||
| @torch.no_grad() | ||
| def load_fsdp_model_to_gpu(model: FSDP): | ||
| if fsdp_version(model) == 2: | ||
| load_fsdp2_model_to_gpu(model) | ||
| return | ||
|
|
||
| assert isinstance(model, FSDP) | ||
| # lazy init FSDP model | ||
| _lazy_init(model, model) | ||
|
|
@@ -143,6 +162,10 @@ def load_fsdp_model_to_gpu(model: FSDP): | |
| # the following still keeps id(._local_shard) != id(.data) | ||
| flat_param._local_shard = flat_param.data | ||
|
|
||
| @torch.no_grad() | ||
| def load_fsdp2_model_to_gpu(model): | ||
| device_id = torch.cuda.current_device() | ||
| model.to(f"cuda:{device_id}", non_blocking=True) | ||
|
|
||
| @torch.no_grad() | ||
| def offload_fsdp_optimizer(optimizer): | ||
|
|
@@ -333,3 +356,46 @@ def init_fn(sub_mod: torch.nn.Module, recurse: bool = True): | |
| return sub_mod | ||
|
|
||
| return init_fn | ||
|
|
||
|
|
||
| def fsdp_version(model): | ||
| if isinstance(model, FSDP): | ||
| return 1 | ||
| elif isinstance(model, FSDPModule): | ||
| return 2 | ||
| else: | ||
| return 0 | ||
|
|
||
|
|
||
| def get_fsdp_state_ctx(model, state_type, state_cfg, optim_cfg): | ||
| if fsdp_version(model) == 1: | ||
| return FSDP.state_dict_type(model, state_type, state_cfg, optim_cfg) | ||
| else: | ||
| return nullcontext() | ||
|
|
||
|
|
||
| def fsdp2_sharding_strategy(device_mesh): | ||
| sharding_strategy = False # zero1,2 | ||
| if device_mesh.ndim == 1: | ||
| sharding_strategy = True # zero3 | ||
| elif device_mesh.ndim == 2: | ||
| sharding_strategy = torch.cuda.device_count() # hsdp | ||
|
||
| else: | ||
| raise NotImplementedError(f"Get device mesh ndim={device_mesh.ndim}, but only support 1 or 2") | ||
| return sharding_strategy | ||
|
|
||
|
|
||
| def apply_fsdp2(model, fsdp_kwargs): | ||
| '''model: AutoModelForCausalLM | ||
| ''' | ||
| assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)" | ||
|
|
||
| fsdp_mesh = fsdp_kwargs.get('mesh') | ||
|
|
||
| # refer torchtitan | ||
| for idx, layer in enumerate(model.model.layers): | ||
| reshard_after_forward = fsdp2_sharding_strategy(fsdp_mesh) | ||
| if idx == len(model.model.layers) - 1: | ||
| reshard_after_forward = False | ||
PeterSH6 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| fully_shard(layer, **fsdp_kwargs, reshard_after_forward=reshard_after_forward) | ||
| fully_shard(model, **fsdp_kwargs, reshard_after_forward=False) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the callsite for
load_fsdp_model_to_gpu? when applyingFSDPorfully_shard, we move model to cudaThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This runs on the actor model. When we disable the offload_policy of fsdp2 while enabling param_offload, it will call load_fsdp_model_to_gpu before actor.update_policy and will call offload_fsdp_model_to_cpu after actor.update_policy, and this acts like fsdp1