Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/peft/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
PeftModelForSeq2SeqLM,
PeftModelForSequenceClassification,
PeftModelForTokenClassification,
get_layer_status,
)
from .tuners import (
AdaLoraConfig,
Expand Down Expand Up @@ -181,6 +182,21 @@ def get_peft_model(
new_name = model.__dict__.get("name_or_path", None)
peft_config.base_model_name_or_path = new_name

# Especially in notebook environments there could be a case that a user
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have 120 char limits in the project, could you please configure your formatter accordingly?

# wants to experiment with different configuration values. However, it
# is likely that there won't be any changes for new configs on an already
# initialized PEFT model. The best we can do is warn the user about it.
try:
if len(get_layer_status(model)) > 0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice idea to use get_layer_status. I wonder if in this case, a simple

if any(isinstance(module, BaseTunerLayer) for module in model.modules())

would not serve the purpose better. This check would stop once the first PEFT layer is found, while get_layer_status would do a bunch of more work unnecessarily.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For some reason I thought that p-tuning/prompt-tuning are also layer tuners (which they aren't) so I thought it was worthwhile to use the more complex get_layer_status. But you're correct, a simple check suffices.

warnings.warn(
"You are trying to modify a model with PEFT for a "
"second time. If you want to reload the model with a "
"different config, make sure to call `.unload()` before."
)
except ValueError:
# not a PEFT model or no adapters in use
pass

if (old_name is not None) and (old_name != new_name):
warnings.warn(
f"The PEFT config's `base_model_name_or_path` was renamed from '{old_name}' to '{new_name}'. "
Expand Down
39 changes: 39 additions & 0 deletions tests/test_mapping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import pytest
import torch


class TestGetPeftModel:
RELOAD_WARNING_EXPECTED_MATCH = r"You are trying to modify a model .*"

@pytest.fixture
def get_peft_model(self):
from peft import get_peft_model

return get_peft_model

@pytest.fixture
def lora_config(self):
from peft import LoraConfig

return LoraConfig(target_modules="0")

@pytest.fixture
def base_model(self):
return torch.nn.Sequential(torch.nn.Linear(10, 2))

def test_get_peft_model_warns_when_reloading_model(self, get_peft_model, lora_config, base_model):
get_peft_model(base_model, lora_config)

with pytest.warns(UserWarning, match=self.RELOAD_WARNING_EXPECTED_MATCH):
get_peft_model(base_model, lora_config)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about also adding a check where the user calls get_peft_model on the PeftModel instance itself?


def test_get_peft_model_proposed_fix_in_warning_helps(self, get_peft_model, lora_config, base_model, recwarn):
peft_model = get_peft_model(base_model, lora_config)
peft_model.unload()
get_peft_model(base_model, lora_config)

warning_checker = pytest.warns(UserWarning, match=self.RELOAD_WARNING_EXPECTED_MATCH)

for warning in recwarn:
if warning_checker.matches(warning):
pytest.fail("Warning raised even though model was unloaded.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a bit of an edge case, but currently we would not detect if a user tries to create, say, a LoRA model based on a base model that was already modified with a prompt learning method, like prefix tuning. This is because those methods don't add any BaseTunerLayers (which is what get_layer_status relies on). Implementing such a check is probably not easy and out of scope for this PR. We could still add an xfail-ing test though.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But that case is also not error-prone, or is it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly, I'm not quite sure if it would just work or not. Coincidentally, there is a recent issue asking a similar question (#2307).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For our sanity I'd suggest to go forward with the current state of things and once we know more about the interplay between soft-prompting vs. soft-prompting and soft-prompting vs. lora we adapt tests and/or implementation. WDYT?

Loading