Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix conflict
  • Loading branch information
AaronZLT committed Apr 7, 2025
commit 113152f9f718a1bff889315b48b663c92b6863a8
36 changes: 31 additions & 5 deletions examples/lorafa_finetune/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,22 +50,48 @@ peft_model.save_pretrained("lorafa-llama-3-8b-inst")

The only change in your code is to pass the LoRA-FA optimizer to the trainer (if training with trainer). Do not forget `from peft.optimizers import create_lorafa_optimizer`!

In this dir, we also provide you a simple example for fine-tuning with LoRA-FA optimizer. Run the finetuning script simply by running:
## Example

In this dir, we also provide you a simple example for fine-tuning with LoRA-FA optimizer.

### Run on CPU, single-GPU or multi-GPU

This 👇 by default will load the model in peft set up with LoRA config, and train the model with LoRA-FA optimizer.

0. CPU

You can simply run LoRA-FA as below:

```bash
python lorafa_finetuning.py --base_model_name_or_path meta-llama/Meta-Llama-3-8B --dataset_name_or_path meta-math/MetaMathQA-40K --lorafa
```

1. Single-GPU

Run the finetuning script on 1 GPU:

```bash
CUDA_VISIBLE_DEVICES=0 python lorafa_finetuning.py --base_model_name_or_path meta-llama/Meta-Llama-3-8B --dataset_name_or_path meta-math/MetaMathQA-40K --lorafa
```

2. Multi-GPU

LoRA-FA can also be run on multi-GPU, with 🤗 Accelerate:

```bash
accelerate launch examples/lorafa_finetuning/lorafa_finetuning.py --base_model_name_or_path meta-llama/Meta-Llama-3-8B --dataset_name_or_path meta-math/MetaMathQA-40K --lorafa
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch lorafa_finetuning.py --base_model_name_or_path meta-llama/Meta-Llama-3-8B --dataset_name_or_path meta-math/MetaMathQA-40K --lorafa
```

This 👆🏻 by default will load the model in peft set up with LoRA config, and train the model with LoRA-FA optimizer. The `accelerate launch` will automatically configure single-GPU or multi-GPU for you.
The `accelerate launch` will automatically configure multi-GPU for you. You can also utilize `accelerate launch` in single-GPU scenario.

## Use the model from 🤗
### Use the model from 🤗
You can load and use the model as any other 🤗 models.
```python
from transformers import AutoModel
model = AutoModel.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
```

### Best practice in fine-tuning Llama on Metamath using LoRA-FA: the hyper-params
## Best practice in fine-tuning Llama using LoRA-FA: the hyper-params

Sometimes, achieving optimal LoRA fine-tuning can be challenging due to the larger number of hyperparameters to consider compared to full fine-tuning. For instance, not only do we need to adjust the commonly used learning rate, but the ideal LoRA rank may also vary depending on the specific model and task. Additionally, there are other factors to consider, such as LoRA alpha and sequence length. To assist with this, we have created a repository of reproducible best practices in the [LoRA-FA examples](https://github.com/AaronZLT/lorafa) for reference. This resource showcases the optimal LoRA-FA fine-tuning hyperparameters for different models across various datasets. By doing so, we significantly reduce the time and effort spent on hyperparameter tuning, and it may also provide insights for tuning other training hyperparameters. We encourage you to experiment and fine-tune on your own downstream tasks as well.

Expand Down
44 changes: 24 additions & 20 deletions examples/lorafa_finetune/lorafa_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,38 +48,43 @@ def train_model(
):
os.environ["TOKENIZERS_PARALLELISM"] = "false"

print("In this example we only spport GPU training.")
compute_dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
device_map = "cuda" if torch.cuda.is_available() else None

# load tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model_name_or_path)

# QLoRA-FA (quantized LoRA-FA): IF YOU WANNA QUANTIZE THE MODEL
# load model
if quantize:
model = AutoModelForCausalLM.from_pretrained(
base_model_name_or_path,
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=(
torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
),
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=False,
bnb_4bit_quant_type="nf4",
),
torch_dtype=torch.bfloat16,
torch_dtype=compute_dtype,
device_map=device_map,
)
# setup for quantized training
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
else:
model = AutoModelForCausalLM.from_pretrained(base_model_name_or_path, torch_dtype=torch.bfloat16)
model = AutoModelForCausalLM.from_pretrained(base_model_name_or_path, torch_dtype=compute_dtype, device_map=device_map)

# LoRA config for the PEFT model
if lora_target_modules is not None:
if lora_target_modules == "all-linear":
target_modules = "all-linear"
else:
target_modules = lora_target_modules.split(",")
else:
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]

lora_config = LoraConfig(
r=lora_rank,
lora_alpha=lora_alpha,
target_modules=(
lora_target_modules.split(",")
if lora_target_modules
else ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
),
target_modules=target_modules,
lora_dropout=lora_dropout,
bias="none",
)
Expand Down Expand Up @@ -120,13 +125,11 @@ def tokenize_function(examples):
save_steps=save_step,
save_total_limit=2,
gradient_accumulation_steps=1,
fp16=True,
lr=lr,
bf16 = True if compute_dtype == torch.bfloat16 else False,
fp16=True if compute_dtype == torch.float16 else False,
learning_rate=lr,
)

# Clear CUDA cache to free memory
torch.cuda.empty_cache()

# Here we initialize the LoRA-FA Optimizer
# After this, all adapter A will be fixed, only adapter B will be trainable
if lorafa:
Expand Down Expand Up @@ -175,9 +178,9 @@ def tokenize_function(examples):
"--output_dir", type=str, default="path/to/output", help="Output directory for the fine-tuned model"
)
parser.add_argument("--batch_size", type=int, default=1, help="Batch size")
parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs")
parser.add_argument("--lr", type=float, default=2e-4, help="Learning rate")
parser.add_argument("--cutoff_len", type=int, default=512, help="Cutoff length for tokenization")
parser.add_argument("--num_epochs", type=int, default=3, help="Number of training epochs")
parser.add_argument("--lr", type=float, default=7e-5, help="Learning rate")
parser.add_argument("--cutoff_len", type=int, default=1024, help="Cutoff length for tokenization")
parser.add_argument("--quantize", action="store_true", help="Use quantization")
parser.add_argument("--eval_step", type=int, default=10, help="Evaluation step interval")
parser.add_argument("--save_step", type=int, default=100, help="Save step interval")
Expand All @@ -188,6 +191,7 @@ def tokenize_function(examples):
"--lora_target_modules", type=str, default=None, help="Comma-separated list of target modules for LoRA"
)
parser.add_argument("--lorafa", action="store_true", help="Use LoRA-FA Optimizer")

args = parser.parse_args()

train_model(
Expand Down