Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 81 additions & 3 deletions src/peft/tuners/loha/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,23 @@ def _available_adapters(self) -> Set[str]:

def create_adapter_parameters(self, adapter_name: str, r: int, shape: Tuple[int, ...]):
# https://github.com/KohakuBlueleaf/LyCORIS/blob/eb460098187f752a5d66406d3affade6f0a07ece/lycoris/modules/loha.py#L130C9-L143C75
if len(shape) == 4:
if len(shape) == 4: # Conv2d
self.hada_t1[adapter_name] = nn.Parameter(torch.empty(r, r, shape[2], shape[3]))
self.hada_w1_a[adapter_name] = nn.Parameter(torch.empty(r, shape[0])) # out_dim, 1-mode
self.hada_w1_b[adapter_name] = nn.Parameter(torch.empty(r, shape[1])) # in_dim , 2-mode

self.hada_t2[adapter_name] = nn.Parameter(torch.empty(r, r, shape[2], shape[3]))
self.hada_w2_a[adapter_name] = nn.Parameter(torch.empty(r, shape[0])) # out_dim, 1-mode
self.hada_w2_b[adapter_name] = nn.Parameter(torch.empty(r, shape[1])) # in_dim , 2-mode
else:
elif len(shape) == 3: # Conv1d
self.hada_t1[adapter_name] = nn.Parameter(torch.empty(r, r, shape[2]))
self.hada_w1_a[adapter_name] = nn.Parameter(torch.empty(r, shape[0])) # out_dim, 1-mode
self.hada_w1_b[adapter_name] = nn.Parameter(torch.empty(r, shape[1])) # in_dim , 2-mode

self.hada_t2[adapter_name] = nn.Parameter(torch.empty(r, r, shape[2]))
self.hada_w2_a[adapter_name] = nn.Parameter(torch.empty(r, shape[0])) # out_dim, 1-mode
self.hada_w2_b[adapter_name] = nn.Parameter(torch.empty(r, shape[1])) # in_dim , 2-mode
else: # Linear
self.hada_w1_a[adapter_name] = nn.Parameter(torch.empty(shape[0], r))
self.hada_w1_b[adapter_name] = nn.Parameter(torch.empty(r, shape[1]))

Expand Down Expand Up @@ -127,14 +135,38 @@ def update_layer(
if isinstance(base_layer, nn.Linear):
shape = tuple(base_layer.weight.shape)
elif isinstance(base_layer, nn.Conv2d):
use_effective_conv2d = use_effective_conv2d and base_layer.kernel_size != (1, 1)
# For 1x1 convolutions, disable effective_conv2d to avoid unnecessary tensor reshaping overhead.
# Since 1x1 convolutions are essentially pointwise operations (matrix multiplications),
# they can be more efficiently handled with the flattened weight representation,
# similar to how Linear layers work. This optimization reduces computational cost
# without affecting the mathematical equivalence of the operation.
if base_layer.kernel_size == (1, 1):
use_effective_conv2d = False
else:
use_effective_conv2d = use_effective_conv2d
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No-op, can be removed. In fact, this snippet does the same as the one it replaces, right?

if use_effective_conv2d:
shape = (base_layer.out_channels, base_layer.in_channels, *base_layer.kernel_size)
else:
shape = (
base_layer.out_channels,
base_layer.in_channels * base_layer.kernel_size[0] * base_layer.kernel_size[1],
)
elif isinstance(base_layer, nn.Conv1d):
# For Conv1d with kernel_size=1, disable effective_conv2d for the same optimization reasons
# as 1x1 Conv2d. Kernel size 1 means no spatial/temporal context, making it equivalent
# to a Linear layer applied across the channel dimension. Using flattened representation
# avoids unnecessary reshaping and improves computational efficiency.
if base_layer.kernel_size[0] == 1:
use_effective_conv2d = False
else:
use_effective_conv2d = use_effective_conv2d
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same, can be removed.

if use_effective_conv2d:
shape = (base_layer.out_channels, base_layer.in_channels, base_layer.kernel_size[0])
else:
shape = (
base_layer.out_channels,
base_layer.in_channels * base_layer.kernel_size[0],
)
else:
raise TypeError(f"LoHa is not implemented for base layers of type {type(base_layer).__name__}")

Expand Down Expand Up @@ -173,6 +205,9 @@ def get_delta_weight(self, adapter_name: str) -> torch.Tensor:
)

