-
Notifications
You must be signed in to change notification settings - Fork 4.8k
Optimize for proper flash attn causal handling #2315
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Can we try to merge these two https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama2_flash_attn_monkey_patch.py? |
Don't follow. This PR is an update to that file which takes advantage of fixed causal behavior in latest flash attn to use flash attn at inference. |
bf7aa7e to
a81a04c
Compare
|
Do we also need to apply these changes to llama2_flash_attn_monkey_patch.py? |
Thats the only file I have added it to. I submitted the initial implementation to deal with past KV through brute force padding of the q tensor (which of course has a significant perf penalty). Now that flash-attn 2.1.0 support the correct handling for causal case of q_len != kv_len I have update the implementation to remove the padding. To prevent people getting incorrect results I have added an assert on flash attention version to make sure the corrected behavior is available. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you store past_key_value in transposed layout? I fell it introduces some redundant memory movement.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very sorry for losing track of this question. The reason is there is code in the LlamaModel (in transformers impl) that assumes a particular memory layout when looking up past kv length. Thats why we have to store it in this order. Fortunately this is a zero copy operation. The transpose operation just reorganizes the tensor metadata but does not result in any memory movement.
I was just recently testing another repo that depends on this and hit this error so it would be great to have this merged so that other people also don't run into this.
Again sorry about not answering this sooner.
14c0818 to
e4758da
Compare
|
@siddartha-RE Thanks! It is merged. |
|
Looks like the merge commit never made it to the main branch? |
|
Not sure what happened here. Possibly main branch was rewritten and the commit dropped. I will try raising the PR again. |
Enable use_cache support with Flash-Attention
As of flash-attn 2.1.0 the library now handles the case of q_len != kv_len so that with causal attention we
can use it even when past key-values are supplied without padding q. This optimizes it for the inference
use case so that this patch is now useful at both inference time and train time.
Checks
format.shto lint the changes in this PR.