Skip to content

Fix #2477: Regression accessing modules_to_save#2481

Merged
githubnemo merged 11 commits intohuggingface:mainfrom
githubnemo:issue/modules-to-save-regression
Apr 17, 2025
Merged

Fix #2477: Regression accessing modules_to_save#2481
githubnemo merged 11 commits intohuggingface:mainfrom
githubnemo:issue/modules-to-save-regression

Conversation

@githubnemo
Copy link
Collaborator

Commit ed3c828 introduced adapter-local modules_to_save initialization which prevented needless initialization but also broke prompt tuning methods as they don't have the modules_to_save attribute.

This change also introduces a sequence classification test suite that also tests prompt tuning methods. While not comprehensive it is sufficient to catch this error and can be extended over time.

While working on this and testing RoBERTa there was also an issue with the default target of AdaLoRA as it defaults to dense (among other modules). This is problematic for PeftModelForSequenceClassification as they mark classification.* as modules_to_save. But since the classification layer is also a dense layer it will be targeted by AdaLoRA. To prevent such situations in the future a general excemption was made in check_target_module_exists to always avoid keys in modules_to_save. For this to work the config modification done in PeftModelForSequenceClassification needed changing.

There's an open TODO to extend the excemption to all AuxiliaryTrainingWrapper classes. I wanted to get feedback for this change first but do you think that would make sense as well @BenjaminBossan?

Commit ed3c828 introduced adapter-local modules_to_save initialization which prevented
needless initialization but also broke prompt tuning methods as they don't have the `modules_to_save`
attribute.

This change also introduces a sequence classification test suite that also tests prompt tuning methods.
While not comprehensive it is sufficient to catch this error and can be extended over time.

While working on this and testing RoBERTa there was also an issue with the default target of `AdaLoRA`
as it defaults to `dense` (among other modules). This is problematic for `PeftModelForSequenceClassification`
as they mark `classification.*` as `modules_to_save`. But since the classification layer is also a dense layer
it will be targeted by `AdaLoRA`. To prevent such situations in the future a general excemption was made in
`check_target_module_exists` to always avoid keys in `modules_to_save`. For this to work the config modification
done in `PeftModelForSequenceClassification` needed changing.
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What a rabbit whole :-/

Generally, I think this direction is good, I added a couple of small comments.

While working on this and testing RoBERTa there was also an issue with the default target of AdaLoRA as it defaults to dense (among other modules). This is problematic for PeftModelForSequenceClassification as they mark classification.* as modules_to_save.

I assume this is tested indirectly through the new tests? If not, let's add a test.

"""
if hasattr(peft_config, "modules_to_save"):
return peft_config.modules_to_save
return None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having this function is fine for me, I just wonder why you chose not to go with getattr(peft_config, "modules_to_save", None).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For some reason I thought it was necessary :) Changed to getattr.

import pytest
import torch
from transformers import (
AutoModelForSequenceClassification,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be one line


class TestSequenceClassificationModels(PeftCommonTester):
r"""
Test if the PeftModel behaves as expected. This includes:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This docstring can be adjusted/scrapped, right? Let's maybe mention instead that it is intended that not the whole test battery is run here.

@githubnemo
Copy link
Collaborator Author

I assume this is tested indirectly through the new tests? If not, let's add a test.

Yes, the save_pretrained tests do that since ModulesToSaveWrapper would expect model.classifier.weight but, when targeted by AdaLoRA, that key would not exist.

WDYT about excempting all aux. training wrappers from being targeted? As it currently stands that would require manually checking the corresponding config arguments (modules_to_save, trainable_token_indices, ). Since we're pre model-modification there's no better way to detect if we're dealing with aux. training wrappers.

I noticed that test_modules_to_save_targets_tuner_layer_raises fails now since the layers are silently ignored. Is that check now redundant?

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WDYT about excempting all aux. training wrappers from being targeted?

👍

I noticed that test_modules_to_save_targets_tuner_layer_raises fails now since the layers are silently ignored. Is that check now redundant?

Yes, the test can be removed or adopted maybe to check that the same layer is not double wrapped, WDYT?

nemo added 6 commits April 8, 2025 19:56
This code was *probably* for dealing with modules_to_save when calling
inject_adapter directly. However, since the only place that does this is
the PEFT mixed module which already deals with modules_to_save this
code is deemed superfluous.

