-
Notifications
You must be signed in to change notification settings - Fork 2.2k
feat(lokr, loha): add 1x1 Conv2d and Conv1d support #2515
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
3e2c996
a716722
67c0cc6
d2559fd
655ea62
fd21ef8
cd93eb5
6f96411
ae3a4eb
3fca72a
711075f
6c44b89
c7ebc58
1936041
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -45,15 +45,23 @@ def _available_adapters(self) -> Set[str]: | |
|
|
||
| def create_adapter_parameters(self, adapter_name: str, r: int, shape: Tuple[int, ...]): | ||
| # https://github.com/KohakuBlueleaf/LyCORIS/blob/eb460098187f752a5d66406d3affade6f0a07ece/lycoris/modules/loha.py#L130C9-L143C75 | ||
| if len(shape) == 4: | ||
| if len(shape) == 4: # Conv2d | ||
| self.hada_t1[adapter_name] = nn.Parameter(torch.empty(r, r, shape[2], shape[3])) | ||
| self.hada_w1_a[adapter_name] = nn.Parameter(torch.empty(r, shape[0])) # out_dim, 1-mode | ||
| self.hada_w1_b[adapter_name] = nn.Parameter(torch.empty(r, shape[1])) # in_dim , 2-mode | ||
|
|
||
| self.hada_t2[adapter_name] = nn.Parameter(torch.empty(r, r, shape[2], shape[3])) | ||
| self.hada_w2_a[adapter_name] = nn.Parameter(torch.empty(r, shape[0])) # out_dim, 1-mode | ||
| self.hada_w2_b[adapter_name] = nn.Parameter(torch.empty(r, shape[1])) # in_dim , 2-mode | ||
| else: | ||
| elif len(shape) == 3: # Conv1d | ||
| self.hada_t1[adapter_name] = nn.Parameter(torch.empty(r, r, shape[2])) | ||
| self.hada_w1_a[adapter_name] = nn.Parameter(torch.empty(r, shape[0])) # out_dim, 1-mode | ||
| self.hada_w1_b[adapter_name] = nn.Parameter(torch.empty(r, shape[1])) # in_dim , 2-mode | ||
|
|
||
| self.hada_t2[adapter_name] = nn.Parameter(torch.empty(r, r, shape[2])) | ||
| self.hada_w2_a[adapter_name] = nn.Parameter(torch.empty(r, shape[0])) # out_dim, 1-mode | ||
| self.hada_w2_b[adapter_name] = nn.Parameter(torch.empty(r, shape[1])) # in_dim , 2-mode | ||
| else: # Linear | ||
| self.hada_w1_a[adapter_name] = nn.Parameter(torch.empty(shape[0], r)) | ||
| self.hada_w1_b[adapter_name] = nn.Parameter(torch.empty(r, shape[1])) | ||
|
|
||
|
|
@@ -127,14 +135,38 @@ def update_layer( | |
| if isinstance(base_layer, nn.Linear): | ||
| shape = tuple(base_layer.weight.shape) | ||
| elif isinstance(base_layer, nn.Conv2d): | ||
| use_effective_conv2d = use_effective_conv2d and base_layer.kernel_size != (1, 1) | ||
| # For 1x1 convolutions, disable effective_conv2d to avoid unnecessary tensor reshaping overhead. | ||
| # Since 1x1 convolutions are essentially pointwise operations (matrix multiplications), | ||
| # they can be more efficiently handled with the flattened weight representation, | ||
| # similar to how Linear layers work. This optimization reduces computational cost | ||
| # without affecting the mathematical equivalence of the operation. | ||
| if base_layer.kernel_size == (1, 1): | ||
| use_effective_conv2d = False | ||
| else: | ||
| use_effective_conv2d = use_effective_conv2d | ||
| if use_effective_conv2d: | ||
| shape = (base_layer.out_channels, base_layer.in_channels, *base_layer.kernel_size) | ||
| else: | ||
| shape = ( | ||
| base_layer.out_channels, | ||
| base_layer.in_channels * base_layer.kernel_size[0] * base_layer.kernel_size[1], | ||
| ) | ||
| elif isinstance(base_layer, nn.Conv1d): | ||
| # For Conv1d with kernel_size=1, disable effective_conv2d for the same optimization reasons | ||
| # as 1x1 Conv2d. Kernel size 1 means no spatial/temporal context, making it equivalent | ||
| # to a Linear layer applied across the channel dimension. Using flattened representation | ||
| # avoids unnecessary reshaping and improves computational efficiency. | ||
| if base_layer.kernel_size[0] == 1: | ||
| use_effective_conv2d = False | ||
| else: | ||
| use_effective_conv2d = use_effective_conv2d | ||
|
||
| if use_effective_conv2d: | ||
| shape = (base_layer.out_channels, base_layer.in_channels, base_layer.kernel_size[0]) | ||
| else: | ||
| shape = ( | ||
| base_layer.out_channels, | ||
| base_layer.in_channels * base_layer.kernel_size[0], | ||
| ) | ||
| else: | ||
| raise TypeError(f"LoHa is not implemented for base layers of type {type(base_layer).__name__}") | ||
|
|
||
|
|
@@ -173,6 +205,9 @@ def get_delta_weight(self, adapter_name: str) -> torch.Tensor: | |
| ) | ||
|
|
||
| base_layer = self.get_base_layer() | ||
|
|
||
|
|
||
| # Reshape to match base layer shape | ||
| weight = weight.reshape(base_layer.weight.shape) | ||
|
|
||
| # Perform rank dropout during training - drop rows of addition weights | ||
|
|
@@ -290,6 +325,49 @@ def __repr__(self) -> str: | |
| return "loha." + rep | ||
|
|
||
|
|
||
| class Conv1d(LoHaLayer): | ||
| """LoHa implemented in Conv1d layer""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| base_layer: nn.Module, | ||
| adapter_name: str = "default", | ||
| r: int = 0, | ||
| alpha: float = 0.0, | ||
| rank_dropout: float = 0.0, | ||
| module_dropout: float = 0.0, | ||
| use_effective_conv2d: bool = False, | ||
| init_weights: bool = True, | ||
| **kwargs, | ||
| ): | ||
| super().__init__(base_layer) | ||
|
|
||
| # Create adapter and set it active | ||
| self._active_adapter = adapter_name | ||
| self.update_layer( | ||
| adapter_name, r, alpha, rank_dropout, module_dropout, init_weights, use_effective_conv2d, **kwargs | ||
| ) | ||
|
|
||
| def _get_delta_activations( | ||
| self, adapter_name: str, input: torch.Tensor, *args: Any, **kwargs: Any | ||
| ) -> torch.Tensor: | ||
| delta_weight = self.get_delta_weight(adapter_name) | ||
| # don't add bias here, because the bias is already included in the output of the base_layer | ||
| base_layer = self.get_base_layer() | ||
| return F.conv1d( | ||
| input, | ||
| delta_weight, | ||
| stride=base_layer.stride, | ||
| padding=base_layer.padding, | ||
| dilation=base_layer.dilation, | ||
| groups=base_layer.groups, | ||
| ) | ||
|
|
||
| def __repr__(self) -> str: | ||
| rep = super().__repr__() | ||
| return "loha." + rep | ||
|
|
||
|
|
||
| # Below code is a direct copy from https://github.com/KohakuBlueleaf/LyCORIS/blob/eb460098187f752a5d66406d3affade6f0a07ece/lycoris/modules/loha.py#L9 | ||
|
|
||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -75,8 +75,8 @@ def create_adapter_parameters( | |
| self.lokr_w1_a[adapter_name] = nn.Parameter(torch.empty(shape[0][0], r)) | ||
| self.lokr_w1_b[adapter_name] = nn.Parameter(torch.empty(r, shape[1][0])) | ||
|
|
||
| if len(shape) == 4: | ||
| # Conv2d | ||
| # Handle both Conv2d and Conv1d | ||
| if len(shape) == 4: # Conv2d | ||
| if use_w2: | ||
| self.lokr_w2[adapter_name] = nn.Parameter(torch.empty(shape[0][1], shape[1][1], *shape[2:])) | ||
| elif use_effective_conv2d: | ||
|
|
@@ -86,6 +86,16 @@ def create_adapter_parameters( | |
| else: | ||
| self.lokr_w2_a[adapter_name] = nn.Parameter(torch.empty(shape[0][1], r)) | ||
| self.lokr_w2_b[adapter_name] = nn.Parameter(torch.empty(r, shape[1][1] * shape[2] * shape[3])) | ||
| elif len(shape) == 3: # Conv1d | ||
| if use_w2: | ||
| self.lokr_w2[adapter_name] = nn.Parameter(torch.empty(shape[0][1], shape[1][1], shape[2])) | ||
| elif use_effective_conv2d: # Even for Conv1d, use the effective parameter for kernel dimension | ||
| self.lokr_t2[adapter_name] = nn.Parameter(torch.empty(r, r, shape[2])) | ||
| self.lokr_w2_a[adapter_name] = nn.Parameter(torch.empty(r, shape[0][1])) # b, 1-mode | ||
| self.lokr_w2_b[adapter_name] = nn.Parameter(torch.empty(r, shape[1][1])) # d, 2-mode | ||
| else: | ||
| self.lokr_w2_a[adapter_name] = nn.Parameter(torch.empty(shape[0][1], r)) | ||
| self.lokr_w2_b[adapter_name] = nn.Parameter(torch.empty(r, shape[1][1] * shape[2])) | ||
|
Comment on lines
92
to
100
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These lines are not covered by tests. I think we need to extend the test cases a bit: diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py
index 505840e1..62e33011 100644
--- a/tests/test_custom_models.py
+++ b/tests/test_custom_models.py
@@ -291,7 +291,9 @@ TEST_CASES = [
("Conv2d 2 LOKR", "Conv2d", LoKrConfig, {"target_modules": ["conv2d", "lin0"]}),
("Conv2d 3 LOKR", "Conv2d", LoKrConfig, {"target_modules": ["conv2d"], "use_effective_conv2d": True}),
("Conv2d 4 LOKR", "Conv2d", LoKrConfig, {"target_modules": ["conv2d", "lin0"], "use_effective_conv2d": True}),
- ("Conv1d LOKR", "Conv1d", LoKrConfig, {"target_modules": ["conv1d"]}),
+ ("Conv1d 1 LOKR", "Conv1d", LoKrConfig, {"target_modules": ["conv1d"]}),
+ ("Conv1d 2 LOKR", "Conv1dBigger", LoKrConfig, {"target_modules": ["conv1d"], "r": 2, "use_effective_conv2d": False}),
+ ("Conv1d 3 LOKR", "Conv1dBigger", LoKrConfig, {"target_modules": ["conv1d"], "r": 2, "use_effective_conv2d": True}),
("Conv2d 1x1 LOKR", "Conv2d1x1", LoKrConfig, {"target_modules": ["conv2d"]}),
(
"Conv2d 5 LOKR",
@@ -1193,6 +1195,29 @@ class ModelConv1D(nn.Module):
return X
+class ModelConv1dBigger(nn.Module):
+ # TODO
+ def __init__(self):
+ super().__init__()
+ self.conv1d = nn.Conv1d(64, 16, 2)
+ self.relu = nn.ReLU()
+ self.flat = nn.Flatten()
+ self.lin0 = nn.Linear(144, 2)
+ self.sm = nn.LogSoftmax(dim=-1)
+ self.dtype = torch.float
+
+ def forward(self, X):
+ X = X.to(self.dtype)
+ X = X.reshape(-1, 1, 10)
+ X = torch.concat([X] * 64, dim=1)
+ X = self.conv1d(X)
+ X = self.relu(X)
+ X = self.flat(X)
+ X = self.lin0(X)
+ X = self.sm(X)
+ return X
+
+
class ModelConv2D(nn.Module):
def __init__(self, bias=True):
super().__init__()
@@ -1426,6 +1451,9 @@ class MockTransformerWrapper:
if model_id == "Conv1d":
return ModelConv1D().to(torch_dtype)
+ if model_id == "Conv1dBigger":
+ return ModelConv1dBigger().to(torch_dtype)
+
if model_id == "Conv2d":
return ModelConv2D().to(torch_dtype)
However, with these changes, there are errors in some tests that involve merging. I'm not sure how trivial it is to fix those errors. The same argument applies to LoHa, where the analogous lines are not covered and tests would need to be extended. |
||
| else: | ||
| # Linear | ||
| if use_w2: | ||
|
|
@@ -201,7 +211,33 @@ def update_layer( | |
|
|
||
| use_w1 = not (decompose_both and r < max(shape[0][0], shape[1][0]) / 2) | ||
| use_w2 = r >= max(shape[0][1], shape[1][1]) / 2 | ||
| use_effective_conv2d = use_effective_conv2d and base_layer.kernel_size != (1, 1) | ||
| # For 1x1 convolutions, disable effective_conv2d to avoid unnecessary tensor reshaping overhead. | ||
| # Since 1x1 convolutions are essentially pointwise operations (matrix multiplications), | ||
| # they can be more efficiently handled with the flattened weight representation, | ||
| # similar to how Linear layers work. This optimization reduces computational cost | ||
| # without affecting the mathematical equivalence of the operation. | ||
| if base_layer.kernel_size == (1, 1): | ||
| use_effective_conv2d = False | ||
| else: | ||
| use_effective_conv2d = use_effective_conv2d | ||
|
||
| elif isinstance(base_layer, nn.Conv1d): | ||
| in_dim, out_dim = base_layer.in_channels, base_layer.out_channels | ||
| k_size = (base_layer.kernel_size[0],) # Convert to a tuple with single element | ||
|
|
||
| in_m, in_n = factorization(in_dim, decompose_factor) | ||
| out_l, out_k = factorization(out_dim, decompose_factor) | ||
| shape = ((out_l, out_k), (in_m, in_n), *k_size) # ((a, b), (c, d), k) | ||
|
|
||
| use_w1 = not (decompose_both and r < max(shape[0][0], shape[1][0]) / 2) | ||
| use_w2 = r >= max(shape[0][1], shape[1][1]) / 2 | ||
| # For Conv1d with kernel_size=1, disable effective_conv2d for the same optimization reasons | ||
| # as 1x1 Conv2d. Kernel size 1 means no spatial/temporal context, making it equivalent | ||
| # to a Linear layer applied across the channel dimension. Using flattened representation | ||
| # avoids unnecessary reshaping and improves computational efficiency. | ||
| if base_layer.kernel_size[0] == 1: | ||
| use_effective_conv2d = False | ||
| else: | ||
| use_effective_conv2d = use_effective_conv2d | ||
|
||
| else: | ||
| raise TypeError(f"LoKr is not implemented for base layers of type {type(base_layer).__name__}") | ||
|
|
||
|
|
@@ -237,7 +273,13 @@ def get_delta_weight(self, adapter_name: str) -> torch.Tensor: | |
|
|
||
| # Make weights with Kronecker product | ||
| weight = make_kron(w1, w2, self.scaling[adapter_name]) | ||
| weight = weight.reshape(self.get_base_layer().weight.shape) | ||
|
|
||
| # Get base layer for reshaping | ||
| base_layer = self.get_base_layer() | ||
|
|
||
|
|
||
| # Regular reshape to match base layer shape | ||
| weight = weight.reshape(base_layer.weight.shape) | ||
|
|
||
| # Perform rank dropout during training - drop rows of addition weights | ||
| rank_dropout = self.rank_dropout[adapter_name] | ||
|
|
@@ -356,6 +398,51 @@ def __repr__(self) -> str: | |
| return "lokr." + rep | ||
|
|
||
|
|
||
| class Conv1d(LoKrLayer): | ||
| """LoKr implemented in Conv1d layer""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| base_layer: nn.Module, | ||
| device: Optional[Union[str, torch.device]] = None, | ||
| dtype: Optional[torch.dtype] = None, | ||
| adapter_name: str = "default", | ||
| r: int = 0, | ||
| alpha: float = 0.0, | ||
| rank_dropout: float = 0.0, | ||
| module_dropout: float = 0.0, | ||
| use_effective_conv2d: bool = False, | ||
| init_weights: bool = True, | ||
| **kwargs, | ||
| ): | ||
| super().__init__(base_layer) | ||
|
|
||
| # Create adapter and set it active | ||
| self._active_adapter = adapter_name | ||
| self.update_layer( | ||
| adapter_name, r, alpha, rank_dropout, module_dropout, init_weights, use_effective_conv2d, **kwargs | ||
| ) | ||
|
|
||
| def _get_delta_activations( | ||
| self, adapter_name: str, input: torch.Tensor, *args: Any, **kwargs: Any | ||
| ) -> torch.Tensor: | ||
| delta_weight = self.get_delta_weight(adapter_name) | ||
| # don't add bias here, because the bias is already included in the output of the base_layer | ||
| base_layer = self.get_base_layer() | ||
| return F.conv1d( | ||
| input, | ||
| delta_weight, | ||
| stride=base_layer.stride, | ||
| padding=base_layer.padding, | ||
| dilation=base_layer.dilation, | ||
| groups=base_layer.groups, | ||
| ) | ||
|
|
||
| def __repr__(self) -> str: | ||
| rep = super().__repr__() | ||
| return "lokr." + rep | ||
|
|
||
|
|
||
| # Below code is a direct copy from https://github.com/KohakuBlueleaf/LyCORIS/blob/eb460098187f752a5d66406d3affade6f0a07ece/lycoris/modules/lokr.py#L11 | ||
|
|
||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No-op, can be removed. In fact, this snippet does the same as the one it replaces, right?