base_layer = self.get_base_layer()


# Reshape to match base layer shape
weight = weight.reshape(base_layer.weight.shape)

# Perform rank dropout during training - drop rows of addition weights
Expand Down Expand Up @@ -290,6 +325,49 @@ def __repr__(self) -> str:
return "loha." + rep


class Conv1d(LoHaLayer):
"""LoHa implemented in Conv1d layer"""

def __init__(
self,
base_layer: nn.Module,
adapter_name: str = "default",
r: int = 0,
alpha: float = 0.0,
rank_dropout: float = 0.0,
module_dropout: float = 0.0,
use_effective_conv2d: bool = False,
init_weights: bool = True,
**kwargs,
):
super().__init__(base_layer)

# Create adapter and set it active
self._active_adapter = adapter_name
self.update_layer(
adapter_name, r, alpha, rank_dropout, module_dropout, init_weights, use_effective_conv2d, **kwargs
)

def _get_delta_activations(
self, adapter_name: str, input: torch.Tensor, *args: Any, **kwargs: Any
) -> torch.Tensor:
delta_weight = self.get_delta_weight(adapter_name)
# don't add bias here, because the bias is already included in the output of the base_layer
base_layer = self.get_base_layer()
return F.conv1d(
input,
delta_weight,
stride=base_layer.stride,
padding=base_layer.padding,
dilation=base_layer.dilation,
groups=base_layer.groups,
)

def __repr__(self) -> str:
rep = super().__repr__()
return "loha." + rep


# Below code is a direct copy from https://github.com/KohakuBlueleaf/LyCORIS/blob/eb460098187f752a5d66406d3affade6f0a07ece/lycoris/modules/loha.py#L9


Expand Down
3 changes: 2 additions & 1 deletion src/peft/tuners/loha/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from peft.tuners.lycoris_utils import LycorisConfig, LycorisTuner

from .layer import Conv2d, Linear, LoHaLayer
from .layer import Conv1d, Conv2d, Linear, LoHaLayer


class LoHaModel(LycorisTuner):
Expand Down Expand Up @@ -85,6 +85,7 @@ class LoHaModel(LycorisTuner):
prefix: str = "hada_"
layers_mapping: Dict[Type[torch.nn.Module], Type[LoHaLayer]] = {
torch.nn.Conv2d: Conv2d,
torch.nn.Conv1d: Conv1d,
torch.nn.Linear: Linear,
}

