-
Notifications
You must be signed in to change notification settings - Fork 2.2k
FIX #2295: Warn when user reloads modified model #2306
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
f32e517
e9d8b30
50aea9d
f261184
8c2f941
8145b71
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
When modifying a model with `get_peft_model` that was already modified in the same way, even specifying a different config may not change the trainable parameter count, e.g. when specifying target modules that are only a subset of the previous target modules. With this patch a warning will be issued with a hint to `.unload()` when calling `get_peft_model` on an already modified model.
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -31,6 +31,7 @@ | |
| PeftModelForSeq2SeqLM, | ||
| PeftModelForSequenceClassification, | ||
| PeftModelForTokenClassification, | ||
| get_layer_status, | ||
| ) | ||
| from .tuners import ( | ||
| AdaLoraConfig, | ||
|
|
@@ -181,6 +182,21 @@ def get_peft_model( | |
| new_name = model.__dict__.get("name_or_path", None) | ||
| peft_config.base_model_name_or_path = new_name | ||
|
|
||
| # Especially in notebook environments there could be a case that a user | ||
| # wants to experiment with different configuration values. However, it | ||
| # is likely that there won't be any changes for new configs on an already | ||
| # initialized PEFT model. The best we can do is warn the user about it. | ||
| try: | ||
| if len(get_layer_status(model)) > 0: | ||
|
||
| warnings.warn( | ||
| "You are trying to modify a model with PEFT for a " | ||
| "second time. If you want to reload the model with a " | ||
| "different config, make sure to call `.unload()` before." | ||
| ) | ||
| except ValueError: | ||
| # not a PEFT model or no adapters in use | ||
| pass | ||
|
|
||
| if (old_name is not None) and (old_name != new_name): | ||
| warnings.warn( | ||
| f"The PEFT config's `base_model_name_or_path` was renamed from '{old_name}' to '{new_name}'. " | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| import pytest | ||
BenjaminBossan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| import torch | ||
|
|
||
|
|
||
| class TestGetPeftModel: | ||
| RELOAD_WARNING_EXPECTED_MATCH = r"You are trying to modify a model .*" | ||
|
|
||
| @pytest.fixture | ||
BenjaminBossan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| def get_peft_model(self): | ||
| from peft import get_peft_model | ||
|
|
||
| return get_peft_model | ||
|
|
||
| @pytest.fixture | ||
| def lora_config(self): | ||
| from peft import LoraConfig | ||
|
|
||
| return LoraConfig(target_modules="0") | ||
|
|
||
| @pytest.fixture | ||
| def base_model(self): | ||
| return torch.nn.Sequential(torch.nn.Linear(10, 2)) | ||
|
|
||
| def test_get_peft_model_warns_when_reloading_model(self, get_peft_model, lora_config, base_model): | ||
| get_peft_model(base_model, lora_config) | ||
|
|
||
| with pytest.warns(UserWarning, match=self.RELOAD_WARNING_EXPECTED_MATCH): | ||
| get_peft_model(base_model, lora_config) | ||
|
||
|
|
||
| def test_get_peft_model_proposed_fix_in_warning_help(self, get_peft_model, lora_config, base_model, recwarn): | ||
| peft_model = get_peft_model(base_model, lora_config) | ||
| peft_model.unload() | ||
| get_peft_model(base_model, lora_config) | ||
|
|
||
| warning_checker = pytest.warns(UserWarning, match=self.RELOAD_WARNING_EXPECTED_MATCH) | ||
|
|
||
| for warning in recwarn: | ||
| if warning_checker.matches(warning): | ||
| pytest.fail("Warning raised even though model was unloaded.") | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe a bit of an edge case, but currently we would not detect if a user tries to create, say, a LoRA model based on a base model that was already modified with a prompt learning method, like prefix tuning. This is because those methods don't add any
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But that case is also not error-prone, or is it?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Honestly, I'm not quite sure if it would just work or not. Coincidentally, there is a recent issue asking a similar question (#2307).
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For our sanity I'd suggest to go forward with the current state of things and once we know more about the interplay between soft-prompting vs. soft-prompting and soft-prompting vs. lora we adapt tests and/or implementation. WDYT? |
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have 120 char limits in the project, could you please configure your formatter accordingly?