Skip to content

Commit d936478

Browse files
authored
ENH Make OFT faster and more memory efficient (#2575)
Make OFT faster and more memory efficient. This new version of OFT is not backwards compatible with older checkpoints and vice versa. To load older checkpoints, downgrade PEFT to 0.15.2 or lower.
1 parent e34852f commit d936478

File tree

18 files changed

+2049
-316
lines changed

18 files changed

+2049
-316
lines changed

docs/source/conceptual_guides/oft.md

Lines changed: 63 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ rendered properly in your Markdown viewer.
1616

1717
# Orthogonal Finetuning (OFT and BOFT)
1818

19-
This conceptual guide gives a brief overview of [OFT](https://huggingface.co/papers/2306.07280) and [BOFT](https://huggingface.co/papers/2311.06243), a parameter-efficient fine-tuning technique that utilizes orthogonal matrix to multiplicatively transform the pretrained weight matrices.
19+
This conceptual guide gives a brief overview of [OFT](https://huggingface.co/papers/2306.07280), [OFTv2](https://www.arxiv.org/abs/2506.19847) and [BOFT](https://huggingface.co/papers/2311.06243), a parameter-efficient fine-tuning technique that utilizes orthogonal matrix to multiplicatively transform the pretrained weight matrices.
2020

21-
To achieve efficient fine-tuning, OFT represents the weight updates with an orthogonal transformation. The orthogonal transformation is parameterized by an orthogonal matrix multiplied to the pretrained weight matrix. These new matrices can be trained to adapt to the new data while keeping the overall number of changes low. The original weight matrix remains frozen and doesnt receive any further adjustments. To produce the final results, both the original and the adapted weights are multiplied togethor.
21+
To achieve efficient fine-tuning, OFT represents the weight updates with an orthogonal transformation. The orthogonal transformation is parameterized by an orthogonal matrix multiplied to the pretrained weight matrix. These new matrices can be trained to adapt to the new data while keeping the overall number of changes low. The original weight matrix remains frozen and doesn't receive any further adjustments. To produce the final results, both the original and the adapted weights are multiplied togethor.
2222

2323
Orthogonal Butterfly (BOFT) generalizes OFT with Butterfly factorization and further improves its parameter efficiency and finetuning flexibility. In short, OFT can be viewed as a special case of BOFT. Different from LoRA that uses additive low-rank weight updates, BOFT uses multiplicative orthogonal weight updates. The comparison is shown below.
2424

@@ -58,13 +58,25 @@ As with other methods supported by PEFT, to fine-tune a model using OFT or BOFT,
5858
4. Train the `PeftModel` as you normally would train the base model.
5959

6060

61+
### OFT-specific parameters
62+
63+
`OFTConfig` allows you to control how OFT is applied to the base model through the following parameters:
64+
65+
- `r`: OFT rank, number of OFT blocks per injected layer. **Bigger** `r` results in more sparse update matrices with **fewer** trainable paramters. **Note**: You can only specify either `r` or `oft_block_size`, but not both simultaneously, because `r` × `oft_block_size` = layer dimension. For simplicity, we let the user speficy either `r` or `oft_block_size` and infer the other one. Default set to `r = 0`, the user is advised to set the `oft_block_size` instead for better clarity.
66+
- `oft_block_size`: OFT block size across different layers. **Bigger** `oft_block_size` results in more dense update matrices with **more** trainable parameters. **Note**: Please choose `oft_block_size` to be divisible by layer's input dimension (`in_features`), e.g., 4, 8, 16. You can only specify either `r` or `oft_block_size`, but not both simultaneously, because `r` × `oft_block_size` = layer dimension. For simplicity, we let the user speficy either `r` or `oft_block_size` and infer the other one. Default set to `oft_block_size = 32`.
67+
- `use_cayley_neumann`: Specifies whether to use the Cayley-Neumann parameterization (efficient but approximate) or the vanilla Cayley parameterization (exact but computationally expensive because of matrix inverse). We recommend to set it to `True` for better efficiency, but performance may be slightly worse because of the approximation error. Please test both settings (`True` and `False`) depending on your needs. Default is `False`.
68+
- `module_dropout`: The multiplicative dropout probability, by setting OFT blocks to identity during training, similar to the dropout layer in LoRA.
69+
- `bias`: specify if the `bias` parameters should be trained. Can be `"none"`, `"all"` or `"oft_only"`.
70+
- `target_modules`: The modules (for example, attention blocks) to inject the OFT matrices.
71+
- `modules_to_save`: List of modules apart from OFT matrices to be set as trainable and saved in the final checkpoint. These typically include model's custom head that is randomly initialized for the fine-tuning task.
72+
6173
### BOFT-specific parameters
6274

63-
`BOFTConfig` allows you to control how OFT/BOFT is applied to the base model through the following parameters:
75+
`BOFTConfig` allows you to control how BOFT is applied to the base model through the following parameters:
6476

65-
- `boft_block_size`: the BOFT matrix block size across different layers, expressed in `int`. Smaller block size results in sparser update matrices with fewer trainable parameters. **Note**, please choose `boft_block_size` to be divisible by most layer's input dimension (`in_features`), e.g., 4, 8, 16. Also, please only
77+
- `boft_block_size`: the BOFT matrix block size across different layers, expressed in `int`. **Bigger** `boft_block_size` results in more dense update matrices with **more** trainable parameters. **Note**, please choose `boft_block_size` to be divisible by most layer's input dimension (`in_features`), e.g., 4, 8, 16. Also, please only
6678
specify either `boft_block_size` or `boft_block_num`, but not both simultaneously or leaving both to 0, because `boft_block_size` x `boft_block_num` must equal the layer's input dimension.
67-
- `boft_block_num`: the number of BOFT matrix blocks across different layers, expressed in `int`. Fewer blocks result in sparser update matrices with fewer trainable parameters. **Note**, please choose `boft_block_num` to be divisible by most layer's input dimension (`in_features`), e.g., 4, 8, 16. Also, please only
79+
- `boft_block_num`: the number of BOFT matrix blocks across different layers, expressed in `int`. **Bigger** `boft_block_num` result in sparser update matrices with **fewer** trainable parameters. **Note**, please choose `boft_block_num` to be divisible by most layer's input dimension (`in_features`), e.g., 4, 8, 16. Also, please only
6880
specify either `boft_block_size` or `boft_block_num`, but not both simultaneously or leaving both to 0, because `boft_block_size` x `boft_block_num` must equal the layer's input dimension.
6981
- `boft_n_butterfly_factor`: the number of butterfly factors. **Note**, for `boft_n_butterfly_factor=1`, BOFT is the same as vanilla OFT, for `boft_n_butterfly_factor=2`, the effective block size of OFT becomes twice as big and the number of blocks become half.
7082
- `bias`: specify if the `bias` parameters should be trained. Can be `"none"`, `"all"` or `"boft_only"`.
@@ -74,6 +86,52 @@ specify either `boft_block_size` or `boft_block_num`, but not both simultaneousl
7486

7587

7688

89+
## OFT Example Usage
90+
91+
For using OFT for quantized finetuning with [TRL](https://github.com/huggingface/trl) for `SFT`, `PPO`, or `DPO` fine-tuning, follow the following outline:
92+
93+
```py
94+
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
95+
from trl import SFTTrainer
96+
from peft import OFTConfig
97+
98+
if use_quantization:
99+
bnb_config = BitsAndBytesConfig(
100+
load_in_4bit=True,
101+
bnb_4bit_quant_type="nf4",
102+
bnb_4bit_compute_dtype=torch.bfloat16,
103+
bnb_4bit_use_double_quant=True,
104+
bnb_4bit_quant_storage=torch.bfloat16,
105+
)
106+
107+
model = AutoModelForCausalLM.from_pretrained(
108+
"model_name",
109+
quantization_config=bnb_config
110+
)
111+
tokenizer = AutoTokenizer.from_pretrained("model_name")
112+
113+
# Configure OFT
114+
peft_config = OFTConfig(
115+
oft_block_size=32,
116+
use_cayley_neumann=True,
117+
target_modules="all-linear",
118+
bias="none",
119+
task_type="CAUSAL_LM"
120+
)
121+
122+
trainer = SFTTrainer(
123+
model=model,
124+
train_dataset=ds['train'],
125+
peft_config=peft_config,
126+
tokenizer=tokenizer,
127+
args=training_arguments,
128+
data_collator=collator,
129+
)
130+
131+
trainer.train()
132+
```
133+
134+
77135
## BOFT Example Usage
78136

79137
For an example of the BOFT method application to various downstream tasks, please refer to the following guides:

src/peft/tuners/oft/__init__.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,41 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from peft.import_utils import is_bnb_4bit_available, is_bnb_available, is_eetq_available
1516
from peft.utils import register_peft_method
1617

1718
from .config import OFTConfig
19+
from .gptq import GPTQOFTLinear
1820
from .layer import Conv2d, Linear, OFTLayer
1921
from .model import OFTModel
2022

2123

22-
__all__ = ["Conv2d", "Linear", "OFTConfig", "OFTLayer", "OFTModel"]
24+
__all__ = [
25+
"Conv2d",
26+
"GPTQOFTLinear",
27+
"Linear",
28+
"OFTConfig",
29+
"OFTLayer",
30+
"OFTModel",
31+
]
2332

2433
register_peft_method(name="oft", config_cls=OFTConfig, model_cls=OFTModel)
34+
35+
36+
def __getattr__(name):
37+
if (name == "Linear8bitLt") and is_bnb_available():
38+
from .bnb import Linear8bitLt
39+
40+
return Linear8bitLt
41+
42+
if (name == "Linear4bit") and is_bnb_4bit_available():
43+
from .bnb import Linear4bit
44+
45+
return Linear4bit
46+
47+
if (name == "EetqOFTLinear") and is_eetq_available():
48+
from .eetq import EetqOFTLinear
49+
50+
return EetqOFTLinear
51+
52+
raise AttributeError(f"module {__name__} has no attribute {name}")

src/peft/tuners/oft/aqlm.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# Copyright 2025-present the HuggingFace Inc. team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Any, Optional
16+
17+
import torch
18+
19+
from peft.import_utils import is_aqlm_available
20+
from peft.tuners.oft.layer import OFTLayer
21+
from peft.tuners.tuners_utils import BaseTunerLayer
22+
23+
24+
if is_aqlm_available():
25+
from aqlm import QuantizedLinear
26+
27+
28+
class AqlmOFTLinear(torch.nn.Module, OFTLayer):
29+
def __init__(
30+
self,
31+
base_layer,
32+
adapter_name: str,
33+
r: int = 0,
34+
oft_block_size: int = 32,
35+
module_dropout: float = 0.0,
36+
init_weights: bool = True,
37+
coft: bool = False,
38+
eps: float = 6e-5,
39+
block_share: bool = False,
40+
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
41+
use_cayley_neumann: bool = False,
42+
num_cayley_neumann_terms: int = 5,
43+
**kwargs,
44+
):
45+
super().__init__()
46+
OFTLayer.__init__(self, base_layer)
47+
48+
self._active_adapter = adapter_name
49+
self.update_layer(
50+
adapter_name,
51+
r,
52+
oft_block_size=oft_block_size,
53+
module_dropout=module_dropout,
54+
init_weights=init_weights,
55+
coft=coft,
56+
eps=eps,
57+
block_share=block_share,
58+
use_cayley_neumann=use_cayley_neumann,
59+
num_cayley_neumann_terms=num_cayley_neumann_terms,
60+
)
61+
62+
def forward(self, x: torch.Tensor):
63+
# note: logic differs from default Linear because merging is not supported
64+
if self.disable_adapters:
65+
return self.base_layer(x)
66+
67+
for active_adapter in self.active_adapters:
68+
if active_adapter not in self.oft_R.keys():
69+
continue
70+
oft_R = self.oft_R[active_adapter]
71+
72+
requires_conversion = not torch.is_autocast_enabled()
73+
if requires_conversion:
74+
expected_dtype = x.dtype
75+
x = self._cast_input_dtype(x, oft_R.weight.dtype)
76+
77+
x = oft_R(x)
78+
79+
result = self.base_layer(x)
80+
if requires_conversion:
81+
result = result.to(expected_dtype)
82+
return result
83+
84+
def __repr__(self) -> str:
85+
rep = super().__repr__()
86+
return "oft." + rep
87+
88+
89+
def dispatch_aqlm(
90+
target: torch.nn.Module,
91+
adapter_name: str,
92+
**kwargs: Any,
93+
) -> Optional[torch.nn.Module]:
94+
new_module = None
95+
96+
if isinstance(target, BaseTunerLayer):
97+
target_base_layer = target.get_base_layer()
98+
else:
99+
target_base_layer = target
100+
101+
if is_aqlm_available() and isinstance(target_base_layer, QuantizedLinear):
102+
new_module = AqlmOFTLinear(target, adapter_name, **kwargs)
103+
target.qweight = target_base_layer.codes
104+
105+
return new_module

src/peft/tuners/oft/awq.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# Copyright 2025-present the HuggingFace Inc. team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import importlib.metadata as importlib_metadata
15+
from typing import Any, Optional
16+
17+
import packaging.version
18+
import torch
19+
20+
from peft.import_utils import is_auto_awq_available
21+
from peft.tuners.oft.layer import OFTLayer
22+
from peft.tuners.tuners_utils import BaseTunerLayer
23+
24+
25+
class AwqOFTLinear(torch.nn.Module, OFTLayer):
26+
def __init__(
27+
self,
28+
base_layer,
29+
adapter_name,
30+
r: int = 0,
31+
oft_block_size: int = 32,
32+
module_dropout: float = 0.0,
33+
coft: bool = False,
34+
eps: float = 6e-5,
35+
block_share: bool = False,
36+
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
37+
init_weights: bool = True,
38+
use_cayley_neumann: bool = False,
39+
num_cayley_neumann_terms: int = 5,
40+
**kwargs,
41+
):
42+
super().__init__()
43+
OFTLayer.__init__(self, base_layer)
44+
45+
# self.base_layer and self.quant_linear_module are the same; we need the former for consistency and the latter
46+
# for backwards compatibility
47+
self.quant_linear_module = base_layer
48+
49+
self._active_adapter = adapter_name
50+
self.update_layer(
51+
adapter_name,
52+
r,
53+
oft_block_size=oft_block_size,
54+
module_dropout=module_dropout,
55+
coft=coft,
56+
eps=eps,
57+
block_share=block_share,
58+
init_weights=init_weights,
59+
use_cayley_neumann=use_cayley_neumann,
60+
num_cayley_neumann_terms=num_cayley_neumann_terms,
61+
)
62+
63+
def forward(self, x: torch.Tensor):
64+
if self.disable_adapters:
65+
result = self.quant_linear_module(x)
66+
return result
67+
68+
for active_adapter in self.active_adapters:
69+
if active_adapter not in self.oft_R.keys():
70+
continue
71+
oft_R = self.oft_R[active_adapter]
72+
73+
requires_conversion = not torch.is_autocast_enabled()
74+
if requires_conversion:
75+
expected_dtype = x.dtype
76+
x = self._cast_input_dtype(x, oft_R.weight.dtype)
77+
78+
x = oft_R(x)
79+
if requires_conversion:
80+
x = x.to(expected_dtype)
81+
82+
result = self.quant_linear_module(x)
83+
return result
84+
85+
def __repr__(self) -> str:
86+
rep = super().__repr__()
87+
return "oft." + rep
88+
89+
90+
def dispatch_awq(
91+
target: torch.nn.Module,
92+
adapter_name: str,
93+
**kwargs: Any,
94+
) -> Optional[torch.nn.Module]:
95+
new_module = None
96+
97+
if isinstance(target, BaseTunerLayer):
98+
target_base_layer = target.get_base_layer()
99+
else:
100+
target_base_layer = target
101+
102+
if is_auto_awq_available():
103+
from awq.modules.linear import WQLinear_GEMM
104+
105+
if isinstance(target_base_layer, WQLinear_GEMM):
106+
# Raise the error only at the dispatch level
107+
AUTOAWQ_MINIMUM_VERSION = packaging.version.parse("0.2.0")
108+
version_autoawq = packaging.version.parse(importlib_metadata.version("autoawq"))
109+
110+
if AUTOAWQ_MINIMUM_VERSION > version_autoawq:
111+
raise ImportError(
112+
f"Found an incompatible version of auto-awq. Found version {version_autoawq}, "
113+
f"but only versions above {AUTOAWQ_MINIMUM_VERSION} are supported for PEFT."
114+
)
115+
116+
new_module = AwqOFTLinear(target, adapter_name, **kwargs)
117+
target.qweight = target_base_layer.qweight
118+
119+
return new_module

0 commit comments

Comments
 (0)