Skip to content
Prev Previous commit
Next Next commit
Reviewer comments
  • Loading branch information
nemo committed Apr 7, 2025
commit a5267bf6dfbf8c380d9d5ef7ec1bd33aede7e970
12 changes: 7 additions & 5 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@
_prepare_prompt_learning_config,
_set_adapter,
_set_trainable,
get_modules_to_save_from_config,
get_peft_model_state_dict,
id_tensor_storage,
infer_device,
Expand Down Expand Up @@ -954,7 +953,7 @@ def set_additional_trainable_modules(self, peft_config, adapter_name):
else:
self.modules_to_save.update(peft_config.modules_to_save)
# this may add a new ModulesToSaveWrapper
_set_trainable(self, adapter_name, module_names=get_modules_to_save_from_config(peft_config))
_set_trainable(self, adapter_name, module_names=getattr(peft_config, "modules_to_save", None))

if getattr(peft_config, "trainable_token_indices", None) is not None:
if isinstance(peft_config.trainable_token_indices, dict):
Expand Down Expand Up @@ -1492,6 +1491,9 @@ def __init__(
else:
peft_config.modules_to_save.extend(classifier_module_names)

# The modification of peft_config must happen before the init call as the `modules_to_save` information
# will be used to guard the target layer matching against matching `modules_to_save` layers. Only the
# config is relevant for this, the `modules_to_save` attribute can follow later.
super().__init__(model, peft_config, adapter_name, **kwargs)

if self.modules_to_save is None:
Expand All @@ -1505,7 +1507,7 @@ def __init__(
break

# to make sure classifier layer is trainable; this may add a new ModulesToSaveWrapper
_set_trainable(self, adapter_name, module_names=get_modules_to_save_from_config(peft_config))
_set_trainable(self, adapter_name, module_names=getattr(peft_config, "modules_to_save", None))

def add_adapter(self, adapter_name: str, peft_config: PeftConfig, low_cpu_mem_usage: bool = False) -> None:
"""
Expand Down Expand Up @@ -2296,7 +2298,7 @@ def __init__(
break

# to make sure classifier layer is trainable; this may add a new ModulesToSaveWrapper
_set_trainable(self, adapter_name, module_names=get_modules_to_save_from_config(peft_config))
_set_trainable(self, adapter_name, module_names=getattr(peft_config, "modules_to_save", None))

def add_adapter(self, adapter_name: str, peft_config: PeftConfig, low_cpu_mem_usage: bool = False) -> None:
"""
Expand Down Expand Up @@ -2517,7 +2519,7 @@ def __init__(
break

# to make sure classifier layer is trainable; this may add a new ModulesToSaveWrapper
_set_trainable(self, adapter_name, module_names=get_modules_to_save_from_config(peft_config))
_set_trainable(self, adapter_name, module_names=getattr(peft_config, "modules_to_save", None))

def add_adapter(self, adapter_name: str, peft_config: PeftConfig, low_cpu_mem_usage: bool = False) -> None:
"""
Expand Down
5 changes: 2 additions & 3 deletions src/peft/tuners/tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
SEQ_CLS_HEAD_NAMES,
)
from peft.utils.integrations import init_empty_weights
from peft.utils.other import get_modules_to_save_from_config
from peft.utils.peft_types import PeftType, TaskType

from ..config import PeftConfig
Expand Down Expand Up @@ -513,7 +512,7 @@ def inject_adapter(
# All targeted modules were excluded
raise ValueError(
"All modules were excluded. This is likely unintended. "
"Check your `target_modules` and `exclude_modules` configuration."
"Check your `target_modules`, `exclude_modules` and `modules_to_save` configuration."
)
elif not excluded_modules and unmatched_modules:
# None of the targeted modules matched
Expand Down Expand Up @@ -1016,7 +1015,7 @@ def check_target_module_exists(config, key: str) -> bool | re.Match[str] | None:
# Adapters should never match on modules to save modules as it is a guarantee for conflicts of behavior
# between `ModulesToSaveWrapper` internals and the potential adapter.
# TODO extend this to AuxiliaryTrainingWrapper in this PR if possible
modules_to_save = get_modules_to_save_from_config(config)
modules_to_save = getattr(config, "modules_to_save", None)
if modules_to_save:
if any(re.match(rf"(^|.*\.){m}($|\..*)", key) for m in modules_to_save):
return _ExcludedModule()
Expand Down
2 changes: 0 additions & 2 deletions src/peft/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
cast_mixed_precision_params,
get_auto_gptq_quant_linear,
get_gptqmodel_quant_linear,
get_modules_to_save_from_config,
get_quantization_config,
id_tensor_storage,
infer_device,
Expand Down Expand Up @@ -84,7 +83,6 @@
"cast_mixed_precision_params",
"get_auto_gptq_quant_linear",
"get_gptqmodel_quant_linear",
"get_modules_to_save_from_config",
"get_peft_model_state_dict",
"get_quantization_config",
"id_tensor_storage",
Expand Down
11 changes: 1 addition & 10 deletions src/peft/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,7 @@ def _set_trainable(

if not module_names:
# This is useful for the case that the PEFT config does not have `modules_to_save`, e.g.
# in the case of prompt tuning and friends as returned by `get_modules_to_save_from_config`.
# in the case of prompt tuning and friends.
return

trainable_modules = []
Expand Down Expand Up @@ -1126,12 +1126,3 @@ def get_pattern_key(pattern_keys: Sequence[str], key_to_match: str) -> str:
return key

return key_to_match


def get_modules_to_save_from_config(peft_config) -> Optional[list[str]]:
"""Utility for retrieving `modules_to_save` from a PEFT config.
This is useful for exceptional tuners like `PromptTuning` which do not have this attribute.
"""
if hasattr(peft_config, "modules_to_save"):
return peft_config.modules_to_save
return None
10 changes: 3 additions & 7 deletions tests/test_seq_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@

import pytest
import torch
from transformers import (
AutoModelForSequenceClassification,
)
from transformers import AutoModelForSequenceClassification

from peft import (
AdaLoraConfig,
Expand Down Expand Up @@ -181,10 +179,8 @@

class TestSequenceClassificationModels(PeftCommonTester):
r"""
Test if the PeftModel behaves as expected. This includes:
- test if the model has the expected methods

We use pytest for debugging purposes to test each model individually.
Tests for basic coverage of AutoModelForSequenceClassification and classification-specific cases.
Most of the functionality is probably already covered by other tests.
"""

transformers_class = AutoModelForSequenceClassification
Expand Down
Loading