Skip to content
Prev Previous commit
Next Next commit
Move set_additional_trainable_modules
Move `set_additional_trainable_modules` to `inject_adapter` in case of adapters such as LoRAs, or,
in case of prompt tuning adapters, to their respective initialization point (while keeping the order
of operations intact).

Before this change a significant portion of `modules_to_save` initialization was removed from
`check_target_layer_exists` (called from `inject_adapter`) which only handled the `modules_to_save`
parameter in cases where this function was called directly (e.g., via `LoraModel.add_weighted_adapter`).
This also meant that trainable tokens was completely ignored in these cases. It also copied code from
`_set_trainable`.

The removal prompted the need to find a replacement which is this change: on adapter injection we will
now always check if there need to be additional trainable modules, not only during `PeftModel` init.
  • Loading branch information
nemo committed Apr 15, 2025
commit b133062dd2aab99af034a3f871e48ef5aad79d00
77 changes: 13 additions & 64 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer
from peft.utils.constants import DUMMY_MODEL_CONFIG
from peft.utils.integrations import init_empty_weights
from peft.utils.other import TrainableTokensWrapper
from peft.utils.other import set_additional_trainable_modules

from . import __version__
from .config import PeftConfig
Expand All @@ -53,7 +53,6 @@
PeftType,
TaskType,
_get_batch_size,
_get_input_embeddings_name,
_prepare_prompt_learning_config,
_set_adapter,
_set_trainable,
Expand Down Expand Up @@ -130,8 +129,6 @@ def __init__(
with ctx():
self.base_model = cls(model, {adapter_name: peft_config}, adapter_name)

self.set_additional_trainable_modules(peft_config, adapter_name)

if hasattr(self.base_model, "_cast_adapter_dtype"):
self.base_model._cast_adapter_dtype(
adapter_name=adapter_name, autocast_adapter_dtype=autocast_adapter_dtype
Expand Down Expand Up @@ -931,8 +928,20 @@ def add_adapter(self, adapter_name: str, peft_config: PeftConfig, low_cpu_mem_us

peft_config = _prepare_prompt_learning_config(peft_config, dict_config)
self._setup_prompt_encoder(adapter_name)
set_additional_trainable_modules(
model=self.base_model,
peft_config=peft_config,
model_config=BaseTuner.get_model_config(self),
adapter_name=adapter_name,
)
elif peft_config.is_adaption_prompt:
self.base_model.add_adapter(adapter_name, peft_config)
set_additional_trainable_modules(
model=self.base_model,
peft_config=peft_config,
model_config=BaseTuner.get_model_config(self),
adapter_name=adapter_name,
)
else:
self.peft_config[adapter_name] = peft_config
self.base_model.inject_adapter(
Expand All @@ -943,8 +952,6 @@ def add_adapter(self, adapter_name: str, peft_config: PeftConfig, low_cpu_mem_us
del self.peft_config[adapter_name]
raise

self.set_additional_trainable_modules(peft_config, adapter_name)

def delete_adapter(self, adapter_name: str) -> None:
"""
Deletes an existing adapter.
Expand Down Expand Up @@ -977,64 +984,6 @@ def modules_to_save(self) -> Optional[set[str]]:
return None
return modules

def set_additional_trainable_modules(self, peft_config, adapter_name):
if getattr(peft_config, "modules_to_save", None) is not None:
# this may add a new ModulesToSaveWrapper
_set_trainable(self, adapter_name, module_names=getattr(peft_config, "modules_to_save", None))

if getattr(peft_config, "trainable_token_indices", None) is not None:
if isinstance(peft_config.trainable_token_indices, dict):
target_layers = peft_config.trainable_token_indices
else:
layer_name = _get_input_embeddings_name(self.model, "embed_tokens")
target_layers = {layer_name: peft_config.trainable_token_indices}

if self.modules_to_save:
for target_layer in target_layers:
if target_layer in self.modules_to_save:
raise ValueError(
"The embedding layer is already marked to be trained fully, either specify "
f'`modules_to_save=[..., "{target_layer}", ...]` or '
f"`trainable_tokens={{'{target_layer}': x}}` but not both."
)

# we are not adding these module names to `self.modules_to_save` as this is strictly reserved for the
# `ModulesToSaveWrapper`.

for target_layer, token_indices in target_layers.items():
_set_trainable(
self,
adapter_name,
module_names=[target_layer],
strict_module_check=True,
wrapper_cls=TrainableTokensWrapper,
token_indices=token_indices,
)

# There might be the possibility that we have output weights that are tied to the input weights.
# In that case we will tie any module that wants tied weights to the token adapter to make sure that
# any modification is reflected in the tied layers as well.
model_config = BaseTuner.get_model_config(self)
if (
model_config.get("tie_word_embeddings", False)
# some models may be misconfigured to have weight tying enabled but don't define tied weights keys
and self.model._tied_weights_keys is not None
and isinstance(self.model.get_input_embeddings(), TrainableTokensWrapper)
):
# the embedding layer is modified and we want weight tying.
module_keys = [".".join(n.split(".")[:-1]) for n in self.model._tied_weights_keys]

token_adapter = self.model.get_input_embeddings().token_adapter
_set_trainable(
self,
adapter_name,
module_names=module_keys,
strict_module_check=True,
wrapper_cls=TrainableTokensWrapper,
token_indices=token_adapter.token_indices[adapter_name],
tied_adapter=self.model.get_input_embeddings().token_adapter,
)

def get_layer_status(self) -> list[TunerLayerStatus]:
"""Get the status of each adapter layer in the model.

Expand Down
9 changes: 8 additions & 1 deletion src/peft/tuners/tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
SEQ_CLS_HEAD_NAMES,
)
from peft.utils.integrations import init_empty_weights
from peft.utils.other import AuxiliaryTrainingWrapper
from peft.utils.other import AuxiliaryTrainingWrapper, set_additional_trainable_modules
from peft.utils.peft_types import PeftType, TaskType

from ..config import PeftConfig
Expand Down Expand Up @@ -544,6 +544,13 @@ def inject_adapter(
if adapter_name in n:
p.requires_grad = False

set_additional_trainable_modules(
model=model,
peft_config=peft_config,
model_config=BaseTuner.get_model_config(self),
adapter_name=adapter_name,
)

def merge_adapter(self, adapter_names: Optional[list[str]] = None, safe_merge: bool = False) -> None:
"""
This method merges the adapter layers into the base model.
Expand Down
2 changes: 2 additions & 0 deletions src/peft/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
id_tensor_storage,
infer_device,
prepare_model_for_kbit_training,
set_additional_trainable_modules,
shift_tokens_right,
transpose,
)
Expand Down Expand Up @@ -92,6 +93,7 @@
"prepare_model_for_kbit_training",
"register_peft_method",
"replace_lora_weights_loftq",
"set_additional_trainable_modules",
"set_peft_model_state_dict",
"shift_tokens_right",
"transpose",
Expand Down
61 changes: 61 additions & 0 deletions src/peft/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -1205,3 +1205,64 @@ def get_pattern_key(pattern_keys: Sequence[str], key_to_match: str) -> str:
return key

return key_to_match


def set_additional_trainable_modules(model, peft_config, model_config, adapter_name):
"""Handle the resolution of additional trainable modules (also called AuxiliaryTrainingWrapper)
by checking the config if such modules are requested and adding them to the model.

Currently trainable tokens and modules to save are considered additional trainable modules.
"""
if getattr(peft_config, "modules_to_save", None) is not None:
# this may add a new ModulesToSaveWrapper
_set_trainable(model, adapter_name, module_names=getattr(peft_config, "modules_to_save", None))

if getattr(peft_config, "trainable_token_indices", None) is not None:
if isinstance(peft_config.trainable_token_indices, dict):
target_layers = peft_config.trainable_token_indices
else:
layer_name = _get_input_embeddings_name(model, "embed_tokens")
target_layers = {layer_name: peft_config.trainable_token_indices}

modules_to_save = getattr(peft_config, "modules_to_save", None)
if modules_to_save is not None:
for target_layer in target_layers:
if target_layer in modules_to_save:
raise ValueError(
"The embedding layer is already marked to be trained fully, either specify "
f'`modules_to_save=[..., "{target_layer}", ...]` or '
f"`trainable_tokens={{'{target_layer}': x}}` but not both."
)

for target_layer, token_indices in target_layers.items():
_set_trainable(
model,
adapter_name,
module_names=[target_layer],
strict_module_check=True,
wrapper_cls=TrainableTokensWrapper,
token_indices=token_indices,
)

# There might be the possibility that we have output weights that are tied to the input weights.
# In that case we will tie any module that wants tied weights to the token adapter to make sure that
# any modification is reflected in the tied layers as well.
if (
model_config.get("tie_word_embeddings", False)
# some models may be misconfigured to have weight tying enabled but don't define tied weights keys
and model._tied_weights_keys is not None
and isinstance(model.get_input_embeddings(), TrainableTokensWrapper)
):
# the embedding layer is modified and we want weight tying.
module_keys = [".".join(n.split(".")[:-1]) for n in model._tied_weights_keys]

token_adapter = model.get_input_embeddings().token_adapter
_set_trainable(
model,
adapter_name,
module_names=module_keys,
strict_module_check=True,
wrapper_cls=TrainableTokensWrapper,
token_indices=token_adapter.token_indices[adapter_name],
tied_adapter=model.get_input_embeddings().token_adapter,
)
Loading