Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Address review comments
  • Loading branch information
nemo committed Aug 27, 2025
commit c7ebc5837456423fc5119b498732ea5f767fd553
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -144,4 +144,5 @@ wandb
method_comparison/MetaMathQA/cancelled_results/
method_comparison/MetaMathQA/temporary_results/

# Coding agents
**/.claude/settings.local.json
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's unrelated but probably not too uncommon anymore, so I'd vote to keep it instead + adding a # Coding agents comment. WDYT?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't this open the floodgates? I'd rather recommend folks to use a global .gitignore for their specific stuff.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've removed it for now. Let's see if we need to comment on this often in the future, then we can add it again.

8 changes: 6 additions & 2 deletions src/peft/tuners/loha/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ class LoHaConfig(LycorisConfig):
module_dropout (`float`):
The dropout probability for disabling LoHa modules during training.
use_effective_conv2d (`bool`):
Use parameter effective decomposition for Conv2d with ksize > 1 ("Proposition 3" from FedPara paper).
Use parameter effective decomposition for Conv2d (and Conv1d) with ksize > 1 ("Proposition 3" from FedPara
paper).
target_modules (`Optional[Union[List[str], str]]`):
The names of the modules to apply the adapter to. If this is specified, only the modules with the specified
names will be replaced. When passing a string, a regex match will be performed. When passing a list of
Expand Down Expand Up @@ -79,7 +80,10 @@ class LoHaConfig(LycorisConfig):
use_effective_conv2d: bool = field(
default=False,
metadata={
"help": 'Use parameter effective decomposition for Conv2d 3x3 with ksize > 1 ("Proposition 3" from FedPara paper)'
"help": (
"Use parameter effective decomposition for Conv2d (and Conv1d) with ksize > 1 "
'("Proposition 3" from FedPara paper)'
)
},
)
target_modules: Optional[Union[list[str], str]] = field(
Expand Down
14 changes: 4 additions & 10 deletions src/peft/tuners/loha/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@ def create_adapter_parameters(self, adapter_name: str, r: int, shape: tuple[int,
self.hada_w2_a[adapter_name] = nn.Parameter(torch.empty(r, shape[0])) # out_dim, 1-mode
self.hada_w2_b[adapter_name] = nn.Parameter(torch.empty(r, shape[1])) # in_dim , 2-mode
elif len(shape) == 3: # Conv1d
self.hada_t1[adapter_name] = nn.Parameter(torch.empty(r, r, shape[2]))
self.hada_t1[adapter_name] = nn.Parameter(torch.empty(r, r, shape[2], 1))
self.hada_w1_a[adapter_name] = nn.Parameter(torch.empty(r, shape[0])) # out_dim, 1-mode
self.hada_w1_b[adapter_name] = nn.Parameter(torch.empty(r, shape[1])) # in_dim , 2-mode

self.hada_t2[adapter_name] = nn.Parameter(torch.empty(r, r, shape[2]))
self.hada_t2[adapter_name] = nn.Parameter(torch.empty(r, r, shape[2], 1))
self.hada_w2_a[adapter_name] = nn.Parameter(torch.empty(r, shape[0])) # out_dim, 1-mode
self.hada_w2_b[adapter_name] = nn.Parameter(torch.empty(r, shape[1])) # in_dim , 2-mode
else: # Linear
Expand Down Expand Up @@ -140,10 +140,7 @@ def update_layer(
# they can be more efficiently handled with the flattened weight representation,
# similar to how Linear layers work. This optimization reduces computational cost
# without affecting the mathematical equivalence of the operation.
if base_layer.kernel_size == (1, 1):
use_effective_conv2d = False
else:
use_effective_conv2d = use_effective_conv2d
use_effective_conv2d = use_effective_conv2d and base_layer.kernel_size != (1, 1)
if use_effective_conv2d:
shape = (base_layer.out_channels, base_layer.in_channels, *base_layer.kernel_size)
else:
Expand All @@ -156,10 +153,7 @@ def update_layer(
# as 1x1 Conv2d. Kernel size 1 means no spatial/temporal context, making it equivalent
# to a Linear layer applied across the channel dimension. Using flattened representation
# avoids unnecessary reshaping and improves computational efficiency.
if base_layer.kernel_size[0] == 1:
use_effective_conv2d = False
else:
use_effective_conv2d = use_effective_conv2d
use_effective_conv2d = use_effective_conv2d and base_layer.kernel_size[0] != 1
if use_effective_conv2d:
shape = (base_layer.out_channels, base_layer.in_channels, base_layer.kernel_size[0])
else:
Expand Down
8 changes: 6 additions & 2 deletions src/peft/tuners/lokr/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ class LoKrConfig(LycorisConfig):
module_dropout (`float`):
The dropout probability for disabling LoKr modules during training.
use_effective_conv2d (`bool`):
Use parameter effective decomposition for Conv2d with ksize > 1 ("Proposition 3" from FedPara paper).
Use parameter effective decomposition for Conv2d (and Conv1d) with ksize > 1 ("Proposition 3" from FedPara
paper).
decompose_both (`bool`):
Perform rank decomposition of left kronecker product matrix.
decompose_factor (`int`):
Expand Down Expand Up @@ -85,7 +86,10 @@ class LoKrConfig(LycorisConfig):
use_effective_conv2d: bool = field(
default=False,
metadata={
"help": 'Use parameter effective decomposition for Conv2d 3x3 with ksize > 1 ("Proposition 3" from FedPara paper)'
"help": (
"Use parameter effective decomposition for Conv2d (and Conv1d) with ksize > 1 "
'("Proposition 3" from FedPara paper)'
)
},
)
decompose_both: bool = field(
Expand Down
14 changes: 5 additions & 9 deletions src/peft/tuners/lokr/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ def create_adapter_parameters(
if use_w2:
self.lokr_w2[adapter_name] = nn.Parameter(torch.empty(shape[0][1], shape[1][1], shape[2]))
elif use_effective_conv2d: # Even for Conv1d, use the effective parameter for kernel dimension
self.lokr_t2[adapter_name] = nn.Parameter(torch.empty(r, r, shape[2]))
# We pass (r, r, kernel_size, 1) in order to be compatible with the 2d assumptions made
# in make_weight_cp (only relevant for the effective conv2d case).
self.lokr_t2[adapter_name] = nn.Parameter(torch.empty(r, r, shape[2], 1))
self.lokr_w2_a[adapter_name] = nn.Parameter(torch.empty(r, shape[0][1])) # b, 1-mode
self.lokr_w2_b[adapter_name] = nn.Parameter(torch.empty(r, shape[1][1])) # d, 2-mode
else:
Expand Down Expand Up @@ -216,10 +218,7 @@ def update_layer(
# they can be more efficiently handled with the flattened weight representation,
# similar to how Linear layers work. This optimization reduces computational cost
# without affecting the mathematical equivalence of the operation.
if base_layer.kernel_size == (1, 1):
use_effective_conv2d = False
else:
use_effective_conv2d = use_effective_conv2d
use_effective_conv2d = use_effective_conv2d and base_layer.kernel_size != (1, 1)
elif isinstance(base_layer, nn.Conv1d):
in_dim, out_dim = base_layer.in_channels, base_layer.out_channels
k_size = (base_layer.kernel_size[0],) # Convert to a tuple with single element
Expand All @@ -234,10 +233,7 @@ def update_layer(
# as 1x1 Conv2d. Kernel size 1 means no spatial/temporal context, making it equivalent
# to a Linear layer applied across the channel dimension. Using flattened representation
# avoids unnecessary reshaping and improves computational efficiency.
if base_layer.kernel_size[0] == 1:
use_effective_conv2d = False
else:
use_effective_conv2d = use_effective_conv2d
use_effective_conv2d = use_effective_conv2d and base_layer.kernel_size[0] != 1
else:
raise TypeError(f"LoKr is not implemented for base layers of type {type(base_layer).__name__}")

Expand Down
54 changes: 53 additions & 1 deletion tests/test_custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,20 @@
("Conv2d 3 LOHA", "Conv2d", LoHaConfig, {"target_modules": ["conv2d"], "use_effective_conv2d": True}),
("Conv2d 4 LOHA", "Conv2d", LoHaConfig, {"target_modules": ["conv2d", "lin0"], "use_effective_conv2d": True}),
("Conv1d LOHA", "Conv1d", LoHaConfig, {"target_modules": ["conv1d"]}),
("Conv1d LOHA 1", "Conv1d", LoHaConfig, {"target_modules": ["conv1d"]}),
("Conv1d LOHA 2", "Conv1d", LoHaConfig, {"target_modules": ["conv1d"], "r": 2}),
(
"Conv1d LOHA 3",
"Conv1dBigger",
LoHaConfig,
{"target_modules": ["conv1d"], "r": 2, "use_effective_conv2d": True},
),
(
"Conv1d LOHA 4",
"Conv1dBigger",
LoHaConfig,
{"target_modules": ["conv1d"], "r": 2, "use_effective_conv2d": False},
),
("Conv2d 1x1 LOHA", "Conv2d1x1", LoHaConfig, {"target_modules": ["conv2d"]}),
# LoKr
("Vanilla MLP 1 LOKR", "MLP", LoKrConfig, {"target_modules": "lin0"}),
Expand All @@ -287,11 +301,24 @@
),
("Vanilla MLP 7 LOKR", "MLP", LoKrConfig, {"target_modules": "lin0", "rank_dropout": 0.5}),
("Vanilla MLP 8 LOKR", "MLP", LoKrConfig, {"target_modules": "lin0", "decompose_both": True, "r": 1, "alpha": 1}),
("Conv1d LOKR 1", "Conv1d", LoKrConfig, {"target_modules": ["conv1d"]}),
("Conv1d LOKR 2", "Conv1d", LoKrConfig, {"target_modules": ["conv1d"], "r": 2}),
(
"Conv1d LOKR 3",
"Conv1dBigger",
LoKrConfig,
{"target_modules": ["conv1d"], "r": 2, "use_effective_conv2d": True},
),
(
"Conv1d LOKR 4",
"Conv1dBigger",
LoKrConfig,
{"target_modules": ["conv1d"], "r": 2, "use_effective_conv2d": False},
),
("Conv2d 1 LOKR", "Conv2d", LoKrConfig, {"target_modules": ["conv2d"]}),
("Conv2d 2 LOKR", "Conv2d", LoKrConfig, {"target_modules": ["conv2d", "lin0"]}),
("Conv2d 3 LOKR", "Conv2d", LoKrConfig, {"target_modules": ["conv2d"], "use_effective_conv2d": True}),
("Conv2d 4 LOKR", "Conv2d", LoKrConfig, {"target_modules": ["conv2d", "lin0"], "use_effective_conv2d": True}),
("Conv1d LOKR", "Conv1d", LoKrConfig, {"target_modules": ["conv1d"]}),
("Conv2d 1x1 LOKR", "Conv2d1x1", LoKrConfig, {"target_modules": ["conv2d"]}),
(
"Conv2d 5 LOKR",
Expand Down Expand Up @@ -1193,6 +1220,28 @@ def forward(self, X):
return X


class ModelConv1DBigger(nn.Module):
def __init__(self):
super().__init__()
self.conv1d = nn.Conv1d(64, 16, 2)
self.relu = nn.ReLU()
self.flat = nn.Flatten()
self.lin0 = nn.Linear(144, 2)
self.sm = nn.LogSoftmax(dim=-1)
self.dtype = torch.float

def forward(self, X):
X = X.to(self.dtype)
X = X.reshape(-1, 1, 10)
X = torch.concat([X] * 64, dim=1)
X = self.conv1d(X)
X = self.relu(X)
X = self.flat(X)
X = self.lin0(X)
X = self.sm(X)
return X


class ModelConv2D(nn.Module):
def __init__(self, bias=True):
super().__init__()
Expand Down Expand Up @@ -1426,6 +1475,9 @@ def from_pretrained(cls, model_id, torch_dtype=None):
if model_id == "Conv1d":
return ModelConv1D().to(torch_dtype)

if model_id == "Conv1dBigger":
return ModelConv1DBigger().to(torch_dtype)

if model_id == "Conv2d":
return ModelConv2D().to(torch_dtype)

Expand Down