Expand Down
95 changes: 91 additions & 4 deletions src/peft/tuners/lokr/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ def create_adapter_parameters(
self.lokr_w1_a[adapter_name] = nn.Parameter(torch.empty(shape[0][0], r))
self.lokr_w1_b[adapter_name] = nn.Parameter(torch.empty(r, shape[1][0]))

if len(shape) == 4:
# Conv2d
# Handle both Conv2d and Conv1d
if len(shape) == 4: # Conv2d
if use_w2:
self.lokr_w2[adapter_name] = nn.Parameter(torch.empty(shape[0][1], shape[1][1], *shape[2:]))
elif use_effective_conv2d:
Expand All @@ -86,6 +86,16 @@ def create_adapter_parameters(
else:
self.lokr_w2_a[adapter_name] = nn.Parameter(torch.empty(shape[0][1], r))
self.lokr_w2_b[adapter_name] = nn.Parameter(torch.empty(r, shape[1][1] * shape[2] * shape[3]))
elif len(shape) == 3: # Conv1d
if use_w2:
self.lokr_w2[adapter_name] = nn.Parameter(torch.empty(shape[0][1], shape[1][1], shape[2]))
elif use_effective_conv2d: # Even for Conv1d, use the effective parameter for kernel dimension
self.lokr_t2[adapter_name] = nn.Parameter(torch.empty(r, r, shape[2]))
self.lokr_w2_a[adapter_name] = nn.Parameter(torch.empty(r, shape[0][1])) # b, 1-mode
self.lokr_w2_b[adapter_name] = nn.Parameter(torch.empty(r, shape[1][1])) # d, 2-mode
else:
self.lokr_w2_a[adapter_name] = nn.Parameter(torch.empty(shape[0][1], r))
self.lokr_w2_b[adapter_name] = nn.Parameter(torch.empty(r, shape[1][1] * shape[2]))
Comment on lines 92 to 100
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These lines are not covered by tests. I think we need to extend the test cases a bit:

diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py
index 505840e1..62e33011 100644
--- a/tests/test_custom_models.py
+++ b/tests/test_custom_models.py
@@ -291,7 +291,9 @@ TEST_CASES = [
     ("Conv2d 2 LOKR", "Conv2d", LoKrConfig, {"target_modules": ["conv2d", "lin0"]}),
     ("Conv2d 3 LOKR", "Conv2d", LoKrConfig, {"target_modules": ["conv2d"], "use_effective_conv2d": True}),
     ("Conv2d 4 LOKR", "Conv2d", LoKrConfig, {"target_modules": ["conv2d", "lin0"], "use_effective_conv2d": True}),
-    ("Conv1d LOKR", "Conv1d", LoKrConfig, {"target_modules": ["conv1d"]}),
+    ("Conv1d 1 LOKR", "Conv1d", LoKrConfig, {"target_modules": ["conv1d"]}),
+    ("Conv1d 2 LOKR", "Conv1dBigger", LoKrConfig, {"target_modules": ["conv1d"], "r": 2, "use_effective_conv2d": False}),
+    ("Conv1d 3 LOKR", "Conv1dBigger", LoKrConfig, {"target_modules": ["conv1d"], "r": 2, "use_effective_conv2d": True}),
     ("Conv2d 1x1 LOKR", "Conv2d1x1", LoKrConfig, {"target_modules": ["conv2d"]}),
     (
         "Conv2d 5 LOKR",
@@ -1193,6 +1195,29 @@ class ModelConv1D(nn.Module):
         return X
 
 
+class ModelConv1dBigger(nn.Module):
+    # TODO
+    def __init__(self):
+        super().__init__()
+        self.conv1d = nn.Conv1d(64, 16, 2)
+        self.relu = nn.ReLU()
+        self.flat = nn.Flatten()
+        self.lin0 = nn.Linear(144, 2)
+        self.sm = nn.LogSoftmax(dim=-1)
+        self.dtype = torch.float
+
+    def forward(self, X):
+        X = X.to(self.dtype)
+        X = X.reshape(-1, 1, 10)
+        X = torch.concat([X] * 64, dim=1)
+        X = self.conv1d(X)
+        X = self.relu(X)
+        X = self.flat(X)
+        X = self.lin0(X)
+        X = self.sm(X)
+        return X
+
+
 class ModelConv2D(nn.Module):
     def __init__(self, bias=True):
         super().__init__()
@@ -1426,6 +1451,9 @@ class MockTransformerWrapper:
         if model_id == "Conv1d":
             return ModelConv1D().to(torch_dtype)
 
+        if model_id == "Conv1dBigger":
+            return ModelConv1dBigger().to(torch_dtype)
+
         if model_id == "Conv2d":
             return ModelConv2D().to(torch_dtype)
 

However, with these changes, there are errors in some tests that involve merging. I'm not sure how trivial it is to fix those errors.

The same argument applies to LoHa, where the analogous lines are not covered and tests would need to be extended.

else:
# Linear
if use_w2:
Expand Down Expand Up @@ -201,7 +211,33 @@ def update_layer(

use_w1 = not (decompose_both and r < max(shape[0][0], shape[1][0]) / 2)
use_w2 = r >= max(shape[0][1], shape[1][1]) / 2
use_effective_conv2d = use_effective_conv2d and base_layer.kernel_size != (1, 1)
# For 1x1 convolutions, disable effective_conv2d to avoid unnecessary tensor reshaping overhead.
# Since 1x1 convolutions are essentially pointwise operations (matrix multiplications),
# they can be more efficiently handled with the flattened weight representation,
# similar to how Linear layers work. This optimization reduces computational cost
# without affecting the mathematical equivalence of the operation.
if base_layer.kernel_size == (1, 1):
use_effective_conv2d = False
else:
use_effective_conv2d = use_effective_conv2d
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

elif isinstance(base_layer, nn.Conv1d):
in_dim, out_dim = base_layer.in_channels, base_layer.out_channels
k_size = (base_layer.kernel_size[0],) # Convert to a tuple with single element

in_m, in_n = factorization(in_dim, decompose_factor)
out_l, out_k = factorization(out_dim, decompose_factor)
shape = ((out_l, out_k), (in_m, in_n), *k_size) # ((a, b), (c, d), k)

use_w1 = not (decompose_both and r < max(shape[0][0], shape[1][0]) / 2)
use_w2 = r >= max(shape[0][1], shape[1][1]) / 2
# For Conv1d with kernel_size=1, disable effective_conv2d for the same optimization reasons
# as 1x1 Conv2d. Kernel size 1 means no spatial/temporal context, making it equivalent
# to a Linear layer applied across the channel dimension. Using flattened representation
# avoids unnecessary reshaping and improves computational efficiency.
if base_layer.kernel_size[0] == 1:
use_effective_conv2d = False
else:
use_effective_conv2d = use_effective_conv2d
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

else:
raise TypeError(f"LoKr is not implemented for base layers of type {type(base_layer).__name__}")

Expand Down Expand Up @@ -237,7 +273,13 @@ def get_delta_weight(self, adapter_name: str) -> torch.Tensor:

# Make weights with Kronecker product
weight = make_kron(w1, w2, self.scaling[adapter_name])
weight = weight.reshape(self.get_base_layer().weight.shape)

# Get base layer for reshaping
base_layer = self.get_base_layer()


# Regular reshape to match base layer shape
weight = weight.reshape(base_layer.weight.shape)

# Perform rank dropout during training - drop rows of addition weights
rank_dropout = self.rank_dropout[adapter_name]
Expand Down Expand Up @@ -356,6 +398,51 @@ def __repr__(self) -> str:
return "lokr." + rep


class Conv1d(LoKrLayer):
"""LoKr implemented in Conv1d layer"""

def __init__(
self,
base_layer: nn.Module,
device: Optional[Union[str, torch.device]] = None,
dtype: Optional[torch.dtype] = None,
adapter_name: str = "default",
r: int = 0,
alpha: float = 0.0,
rank_dropout: float = 0.0,
module_dropout: float = 0.0,
use_effective_conv2d: bool = False,
init_weights: bool = True,
**kwargs,
):
super().__init__(base_layer)

# Create adapter and set it active
self._active_adapter = adapter_name
self.update_layer(
adapter_name, r, alpha, rank_dropout, module_dropout, init_weights, use_effective_conv2d, **kwargs
)

def _get_delta_activations(
self, adapter_name: str, input: torch.Tensor, *args: Any, **kwargs: Any
) -> torch.Tensor:
delta_weight = self.get_delta_weight(adapter_name)
# don't add bias here, because the bias is already included in the output of the base_layer
base_layer = self.get_base_layer()
return F.conv1d(
input,
delta_weight,
stride=base_layer.stride,
padding=base_layer.padding,
dilation=base_layer.dilation,
groups=base_layer.groups,
)

def __repr__(self) -> str:
rep = super().__repr__()
return "lokr." + rep


# Below code is a direct copy from https://github.com/KohakuBlueleaf/LyCORIS/blob/eb460098187f752a5d66406d3affade6f0a07ece/lycoris/modules/lokr.py#L11


Expand Down
3 changes: 2 additions & 1 deletion src/peft/tuners/lokr/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from peft.tuners.lycoris_utils import LycorisConfig, LycorisTuner

from .layer import Conv2d, Linear, LoKrLayer
from .layer import Conv1d, Conv2d, Linear, LoKrLayer


class LoKrModel(LycorisTuner):
Expand Down Expand Up @@ -86,6 +86,7 @@ class LoKrModel(LycorisTuner):
prefix: str = "lokr_"
layers_mapping: Dict[Type[torch.nn.Module], Type[LoKrLayer]] = {
torch.nn.Conv2d: Conv2d,
torch.nn.Conv1d: Conv1d,
torch.nn.Linear: Linear,
}

Expand Down
2 changes: 1 addition & 1 deletion src/peft/tuners/lycoris_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def _create_new_module(cls, config: LycorisConfig, adapter_name: str, target: nn
else:
target_base_layer = target

if isinstance(target_base_layer, torch.nn.Conv2d):
if isinstance(target_base_layer, (torch.nn.Conv2d, torch.nn.Conv1d)):
new_module = new_module_cls(target, adapter_name=adapter_name, **kwargs)
elif isinstance(target_base_layer, torch.nn.Linear):
new_module = new_module_cls(target, adapter_name=adapter_name, **kwargs)
Expand Down
50 changes: 50 additions & 0 deletions tests/test_custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,8 @@
("Conv2d 2 LOHA", "Conv2d", LoHaConfig, {"target_modules": ["conv2d", "lin0"]}),
("Conv2d 3 LOHA", "Conv2d", LoHaConfig, {"target_modules": ["conv2d"], "use_effective_conv2d": True}),
("Conv2d 4 LOHA", "Conv2d", LoHaConfig, {"target_modules": ["conv2d", "lin0"], "use_effective_conv2d": True}),
("Conv1d LOHA", "Conv1d", LoHaConfig, {"target_modules": ["conv1d"]}),
("Conv2d 1x1 LOHA", "Conv2d1x1", LoHaConfig, {"target_modules": ["conv2d"]}),
# LoKr
("Vanilla MLP 1 LOKR", "MLP", LoKrConfig, {"target_modules": "lin0"}),
("Vanilla MLP 2 LOKR", "MLP", LoKrConfig, {"target_modules": ["lin0"]}),
Expand All @@ -252,6 +254,8 @@
("Conv2d 2 LOKR", "Conv2d", LoKrConfig, {"target_modules": ["conv2d", "lin0"]}),
("Conv2d 3 LOKR", "Conv2d", LoKrConfig, {"target_modules": ["conv2d"], "use_effective_conv2d": True}),
("Conv2d 4 LOKR", "Conv2d", LoKrConfig, {"target_modules": ["conv2d", "lin0"], "use_effective_conv2d": True}),
("Conv1d LOKR", "Conv1d", LoKrConfig, {"target_modules": ["conv1d"]}),
("Conv2d 1x1 LOKR", "Conv2d1x1", LoKrConfig, {"target_modules": ["conv2d"]}),
(
"Conv2d 5 LOKR",
"Conv2d",
Expand Down Expand Up @@ -852,6 +856,46 @@ def forward(self, X):
return X


class ModelConv2D1x1(nn.Module):
def __init__(self):
super().__init__()
self.conv2d = nn.Conv2d(1, 10, kernel_size=(1, 1), padding=0)
self.relu = nn.ReLU()
self.flat = nn.Flatten()
self.lin0 = nn.Linear(10 * 5 * 5, 2)
self.sm = nn.LogSoftmax(dim=-1)
self.dtype = torch.float

def forward(self, X):
X = X.to(self.dtype)
X = X.reshape(-1, 1, 5, 5)
X = self.conv2d(X)
X = self.relu(X)
X = self.flat(X)
X = self.lin0(X)
X = self.sm(X)
return X


class ModelConv1D(nn.Module):
def __init__(self):
super().__init__()
self.conv1d = nn.Conv1d(in_channels=3, out_channels=10, kernel_size=1)
self.relu = nn.ReLU()
self.flat = nn.Flatten()
self.lin0 = nn.Linear(10 * 10, 2)
self.dtype = torch.float

def forward(self, x):
x = x.to(self.dtype)
x = x.reshape(-1, 3, 10) # batch, channels, seq_len
x = self.conv1d(x)
x = self.relu(x)
x = self.flat(x)
x = self.lin0(x)
return x


class ModelConv3D(nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -912,6 +956,12 @@ def from_pretrained(cls, model_id, torch_dtype=None):

if model_id == "Conv2d":
return ModelConv2D().to(torch_dtype)

if model_id == "Conv2d1x1":
return ModelConv2D1x1().to(torch_dtype)

if model_id == "Conv1d":
return ModelConv1D().to(torch_dtype)

if model_id == "Conv3d":
return ModelConv3D().to(torch_dtype)
Expand Down