Skip to content
Prev Previous commit
Next Next commit
Check modules_to_save for prompt tuning methods
  • Loading branch information
nemo committed Apr 9, 2025
commit 777448bbe1115ed4f2c53357b45834ab03fdebd6
11 changes: 6 additions & 5 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -968,7 +968,7 @@ def delete_adapter(self, adapter_name: str) -> None:
def modules_to_save(self) -> Optional[set[str]]:
modules: set[str] = set()
for config in self.peft_config.values():
if config.modules_to_save:
if hasattr(config, "modules_to_save"):
# modules_to_save can only be a sequence of str, not a str
modules.update(config.modules_to_save)

Expand Down Expand Up @@ -1523,10 +1523,11 @@ def __init__(
# config is relevant for this, the `modules_to_save` attribute can follow later.
super().__init__(model, peft_config, adapter_name, **kwargs)

for name, _ in self.base_model.named_children():
if any(module_name in name for module_name in self.modules_to_save):
self.cls_layer_name = name
break
if hasattr(peft_config, "modules_to_save"):
for name, _ in self.base_model.named_children():
if any(module_name in name for module_name in self.modules_to_save):
self.cls_layer_name = name
break

# to make sure classifier layer is trainable; this may add a new ModulesToSaveWrapper
_set_trainable(self, adapter_name, module_names=getattr(peft_config, "modules_to_save", None))
Expand Down
1 change: 0 additions & 1 deletion src/peft/tuners/tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,7 +992,6 @@ def check_target_module_exists(config, key: str) -> bool | re.Match[str] | None:

# Adapters should never match on modules to save modules as it is a guarantee for conflicts of behavior
# between `ModulesToSaveWrapper` internals and the potential adapter.
# TODO extend this to AuxiliaryTrainingWrapper in this PR if possible
modules_to_save = getattr(config, "modules_to_save", None)
if modules_to_save:
if any(re.match(rf"(^|.*\.){m}($|\..*)", key) for m in modules_to_save):
Expand Down
Loading