Skip to content
Prev Previous commit
fix wrap_policy usage in test
  • Loading branch information
PeterSH6 committed May 2, 2025
commit 97a0f239f9fb0af5bd640d47035da6bba2a68221
4 changes: 2 additions & 2 deletions verl/utils/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,8 +418,8 @@ def apply_fsdp2(model, fsdp_kwargs, config):
assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)"

default_transformer_cls_names_to_wrap = getattr(model, "_no_split_modules", None)
fsdp_transformer_layer_cls_to_wrap = config.wrap_policy.get("transformer_layer_cls_to_wrap", default_transformer_cls_names_to_wrap)

fsdp_transformer_layer_cls_to_wrap = config.get("wrap_policy", {}).get("transformer_layer_cls_to_wrap", default_transformer_cls_names_to_wrap)
if isinstance(fsdp_transformer_layer_cls_to_wrap, str):
fsdp_transformer_layer_cls_to_wrap = [fsdp_transformer_layer_cls_to_wrap]

Expand Down