Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feat(lokr, loha): Refine 1x1 Conv2d/Conv1d handling and test naming c…
…onsistency

Clarify existing optimal handling for 1x1 Conv2d and Conv1d layers;
confirmed  is already correctly disabled in  files,
addressing reviewer feedback on potential dead code.

Standardize test case naming in  for consistency:
- Updated 'LoHa' to 'LOHA' and 'LoKr' to 'LOKR' (all caps).
- Aligned Conv1D/Conv1d naming with PyTorch conventions for clarity.
  • Loading branch information
grewalsk committed May 27, 2025
commit 655ea628d9b950c45b12becbb4bb10d2b282f8e9
18 changes: 9 additions & 9 deletions src/peft/tuners/loha/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,12 @@ def update_layer(
if isinstance(base_layer, nn.Linear):
shape = tuple(base_layer.weight.shape)
elif isinstance(base_layer, nn.Conv2d):
# Handle 1x1 convolutions differently
# For 1x1 convolutions, disable effective_conv2d to avoid unnecessary tensor reshaping overhead.
# Since 1x1 convolutions are essentially pointwise operations (matrix multiplications),
# they can be more efficiently handled with the flattened weight representation,
# similar to how Linear layers work. This optimization reduces computational cost
# without affecting the mathematical equivalence of the operation.
if base_layer.kernel_size == (1, 1):
# For 1x1 convolutions, use a more direct shape without using effective_conv2d
shape = (base_layer.out_channels, base_layer.in_channels, *base_layer.kernel_size)
use_effective_conv2d = False
else:
use_effective_conv2d = use_effective_conv2d
Expand All @@ -150,10 +152,11 @@ def update_layer(
base_layer.in_channels * base_layer.kernel_size[0] * base_layer.kernel_size[1],
)
elif isinstance(base_layer, nn.Conv1d):
# Handle kernel_size=1 Conv1d differently
# For Conv1d with kernel_size=1, disable effective_conv2d for the same optimization reasons
# as 1x1 Conv2d. Kernel size 1 means no spatial/temporal context, making it equivalent
# to a Linear layer applied across the channel dimension. Using flattened representation
# avoids unnecessary reshaping and improves computational efficiency.
if base_layer.kernel_size[0] == 1:
# For kernel_size=1, use a more direct shape without using effective_conv2d
shape = (base_layer.out_channels, base_layer.in_channels, base_layer.kernel_size[0])
use_effective_conv2d = False
else:
use_effective_conv2d = use_effective_conv2d
Expand Down Expand Up @@ -203,9 +206,6 @@ def get_delta_weight(self, adapter_name: str) -> torch.Tensor:

base_layer = self.get_base_layer()

# Special optimization for 1x1 convolutions and kernel_size=1 Conv1d
is_1x1_conv2d = isinstance(base_layer, nn.Conv2d) and base_layer.kernel_size == (1, 1)
is_1_conv1d = isinstance(base_layer, nn.Conv1d) and base_layer.kernel_size[0] == 1

# Reshape to match base layer shape
weight = weight.reshape(base_layer.weight.shape)
Expand Down
16 changes: 9 additions & 7 deletions src/peft/tuners/lokr/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,12 @@ def update_layer(

use_w1 = not (decompose_both and r < max(shape[0][0], shape[1][0]) / 2)
use_w2 = r >= max(shape[0][1], shape[1][1]) / 2
# Handle 1x1 convolutions differently
# For 1x1 convolutions, disable effective_conv2d to avoid unnecessary tensor reshaping overhead.
# Since 1x1 convolutions are essentially pointwise operations (matrix multiplications),
# they can be more efficiently handled with the flattened weight representation,
# similar to how Linear layers work. This optimization reduces computational cost
# without affecting the mathematical equivalence of the operation.
if base_layer.kernel_size == (1, 1):
# For 1x1 convolutions, always disable use_effective_conv2d
use_effective_conv2d = False
else:
use_effective_conv2d = use_effective_conv2d
Expand All @@ -227,9 +230,11 @@ def update_layer(

use_w1 = not (decompose_both and r < max(shape[0][0], shape[1][0]) / 2)
use_w2 = r >= max(shape[0][1], shape[1][1]) / 2
# Handle kernel_size=1 Conv1d differently
# For Conv1d with kernel_size=1, disable effective_conv2d for the same optimization reasons
# as 1x1 Conv2d. Kernel size 1 means no spatial/temporal context, making it equivalent
# to a Linear layer applied across the channel dimension. Using flattened representation
# avoids unnecessary reshaping and improves computational efficiency.
if base_layer.kernel_size[0] == 1:
# For kernel_size=1, always disable use_effective_conv2d
use_effective_conv2d = False
else:
use_effective_conv2d = use_effective_conv2d
Expand Down Expand Up @@ -272,9 +277,6 @@ def get_delta_weight(self, adapter_name: str) -> torch.Tensor:
# Get base layer for reshaping
base_layer = self.get_base_layer()

# Special optimization for 1x1 convolutions and kernel_size=1 Conv1d
is_1x1_conv2d = isinstance(base_layer, nn.Conv2d) and base_layer.kernel_size == (1, 1)
is_1_conv1d = isinstance(base_layer, nn.Conv1d) and base_layer.kernel_size[0] == 1

# Regular reshape to match base layer shape
weight = weight.reshape(base_layer.weight.shape)
Expand Down
8 changes: 4 additions & 4 deletions tests/test_custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,8 @@
("Conv2d 2 LOHA", "Conv2d", LoHaConfig, {"target_modules": ["conv2d", "lin0"]}),
("Conv2d 3 LOHA", "Conv2d", LoHaConfig, {"target_modules": ["conv2d"], "use_effective_conv2d": True}),
("Conv2d 4 LOHA", "Conv2d", LoHaConfig, {"target_modules": ["conv2d", "lin0"], "use_effective_conv2d": True}),
("Conv1D LoHa", "Conv1d", LoHaConfig, {"target_modules": ["conv1d"]}),
("Conv2d 1x1 LoHa", "Conv2d1x1", LoHaConfig, {"target_modules": ["conv2d"]}),
("Conv1d LOHA", "Conv1d", LoHaConfig, {"target_modules": ["conv1d"]}),
("Conv2d 1x1 LOHA", "Conv2d1x1", LoHaConfig, {"target_modules": ["conv2d"]}),
# LoKr
("Vanilla MLP 1 LOKR", "MLP", LoKrConfig, {"target_modules": "lin0"}),
("Vanilla MLP 2 LOKR", "MLP", LoKrConfig, {"target_modules": ["lin0"]}),
Expand All @@ -254,8 +254,8 @@
("Conv2d 2 LOKR", "Conv2d", LoKrConfig, {"target_modules": ["conv2d", "lin0"]}),
("Conv2d 3 LOKR", "Conv2d", LoKrConfig, {"target_modules": ["conv2d"], "use_effective_conv2d": True}),
("Conv2d 4 LOKR", "Conv2d", LoKrConfig, {"target_modules": ["conv2d", "lin0"], "use_effective_conv2d": True}),
("Conv1D LoKr", "Conv1d", LoKrConfig, {"target_modules": ["conv1d"]}),
("Conv2d 1x1 LoKr", "Conv2d1x1", LoKrConfig, {"target_modules": ["conv2d"]}),
("Conv1d LOKR", "Conv1d", LoKrConfig, {"target_modules": ["conv1d"]}),
("Conv2d 1x1 LOKR", "Conv2d1x1", LoKrConfig, {"target_modules": ["conv2d"]}),
(
"Conv2d 5 LOKR",
"Conv2d",
Expand Down