Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
In this commit, we
1. add unit test tests/test_lorafa.py
2. add docs in docs/source/developer_guides/lora.md -> #optimizers
3. a working example for fine-tuning meta-llama/Meta-Llama-3-8B on meta-math/MetaMathQA-40K using LoRA-FA optimizer
4. delete kwargs
5. ruff style
  • Loading branch information
AaronZLT committed Apr 3, 2025
commit 496945178435b055a2d433e5acae812f93660504
35 changes: 34 additions & 1 deletion docs/source/developer_guides/lora.md
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,40 @@ The same logic applies to `alpha_pattern`. If you're in doubt, don't try to get

## Optimizers

LoRA training can optionally include special purpose optimizers. Currently the only such optimizer is LoRA+.
LoRA training can optionally include special purpose optimizers. Currently PEFT supports LoRA-FA and LoRA+.

### LoRA-FA Optimizer

LoRA training can be more effective and efficient using LoRA-FA, as described in [LoRA-FA](https://arxiv.org/abs/2308.03303). LoRA-FA reduces activation memory consumption by fixing the matrix A and only tuning the matrix B. During training, the gradient of B is optimized to approximate the full parameter fine-tuning gradient.

```py
from peft import LoraConfig, get_peft_model
from peft.optimizers import create_lorafa_optimizer
from transformers import Trainer, get_cosine_schedule_with_warmup

base_model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")

config = LoraConfig(...)
model = get_peft_model(base_model, config)

optimizer = create_lorafa_optimizer(
model=model,
r=16,
lora_alpha=32,
learning_rate=7e-5,
)

scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=100,
num_training_steps=1000,
)

trainer = Trainer(
...,
optimizers=(optimizer, scheduler),
)
```

### LoRA+ optimized LoRA

Expand Down
93 changes: 93 additions & 0 deletions examples/lorafa_finetune/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# LoRA-FA: Memory-efficient Low-rank Adaptation for Large Language Models Fine-tuning

## Introduction

[LoRA-FA](https://arxiv.org/abs/2308.03303) is a noval Parameter-efficient Fine-tuning method, which freeze the projection down layer (matrix A) during LoRA training process and thus lead to less GPU memory consumption by eliminating the needing for storage the activations of input tensors (X). Futhermore, LoRA-FA Narrows the gap between the update amount of pre-trained weights when using the low-rank fine-tuning method and the full fine-tuning method. In conclusion, LoRA-FA reduce the memory consumption and lead to suprior performance compared with vanilla LoRA.

## Quick start

```python
import torch
from peft import LoraConfig, get_peft_model
from peft.optimizers import create_lorafa_optimizer
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer
from datasets import load_dataset

model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
dataset = load_dataset("timdettmers/openassistant-guanaco", split="train")

lora_rank = 16
lora_alpha = 32

lora_config = LoraConfig(
r=lora_rank,
lora_alpha=lora_alpha,
bias="none",
)
peft_model = get_peft_model(model, lora_config)
optimizer = create_lorafa_optimizer(
model=peft_model,
r=lora_rank,
lora_alpha=lora_alpha,
learning_rate=7e-5,
)
# you can also use scheduler, we recommend get_cosine_schedule_with_warmup from transformers
# for better model performance
scheduler = None

trainer = transformers.Trainer(
model=peft_model,
train_dataset=dataset,
dataset_text_field="text",
max_seq_length=2048,
tokenizer=tokenizer,
optimizers=(optimizer, None),
)
trainer.train()
peft_model.save_pretrained("lorafa-llama-3-8b-inst")
```

The only change in your code is to pass the LoRA-FA optimizer to the trainer (if training with trainer). Do not forget `from peft.optimizers import create_lorafa_optimizer`!

In this dir, we also provide you a very toy example for fine-tuning with LoRA-FA optimizer. Run the finetuning script simply by running:

```bash
accelerate launch examples/lorafa_finetuning/lorafa_finetuning.py --base_model_name_or_path meta-llama/Meta-Llama-3-8B --dataset_name_or_path meta-math/MetaMathQA-40K --lorafa
```

This 👆🏻 by default will load the model in peft set up with LoRA config, and train the model with LoRA-FA optimizer. The `accelerate launch` will automatically configure single-GPU or multi-GPU for you.

LoRA-FA also supports quantization. To use bitsandbytes NF4 4-bit quantization try:

```bash
accelerate launch examples/lorafa_finetuning/lorafa_finetuning.py --base_model_name_or_path meta-llama/Meta-Llama-3-8B --dataset_name_or_path meta-math/MetaMathQA-40K --lorafa --quantize
```

## Use the model from 🤗
You can load and use the model as any other 🤗 models.
```python
from transformers import AutoModel
model = AutoModel.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
```

### Best practice in fine-tuning Llama on Metamath using LoRA-FA

In LoRA-FA's fine-tuning, we recommend to use a larger lora rank such as 64 or 128 (max). LoRA-FA can achieve 57.3% on GSM8K, just by fine-tuning Llama-2-7b-chat-hf on meta-math/MetaMathQA-40K for 3 epochs! For the best practices you can just check [LoRA-FA examples](https://github.com/AaronZLT/lorafa).

## LoRA-FA vs. LoRA

Despite its advantages, LoRA-FA remains inherently constrained by its low-rank approximation nature and potential catastrophic forgetting. Besides, since LoRA-FA has less trainable parameter than LoRA, LoRA-FA may converge slower than LoRA and requires larger lora rank and fine-grained hyper-parameter (mainly learning rate) search. Addressing these limitations, particularly approximation accuracy and forgetting phenomena, represents a promising direction for future work.

## Citation
```
@misc{zhang2023lorafamemoryefficientlowrankadaptation,
title={LoRA-FA: Memory-efficient Low-rank Adaptation for Large Language Models Fine-tuning},
author={Longteng Zhang and Lin Zhang and Shaohuai Shi and Xiaowen Chu and Bo Li},
year={2023},
eprint={2308.03303},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2308.03303},
}
```
195 changes: 195 additions & 0 deletions examples/lorafa_finetune/lorafa_finetuning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
import os

import torch
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
DataCollatorForLanguageModeling,
Trainer,
TrainingArguments,
)

from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from peft.optimizers import create_lorafa_optimizer


def train_model(
base_model_name_or_path: str,
dataset_name_or_path: str,
output_dir: str,
batch_size: int,
num_epochs: int,
learning_rate: float,
cutoff_len: int,
quantize: bool,
eval_step: int,
save_step: int,
lora_rank: int,
lora_alpha: int,
lora_dropout: float,
lora_target_modules: str,
lorafa: bool,
):
os.environ["TOKENIZERS_PARALLELISM"] = "false"

print("In this example we only spport GPU training.")

# load tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model_name_or_path)

# QLoRA-FA (quantized LoRA-FA): IF YOU WANNA QUANTIZE THE MODEL
if quantize:
model = AutoModelForCausalLM.from_pretrained(
base_model_name_or_path,
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=(
torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
),
bnb_4bit_use_double_quant=False,
bnb_4bit_quant_type="nf4",
),
torch_dtype=torch.bfloat16,
)
# setup for quantized training
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
else:
model = AutoModelForCausalLM.from_pretrained(base_model_name_or_path, torch_dtype=torch.bfloat16)
# LoRA config for the PEFT model
lora_config = LoraConfig(
r=lora_rank,
lora_alpha=lora_alpha,
target_modules=(
lora_target_modules.split(",")
if lora_target_modules
else ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
),
lora_dropout=lora_dropout,
bias="none",
)

