Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
49fab86
[WIP] Add LoRA multihead attention module
BenjaminBossan Jan 5, 2024
d8e9589
Make style
BenjaminBossan Jan 5, 2024
0e188a3
Remove commented code
BenjaminBossan Jan 5, 2024
b409d81
Remove assignment of weight to new module
BenjaminBossan Jan 5, 2024
173062c
Make state_dict and named_parameters work
BenjaminBossan Jan 5, 2024
1e007f5
Extend test coverage a bit
BenjaminBossan Jan 8, 2024
557c4a1
Clean ups after reviewer feedback:
BenjaminBossan Jan 9, 2024
add1f51
Reviewer feedback: removed another unnecessary arg
BenjaminBossan Jan 9, 2024
e44e030
Make style
BenjaminBossan Jan 9, 2024
8d62579
Merge branch 'main' into feat-add-lora-multihead-attention
BenjaminBossan Jan 9, 2024
c5d8a6b
Apply LoRA also to the out_proj of MHA
BenjaminBossan Jan 12, 2024
9dc4a4d
Merge branch 'main' into feat-add-lora-multihead-attention
BenjaminBossan Feb 7, 2024
c3fb2ce
Fix bug with incorrectly set gradient
BenjaminBossan Feb 7, 2024
17d407b
Fix failing tests
BenjaminBossan Feb 7, 2024
4cbf6e9
Merge branch 'main' into feat-add-lora-multihead-attention
BenjaminBossan Feb 26, 2024
e0cae11
Move to pytest style asserts
BenjaminBossan Feb 26, 2024
52c8d9b
Fix safe merging code
BenjaminBossan Feb 26, 2024
977c84b
Merge branch 'main' into feat-add-lora-multihead-attention
BenjaminBossan Mar 11, 2024
96d376d
No need to set bias for MHA anymore, see #1530
BenjaminBossan Mar 11, 2024
0c17476
Merge branch 'main' into feat-add-lora-multihead-attention
BenjaminBossan Mar 26, 2024
4b8db0c
Fix style
BenjaminBossan Mar 26, 2024
7e91712
Merge branch 'main' into feat-add-lora-multihead-attention
BenjaminBossan May 21, 2024
e12070b
Merge branch 'main' into feat-add-lora-multihead-attention
BenjaminBossan Jul 25, 2024
7b6c7cb
Remove duplicate merge
BenjaminBossan Jul 25, 2024
e6ab8ed
Raise error for multi adapter batch inference
BenjaminBossan Jul 25, 2024
8ec6c3c
Raise error for DoRA + MHA
BenjaminBossan Jul 25, 2024
f6ba465
Fix error when adding multiple adapters to MHA
BenjaminBossan Jul 25, 2024
fb18886
Better way of param initialization
BenjaminBossan Jul 26, 2024
4ff2ec3
Add tests for broken loading and workaround
BenjaminBossan Jul 26, 2024
d1f6ab2
make style
BenjaminBossan Jul 26, 2024
65363be
Merge branch 'main' into feat-add-lora-multihead-attention
BenjaminBossan Sep 3, 2024
7ba2e68
Fix wrong merge conflict resolution in test
BenjaminBossan Sep 4, 2024
6ef04b0
Ensure that base weights have requires_grad False
BenjaminBossan Sep 4, 2024
07c7240
Merge branch 'main' into feat-add-lora-multihead-attention
BenjaminBossan Sep 4, 2024
cc3ac3d
Remove xpass-ing test
BenjaminBossan Sep 4, 2024
03c466f
Merge branch 'main' into feat-add-lora-multihead-attention
BenjaminBossan Sep 12, 2024
e558caa
MAINT: Give stale bot permissions for PRs too (#2064)
BenjaminBossan Sep 12, 2024
38f4a98
ENH BOFT don't save boft_P buffer (#2050)
sywangyi Sep 13, 2024
7e5c61d
FIX Command line args in PiSSA preprocess (#2053)
keakon Sep 13, 2024
183bf52
MNT Update deprecated evaluation_strategy (#1664)
muellerzr Sep 13, 2024
b970607
ENH Multi adapters in same batch: modules_to_save (#1990)
saeid93 Sep 17, 2024
732e8e7
FIX Bug that prevents BOFT from loading 2 adapters (#2068)
BenjaminBossan Sep 18, 2024
79e2b38
TST Skip some quantization tests on XPU (#2074)
faaany Sep 18, 2024
61e6934
Improve test coverage for initialization of MHA
BenjaminBossan Sep 18, 2024
ced2f15
Merge branch 'main' into feat-add-lora-multihead-attention
BenjaminBossan Oct 14, 2024
4c31bbc
Fix bug with unloading multihead attention layer
BenjaminBossan Oct 21, 2024
1dbb9a5
Fix bug in unloading
BenjaminBossan Oct 22, 2024
e094234
Fix for low_cpu_mem_usage
BenjaminBossan Nov 1, 2024
e90af48
Merge branch 'main' into feat-add-lora-multihead-attention
BenjaminBossan Nov 1, 2024
30a08e7
Merge branch 'main' into feat-add-lora-multihead-attention
BenjaminBossan Nov 1, 2024
09f5ea6
Add tests for init_empty_weights
BenjaminBossan Nov 26, 2024
6a83bd7
Merge branch 'main' into feat-add-lora-multihead-attention
BenjaminBossan Nov 26, 2024
3b0471a
Merge branch 'main' into feat-add-lora-multihead-attention
BenjaminBossan Dec 9, 2024
465a85e
Add MHA to modules unsupported by EVA
BenjaminBossan Dec 9, 2024
266f9da
Add comment on why/how empty init works
BenjaminBossan Jan 6, 2025
39e755e
Expose attributes of underlying MHA module
BenjaminBossan Jan 6, 2025
4857858
Apply suggestions from code review
BenjaminBossan Jan 6, 2025
74cbba6
Remove trailing whitespace
BenjaminBossan Jan 6, 2025
14deb9f
Linting..
BenjaminBossan Jan 6, 2025
ba2a8dd
Reviewer comment: Add comments for clarification
BenjaminBossan Jan 8, 2025
ac10b18
Reviewer feedback: Remove q_proj_weight
BenjaminBossan Jan 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Apply LoRA also to the out_proj of MHA
Before, LoRA was applied only to the in_proj. Now it is also applied to
the out_proj.

Unfortunately, there is no easy way to just apply a normal lora.Linear
to the out_proj by targeting it with target_modules. If that worked, it
would be much nicer to do that, so that users can decide for themselves
if they want to apply LoRA to the out_proj or not.

The reason why it doesn't work is twofold:

1. We cannot really control the order in which LoRA is applied, so when
   the LoRA adapter is injected to out_proj, the whole MHA layer may
   already be wrapped by lora.MultiheadAttention.
2. Even if we successfully applied a normal lora.Linear to the out_proj,
   it would not work correctly. This is because the forward method of
   out_proj is not used at all by nn.MultiheadAttention. Instead, it
   just passes the weight and bias to F.multi_head_attention_forward.
   Therefore, we must ensure that the weights are merged and unmerged
   correctly, same as for in_proj, and we cannot do that if we use a
   normal lora.Linear.

Note that the test test_merge_layers for MHA fails. This is most likely
because of an existing bug in now merging is implemented, see PR #1355.
Once that is merged, the test should pass.
  • Loading branch information
BenjaminBossan committed Jan 12, 2024
commit c5d8a6b6398621a2e1598b64098dd77a4b9f7d0e
78 changes: 76 additions & 2 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,10 @@ class MultiheadAttention(nn.Module, LoraLayer):
This is currently only implemented for the case of `_qkv_same_embed_dim = True`, i.e. query, key, and value having
the same dimension.

Note: LoRA is applied to both the in_proj (query/key/value) and out_proj. There is currently no way to specify only
one of them. Don't try to apply LoRA to the out_proj of MultiheadAttention by targeting that layer specifically,
since the forward method of that layer is not being used, hence the LoRA adapter would be ignored.

This is a little bit hacky because of the way that MultiheadAttention is implemented in PyTorch. It works by
merging the weights before the forward call and unmerging them after the forward call.
"""
Expand All @@ -723,6 +727,23 @@ def __init__(
self._active_adapter = adapter_name
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora)

# Note: LoRA is applied to both in_proj and out_proj. There is currently no way to only specify one of them.
if isinstance(base_layer.out_proj, nn.Linear):
self.base_layer.out_proj = Linear(
base_layer.out_proj,
adapter_name,
r=r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
init_lora_weights=init_lora_weights,
use_rslora=use_rslora,
**kwargs,
)
# bias is accessed directly by nn.MultiheadAttention
self.base_layer.out_proj.bias = self.base_layer.out_proj.get_base_layer().bias
else:
raise ValueError(f"out_proj must be an instance of nn.Linear for {self.__class__.__name__}.")

def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None:
"""
Merge the active adapter weights into the base weights
Expand Down Expand Up @@ -753,21 +774,41 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N
if active_adapter in self.lora_A.keys():
base_layer = self.get_base_layer()
if safe_merge:
# merging in_proj
# TODO: work with separate weights
orig_weights = base_layer.in_proj_weight.data.detach().clone()
orig_weights += self.get_delta_weight(active_adapter)

if not torch.isfinite(orig_weights).all():
raise ValueError(
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
)

del base_layer.in_proj_weight
base_layer.in_proj_weight = orig_weights

# merging out_proj
orig_weights = base_layer.out_proj.weight.data.detach().clone()
orig_weights += base_layer.out_proj.get_delta_weight(active_adapter)
if not torch.isfinite(orig_weights).all():
raise ValueError(
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
)
del base_layer.out_proj.get_base_layer().weight
base_layer.out_proj.get_base_layer().weight = orig_weights
else:
# merging in_proj
# TODO: work with separate weights
weight_merged = base_layer.in_proj_weight.data.detach() + self.get_delta_weight(active_adapter)
del base_layer.in_proj_weight
base_layer.in_proj_weight = weight_merged
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this throw an exception? AFAICS we're assigning a tensor to a parameter value:

foo = torch.nn.Linear(10, 100)
foo.weight = foo.weight.detach() # raises

What am I missing?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's true that we change the type here, I guess you could consider this part of the hack to make this work. At the end, through _restore_weights, the correct type is restored.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, yes. I missed the del statement which unregisters the parameter and, thus, removes the setattr constraint. WDYT about something along the lines of

# unregister parameter implicitly and overwrite using merged weights; gradients are computed
# after forward and, thus, after unmerging (see forward()), therefore this is safe to do.
del base_layer.in_proj_weight
base_layer.in_proj_weight = orig_weights_in

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


# merging out_proj
weight_merged = base_layer.out_proj.weight.data.detach() + base_layer.out_proj.get_delta_weight(
active_adapter
)
del base_layer.out_proj.get_base_layer().weight
base_layer.out_proj.get_base_layer().weight = weight_merged
# self.get_base_layer().out_proj.merge(adapter_names=[active_adapter])
self.merged_adapters.append(active_adapter)

def unmerge(self) -> None:
Expand All @@ -782,7 +823,14 @@ def unmerge(self) -> None:
while len(self.merged_adapters) > 0:
active_adapter = self.merged_adapters.pop()
if active_adapter in self.lora_A.keys():
# in_proj
self.get_base_layer().in_proj_weight.data -= self.get_delta_weight(active_adapter)
# out_proj
self.get_base_layer().out_proj.weight.data -= self.get_base_layer().out_proj.get_delta_weight(
active_adapter
)

self.get_base_layer().out_proj.unmerge()

def get_delta_weight(self, adapter) -> torch.Tensor:
"""
Expand Down Expand Up @@ -828,15 +876,34 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
# merge all adapters that are active for this module
out_proj = self.get_base_layer().out_proj
if out_proj.active_adapters != self.active_adapters:
# We have a case that in_proj and out_proj have diverging merged adapters. We cannot
# really deal with this correctly, thus it's better to raise than possibly create a hard to debug mess
cls_name = self.get_base_layer().__class__.__name__
raise ValueError(
f"The out_proj layer of {cls_name} has merged layers but {cls_name} itself doesn't; please ensure "
"that either both or none have merged layers"
)

# Merge all adapters that are active for this module, i.e. the LoRA weights for in_proj and out_proj.
# in_proj uses nn.Parameters, therefore, there is no forward method to be used and we have to explicitly
# merge for the LoRA weights to have an effect:
# https://github.com/pytorch/pytorch/blob/6ebb26d572d5fcdc6ac0d1297bdf8d1eb5d20722/torch/nn/modules/activation.py#L1020
# For out_proj, we have an nn.Linear (or rather: NonDynamicallyQuantizableLinear), but its forward method
# is not used:
# https://github.com/pytorch/pytorch/blob/6ebb26d572d5fcdc6ac0d1297bdf8d1eb5d20722/torch/nn/modules/activation.py#L1267-L1271
# Therefore, its LoRA weights also need to be merged to have an effect.
active_adapters = [a for a in self.active_adapters if a in self.lora_A]
try:
self.merge(adapter_names=active_adapters)
out_proj.merge(adapter_names=active_adapters)
result = self.base_layer(x, *args, **kwargs)
finally:
# it's safe to call unmerge(), which unmerges all adapters, because we checked that not self.merged,
# i.e. there is was no merged layer before
self.unmerge()
out_proj.unmerge()

result = (result[0].to(previous_dtype), result[1].to(previous_dtype) if result[1] is not None else result[1])
return result
Expand All @@ -848,12 +915,19 @@ def _restore_weights(self):
# We cannot call register_parameter for merging/unmerging because that cuts them off from the autograd graph.
# Note that this is hacky, since we need to ensure that _restore_weights is called by each method that needs it.

# in_proj
# TODO work with separate weights
base_layer = self.get_base_layer()
weight = base_layer.in_proj_weight.data
del base_layer.in_proj_weight
base_layer.register_parameter("in_proj_weight", nn.Parameter(weight))

# out_proj
base_layer = base_layer.out_proj.get_base_layer()
weight = base_layer.weight.data
del base_layer.weight
base_layer.register_parameter("weight", nn.Parameter(weight))

def state_dict(self, *args, **kwargs):
self._restore_weights()
return super().state_dict(*args, **kwargs)
Expand Down
5 changes: 3 additions & 2 deletions tests/test_custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
("Conv2d 1 LoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d"]}),
("Conv2d 2 LoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d", "lin0"]}),
("MHA 1 LoRA", "MHA", LoraConfig, {"target_modules": ["mha"]}),
("MHA 1 LoRA", "MHA", LoraConfig, {"target_modules": ["mha", "lin0"]}),
("MHA 2 LoRA", "MHA", LoraConfig, {"target_modules": ["mha", "lin0"]}),
#######
# IA³ #
#######
Expand Down Expand Up @@ -552,7 +552,8 @@ def test_forward_output_finite(self, test_name, model_id, config_cls, config_kwa

@parameterized.expand(TEST_CASES)
def test_only_params_are_updated(self, test_name, model_id, config_cls, config_kwargs):
# An explicit test that when using LoRA on a custom model, only the LoRA parameters are updated during training
# An explicit test that when using an adapter on a custom model, only the LoRA parameters are updated during
# training
X = self.prepare_inputs_for_testing()
model = self.transformers_class.from_pretrained(model_id).to(self.torch_device)
config = config_cls(
Expand Down