[Feature Request] Add LoRA-FA to PEFT#2468
Conversation
|
Issue ref to [https://github.com//issues/2469]. |
BenjaminBossan
left a comment
There was a problem hiding this comment.
Thanks a lot for creating this PR to add the LoRA-FA optimizer to PEFT. This is a welcome addition. I only did a quick review for now, as I had some questions before going fully in depth. Please check my comments.
Before we can merge the PR, we also need to add a few things:
- unit tests: check the loraplus tests for inspiration
- docs: let's add a section to the optimizer docs
- nice to have: having a working example is great for users to get started quickly
- method comparison: we now have a framework for testing the performance of different PEFT methods, LoRA FA could be added there (I can work on this after merging the PR, it's not required to merge it)
Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
1. add unit test tests/test_lorafa.py 2. add docs in docs/source/developer_guides/lora.md -> #optimizers 3. a working example for fine-tuning meta-llama/Meta-Llama-3-8B on meta-math/MetaMathQA-40K using LoRA-FA optimizer 4. delete kwargs 5. ruff style
There was a problem hiding this comment.
Thanks for the updates, this is looking very good and should only require a few more changes before it's ready to merge. Please check my comments.
On top of this, I have some questions regarding the implementation. AFAICT you're not one of the paper authors, but maybe you can still answer them (or we can ping one of the authors):
-
Regarding table 4 of the paper, I'm a bit surprised that LoRA memory requirements are often very close to full fine-tuning. Do you have any idea why that would be? I would expect a larger gap.
-
Regarding the performance vs normal LoRA, I added LoRA-FA to the method comparison framework we have added to PEFT recently (code changes from #2479 required). Similar to what you have in your repo, this trains/evals on MetaMathQA/GSM8K, but only trains on 20000 samples to speed things up.
Here is the summary of for LoRA with rank 32, alpha 64, no dropout, no rslora, meta-llama/Llama-3.2-3B base model (no quantization, bf16):
| metric | vanilla LoRA | LoRA FA |
|---|---|---|
| cuda memory max | 22.3 GB | 20.2 GB |
| cuda memory reserved avg | 11.9 GB | 11.1 GB |
| cuda memory reserved 99th percentile | 17.7 GB | 16.2 GB |
| train loss | 0.607 | 0.651 |
| test accuracy | 47.8% | 44.0% |
So we find indeed a memory saving from LoRA-FA, though it's not as big as the paper suggests. On the other hand, the loss and accuracy are markedly worse. Do you have any idea why that is? Perhaps you can suggest better hyper-params from your experience (default hyper-params).
Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
|
Hi, @BenjaminBossan, thanks for the update! Let me summarize the current issues and clarify some facts.
I am the first author of this paper (I won't mention names here due to some reasons). The current version, v1, on arXiv was released in 2023 and is outdated. Since then, we've conducted numerous experiments, proofs, and improvements to the LoRA-FA algorithm. However, the arXiv version hasn't been updated yet; we plan to update it next week. Regarding the progress of open-sourcing: this is the first PR submitted to PEFT, and the current open-source repository AaronZLT/lorafa is just an initial version. We plan to update the detailed experimental part of the open-source repo after LoRA-FA is officially merged into PEFT. Therefore, the current LoRA-FA, including the algorithm and experimental results, is quite different from the arXiv v1 version, so we can set v1 aside for now.
This is actually a counterintuitive phenomenon. In practice, compared to Full-FT, LoRA only reduces the gradient and optimizer state, while other components, including parameters and activations, may even increase. The increase in parameters is due to the LoRA adapter, but this part is negligible. The focus is on activations. Since LoRA A is trainable, to calculate its gradient, we need the activation, which is X. LoRA A shares one activation X with the non-trainable base model weight W. Additionally, LoRA B also stores activation XA, so for a LoRA linear layer, it requires X+XA activation memory, whereas Full-FT only has to save X. These activations increase dramatically with the input size.
The result you've provided is expected. Since LoRA-FA mainly reduces activation memory, it only shows a significant difference when the input size is large. Here, we provide an example with pure-bf16 training: batch size = 8, sequence length = 1024, model = Llama-2-7b-chat-hf. LoRA-FA requires 36GB of memory to store activations, which allows it to run successfully on an 80GB GPU, but LoRA requires at least 60+GB of memory for activations, leading to an Out of Memory (OOM) error.
It's great to see a framework that can compare all PEFT methods, and I appreciate your effort! I am very willing to maintain it. But let me first address the current issues. The results you've provided show that the accuracy of both LoRA and LoRA-FA is 47.8% and 44.0%, respectively, when fine-tuning meta-llama/Llama-3.2-3B on metamath and evaluating on GSM8K. However, AFAIK, the vanilla Llama-3.2-3B-Instruct (let's focus on the instruct version of Llama models; I'll explain why later) achieves 65.2% on 0-shot GSM8K. We've tested and validated this result before. Therefore, if the fine-tuned performance is lower than the vanilla result (fine-tuned 47.8% vs. vanilla 65.2%), it makes no sense, let alone comparing them horizontally. In fact, verifying the performance of fine-tuning methods is a challenging task, and we have made some progress, which I'll share here. Let's start discussing how to fine-tune Llama-3.2-3B-Instruct and achieve improvements on 0-shot GSM8K.
We also conducted LoRA-FA fine-tuning experiments on Llama-3.2-3B-Instruct. Here are the specific parameters: Ultimately, we achieved 71.4% accuracy on 0-shot GSM8K. I'll find time to upload this experiment as a best practice to lorafa.
I haven't closely examined the method comparison, but based on the fine-tuning results, it seems that some improvements are needed. We can utilize the dataset func and evaluate func in our toolkit, but let's focus on this PR for now. |
|
Thanks for the detailed response. First of all, let me clarify that the reason why I reported the results above was not to criticize LoRA-FA. My goal was to ensure that it runs successfully and to verify with you that the results are in line with expectations. The question whether to accept the PR does not depend on those results.
Ah, thanks, good to know.
Okay, just to clarify: The code in this PR is based on these new findings?
This is clear to me, the memory gains will strongly depend on the size of the base model and on the data. When the base model is small and the data contains long sequences, the memory reduction through LoRA will be small. I guess this applies to Roberta and T5-small and thus explains the results in Table 4.
Nice. I think the memory savings should be highlighted much more in the docs for LoRA-FA, maybe you can even include this example.
To explain: The purpose of the method comparison suite is not to achieve the best possible results on the benchmark. Instead, the main goal is to allow to compare the different PEFT methods using a more or less realistic task that can run in reasonable time and on 24 GB of VRAM. We purposefully chose the non-instruct base model so that we have room for the model to learn. The goal is thus different from a typical paper, where we want to achieve the best possible score. We could always achieve better scores, e.g. by training on more tokens or by choosing a stronger base model. But this will most likely benefit each PEFT method equally, so it would not help to make a better choice for the PEFT method but it would make experimentation slower.
Note that in our script, we always pick the same samples and train on only a subset of all data to save time.
We remove the 6% longest sequences of MetaMathQA, leaving 94% for training. This was necessary because some PEFT methods would require too much memory to run on 24 GB VRAM otherwise.
I'll re-run the experiments with your settings and will report the results later. |
|
Oh, thanks for the explanation, I didn't quite understand it initially. Now I get it. : ) |
|
Okay, then I think we could clarify every open question, great. Please ping me if the PR is ready for review. Here are the results from running LoRA FA with the settings you provided above. I think the most notable change is the higher rank, which affords the model a higher learning capacity. I'm pleasantly surprised that despite the rank being 4x higher, there is hardly any increase in memory requirements -- in fact it's still lower than vanilla LoRA. At the same time, the performance of vanilla LoRA could be matched. There is a marked increase in runtime, but this will still be a worthwhile tradeoff for many users.
|
|
Great, we've addressed all the open questions. The PR is now ready for review, so please feel free to take a look. |
BenjaminBossan
left a comment
There was a problem hiding this comment.
Thanks a lot for the updates, we are almost finished. There are only a few smaller comments/questions left. Please check.
On top of that, could you please run:
make style
ruff format examples/lorafa_finetune/
src/peft/optimizers/lorafa.py
Outdated
| AA_T + delta * torch.eye(A.shape[0]).to(A.device) | ||
| ) | ||
|
|
||
| with autocast(dtype=torch.bfloat16): |
There was a problem hiding this comment.
For my understanding, could you please explain why bf16? Is it just that you experimentally found that it is a good dtype for precision vs memory?
There was a problem hiding this comment.
Compute grad_B_orin @ AA_T_inv in bf16 may speedup (not significant). I will add a bf16 check here.
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
@AaronZLT Could you please merge with or rebase on the latest main, then run |
make style ok now |
|
It seems windows in CI has some network issues. Meanwhile, add a check to avoid |
Just some network flakiness, don't worry about, I'll just restart the runs.
Makes sense. |
BenjaminBossan
left a comment
There was a problem hiding this comment.
Great work, everything LGTM. This is a very promising addition to PEFT. I'll merge the PR when the CI is green.
Adds LoRA with frozen A (LoRA-FA) to PEFT. Paper: https://arxiv.org/abs/2308.03303
Adds LoRA with frozen A (LoRA-FA) to PEFT. Paper: https://arxiv.org/abs/2308.03303
No description provided.