Skip to content

[Feature Request] Add LoRA-FA to PEFT#2468

Merged
BenjaminBossan merged 22 commits intohuggingface:mainfrom
AaronZLT:lorafa
Apr 10, 2025
Merged

[Feature Request] Add LoRA-FA to PEFT#2468
BenjaminBossan merged 22 commits intohuggingface:mainfrom
AaronZLT:lorafa

Conversation

@AaronZLT
Copy link
Contributor

No description provided.

@AaronZLT
Copy link
Contributor Author

Issue ref to [https://github.com//issues/2469].

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for creating this PR to add the LoRA-FA optimizer to PEFT. This is a welcome addition. I only did a quick review for now, as I had some questions before going fully in depth. Please check my comments.

Before we can merge the PR, we also need to add a few things:

  1. unit tests: check the loraplus tests for inspiration
  2. docs: let's add a section to the optimizer docs
  3. nice to have: having a working example is great for users to get started quickly
  4. method comparison: we now have a framework for testing the performance of different PEFT methods, LoRA FA could be added there (I can work on this after merging the PR, it's not required to merge it)

AaronZLT and others added 3 commits April 1, 2025 00:16
Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
1. add unit test tests/test_lorafa.py
2. add docs in docs/source/developer_guides/lora.md -> #optimizers
3. a working example for fine-tuning meta-llama/Meta-Llama-3-8B on meta-math/MetaMathQA-40K using LoRA-FA optimizer
4. delete kwargs
5. ruff style
@AaronZLT AaronZLT requested a review from BenjaminBossan April 3, 2025 19:06
Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates, this is looking very good and should only require a few more changes before it's ready to merge. Please check my comments.

On top of this, I have some questions regarding the implementation. AFAICT you're not one of the paper authors, but maybe you can still answer them (or we can ping one of the authors):

  1. Regarding table 4 of the paper, I'm a bit surprised that LoRA memory requirements are often very close to full fine-tuning. Do you have any idea why that would be? I would expect a larger gap.

  2. Regarding the performance vs normal LoRA, I added LoRA-FA to the method comparison framework we have added to PEFT recently (code changes from #2479 required). Similar to what you have in your repo, this trains/evals on MetaMathQA/GSM8K, but only trains on 20000 samples to speed things up.

Here is the summary of for LoRA with rank 32, alpha 64, no dropout, no rslora, meta-llama/Llama-3.2-3B base model (no quantization, bf16):

metric vanilla LoRA LoRA FA
cuda memory max 22.3 GB 20.2 GB
cuda memory reserved avg 11.9 GB 11.1 GB
cuda memory reserved 99th percentile 17.7 GB 16.2 GB
train loss 0.607 0.651
test accuracy 47.8% 44.0%

So we find indeed a memory saving from LoRA-FA, though it's not as big as the paper suggests. On the other hand, the loss and accuracy are markedly worse. Do you have any idea why that is? Perhaps you can suggest better hyper-params from your experience (default hyper-params).

AaronZLT and others added 2 commits April 6, 2025 16:11
Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
@AaronZLT
Copy link
Contributor Author

AaronZLT commented Apr 6, 2025

Hi, @BenjaminBossan, thanks for the update!

Let me summarize the current issues and clarify some facts.

Author & Recent Progress

I am the first author of this paper (I won't mention names here due to some reasons). The current version, v1, on arXiv was released in 2023 and is outdated. Since then, we've conducted numerous experiments, proofs, and improvements to the LoRA-FA algorithm. However, the arXiv version hasn't been updated yet; we plan to update it next week. Regarding the progress of open-sourcing: this is the first PR submitted to PEFT, and the current open-source repository AaronZLT/lorafa is just an initial version. We plan to update the detailed experimental part of the open-source repo after LoRA-FA is officially merged into PEFT. Therefore, the current LoRA-FA, including the algorithm and experimental results, is quite different from the arXiv v1 version, so we can set v1 aside for now.

Why does LoRA consume similar memory to Full-FT?

This is actually a counterintuitive phenomenon. In practice, compared to Full-FT, LoRA only reduces the gradient and optimizer state, while other components, including parameters and activations, may even increase. The increase in parameters is due to the LoRA adapter, but this part is negligible. The focus is on activations. Since LoRA A is trainable, to calculate its gradient, we need the activation, which is X. LoRA A shares one activation X with the non-trainable base model weight W. Additionally, LoRA B also stores activation XA, so for a LoRA linear layer, it requires X+XA activation memory, whereas Full-FT only has to save X. These activations increase dramatically with the input size.

LoRA Memory vs. LoRA-FA Memory

The result you've provided is expected. Since LoRA-FA mainly reduces activation memory, it only shows a significant difference when the input size is large. Here, we provide an example with pure-bf16 training: batch size = 8, sequence length = 1024, model = Llama-2-7b-chat-hf. LoRA-FA requires 36GB of memory to store activations, which allows it to run successfully on an 80GB GPU, but LoRA requires at least 60+GB of memory for activations, leading to an Out of Memory (OOM) error.

Accuracy and Loss

It's great to see a framework that can compare all PEFT methods, and I appreciate your effort! I am very willing to maintain it. But let me first address the current issues.

The results you've provided show that the accuracy of both LoRA and LoRA-FA is 47.8% and 44.0%, respectively, when fine-tuning meta-llama/Llama-3.2-3B on metamath and evaluating on GSM8K. However, AFAIK, the vanilla Llama-3.2-3B-Instruct (let's focus on the instruct version of Llama models; I'll explain why later) achieves 65.2% on 0-shot GSM8K. We've tested and validated this result before. Therefore, if the fine-tuned performance is lower than the vanilla result (fine-tuned 47.8% vs. vanilla 65.2%), it makes no sense, let alone comparing them horizontally.

In fact, verifying the performance of fine-tuning methods is a challenging task, and we have made some progress, which I'll share here. Let's start discussing how to fine-tune Llama-3.2-3B-Instruct and achieve improvements on 0-shot GSM8K.

  1. Dataset Selection: We recommend using the meta-math/MetaMathQA dataset. However, to save time, we can choose the meta-math/MetaMathQA-40K dataset instead of randomly sampling from MetaMathQA. Not sampling is for two reasons: better convergence and ensuring the same dataset is used for training across different methods.

  2. Dataset Processing: In the era of Llama-2, fine-tuning was relatively easy; simply providing some instruction sequences could achieve good results. But now, this situation has completely changed. Because Llama-2 models inherently have poor performance, they tend to converge easily. Conversely, Llama-3 typically performs well, making improvements difficult. Check this out: Fine tuning LLaMA 3 is a total disaster! : r/LocalLLaMA. It's proven that not only should we use the instruct version of models, but it's also crucial to format the training dataset to match the model's specified instruct version. We do this in our toolkit here: llm-toolkit/llmtoolkit/dataset.py at main · AaronZLT/llm-toolkit. Also, don't forget to apply the same formatting to the evaluation dataset, which is GSM8K here. By the way, careful adjust the sequence length because short sequence may loss tokens. The sequence length should at least cover 99% of the training set and evaluation set.

  3. Hyperparameter Search: Using LoRA as an example, the most important parameters are rank and learning rate. In practice, tuning these parameters isn't difficult.

We also conducted LoRA-FA fine-tuning experiments on Llama-3.2-3B-Instruct. Here are the specific parameters:

# LoRA-FA with global batch size 32
--bf16 True\
--tf32 True\
--num_train_epochs 3\
--learning_rate 9e-5\
--source_max_len 512\
--target_max_len 512\
--peft lorafa\
--lora_rank 128\
--lora_scale 2.0

Ultimately, we achieved 71.4% accuracy on 0-shot GSM8K. I'll find time to upload this experiment as a best practice to lorafa.

About Method Comparison

I haven't closely examined the method comparison, but based on the fine-tuning results, it seems that some improvements are needed. We can utilize the dataset func and evaluate func in our toolkit, but let's focus on this PR for now.

@BenjaminBossan
Copy link
Member

Thanks for the detailed response. First of all, let me clarify that the reason why I reported the results above was not to criticize LoRA-FA. My goal was to ensure that it runs successfully and to verify with you that the results are in line with expectations. The question whether to accept the PR does not depend on those results.

I am the first author of this paper

Ah, thanks, good to know.

Since then, we've conducted numerous experiments, proofs, and improvements to the LoRA-FA algorithm

Okay, just to clarify: The code in this PR is based on these new findings?

In practice, compared to Full-FT, LoRA only reduces the gradient and optimizer state, while other components, including parameters and activations, may even increase.

This is clear to me, the memory gains will strongly depend on the size of the base model and on the data. When the base model is small and the data contains long sequences, the memory reduction through LoRA will be small. I guess this applies to Roberta and T5-small and thus explains the results in Table 4.

Here, we provide an example with pure-bf16 training: batch size = 8, sequence length = 1024, model = Llama-2-7b-chat-hf. LoRA-FA requires 36GB of memory to store activations, which allows it to run successfully on an 80GB GPU, but LoRA requires at least 60+GB of memory for activations, leading to an Out of Memory (OOM) error.

Nice. I think the memory savings should be highlighted much more in the docs for LoRA-FA, maybe you can even include this example.

However, AFAIK, the vanilla Llama-3.2-3B-Instruct (let's focus on the instruct version of Llama models; I'll explain why later) achieves 65.2% on 0-shot GSM8K. We've tested and validated this result before. Therefore, if the fine-tuned performance is lower than the vanilla result (fine-tuned 47.8% vs. vanilla 65.2%), it makes no sense, let alone comparing them horizontally.

To explain: The purpose of the method comparison suite is not to achieve the best possible results on the benchmark. Instead, the main goal is to allow to compare the different PEFT methods using a more or less realistic task that can run in reasonable time and on 24 GB of VRAM. We purposefully chose the non-instruct base model so that we have room for the model to learn. The goal is thus different from a typical paper, where we want to achieve the best possible score.

We could always achieve better scores, e.g. by training on more tokens or by choosing a stronger base model. But this will most likely benefit each PEFT method equally, so it would not help to make a better choice for the PEFT method but it would make experimentation slower.

  1. However, to save time, we can choose the meta-math/MetaMathQA-40K dataset instead of randomly sampling from MetaMathQA. Not sampling is for two reasons: better convergence and ensuring the same dataset is used for training across different methods.

Note that in our script, we always pick the same samples and train on only a subset of all data to save time.

2. By the way, careful adjust the sequence length because short sequence may loss tokens. The sequence length should at least cover 99% of the training set and evaluation set.

We remove the 6% longest sequences of MetaMathQA, leaving 94% for training. This was necessary because some PEFT methods would require too much memory to run on 24 GB VRAM otherwise.

We also conducted LoRA-FA fine-tuning experiments on Llama-3.2-3B-Instruct. Here are the specific parameters:

I'll re-run the experiments with your settings and will report the results later.

@AaronZLT
Copy link
Contributor Author

AaronZLT commented Apr 7, 2025

Oh, thanks for the explanation, I didn't quite understand it initially. Now I get it. : )
Yes, the code in this pull request is based on these new findings. I haven't thoroughly tested the differences between the instruct and non-instruct versions of the model during fine-tuning. While conducting these experiments, I aimed to provide the best conditions for PEFT to prevent other factors from affecting convergence. Indeed, doing it this way is very consuming.

@BenjaminBossan
Copy link
Member

Okay, then I think we could clarify every open question, great. Please ping me if the PR is ready for review.

Here are the results from running LoRA FA with the settings you provided above. I think the most notable change is the higher rank, which affords the model a higher learning capacity. I'm pleasantly surprised that despite the rank being 4x higher, there is hardly any increase in memory requirements -- in fact it's still lower than vanilla LoRA. At the same time, the performance of vanilla LoRA could be matched. There is a marked increase in runtime, but this will still be a worthwhile tradeoff for many users.

metric vanilla LoRA rank 32 LoRA FA rank 32 LoRA FA rank 128
cuda memory max 22.3 GB 20.2 GB 20.4 GB
cuda memory reserved avg 11.9 GB 11.1 GB 11.3 GB
cuda memory reserved 99th percentile 17.7 GB 16.2 GB 16.5 GB
train loss 0.607 0.651 0.608
test accuracy GSM8K 47.8% 44.0% 48.3%
total time (sec) 1847 1804 2403

@AaronZLT
Copy link
Contributor Author

AaronZLT commented Apr 7, 2025

Great, we've addressed all the open questions. The PR is now ready for review, so please feel free to take a look.

@AaronZLT AaronZLT requested a review from BenjaminBossan April 7, 2025 16:49
Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the updates, we are almost finished. There are only a few smaller comments/questions left. Please check.

On top of that, could you please run:

make style
ruff format examples/lorafa_finetune/

AA_T + delta * torch.eye(A.shape[0]).to(A.device)
)

with autocast(dtype=torch.bfloat16):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my understanding, could you please explain why bf16? Is it just that you experimentally found that it is a good dtype for precision vs memory?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Compute grad_B_orin @ AA_T_inv in bf16 may speedup (not significant). I will add a bf16 check here.

@AaronZLT AaronZLT requested a review from BenjaminBossan April 9, 2025 02:39
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@BenjaminBossan
Copy link
Member

@AaronZLT Could you please merge with or rebase on the latest main, then run make style again?

@AaronZLT
Copy link
Contributor Author

AaronZLT commented Apr 9, 2025

@AaronZLT Could you please merge with or rebase on the latest main, then run make style again?

make style ok now

@AaronZLT
Copy link
Contributor Author

AaronZLT commented Apr 9, 2025

It seems windows in CI has some network issues. Meanwhile, add a check to avoid Torch not compiled with CUDA enabled error on macos.

@BenjaminBossan
Copy link
Member

It seems windows in CI has some network issues.

Just some network flakiness, don't worry about, I'll just restart the runs.

Meanwhile, add a check to avoid Torch not compiled with CUDA enabled error on macos.

Makes sense.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work, everything LGTM. This is a very promising addition to PEFT. I'll merge the PR when the CI is green.

@BenjaminBossan BenjaminBossan merged commit 0c2bdbb into huggingface:main Apr 10, 2025
14 checks passed
@AaronZLT AaronZLT deleted the lorafa branch April 23, 2025 08:50
Guy-Bilitski pushed a commit to Guy-Bilitski/peft that referenced this pull request May 13, 2025
Adds LoRA with frozen A (LoRA-FA) to PEFT.

Paper: https://arxiv.org/abs/2308.03303
efraimdahl pushed a commit to efraimdahl/peft that referenced this pull request Jul 12, 2025
Adds LoRA with frozen A (LoRA-FA) to PEFT.

Paper: https://arxiv.org/abs/2308.03303
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants