-
Notifications
You must be signed in to change notification settings - Fork 2.2k
[Feature Request] Add LoRA-FA to PEFT #2468
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 1 commit
Commits
Show all changes
22 commits
Select commit
Hold shift + click to select a range
b43bd99
adding lorafa
AaronZLT fa733f0
Update src/peft/optimizers/lorafa.py
AaronZLT c617cc0
Merge branch 'huggingface:main' into lorafa
AaronZLT 4969451
In this commit, we
AaronZLT 51daa2c
Update examples/lorafa_finetune/README.md
AaronZLT 109ed82
Update examples/lorafa_finetune/README.md
AaronZLT 1fb50a8
Merge branch 'huggingface:main' into lorafa
AaronZLT 676eb21
rename learning_rate in args to lr
AaronZLT a228d88
delete is_same, add description closure in LoraFAOptimizer.
AaronZLT 1782769
delete nf4 part in LoRA-FA example readme.md
AaronZLT 608549f
more precise with best practices
AaronZLT a1819c1
refactor limitations
AaronZLT 2dc19ee
update copyright
AaronZLT 113152f
fix conflict
AaronZLT 33fd485
fix test_LoraFAOptimizer_step in test_lorafa.py
AaronZLT 8035ddc
add introduction to LoRA-FA
AaronZLT 1058683
ruff format
AaronZLT 9d1a92b
refactor lorafa_finetuning.py in examples
AaronZLT eba6b9c
add bf16 check in lorafa optimizer step
AaronZLT eb5da9b
Merge branch 'huggingface:main' into lorafa
AaronZLT 839a2b4
rebase and make style
AaronZLT 32e5dc0
check cuda to avoid Torch not compiled with CUDA enabled error
AaronZLT File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
In this commit, we
1. add unit test tests/test_lorafa.py 2. add docs in docs/source/developer_guides/lora.md -> #optimizers 3. a working example for fine-tuning meta-llama/Meta-Llama-3-8B on meta-math/MetaMathQA-40K using LoRA-FA optimizer 4. delete kwargs 5. ruff style
- Loading branch information
commit 496945178435b055a2d433e5acae812f93660504
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,93 @@ | ||
| # LoRA-FA: Memory-efficient Low-rank Adaptation for Large Language Models Fine-tuning | ||
|
|
||
| ## Introduction | ||
|
|
||
| [LoRA-FA](https://arxiv.org/abs/2308.03303) is a noval Parameter-efficient Fine-tuning method, which freeze the projection down layer (matrix A) during LoRA training process and thus lead to less GPU memory consumption by eliminating the needing for storage the activations of input tensors (X). Futhermore, LoRA-FA Narrows the gap between the update amount of pre-trained weights when using the low-rank fine-tuning method and the full fine-tuning method. In conclusion, LoRA-FA reduce the memory consumption and lead to suprior performance compared with vanilla LoRA. | ||
AaronZLT marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| ## Quick start | ||
|
|
||
| ```python | ||
| import torch | ||
| from peft import LoraConfig, get_peft_model | ||
| from peft.optimizers import create_lorafa_optimizer | ||
| from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer | ||
| from datasets import load_dataset | ||
|
|
||
| model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") | ||
| tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") | ||
| dataset = load_dataset("timdettmers/openassistant-guanaco", split="train") | ||
|
|
||
| lora_rank = 16 | ||
| lora_alpha = 32 | ||
|
|
||
| lora_config = LoraConfig( | ||
| r=lora_rank, | ||
| lora_alpha=lora_alpha, | ||
| bias="none", | ||
| ) | ||
| peft_model = get_peft_model(model, lora_config) | ||
| optimizer = create_lorafa_optimizer( | ||
| model=peft_model, | ||
| r=lora_rank, | ||
| lora_alpha=lora_alpha, | ||
| learning_rate=7e-5, | ||
| ) | ||
| # you can also use scheduler, we recommend get_cosine_schedule_with_warmup from transformers | ||
| # for better model performance | ||
| scheduler = None | ||
|
|
||
| trainer = transformers.Trainer( | ||
| model=peft_model, | ||
| train_dataset=dataset, | ||
| dataset_text_field="text", | ||
| max_seq_length=2048, | ||
| tokenizer=tokenizer, | ||
| optimizers=(optimizer, None), | ||
| ) | ||
| trainer.train() | ||
| peft_model.save_pretrained("lorafa-llama-3-8b-inst") | ||
| ``` | ||
|
|
||
| The only change in your code is to pass the LoRA-FA optimizer to the trainer (if training with trainer). Do not forget `from peft.optimizers import create_lorafa_optimizer`! | ||
|
|
||
| In this dir, we also provide you a very toy example for fine-tuning with LoRA-FA optimizer. Run the finetuning script simply by running: | ||
AaronZLT marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| ```bash | ||
AaronZLT marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| accelerate launch examples/lorafa_finetuning/lorafa_finetuning.py --base_model_name_or_path meta-llama/Meta-Llama-3-8B --dataset_name_or_path meta-math/MetaMathQA-40K --lorafa | ||
| ``` | ||
|
|
||
| This 👆🏻 by default will load the model in peft set up with LoRA config, and train the model with LoRA-FA optimizer. The `accelerate launch` will automatically configure single-GPU or multi-GPU for you. | ||
|
|
||
| LoRA-FA also supports quantization. To use bitsandbytes NF4 4-bit quantization try: | ||
|
|
||
| ```bash | ||
| accelerate launch examples/lorafa_finetuning/lorafa_finetuning.py --base_model_name_or_path meta-llama/Meta-Llama-3-8B --dataset_name_or_path meta-math/MetaMathQA-40K --lorafa --quantize | ||
| ``` | ||
|
|
||
| ## Use the model from 🤗 | ||
AaronZLT marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| You can load and use the model as any other 🤗 models. | ||
| ```python | ||
| from transformers import AutoModel | ||
| model = AutoModel.from_pretrained("meta-llama/Llama-2-7b-chat-hf") | ||
| ``` | ||
|
|
||
| ### Best practice in fine-tuning Llama on Metamath using LoRA-FA | ||
|
|
||
| In LoRA-FA's fine-tuning, we recommend to use a larger lora rank such as 64 or 128 (max). LoRA-FA can achieve 57.3% on GSM8K, just by fine-tuning Llama-2-7b-chat-hf on meta-math/MetaMathQA-40K for 3 epochs! For the best practices you can just check [LoRA-FA examples](https://github.com/AaronZLT/lorafa). | ||
AaronZLT marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| ## LoRA-FA vs. LoRA | ||
|
|
||
| Despite its advantages, LoRA-FA remains inherently constrained by its low-rank approximation nature and potential catastrophic forgetting. Besides, since LoRA-FA has less trainable parameter than LoRA, LoRA-FA may converge slower than LoRA and requires larger lora rank and fine-grained hyper-parameter (mainly learning rate) search. Addressing these limitations, particularly approximation accuracy and forgetting phenomena, represents a promising direction for future work. | ||
AaronZLT marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| ## Citation | ||
| ``` | ||
| @misc{zhang2023lorafamemoryefficientlowrankadaptation, | ||
| title={LoRA-FA: Memory-efficient Low-rank Adaptation for Large Language Models Fine-tuning}, | ||
| author={Longteng Zhang and Lin Zhang and Shaohuai Shi and Xiaowen Chu and Bo Li}, | ||
| year={2023}, | ||
| eprint={2308.03303}, | ||
| archivePrefix={arXiv}, | ||
| primaryClass={cs.CL}, | ||
| url={https://arxiv.org/abs/2308.03303}, | ||
| } | ||
| ``` | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,195 @@ | ||
| import os | ||
AaronZLT marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| import torch | ||
| from datasets import load_dataset | ||
| from transformers import ( | ||
| AutoModelForCausalLM, | ||
| AutoTokenizer, | ||
| BitsAndBytesConfig, | ||
| DataCollatorForLanguageModeling, | ||
| Trainer, | ||
| TrainingArguments, | ||
| ) | ||
|
|
||
| from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training | ||
| from peft.optimizers import create_lorafa_optimizer | ||
|
|
||
|
|
||
| def train_model( | ||
| base_model_name_or_path: str, | ||
| dataset_name_or_path: str, | ||
| output_dir: str, | ||
| batch_size: int, | ||
| num_epochs: int, | ||
| learning_rate: float, | ||
| cutoff_len: int, | ||
| quantize: bool, | ||
| eval_step: int, | ||
| save_step: int, | ||
| lora_rank: int, | ||
| lora_alpha: int, | ||
| lora_dropout: float, | ||
| lora_target_modules: str, | ||
AaronZLT marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| lorafa: bool, | ||
| ): | ||
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | ||
|
|
||
| print("In this example we only spport GPU training.") | ||
|
|
||
| # load tokenizer | ||
| tokenizer = AutoTokenizer.from_pretrained(base_model_name_or_path) | ||
|
|
||
| # QLoRA-FA (quantized LoRA-FA): IF YOU WANNA QUANTIZE THE MODEL | ||
| if quantize: | ||
| model = AutoModelForCausalLM.from_pretrained( | ||
| base_model_name_or_path, | ||
| quantization_config=BitsAndBytesConfig( | ||
| load_in_4bit=True, | ||
| bnb_4bit_compute_dtype=( | ||
| torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16 | ||
| ), | ||
| bnb_4bit_use_double_quant=False, | ||
| bnb_4bit_quant_type="nf4", | ||
| ), | ||
| torch_dtype=torch.bfloat16, | ||
| ) | ||
| # setup for quantized training | ||
| model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True) | ||
| else: | ||
| model = AutoModelForCausalLM.from_pretrained(base_model_name_or_path, torch_dtype=torch.bfloat16) | ||
AaronZLT marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| # LoRA config for the PEFT model | ||
| lora_config = LoraConfig( | ||
| r=lora_rank, | ||
| lora_alpha=lora_alpha, | ||
| target_modules=( | ||
| lora_target_modules.split(",") | ||
AaronZLT marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if lora_target_modules | ||
| else ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] | ||
| ), | ||
| lora_dropout=lora_dropout, | ||
| bias="none", | ||
| ) | ||
|
|
||
| # get the peft model with LoRA config | ||
| model = get_peft_model(model, lora_config) | ||
|
|
||
| tokenizer.pad_token = tokenizer.eos_token | ||
|
|
||
| # Load the dataset | ||
| dataset = load_dataset(dataset_name_or_path) | ||
|
|
||
| def tokenize_function(examples): | ||
| inputs = tokenizer(examples["query"], padding="max_length", truncation=True, max_length=cutoff_len) | ||
| outputs = tokenizer(examples["response"], padding="max_length", truncation=True, max_length=cutoff_len) | ||
| inputs["labels"] = outputs["input_ids"].copy() | ||
| return inputs | ||
|
|
||
| # Tokenize the dataset and prepare for training | ||
| tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=dataset["train"].column_names) | ||
| dataset = tokenized_datasets["train"].train_test_split(test_size=0.1, shuffle=True, seed=42) | ||
| train_dataset = dataset["train"] | ||
| eval_dataset = dataset["test"] | ||
|
|
||
| # Data collator to dynamically pad the batched examples | ||
| data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False) | ||
|
|
||
| # Define training arguments | ||
| training_args = TrainingArguments( | ||
| output_dir=output_dir, | ||
| num_train_epochs=num_epochs, | ||
| per_device_train_batch_size=batch_size, | ||
| per_device_eval_batch_size=batch_size, | ||
| warmup_steps=100, | ||
| weight_decay=0.01, | ||
| logging_dir="./logs", | ||
| logging_steps=eval_step, | ||
| save_steps=save_step, | ||
| save_total_limit=2, | ||
| gradient_accumulation_steps=1, | ||
| fp16=True, | ||
AaronZLT marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| learning_rate=learning_rate, | ||
| ) | ||
|
|
||
| # Clear CUDA cache to free memory | ||
| torch.cuda.empty_cache() | ||
|
|
||
| # Here we initialize the LoRA-FA Optimizer | ||
| # After this, all adapter A will be fixed, only adapter B will be trainable | ||
| if lorafa: | ||
| optimizer = create_lorafa_optimizer( | ||
| model=model, r=lora_rank, lora_alpha=lora_alpha, learning_rate=learning_rate | ||
| ) | ||
| trainer = Trainer( | ||
| model=model, | ||
| args=training_args, | ||
| train_dataset=train_dataset, | ||
| eval_dataset=eval_dataset, | ||
| data_collator=data_collator, | ||
| optimizers=(optimizer, None), | ||
| ) | ||
| else: | ||
| trainer = Trainer( | ||
| model=model, | ||
| args=training_args, | ||
| train_dataset=train_dataset, | ||
| eval_dataset=eval_dataset, | ||
| data_collator=data_collator, | ||
| ) | ||
|
|
||
| # Start model training | ||
| trainer.train() | ||
|
|
||
| # Save the model and tokenizer locally | ||
| model.save_pretrained(output_dir) | ||
| tokenizer.save_pretrained(output_dir) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| import argparse | ||
|
|
||
| parser = argparse.ArgumentParser(description="Fine-tune Meta-Llama-3-8B-Instruct with LoRA-FA and PEFT") | ||
| parser.add_argument( | ||
| "--base_model_name_or_path", | ||
| type=str, | ||
| default="meta-llama/Meta-Llama-3-8B-Instruct", | ||
| help="Base model name or path", | ||
| ) | ||
| parser.add_argument( | ||
| "--dataset_name_or_path", type=str, default="meta-math/MetaMathQA-40K", help="Dataset name or path" | ||
| ) | ||
| parser.add_argument( | ||
| "--output_dir", type=str, default="path/to/output", help="Output directory for the fine-tuned model" | ||
AaronZLT marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ) | ||
| parser.add_argument("--batch_size", type=int, default=1, help="Batch size") | ||
| parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs") | ||
| parser.add_argument("--learning_rate", type=float, default=2e-4, help="Learning rate") | ||
| parser.add_argument("--cutoff_len", type=int, default=512, help="Cutoff length for tokenization") | ||
| parser.add_argument("--quantize", action="store_true", help="Use quantization") | ||
| parser.add_argument("--eval_step", type=int, default=10, help="Evaluation step interval") | ||
| parser.add_argument("--save_step", type=int, default=100, help="Save step interval") | ||
| parser.add_argument("--lora_rank", type=int, default=16, help="LoRA rank") | ||
| parser.add_argument("--lora_alpha", type=int, default=32, help="LoRA alpha") | ||
| parser.add_argument("--lora_dropout", type=float, default=0.05, help="LoRA dropout rate") | ||
| parser.add_argument( | ||
| "--lora_target_modules", type=str, default=None, help="Comma-separated list of target modules for LoRA" | ||
| ) | ||
| parser.add_argument("--lorafa", action="store_true", help="Use LoRA-FA Optimizer") | ||
| args = parser.parse_args() | ||
|
|
||
| train_model( | ||
| base_model_name_or_path=args.base_model_name_or_path, | ||
| dataset_name_or_path=args.dataset_name_or_path, | ||
| output_dir=args.output_dir, | ||
| batch_size=args.batch_size, | ||
| num_epochs=args.num_epochs, | ||
| learning_rate=args.learning_rate, | ||
| cutoff_len=args.cutoff_len, | ||
| quantize=args.quantize, | ||
| eval_step=args.eval_step, | ||
| save_step=args.save_step, | ||
| lora_rank=args.lora_rank, | ||
| lora_alpha=args.lora_alpha, | ||
| lora_dropout=args.lora_dropout, | ||
| lora_target_modules=args.lora_target_modules, | ||
| lorafa=args.lorafa, | ||
| ) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.