Skip to content
24 changes: 14 additions & 10 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -953,7 +953,7 @@ def set_additional_trainable_modules(self, peft_config, adapter_name):
else:
self.modules_to_save.update(peft_config.modules_to_save)
# this may add a new ModulesToSaveWrapper
_set_trainable(self, adapter_name, module_names=peft_config.modules_to_save)
_set_trainable(self, adapter_name, module_names=getattr(peft_config, "modules_to_save", None))

if getattr(peft_config, "trainable_token_indices", None) is not None:
if isinstance(peft_config.trainable_token_indices, dict):
Expand Down Expand Up @@ -1483,27 +1483,31 @@ class PeftModelForSequenceClassification(PeftModel):
def __init__(
self, model: torch.nn.Module, peft_config: PeftConfig, adapter_name: str = "default", **kwargs
) -> None:
super().__init__(model, peft_config, adapter_name, **kwargs)

classifier_module_names = ["classifier", "score"]
if self.modules_to_save is None:
self.modules_to_save = set(classifier_module_names)
else:
self.modules_to_save.update(classifier_module_names)

if hasattr(peft_config, "modules_to_save"):
if peft_config.modules_to_save is None:
peft_config.modules_to_save = classifier_module_names[:]
else:
peft_config.modules_to_save.extend(classifier_module_names)

# The modification of peft_config must happen before the init call as the `modules_to_save` information
# will be used to guard the target layer matching against matching `modules_to_save` layers. Only the
# config is relevant for this, the `modules_to_save` attribute can follow later.
super().__init__(model, peft_config, adapter_name, **kwargs)

if self.modules_to_save is None:
self.modules_to_save = set(classifier_module_names)
else:
self.modules_to_save.update(classifier_module_names)

for name, _ in self.base_model.named_children():
if any(module_name in name for module_name in self.modules_to_save):
self.cls_layer_name = name
break

# to make sure classifier layer is trainable; this may add a new ModulesToSaveWrapper
_set_trainable(self, adapter_name, module_names=peft_config.modules_to_save)
_set_trainable(self, adapter_name, module_names=getattr(peft_config, "modules_to_save", None))

