-
Notifications
You must be signed in to change notification settings - Fork 2.2k
FIX #2295: Warn when user reloads modified model #2306
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
f32e517
e9d8b30
50aea9d
f261184
8c2f941
8145b71
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,35 +14,40 @@ | |
| import pytest | ||
| import torch | ||
|
|
||
| from peft import get_peft_model | ||
|
|
||
| from peft import get_peft_model, LoraConfig | ||
|
|
||
| class TestGetPeftModel: | ||
| RELOAD_WARNING_EXPECTED_MATCH = r"You are trying to modify a model .*" | ||
|
|
||
| @pytest.fixture | ||
BenjaminBossan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| def lora_config(self): | ||
| from peft import LoraConfig | ||
|
|
||
| def lora_config_0(self): | ||
| return LoraConfig(target_modules="0") | ||
|
|
||
| @pytest.fixture | ||
| def base_model(self): | ||
| return torch.nn.Sequential(torch.nn.Linear(10, 2)) | ||
| return torch.nn.Sequential(torch.nn.Linear(10, 2), torch.nn.Linear(2, 10)) | ||
|
|
||
| def test_get_peft_model_warns_when_reloading_model(self, lora_config, base_model): | ||
| get_peft_model(base_model, lora_config) | ||
| def test_get_peft_model_warns_when_reloading_model(self, lora_config_0, base_model): | ||
| get_peft_model(base_model, lora_config_0) | ||
|
|
||
| with pytest.warns(UserWarning, match=self.RELOAD_WARNING_EXPECTED_MATCH): | ||
| get_peft_model(base_model, lora_config) | ||
| get_peft_model(base_model, lora_config_0) | ||
|
|
||
| def test_get_peft_model_proposed_fix_in_warning_helps(self, lora_config, base_model, recwarn): | ||
| peft_model = get_peft_model(base_model, lora_config) | ||
| def test_get_peft_model_proposed_fix_in_warning_helps(self, lora_config_0, base_model, recwarn): | ||
| peft_model = get_peft_model(base_model, lora_config_0) | ||
| peft_model.unload() | ||
| get_peft_model(base_model, lora_config) | ||
| get_peft_model(base_model, lora_config_0) | ||
|
|
||
| warning_checker = pytest.warns(UserWarning, match=self.RELOAD_WARNING_EXPECTED_MATCH) | ||
|
|
||
| for warning in recwarn: | ||
| if warning_checker.matches(warning): | ||
| pytest.fail("Warning raised even though model was unloaded.") | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe a bit of an edge case, but currently we would not detect if a user tries to create, say, a LoRA model based on a base model that was already modified with a prompt learning method, like prefix tuning. This is because those methods don't add any
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But that case is also not error-prone, or is it?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Honestly, I'm not quite sure if it would just work or not. Coincidentally, there is a recent issue asking a similar question (#2307).
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For our sanity I'd suggest to go forward with the current state of things and once we know more about the interplay between soft-prompting vs. soft-prompting and soft-prompting vs. lora we adapt tests and/or implementation. WDYT? |
||
|
|
||
| def test_get_peft_model_repeated_invocation(self, lora_config_0, base_model): | ||
| peft_model = get_peft_model(base_model, lora_config_0) | ||
|
|
||
| lora_config_1 = LoraConfig(target_modules="base_model.model.1") | ||
|
|
||
| with pytest.warns(UserWarning, match=self.RELOAD_WARNING_EXPECTED_MATCH): | ||
| get_peft_model(peft_model, lora_config_1) | ||
Uh oh!
There was an error while loading. Please reload this page.