Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Test repeated invocation on peft model
  • Loading branch information
nemo committed Jan 7, 2025
commit f261184ad08c897cb57632ffcae5dbe87fd4a3f5
29 changes: 17 additions & 12 deletions tests/test_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,35 +14,40 @@
import pytest
import torch

from peft import get_peft_model

from peft import get_peft_model, LoraConfig

class TestGetPeftModel:
RELOAD_WARNING_EXPECTED_MATCH = r"You are trying to modify a model .*"

@pytest.fixture
def lora_config(self):
from peft import LoraConfig

def lora_config_0(self):
return LoraConfig(target_modules="0")

@pytest.fixture
def base_model(self):
return torch.nn.Sequential(torch.nn.Linear(10, 2))
return torch.nn.Sequential(torch.nn.Linear(10, 2), torch.nn.Linear(2, 10))

def test_get_peft_model_warns_when_reloading_model(self, lora_config, base_model):
get_peft_model(base_model, lora_config)
def test_get_peft_model_warns_when_reloading_model(self, lora_config_0, base_model):
get_peft_model(base_model, lora_config_0)

with pytest.warns(UserWarning, match=self.RELOAD_WARNING_EXPECTED_MATCH):
get_peft_model(base_model, lora_config)
get_peft_model(base_model, lora_config_0)

def test_get_peft_model_proposed_fix_in_warning_helps(self, lora_config, base_model, recwarn):
peft_model = get_peft_model(base_model, lora_config)
def test_get_peft_model_proposed_fix_in_warning_helps(self, lora_config_0, base_model, recwarn):
peft_model = get_peft_model(base_model, lora_config_0)
peft_model.unload()
get_peft_model(base_model, lora_config)
get_peft_model(base_model, lora_config_0)

warning_checker = pytest.warns(UserWarning, match=self.RELOAD_WARNING_EXPECTED_MATCH)

for warning in recwarn:
if warning_checker.matches(warning):
pytest.fail("Warning raised even though model was unloaded.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a bit of an edge case, but currently we would not detect if a user tries to create, say, a LoRA model based on a base model that was already modified with a prompt learning method, like prefix tuning. This is because those methods don't add any BaseTunerLayers (which is what get_layer_status relies on). Implementing such a check is probably not easy and out of scope for this PR. We could still add an xfail-ing test though.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But that case is also not error-prone, or is it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly, I'm not quite sure if it would just work or not. Coincidentally, there is a recent issue asking a similar question (#2307).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For our sanity I'd suggest to go forward with the current state of things and once we know more about the interplay between soft-prompting vs. soft-prompting and soft-prompting vs. lora we adapt tests and/or implementation. WDYT?


def test_get_peft_model_repeated_invocation(self, lora_config_0, base_model):
peft_model = get_peft_model(base_model, lora_config_0)

lora_config_1 = LoraConfig(target_modules="base_model.model.1")

with pytest.warns(UserWarning, match=self.RELOAD_WARNING_EXPECTED_MATCH):
get_peft_model(peft_model, lora_config_1)
Loading