Skip to content

Commit 1466760

Browse files
authored
[misc] fix: no need to use world_size to decide whether to use full_tensor in FSDP2 (#1529)
[misc] fix: no need to use world_size to decide whether to use full_tensor() for FSDP2 state_dict() when world_size==1 ### Checklist Before Starting - [x] Search for similar PR(s). ### What does this PR do? This PR simplifies the parameter loading logic within the `FSDPVLLMShardingManager` by removing an unnecessary `world_size` check when determining whether to call `full_tensor()` on parameters obtained from an FSDP2 model's `state_dict()`. As the FSDP2 parameters are all `DTensor`. ### High-Level Design The change modifies the update_params method. When loading weights into the vLLM model, parameters from the FSDP state_dict() (which might be ShardedTensor or DTensor instances under FSDP2 when world_size == 1) are converted to full tensors using param.full_tensor(). This PR ensures this conversion happens if the full_tensor() method is available on the parameter, without an additional, potentially incorrect, check against world_size == 1. ### Specific Changes Skip. See file changes ### API No ### Usage Example No ### Test No CI changes ### Additional Info. - **Issue Number**: No - **Training**: [Note which backend this PR will affect: FSDP - **Inference**: [Note which backend this PR will affect: vLLM ### Checklist Before Submitting - [x] Read the [Contribute Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting). - [ ] Add `[BREAKING]` to the PR title if it breaks any API. - [ ] Update the documentation about your changes in the [docs](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add CI test(s) if neccessary.
1 parent 11622fc commit 1466760

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

verl/workers/sharding_manager/fsdp_vllm.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,12 @@
2121
from torch.distributed.fsdp.api import FullStateDictConfig, ShardedStateDictConfig, StateDictType
2222
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
2323

24+
try:
25+
# for torch 2.5+
26+
from torch.distributed.tensor import DTensor
27+
except ImportError:
28+
from torch.distributed._tensor import DTensor
29+
2430
from verl import DataProto
2531
from verl.protocol import all_gather_data_proto
2632
from verl.third_party.vllm import LLM, vllm_version
@@ -52,13 +58,13 @@ def __init__(
5258
self.inference_engine = inference_engine
5359
# self.model_runner = inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner if inference_engine else None
5460

55-
if 'vllm_v_0_6_3' in str(type(self.inference_engine)) or 'vllm_v_0_5_4' in str(type(self.inference_engine)):
61+
if "vllm_v_0_6_3" in str(type(self.inference_engine)) or "vllm_v_0_5_4" in str(type(self.inference_engine)):
5662
# vLLM <= v0.6.3
5763
self.model_runner = self.inference_engine.llm_engine.model_executor.worker.model_runner if self.inference_engine else None
5864
else:
5965
# vLLM > v0.6.3
6066
self.model_runner = self.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner if self.inference_engine else None
61-
67+
6268
self.model_config = model_config
6369
self.device_mesh = device_mesh
6470
self.offload_param = offload_param
@@ -188,7 +194,6 @@ def postprocess_data(self, data: DataProto) -> DataProto:
188194
def update_params(self, updated_params):
189195
model = self.model_runner.model
190196
patch_vllm_moe_model_weight_loader(model)
191-
world_size = torch.distributed.get_world_size()
192197
device = torch.cuda.current_device() # used when fsdp2 set cpu_offload_policy
193-
loaded_params = model.load_weights(((name, param.to(device, non_blocking=True).full_tensor() if world_size != 1 and hasattr(param, "full_tensor") else param) for name, param in updated_params.items()))
198+
loaded_params = model.load_weights(((name, param.to(device, non_blocking=True).full_tensor() if isinstance(param, DTensor) else param) for name, param in updated_params.items()))
194199
logger.info("vLLM load weights, loaded_params: %d", len(loaded_params))

0 commit comments

Comments
 (0)