This also makes dealing with ignoring `modules_to_save` in during targeting
easier since we can use the code in `check_target_module_exists` for every
case (targeting nested layer in modules_to_save module + direct targeting of
modules_to_save module).
Otherwise the model's classification head will be re-initialized regularly, breaking assumptions
@githubnemo githubnemo force-pushed the issue/modules-to-save-regression branch 2 times, most recently from 67db9ba to f903d78 Compare April 10, 2025 15:12
githubnemo pushed a commit to githubnemo/peft that referenced this pull request Apr 14, 2025
This change is a breakout of the changes in PR huggingface#2481 to form a patch release.
There are no additional tests.
Move `set_additional_trainable_modules` to `inject_adapter` in case of adapters such as LoRAs, or,
in case of prompt tuning adapters, to their respective initialization point (while keeping the order
of operations intact).

Before this change a significant portion of `modules_to_save` initialization was removed from
`check_target_layer_exists` (called from `inject_adapter`) which only handled the `modules_to_save`
parameter in cases where this function was called directly (e.g., via `LoraModel.add_weighted_adapter`).
This also meant that trainable tokens was completely ignored in these cases. It also copied code from
`_set_trainable`.

The removal prompted the need to find a replacement which is this change: on adapter injection we will
now always check if there need to be additional trainable modules, not only during `PeftModel` init.
githubnemo added a commit that referenced this pull request Apr 15, 2025
This change is a breakout of the changes in PR #2481 to form a patch release.
There are no additional tests but the HF_HUB_OFFLINE feature was merged to improve the CI experience.