def add_adapter(self, adapter_name: str, peft_config: PeftConfig, low_cpu_mem_usage: bool = False) -> None:
"""
Expand Down Expand Up @@ -2294,7 +2298,7 @@ def __init__(
break

# to make sure classifier layer is trainable; this may add a new ModulesToSaveWrapper
_set_trainable(self, adapter_name, module_names=peft_config.modules_to_save)
_set_trainable(self, adapter_name, module_names=getattr(peft_config, "modules_to_save", None))

def add_adapter(self, adapter_name: str, peft_config: PeftConfig, low_cpu_mem_usage: bool = False) -> None:
"""
Expand Down Expand Up @@ -2515,7 +2519,7 @@ def __init__(
break

# to make sure classifier layer is trainable; this may add a new ModulesToSaveWrapper
_set_trainable(self, adapter_name, module_names=peft_config.modules_to_save)
_set_trainable(self, adapter_name, module_names=getattr(peft_config, "modules_to_save", None))

def add_adapter(self, adapter_name: str, peft_config: PeftConfig, low_cpu_mem_usage: bool = False) -> None:
"""
Expand Down
10 changes: 9 additions & 1 deletion src/peft/tuners/tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ def inject_adapter(
# All targeted modules were excluded
raise ValueError(
"All modules were excluded. This is likely unintended. "
"Check your `target_modules` and `exclude_modules` configuration."
"Check your `target_modules`, `exclude_modules` and `modules_to_save` configuration."
)
elif not excluded_modules and unmatched_modules:
# None of the targeted modules matched
Expand Down Expand Up @@ -1012,6 +1012,14 @@ def check_target_module_exists(config, key: str) -> bool | re.Match[str] | None:
elif any(key.endswith(f".{exclude_key}") for exclude_key in config.exclude_modules):
return _ExcludedModule()

# Adapters should never match on modules to save modules as it is a guarantee for conflicts of behavior
# between `ModulesToSaveWrapper` internals and the potential adapter.
# TODO extend this to AuxiliaryTrainingWrapper in this PR if possible
modules_to_save = getattr(config, "modules_to_save", None)
if modules_to_save:
if any(re.match(rf"(^|.*\.){m}($|\..*)", key) for m in modules_to_save):
return _ExcludedModule()

if isinstance(config.target_modules, str):
target_module_found = re.fullmatch(config.target_modules, key)
elif key in config.target_modules:
Expand Down
5 changes: 5 additions & 0 deletions src/peft/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,11 @@ def _set_trainable(
if wrapper_cls is None:
wrapper_cls = ModulesToSaveWrapper

if not module_names:
# This is useful for the case that the PEFT config does not have `modules_to_save`, e.g.
# in the case of prompt tuning and friends.
return

trainable_modules = []
found_modules = set()
# disable removal of duplicates to support targeting tied weights
Expand Down
248 changes: 248 additions & 0 deletions tests/test_seq_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
# Copyright 2025-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License governing permissions and limitations under the License.

import pytest
import torch
from transformers import AutoModelForSequenceClassification

from peft import (
AdaLoraConfig,
BOFTConfig,
BoneConfig,
CPTConfig,
FourierFTConfig,
HRAConfig,
IA3Config,
LoraConfig,
OFTConfig,
PrefixTuningConfig,
PromptEncoderConfig,
PromptTuningConfig,
PromptTuningInit,
VBLoRAConfig,
VeraConfig,
)

from .testing_common import PeftCommonTester


PEFT_SEQ_CLS_MODELS_TO_TEST = [
"hf-internal-testing/tiny-random-BertForSequenceClassification",
"hf-internal-testing/tiny-random-RobertaForSequenceClassification",
]


ALL_CONFIGS = [
(
AdaLoraConfig,
{
"task_type": "SEQ_CLS",
"target_modules": None,
"total_step": 1,
},
),
(
BOFTConfig,
{
"task_type": "SEQ_CLS",
"target_modules": None,
},
),
(
BoneConfig,
{
"task_type": "SEQ_CLS",
"target_modules": None,
"r": 2,
},
),
(
CPTConfig,
{
"task_type": "SEQ_CLS",
"cpt_token_ids": [0, 1, 2, 3, 4, 5, 6, 7], # Example token IDs for testing
"cpt_mask": [1, 1, 1, 1, 1, 1, 1, 1],
"cpt_tokens_type_mask": [1, 2, 2, 2, 3, 3, 4, 4],
},
),
(
FourierFTConfig,
{
"task_type": "SEQ_CLS",
"n_frequency": 10,
"target_modules": None,
},
),
(
HRAConfig,
{
"task_type": "SEQ_CLS",
"target_modules": None,
},
),
(
IA3Config,
{
"task_type": "SEQ_CLS",
"target_modules": None,
"feedforward_modules": None,
},
),
(
LoraConfig,
{
"task_type": "SEQ_CLS",
"r": 8,
"lora_alpha": 32,
"target_modules": None,
"lora_dropout": 0.05,
"bias": "none",
},
),
# LoRA + trainable tokens
(
LoraConfig,
{
"task_type": "SEQ_CLS",
"r": 8,
"lora_alpha": 32,
"target_modules": None,
"lora_dropout": 0.05,
"bias": "none",
"trainable_token_indices": [0, 1, 3],
},
),
(
OFTConfig,
{
"task_type": "SEQ_CLS",
"target_modules": None,
},
),
(
PrefixTuningConfig,
{
"task_type": "SEQ_CLS",
"num_virtual_tokens": 10,
},
),
(
PromptEncoderConfig,
{
"task_type": "SEQ_CLS",
"num_virtual_tokens": 10,
"encoder_hidden_size": 32,
},
),
(
PromptTuningConfig,
{
"task_type": "SEQ_CLS",
"num_virtual_tokens": 10,
},
),
(
VBLoRAConfig,
{
"task_type": "SEQ_CLS",
"target_modules": None,
"vblora_dropout": 0.05,
"vector_length": 1,
"num_vectors": 2,
},
),
(
VeraConfig,
{
"task_type": "SEQ_CLS",
"r": 8,
"target_modules": None,
"vera_dropout": 0.05,
"projection_prng_key": 0xFF,
"d_initial": 0.1,
"save_projection": True,
"bias": "none",
},
),
]


class TestSequenceClassificationModels(PeftCommonTester):
r"""
Tests for basic coverage of AutoModelForSequenceClassification and classification-specific cases.
Most of the functionality is probably already covered by other tests.
"""

transformers_class = AutoModelForSequenceClassification

def skipTest(self, reason=""):
# for backwards compatibility with unittest style test classes
pytest.skip(reason)

def prepare_inputs_for_testing(self):
input_ids = torch.tensor([[1, 1, 1], [1, 2, 1]]).to(self.torch_device)
attention_mask = torch.tensor([[1, 1, 1], [1, 0, 1]]).to(self.torch_device)
return {"input_ids": input_ids, "attention_mask": attention_mask}

@pytest.mark.parametrize("model_id", PEFT_SEQ_CLS_MODELS_TO_TEST)
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
def test_attributes_parametrized(self, model_id, config_cls, config_kwargs):
self._test_model_attr(model_id, config_cls, config_kwargs.copy())

@pytest.mark.parametrize("model_id", PEFT_SEQ_CLS_MODELS_TO_TEST)
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
def test_adapter_name(self, model_id, config_cls, config_kwargs):
self._test_adapter_name(model_id, config_cls, config_kwargs.copy())

@pytest.mark.parametrize("model_id", PEFT_SEQ_CLS_MODELS_TO_TEST)
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
def test_prepare_for_training_parametrized(self, model_id, config_cls, config_kwargs):
self._test_prepare_for_training(model_id, config_cls, config_kwargs.copy())

@pytest.mark.parametrize("model_id", PEFT_SEQ_CLS_MODELS_TO_TEST)
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
def test_prompt_tuning_text_prepare_for_training(self, model_id, config_cls, config_kwargs):
if config_cls != PromptTuningConfig:
pytest.skip(f"This test does not apply to {config_cls}")
config_kwargs = config_kwargs.copy()
config_kwargs["prompt_tuning_init"] = PromptTuningInit.TEXT
config_kwargs["prompt_tuning_init_text"] = "This is a test prompt."
config_kwargs["tokenizer_name_or_path"] = model_id
self._test_prepare_for_training(model_id, config_cls, config_kwargs.copy())

@pytest.mark.parametrize("model_id", PEFT_SEQ_CLS_MODELS_TO_TEST)
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
def test_save_pretrained(self, model_id, config_cls, config_kwargs):
self._test_save_pretrained(model_id, config_cls, config_kwargs.copy())

@pytest.mark.parametrize("model_id", PEFT_SEQ_CLS_MODELS_TO_TEST)
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
def test_save_pretrained_pickle(self, model_id, config_cls, config_kwargs):
self._test_save_pretrained(model_id, config_cls, config_kwargs.copy(), safe_serialization=False)

@pytest.mark.parametrize("model_id", PEFT_SEQ_CLS_MODELS_TO_TEST)
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
def test_save_pretrained_selected_adapters(self, model_id, config_cls, config_kwargs):
self._test_save_pretrained_selected_adapters(model_id, config_cls, config_kwargs.copy())

@pytest.mark.parametrize("model_id", PEFT_SEQ_CLS_MODELS_TO_TEST)
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
def test_save_pretrained_selected_adapters_pickle(self, model_id, config_cls, config_kwargs):
self._test_save_pretrained_selected_adapters(
model_id, config_cls, config_kwargs.copy(), safe_serialization=False
)

@pytest.mark.parametrize("model_id", PEFT_SEQ_CLS_MODELS_TO_TEST)
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
def test_from_pretrained_config_construction(self, model_id, config_cls, config_kwargs):
self._test_from_pretrained_config_construction(model_id, config_cls, config_kwargs.copy())
Loading