ENH: Targeting multiple parameters on the same module#2665
Conversation
When the target_parameters feature for LoRA was introduced in huggingface#2638, there was one gap, namely the possibility to target multiple nn.Parameters on the same module. (There was only a workaroud involving multiple adapters, but that is not user friendly.) With this PR, it is now possible to achieve this. The mechanism to enable this is a bit crude, namely allowing to nest multiple ParamWrappers. This should generally be fine as long as there are only a couple of nn.Parameters being targeted on the same module. When there are dozens or hundreds, this approach could load to slow downs or other issues. A side effect of this implementation is that the ParamWrapper, when it removes the parametrization, now only removes its own parametrization. When using nn.utils.parametrize.remove_parametrization, it removes all parametrizations, which is bad when we have nested parametrizations.
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
githubnemo
left a comment
There was a problem hiding this comment.
Surprisingly small change :)
Looks good, I appreciate the thorough description.
Some remarks below
src/peft/tuners/lora/layer.py
Outdated
| if isinstance(self.base_layer, ParamWrapper): | ||
| param = getattr(self.get_base_layer(), self.parameter_name) | ||
| else: | ||
| param = getattr(self.base_layer, self.parameter_name) | ||
| return param |
There was a problem hiding this comment.
Why the distinction, why not always use self.get_base_layer()? Let's also document the reason
There was a problem hiding this comment.
It was just defensive programming, but yeah, I don't think always calling self.get_base_layer() has any practical disadvantage.
| else: | ||
| # More than one parametrization, only remove this specific one. Unfortunately, torch does not implement a | ||
| # way to only remove a single parametrization, so we implement this ourselves and follow original torch code | ||
| # closely: | ||
| # see https://github.com/pytorch/pytorch/blob/7d296d5c19750cecd82e2b95f6fb0f8dd918282e/torch/nn/utils/parametrize.py#L731-L737 | ||
|
|
||
| original = base_layer.parametrizations[parameter_name].original | ||
| # Delete the property that manages the parametrization | ||
| delattr(base_layer.__class__, parameter_name) | ||
| # Delete the ParametrizationList | ||
| del base_layer.parametrizations[parameter_name] | ||
| # Restore the parameter / buffer into the main class | ||
| _register_parameter_or_buffer(base_layer, parameter_name, original) | ||
|
|
There was a problem hiding this comment.
Maybe I'm misunderstanding the comment but AFAICS remove_parametrizations only removes the parametrization for tensor_name, in this case self.parameter_name? Has this changed between pytorch versions?
There was a problem hiding this comment.
@githubnemo That sounds right. To expand on this:len(base_layer.parametrizations) == 1 tells us that only one parameter has a parametrization, and based on the check prior to it, we know that it is parameter_name. However it doesn't tell us how many parametrizations are applied to that parameter; i.e. we just know we have only one ParametrizationList which may have length >= 1.
Since we're interested in just the parametrizations on self.parameter_name we should instead check len(base_layer.parametrizations[parameter_name]) == 1. If it is, assume the remove_parametrizations API can be safely used, even if we've got more than one target parameter on the module.
Otherwise, we shouldn't be deleting the entire ParametrizationList; just the specific _LoraParameterProxy.
There was a problem hiding this comment.
Yes, you are both right, this doesn't really make sense as is. For now, I "reverted" to nn.utils.parametrize.remove_parametrizations, which is enough for this PR, and added a # TODO for your case, Matthew. Let's fix that in a separate PR.
tests/test_target_parameters.py
Outdated
| with torch.nn.utils.parametrize.cached(): | ||
| return W + self.delta_weight |
There was a problem hiding this comment.
this leaks the implementation, wouldn't it be better to call the original implementation from the mock like a decorator?
There was a problem hiding this comment.
It is better, but I was lazy :-p
Fixed now.
When the target_parameters feature for LoRA was introduced in huggingface#2638, there was one gap, namely the possibility to target multiple nn.Parameters on the same module (there was only a workaround involving multiple adapters, but that is not user friendly). With this PR, it is now possible to achieve this. The mechanism to enable this is a bit crude, namely allowing to nest multiple ParamWrappers. This should generally be fine as long as there are only a couple of nn.Parameters being targeted on the same module. When there are dozens or hundreds, this approach could load to slow downs or other issues. A side effect of this implementation is that the ParamWrapper, when it removes the parametrization, now only removes its own parametrization. When using nn.utils.parametrize.remove_parametrization, it removes all parametrizations, which is bad when we have nested parametrizations. Alternative approaches Some alternative approaches were discussed internally but the chosen one was considered most practical. Allow to have more than one adapted parameter per LoRA layer. This would require to have nested dicts for the LoRA parameters, something like self.lora_A[adapter_name][parameter_name]. We don't have this anywhere so far and it would probably break implicit assumptions about PEFT layers in many places (like, parsing of state_dict keys), requiring many adjustments. Have an auxiliary module that contains the individual LoRA layers that target the individual parameters. This could be the cleanest solution and would probably be more efficient if there are a huge number of targeted parameters per module. However, this also brings extra complexity, as it requires implementing the logic of how to route the information to the right parameter, and it may be a solution to a problem that is irrelevant in practice (large number of targets per module).
When the
target_parametersfeature for LoRA was introduced in #2638, there was one gap, namely the possibility to target multiplenn.Parameters on the same module (there was only a workaround involving multiple adapters, but that is not user friendly). With this PR, it is now possible to achieve this.The mechanism to enable this is a bit crude, namely allowing to nest multiple
ParamWrappers. This should generally be fine as long as there are only a couple ofnn.Parameters being targeted on the same module. When there are dozens or hundreds, this approach could load to slow downs or other issues.A side effect of this implementation is that the
ParamWrapper, when it removes the parametrization, now only removes its own parametrization. When usingnn.utils.parametrize.remove_parametrization, it removes all parametrizations, which is bad when we have nested parametrizations.Alternative approaches
Some alternative approaches were discussed internally but the chosen one was considered most practical.
self.lora_A[adapter_name][parameter_name]. We don't have this anywhere so far and it would probably break implicit assumptions about PEFT layers in many places (like, parsing ofstate_dictkeys), requiring many adjustments.