-
Notifications
You must be signed in to change notification settings - Fork 2.2k
[WIP] Add LoRA multihead attention module #1324
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
49fab86
d8e9589
0e188a3
b409d81
173062c
1e007f5
557c4a1
add1f51
e44e030
8d62579
c5d8a6b
9dc4a4d
c3fb2ce
17d407b
4cbf6e9
e0cae11
52c8d9b
977c84b
96d376d
0c17476
4b8db0c
7e91712
e12070b
7b6c7cb
e6ab8ed
8ec6c3c
f6ba465
fb18886
4ff2ec3
d1f6ab2
65363be
7ba2e68
6ef04b0
07c7240
cc3ac3d
03c466f
e558caa
38f4a98
7e5c61d
183bf52
b970607
732e8e7
79e2b38
61e6934
ced2f15
4c31bbc
1dbb9a5
e094234
e90af48
30a08e7
09f5ea6
6a83bd7
3b0471a
465a85e
266f9da
39e755e
4857858
74cbba6
14deb9f
ba2a8dd
ac10b18
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
Before, LoRA was applied only to the in_proj. Now it is also applied to the out_proj. Unfortunately, there is no easy way to just apply a normal lora.Linear to the out_proj by targeting it with target_modules. If that worked, it would be much nicer to do that, so that users can decide for themselves if they want to apply LoRA to the out_proj or not. The reason why it doesn't work is twofold: 1. We cannot really control the order in which LoRA is applied, so when the LoRA adapter is injected to out_proj, the whole MHA layer may already be wrapped by lora.MultiheadAttention. 2. Even if we successfully applied a normal lora.Linear to the out_proj, it would not work correctly. This is because the forward method of out_proj is not used at all by nn.MultiheadAttention. Instead, it just passes the weight and bias to F.multi_head_attention_forward. Therefore, we must ensure that the weights are merged and unmerged correctly, same as for in_proj, and we cannot do that if we use a normal lora.Linear. Note that the test test_merge_layers for MHA fails. This is most likely because of an existing bug in now merging is implemented, see PR #1355. Once that is merged, the test should pass.
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -698,6 +698,10 @@ class MultiheadAttention(nn.Module, LoraLayer): | |
| This is currently only implemented for the case of `_qkv_same_embed_dim = True`, i.e. query, key, and value having | ||
| the same dimension. | ||
|
|
||
| Note: LoRA is applied to both the in_proj (query/key/value) and out_proj. There is currently no way to specify only | ||
| one of them. Don't try to apply LoRA to the out_proj of MultiheadAttention by targeting that layer specifically, | ||
| since the forward method of that layer is not being used, hence the LoRA adapter would be ignored. | ||
|
|
||
| This is a little bit hacky because of the way that MultiheadAttention is implemented in PyTorch. It works by | ||
| merging the weights before the forward call and unmerging them after the forward call. | ||
| """ | ||
|
|
@@ -723,6 +727,23 @@ def __init__( | |
| self._active_adapter = adapter_name | ||
| self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora) | ||
|
|
||
| # Note: LoRA is applied to both in_proj and out_proj. There is currently no way to only specify one of them. | ||
| if isinstance(base_layer.out_proj, nn.Linear): | ||
| self.base_layer.out_proj = Linear( | ||
| base_layer.out_proj, | ||
| adapter_name, | ||
| r=r, | ||
| lora_alpha=lora_alpha, | ||
| lora_dropout=lora_dropout, | ||
| init_lora_weights=init_lora_weights, | ||
| use_rslora=use_rslora, | ||
| **kwargs, | ||
| ) | ||
| # bias is accessed directly by nn.MultiheadAttention | ||
| self.base_layer.out_proj.bias = self.base_layer.out_proj.get_base_layer().bias | ||
| else: | ||
| raise ValueError(f"out_proj must be an instance of nn.Linear for {self.__class__.__name__}.") | ||
|
|
||
| def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None: | ||
| """ | ||
| Merge the active adapter weights into the base weights | ||
|
|
@@ -753,21 +774,41 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N | |
| if active_adapter in self.lora_A.keys(): | ||
| base_layer = self.get_base_layer() | ||
| if safe_merge: | ||
| # merging in_proj | ||
| # TODO: work with separate weights | ||
| orig_weights = base_layer.in_proj_weight.data.detach().clone() | ||
| orig_weights += self.get_delta_weight(active_adapter) | ||
|
|
||
| if not torch.isfinite(orig_weights).all(): | ||
| raise ValueError( | ||
| f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" | ||
| ) | ||
|
|
||
| del base_layer.in_proj_weight | ||
| base_layer.in_proj_weight = orig_weights | ||
|
|
||
| # merging out_proj | ||
BenjaminBossan marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| orig_weights = base_layer.out_proj.weight.data.detach().clone() | ||
| orig_weights += base_layer.out_proj.get_delta_weight(active_adapter) | ||
| if not torch.isfinite(orig_weights).all(): | ||
| raise ValueError( | ||
| f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" | ||
| ) | ||
| del base_layer.out_proj.get_base_layer().weight | ||
| base_layer.out_proj.get_base_layer().weight = orig_weights | ||
| else: | ||
| # merging in_proj | ||
BenjaminBossan marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| # TODO: work with separate weights | ||
| weight_merged = base_layer.in_proj_weight.data.detach() + self.get_delta_weight(active_adapter) | ||
| del base_layer.in_proj_weight | ||
| base_layer.in_proj_weight = weight_merged | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't this throw an exception? AFAICS we're assigning a tensor to a parameter value: foo = torch.nn.Linear(10, 100)
foo.weight = foo.weight.detach() # raisesWhat am I missing?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's true that we change the type here, I guess you could consider this part of the hack to make this work. At the end, through
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah, yes. I missed the # unregister parameter implicitly and overwrite using merged weights; gradients are computed
# after forward and, thus, after unmerging (see forward()), therefore this is safe to do.
del base_layer.in_proj_weight
base_layer.in_proj_weight = orig_weights_in
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
|
|
||
| # merging out_proj | ||
BenjaminBossan marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| weight_merged = base_layer.out_proj.weight.data.detach() + base_layer.out_proj.get_delta_weight( | ||
| active_adapter | ||
| ) | ||
| del base_layer.out_proj.get_base_layer().weight | ||
| base_layer.out_proj.get_base_layer().weight = weight_merged | ||
| # self.get_base_layer().out_proj.merge(adapter_names=[active_adapter]) | ||
| self.merged_adapters.append(active_adapter) | ||
|
|
||
| def unmerge(self) -> None: | ||
|
|
@@ -782,7 +823,14 @@ def unmerge(self) -> None: | |
| while len(self.merged_adapters) > 0: | ||
| active_adapter = self.merged_adapters.pop() | ||
| if active_adapter in self.lora_A.keys(): | ||
| # in_proj | ||
| self.get_base_layer().in_proj_weight.data -= self.get_delta_weight(active_adapter) | ||
| # out_proj | ||
| self.get_base_layer().out_proj.weight.data -= self.get_base_layer().out_proj.get_delta_weight( | ||
| active_adapter | ||
| ) | ||
|
|
||
| self.get_base_layer().out_proj.unmerge() | ||
|
|
||
| def get_delta_weight(self, adapter) -> torch.Tensor: | ||
| """ | ||
|
|
@@ -828,15 +876,34 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: | |
| elif self.merged: | ||
| result = self.base_layer(x, *args, **kwargs) | ||
| else: | ||
| # merge all adapters that are active for this module | ||
| out_proj = self.get_base_layer().out_proj | ||
| if out_proj.active_adapters != self.active_adapters: | ||
| # We have a case that in_proj and out_proj have diverging merged adapters. We cannot | ||
| # really deal with this correctly, thus it's better to raise than possibly create a hard to debug mess | ||
| cls_name = self.get_base_layer().__class__.__name__ | ||
| raise ValueError( | ||
| f"The out_proj layer of {cls_name} has merged layers but {cls_name} itself doesn't; please ensure " | ||
| "that either both or none have merged layers" | ||
| ) | ||
|
|
||
| # Merge all adapters that are active for this module, i.e. the LoRA weights for in_proj and out_proj. | ||
| # in_proj uses nn.Parameters, therefore, there is no forward method to be used and we have to explicitly | ||
| # merge for the LoRA weights to have an effect: | ||
| # https://github.com/pytorch/pytorch/blob/6ebb26d572d5fcdc6ac0d1297bdf8d1eb5d20722/torch/nn/modules/activation.py#L1020 | ||
| # For out_proj, we have an nn.Linear (or rather: NonDynamicallyQuantizableLinear), but its forward method | ||
| # is not used: | ||
| # https://github.com/pytorch/pytorch/blob/6ebb26d572d5fcdc6ac0d1297bdf8d1eb5d20722/torch/nn/modules/activation.py#L1267-L1271 | ||
| # Therefore, its LoRA weights also need to be merged to have an effect. | ||
| active_adapters = [a for a in self.active_adapters if a in self.lora_A] | ||
| try: | ||
| self.merge(adapter_names=active_adapters) | ||
| out_proj.merge(adapter_names=active_adapters) | ||
| result = self.base_layer(x, *args, **kwargs) | ||
| finally: | ||
| # it's safe to call unmerge(), which unmerges all adapters, because we checked that not self.merged, | ||
| # i.e. there is was no merged layer before | ||
| self.unmerge() | ||
| out_proj.unmerge() | ||
|
|
||
| result = (result[0].to(previous_dtype), result[1].to(previous_dtype) if result[1] is not None else result[1]) | ||
| return result | ||
|
|
@@ -848,12 +915,19 @@ def _restore_weights(self): | |
| # We cannot call register_parameter for merging/unmerging because that cuts them off from the autograd graph. | ||
githubnemo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # Note that this is hacky, since we need to ensure that _restore_weights is called by each method that needs it. | ||
|
|
||
| # in_proj | ||
| # TODO work with separate weights | ||
| base_layer = self.get_base_layer() | ||
| weight = base_layer.in_proj_weight.data | ||
| del base_layer.in_proj_weight | ||
| base_layer.register_parameter("in_proj_weight", nn.Parameter(weight)) | ||
|
|
||
| # out_proj | ||
| base_layer = base_layer.out_proj.get_base_layer() | ||
| weight = base_layer.weight.data | ||
| del base_layer.weight | ||
| base_layer.register_parameter("weight", nn.Parameter(weight)) | ||
|
|
||
| def state_dict(self, *args, **kwargs): | ||
| self._restore_weights() | ||
| return super().state_dict(*args, **kwargs) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.