# get the peft model with LoRA config
model = get_peft_model(model, lora_config)

tokenizer.pad_token = tokenizer.eos_token

# Load the dataset
dataset = load_dataset(dataset_name_or_path)

def tokenize_function(examples):
inputs = tokenizer(examples["query"], padding="max_length", truncation=True, max_length=cutoff_len)
outputs = tokenizer(examples["response"], padding="max_length", truncation=True, max_length=cutoff_len)
inputs["labels"] = outputs["input_ids"].copy()
return inputs

# Tokenize the dataset and prepare for training
tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=dataset["train"].column_names)
dataset = tokenized_datasets["train"].train_test_split(test_size=0.1, shuffle=True, seed=42)
train_dataset = dataset["train"]
eval_dataset = dataset["test"]

# Data collator to dynamically pad the batched examples
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

# Define training arguments
training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=num_epochs,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
warmup_steps=100,
weight_decay=0.01,
logging_dir="./logs",
logging_steps=eval_step,
save_steps=save_step,
save_total_limit=2,
gradient_accumulation_steps=1,
fp16=True,
learning_rate=learning_rate,
)

# Clear CUDA cache to free memory
torch.cuda.empty_cache()

# Here we initialize the LoRA-FA Optimizer
# After this, all adapter A will be fixed, only adapter B will be trainable
if lorafa:
optimizer = create_lorafa_optimizer(
model=model, r=lora_rank, lora_alpha=lora_alpha, learning_rate=learning_rate
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=data_collator,
optimizers=(optimizer, None),
)
else:
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=data_collator,
)

# Start model training
trainer.train()

# Save the model and tokenizer locally
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser(description="Fine-tune Meta-Llama-3-8B-Instruct with LoRA-FA and PEFT")
parser.add_argument(
"--base_model_name_or_path",
type=str,
default="meta-llama/Meta-Llama-3-8B-Instruct",
help="Base model name or path",
)
parser.add_argument(
"--dataset_name_or_path", type=str, default="meta-math/MetaMathQA-40K", help="Dataset name or path"
)
parser.add_argument(
"--output_dir", type=str, default="path/to/output", help="Output directory for the fine-tuned model"
)
parser.add_argument("--batch_size", type=int, default=1, help="Batch size")
parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs")
parser.add_argument("--learning_rate", type=float, default=2e-4, help="Learning rate")
parser.add_argument("--cutoff_len", type=int, default=512, help="Cutoff length for tokenization")
parser.add_argument("--quantize", action="store_true", help="Use quantization")
parser.add_argument("--eval_step", type=int, default=10, help="Evaluation step interval")
parser.add_argument("--save_step", type=int, default=100, help="Save step interval")
parser.add_argument("--lora_rank", type=int, default=16, help="LoRA rank")
parser.add_argument("--lora_alpha", type=int, default=32, help="LoRA alpha")
parser.add_argument("--lora_dropout", type=float, default=0.05, help="LoRA dropout rate")
parser.add_argument(
"--lora_target_modules", type=str, default=None, help="Comma-separated list of target modules for LoRA"
)
parser.add_argument("--lorafa", action="store_true", help="Use LoRA-FA Optimizer")
args = parser.parse_args()

train_model(
base_model_name_or_path=args.base_model_name_or_path,
dataset_name_or_path=args.dataset_name_or_path,
output_dir=args.output_dir,
batch_size=args.batch_size,
num_epochs=args.num_epochs,
learning_rate=args.learning_rate,
cutoff_len=args.cutoff_len,
quantize=args.quantize,
eval_step=args.eval_step,
save_step=args.save_step,
lora_rank=args.lora_rank,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
lora_target_modules=args.lora_target_modules,
lorafa=args.lorafa,
)
1 change: 1 addition & 0 deletions src/peft/optimizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@
from .lorafa import create_lorafa_optimizer
from .loraplus import create_loraplus_optimizer


__all__ = ["create_lorafa_optimizer", "create_loraplus_optimizer"]
Loading