Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
make style linting
  • Loading branch information
grewalsk committed May 28, 2025
commit cd93eb54148a339e1082e514f20ca679e31158c8
9 changes: 4 additions & 5 deletions src/peft/tuners/loha/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def update_layer(
shape = tuple(base_layer.weight.shape)
elif isinstance(base_layer, nn.Conv2d):
# For 1x1 convolutions, disable effective_conv2d to avoid unnecessary tensor reshaping overhead.
# Since 1x1 convolutions are essentially pointwise operations (matrix multiplications),
# Since 1x1 convolutions are essentially pointwise operations (matrix multiplications),
# they can be more efficiently handled with the flattened weight representation,
# similar to how Linear layers work. This optimization reduces computational cost
# without affecting the mathematical equivalence of the operation.
Expand All @@ -152,8 +152,8 @@ def update_layer(
base_layer.in_channels * base_layer.kernel_size[0] * base_layer.kernel_size[1],
)
elif isinstance(base_layer, nn.Conv1d):
# For Conv1d with kernel_size=1, disable effective_conv2d for the same optimization reasons
# as 1x1 Conv2d. Kernel size 1 means no spatial/temporal context, making it equivalent
# For Conv1d with kernel_size=1, disable effective_conv2d for the same optimization reasons
# as 1x1 Conv2d. Kernel size 1 means no spatial/temporal context, making it equivalent
# to a Linear layer applied across the channel dimension. Using flattened representation
# avoids unnecessary reshaping and improves computational efficiency.
if base_layer.kernel_size[0] == 1:
Expand Down Expand Up @@ -205,8 +205,7 @@ def get_delta_weight(self, adapter_name: str) -> torch.Tensor:
)

base_layer = self.get_base_layer()



# Reshape to match base layer shape
weight = weight.reshape(base_layer.weight.shape)

Expand Down
11 changes: 5 additions & 6 deletions src/peft/tuners/lokr/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def update_layer(
use_w1 = not (decompose_both and r < max(shape[0][0], shape[1][0]) / 2)
use_w2 = r >= max(shape[0][1], shape[1][1]) / 2
# For 1x1 convolutions, disable effective_conv2d to avoid unnecessary tensor reshaping overhead.
# Since 1x1 convolutions are essentially pointwise operations (matrix multiplications),
# Since 1x1 convolutions are essentially pointwise operations (matrix multiplications),
# they can be more efficiently handled with the flattened weight representation,
# similar to how Linear layers work. This optimization reduces computational cost
# without affecting the mathematical equivalence of the operation.
Expand All @@ -230,8 +230,8 @@ def update_layer(

use_w1 = not (decompose_both and r < max(shape[0][0], shape[1][0]) / 2)
use_w2 = r >= max(shape[0][1], shape[1][1]) / 2
# For Conv1d with kernel_size=1, disable effective_conv2d for the same optimization reasons
# as 1x1 Conv2d. Kernel size 1 means no spatial/temporal context, making it equivalent
# For Conv1d with kernel_size=1, disable effective_conv2d for the same optimization reasons
# as 1x1 Conv2d. Kernel size 1 means no spatial/temporal context, making it equivalent
# to a Linear layer applied across the channel dimension. Using flattened representation
# avoids unnecessary reshaping and improves computational efficiency.
if base_layer.kernel_size[0] == 1:
Expand Down Expand Up @@ -273,11 +273,10 @@ def get_delta_weight(self, adapter_name: str) -> torch.Tensor:

# Make weights with Kronecker product
weight = make_kron(w1, w2, self.scaling[adapter_name])

# Get base layer for reshaping
base_layer = self.get_base_layer()



# Regular reshape to match base layer shape
weight = weight.reshape(base_layer.weight.shape)

Expand Down
10 changes: 5 additions & 5 deletions tests/test_custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -974,12 +974,12 @@ def __init__(self):
self.relu = nn.ReLU()
self.flat = nn.Flatten()
self.lin0 = nn.Linear(10 * 5 * 5, 2)
self.sm = nn.LogSoftmax(dim=-1)
self.dtype = torch.float
self.sm = nn.LogSoftmax(dim=-1)
self.dtype = torch.float

def forward(self, X):
X = X.to(self.dtype)
X = X.reshape(-1, 1, 5, 5)
X = X.reshape(-1, 1, 5, 5)
X = self.conv2d(X)
X = self.relu(X)
X = self.flat(X)
Expand Down Expand Up @@ -1094,10 +1094,10 @@ def from_pretrained(cls, model_id, torch_dtype=None):

if model_id == "Conv2d":
return ModelConv2D().to(torch_dtype)

if model_id == "Conv2d1x1":
return ModelConv2D1x1().to(torch_dtype)

if model_id == "Conv1d":
return ModelConv1D().to(torch_dtype)

Expand Down
Loading