-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Description
System Info
Python 3.11.9
transformers==4.40.2
peft==0.11.2
Who can help?
@BenjaminBossan
A bug occurs in the PEFT library when using multiple LoRA adapters, each with a unique modules_to_save configuration. The issue arises when the modules_to_save from the first LoRA adapter (e.g., adapter_1) is applied to subsequent adapters (e.g., adapter_2), rather than maintaining independent configurations. As a result, modules specified in modules_to_save for adapter_1 also appear in adapter_2, leading to unintended behavior and possibly affecting fine-tuning accuracy. This incorrect handling of modules_to_save causes duplicate entries where only the respective LoRA adapter’s modules should be saved.
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder - My own task or dataset (give details below)
Reproduction
The following example code demonstrates this issue, displaying the model structure where adapter_2 contains modules meant only for adapter_1.
Example Code
import os
from transformers import AutoModelForCausalLM
from peft import LoraConfig, get_peft_model, PeftModel
# Get the directory of the current Python script
script_dir = os.path.dirname(os.path.abspath(__file__))
# Define relative paths for adapters
adapter_1_path = os.path.join(script_dir, "adapter_1")
adapter_2_path = os.path.join(script_dir, "adapter_2")
# Load base model
base_model = AutoModelForCausalLM.from_pretrained("gpt2")
# Define LoRA configs with different modules_to_save
lora_config_1 = LoraConfig(
r=8,
lora_alpha=32,
target_modules=["c_attn"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
modules_to_save=["lm_head"]
)
lora_config_2 = LoraConfig(
r=8,
lora_alpha=32,
target_modules=["c_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
modules_to_save=["wte"]
)
# Apply and save the first adapter
os.makedirs(adapter_1_path, exist_ok=True)
model_with_lora_1 = get_peft_model(base_model, lora_config_1, adapter_name="adapter_1")
model_with_lora_1.save_pretrained(adapter_1_path)
# Apply and save the second adapter
os.makedirs(adapter_2_path, exist_ok=True)
model_with_lora_2 = get_peft_model(base_model, lora_config_2, adapter_name="adapter_2")
model_with_lora_2.save_pretrained(adapter_2_path)
# Load a fresh base model and wrap it in PeftModel by loading the first adapter
base_model = AutoModelForCausalLM.from_pretrained("gpt2")
peft_model = PeftModel.from_pretrained(base_model, os.path.join(adapter_1_path, "adapter_1"), adapter_name="adapter_1")
# Load the second adapter into the PeftModel
peft_model.load_adapter(os.path.join(adapter_2_path, "adapter_2"), adapter_name="adapter_2")
# Display structure and inspect unexpected 'modules_to_save' overlap
print("Expected `modules_to_save` for each adapter:")
print("Adapter 1 `modules_to_save`: ['lm_head']")
print("Adapter 2 `modules_to_save`: ['wte']")
print("\nActual model structure and `modules_to_save` contents:\n")
print(peft_model.transformer.wte)
print(peft_model.lm_head)The code output will be:
Expected `modules_to_save` for each adapter:
Adapter 1 `modules_to_save`: ['lm_head']
Adapter 2 `modules_to_save`: ['wte']
Actual model structure and `modules_to_save` contents:
ModulesToSaveWrapper(
(original_module): Embedding(50257, 768)
(modules_to_save): ModuleDict(
(adapter_2): Embedding(50257, 768)
)
)
ModulesToSaveWrapper(
(original_module): Linear(in_features=768, out_features=50257, bias=False)
(modules_to_save): ModuleDict(
(adapter_1): Linear(in_features=768, out_features=50257, bias=False)
(adapter_2): Linear(in_features=768, out_features=50257, bias=False)
)
)Expected behavior
As you see adapter 2 is also built for the "lm_head" module to which it shouldn't, the expected output is shown below:
Expected `modules_to_save` for each adapter:
Adapter 1 `modules_to_save`: ['lm_head']
Adapter 2 `modules_to_save`: ['wte']
Actual model structure and `modules_to_save` contents:
ModulesToSaveWrapper(
(original_module): Embedding(50257, 768)
(modules_to_save): ModuleDict(
(adapter_2): Embedding(50257, 768)
)
)
ModulesToSaveWrapper(
(original_module): Linear(in_features=768, out_features=50257, bias=False)
(modules_to_save): ModuleDict(
(adapter_1): Linear(in_features=768, out_features=50257, bias=False)
)
)