* Testing common uses situational HF_HUB_OFFLINE (#2490)

Employ offline mode when the model was already accessed once from the hub in order to speed up the CI and make the process less prone to rate limiting.

The idea here is that we can mark contexts that, once they were visited once for a specific model id, we can assume that they are cached locally and can set HF_HUB_OFFLINE=1 for this context. This PR tests this concept for testing_common which is already a big chunk of the tests and probably has the biggest gain given the amount of change.

We already saw that the assumption does not always hold true: for the prompt tuning tests (_test_prepare_input_for_generation) there is a case where one time the tokenizer is not used for model X and after that time the tokenizer is used - since we're setting the hub to offline for the second time the tokenizer from_pretrained call will fail. This problem is alleviated by adding the tokenizer name to the model id as cache identifier.

(cherry picked from commit 1083964)
(Removed delete adapter tests)
@githubnemo
Copy link
Collaborator Author

I addressed the comments and also fixed the failing tests with respect to the code removal in check_target_layer_exists. To quote the commit message:

Move set_additional_trainable_modules to inject_adapter in case of adapters such as LoRAs, or,
in case of prompt tuning adapters, to their respective initialization point (while keeping the order
of operations intact).

Before this change a significant portion of modules_to_save initialization was removed from
check_target_layer_exists (called from inject_adapter) which only handled the modules_to_save
parameter in cases where this function was called directly (e.g., via LoraModel.add_weighted_adapter).
This also meant that trainable tokens was completely ignored in these cases. It also copied code from
_set_trainable.

The removal prompted the need to find a replacement which is this change: on adapter injection we will
now always check if there need to be additional trainable modules, not only during PeftModel init.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR LGTM, thanks, nice work.

One small thing: It would be good to have tests for the transformers and diffusers integrations when using modules_to_save (and possibly trainable_token_indices?) to check requires_grad. I would be fine with adding them later, in which case the PR can be merged.

@githubnemo
Copy link
Collaborator Author

OK, let's do the tests in a separate PR. Thanks for the review :)

@githubnemo githubnemo merged commit 36160a5 into huggingface:main Apr 17, 2025
14 checks passed
BenjaminBossan added a commit to BenjaminBossan/peft that referenced this pull request Apr 22, 2025
PR huggingface#2481 added sequence classification tests to PEFT. The test matrix
included CPT. However, CPT only supports the task type CAUSAL_LM. These
tests still passed but now started failing with:

> AttributeError: object has no attribute 'prepare_inputs_for_generation'

This is probably a change in transformers but the since causal LM was
never meant to work, the actual fix is to remove CPT from the seq cls
test matrix.

Since CPT automatically changes the task type to CAUSAL_LM, this mistake
can be hard to spot. Therefore, this PR also adds a warning if users
pass the wrong task type.
BenjaminBossan added a commit that referenced this pull request Apr 23, 2025
PR #2481 added sequence classification tests to PEFT. The test matrix
included CPT. However, CPT only supports the task type CAUSAL_LM. These
tests still passed but now started failing with:

> AttributeError: object has no attribute 'prepare_inputs_for_generation'

This is probably a change in transformers but the since causal LM was
never meant to work, the actual fix is to remove CPT from the seq cls
test matrix.

Since CPT automatically changes the task type to CAUSAL_LM, this mistake
can be hard to spot. Therefore, this PR also adds a warning if users
pass the wrong task type. In the future, this will raise an error.
Guy-Bilitski pushed a commit to Guy-Bilitski/peft that referenced this pull request May 13, 2025
…face#2481)

* Fix huggingface#2477: Regression accessing `modules_to_save`

Commit 501df99 introduced adapter-local modules_to_save initialization which prevented
needless initialization but also broke prompt tuning methods as they don't have the `modules_to_save`
attribute.

This change also introduces a sequence classification test suite that also tests prompt tuning methods.
While not comprehensive it is sufficient to catch this error and can be extended over time.

While working on this and testing RoBERTa there was also an issue with the default target of `AdaLoRA`
as it defaults to `dense` (among other modules). This is problematic for `PeftModelForSequenceClassification`
as they mark `classification.*` as `modules_to_save`. But since the classification layer is also a dense layer
it will be targeted by `AdaLoRA`. To prevent such situations in the future a general excemption was made in
`check_target_module_exists` to always avoid keys in `modules_to_save`. For this to work the config modification
done in `PeftModelForSequenceClassification` needed changing.

* Remove presumably superflous code from inject_adapter

This code was *probably* for dealing with modules_to_save when calling
inject_adapter directly. However, since the only place that does this is
the PEFT mixed module which already deals with modules_to_save this
code is deemed superfluous.

This also makes dealing with ignoring `modules_to_save` in during targeting
easier since we can use the code in `check_target_module_exists` for every
case (targeting nested layer in modules_to_save module + direct targeting of
modules_to_save module).

* Move `set_additional_trainable_modules`

Move `set_additional_trainable_modules` to `inject_adapter` in case of adapters such as LoRAs, or,
in case of prompt tuning adapters, to their respective initialization point (while keeping the order
of operations intact).

Before this change a significant portion of `modules_to_save` initialization was removed from
`check_target_layer_exists` (called from `inject_adapter`) which only handled the `modules_to_save`
parameter in cases where this function was called directly (e.g., via `LoraModel.add_weighted_adapter`).
This also meant that trainable tokens was completely ignored in these cases. It also copied code from
`_set_trainable`.

The removal prompted the need to find a replacement which is this change: on adapter injection we will
now always check if there need to be additional trainable modules, not only during `PeftModel` init.
Guy-Bilitski pushed a commit to Guy-Bilitski/peft that referenced this pull request May 13, 2025
…ce#2507)

PR huggingface#2481 added sequence classification tests to PEFT. The test matrix
included CPT. However, CPT only supports the task type CAUSAL_LM. These
tests still passed but now started failing with:

> AttributeError: object has no attribute 'prepare_inputs_for_generation'

This is probably a change in transformers but the since causal LM was
never meant to work, the actual fix is to remove CPT from the seq cls
test matrix.

Since CPT automatically changes the task type to CAUSAL_LM, this mistake
can be hard to spot. Therefore, this PR also adds a warning if users
pass the wrong task type. In the future, this will raise an error.
efraimdahl pushed a commit to efraimdahl/peft that referenced this pull request Jul 12, 2025
…face#2481)

* Fix huggingface#2477: Regression accessing `modules_to_save`

Commit ed3c828 introduced adapter-local modules_to_save initialization which prevented
needless initialization but also broke prompt tuning methods as they don't have the `modules_to_save`
attribute.

This change also introduces a sequence classification test suite that also tests prompt tuning methods.
While not comprehensive it is sufficient to catch this error and can be extended over time.

While working on this and testing RoBERTa there was also an issue with the default target of `AdaLoRA`
as it defaults to `dense` (among other modules). This is problematic for `PeftModelForSequenceClassification`
as they mark `classification.*` as `modules_to_save`. But since the classification layer is also a dense layer
it will be targeted by `AdaLoRA`. To prevent such situations in the future a general excemption was made in
`check_target_module_exists` to always avoid keys in `modules_to_save`. For this to work the config modification
done in `PeftModelForSequenceClassification` needed changing.

* Remove presumably superflous code from inject_adapter

This code was *probably* for dealing with modules_to_save when calling
inject_adapter directly. However, since the only place that does this is
the PEFT mixed module which already deals with modules_to_save this
code is deemed superfluous.

This also makes dealing with ignoring `modules_to_save` in during targeting
easier since we can use the code in `check_target_module_exists` for every
case (targeting nested layer in modules_to_save module + direct targeting of
modules_to_save module).

* Move `set_additional_trainable_modules`

Move `set_additional_trainable_modules` to `inject_adapter` in case of adapters such as LoRAs, or,
in case of prompt tuning adapters, to their respective initialization point (while keeping the order
of operations intact).

Before this change a significant portion of `modules_to_save` initialization was removed from
`check_target_layer_exists` (called from `inject_adapter`) which only handled the `modules_to_save`
parameter in cases where this function was called directly (e.g., via `LoraModel.add_weighted_adapter`).
This also meant that trainable tokens was completely ignored in these cases. It also copied code from
`_set_trainable`.

The removal prompted the need to find a replacement which is this change: on adapter injection we will
now always check if there need to be additional trainable modules, not only during `PeftModel` init.
efraimdahl pushed a commit to efraimdahl/peft that referenced this pull request Jul 12, 2025
…ce#2507)

PR huggingface#2481 added sequence classification tests to PEFT. The test matrix
included CPT. However, CPT only supports the task type CAUSAL_LM. These
tests still passed but now started failing with:

> AttributeError: object has no attribute 'prepare_inputs_for_generation'

This is probably a change in transformers but the since causal LM was
never meant to work, the actual fix is to remove CPT from the seq cls
test matrix.

Since CPT automatically changes the task type to CAUSAL_LM, this mistake
can be hard to spot. Therefore, this PR also adds a warning if users
pass the wrong task type. In the future, this will raise an error.
BenjaminBossan added a commit to BenjaminBossan/peft that referenced this pull request Jul 14, 2025
When using prompt learning methods, modules_to_save was not correctly
set automatically. This is really bad when using, for instance, sequence
classification tasks, which require the classifier layer to be added to
modules_to_save.

The issue was introduced in huggingface#2220 where it is wrongly assumed that the
PEFT config always has a modules_to_save attribute, which is not true
for prompt learning. In huggingface#2481, this was partly fixed by using getattr to
avoid an error. However, this did not resolve the fundamental issue that
for prompt learning, there is no such attribute, resulting in
module_to_save not being applied.

This PR proposes to fix this by adding modules_to_save to the prompt
learning configs.
BenjaminBossan added a commit that referenced this pull request Jul 14, 2025
When using prompt learning methods, modules_to_save was not correctly
set automatically. This is really bad when using, for instance, sequence
classification tasks, which require the classifier layer to be added to
modules_to_save.

The issue was introduced in #2220 where it is wrongly assumed that the
PEFT config always has a modules_to_save attribute, which is not true
for prompt learning. In #2481, this was partly fixed by using getattr to
avoid an error. However, this did not resolve the fundamental issue that
for prompt learning, there is no such attribute, resulting in
module_to_save not being applied.

This PR proposes to fix this by adding modules_to_save to the prompt
learning configs.
BenjaminBossan added a commit to BenjaminBossan/peft that referenced this pull request Jul 28, 2025
When using prompt learning methods, modules_to_save was not correctly
set automatically. This is really bad when using, for instance, sequence
classification tasks, which require the classifier layer to be added to
modules_to_save.

The issue was introduced in huggingface#2220 where it is wrongly assumed that the
PEFT config always has a modules_to_save attribute, which is not true
for prompt learning. In huggingface#2481, this was partly fixed by using getattr to
avoid an error. However, this did not resolve the fundamental issue that
for prompt learning, there is no such attribute, resulting in
module_to_save not being applied.

This PR proposes to fix this by adding modules_to_save to the prompt
learning configs.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants