Skip to content

Conversation

@siddartha-RE
Copy link
Contributor

@siddartha-RE siddartha-RE commented Aug 25, 2023

Enable use_cache support with Flash-Attention

As of flash-attn 2.1.0 the library now handles the case of q_len != kv_len so that with causal attention we
can use it even when past key-values are supplied without padding q. This optimizes it for the inference
use case so that this patch is now useful at both inference time and train time.

Checks

  • I've run format.sh to lint the changes in this PR.
  • I've included any doc changes needed.
  • I've made sure the relevant tests are passing (if applicable).

@merrymercy
Copy link
Member

merrymercy commented Aug 27, 2023

@siddartha-RE
Copy link
Contributor Author

Can we try to merge these two https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama2_flash_attn_monkey_patch.py?

Don't follow. This PR is an update to that file which takes advantage of fixed causal behavior in latest flash attn to use flash attn at inference.

@merrymercy merrymercy force-pushed the main branch 2 times, most recently from bf7aa7e to a81a04c Compare August 28, 2023 01:36
@merrymercy
Copy link
Member

Do we also need to apply these changes to llama2_flash_attn_monkey_patch.py?

@siddartha-RE
Copy link
Contributor Author

Do we also need to apply these changes to llama2_flash_attn_monkey_patch.py?

Thats the only file I have added it to. I submitted the initial implementation to deal with past KV through brute force padding of the q tensor (which of course has a significant perf penalty). Now that flash-attn 2.1.0 support the correct handling for causal case of q_len != kv_len I have update the implementation to remove the padding. To prevent people getting incorrect results I have added an assert on flash attention version to make sure the corrected behavior is available.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you store past_key_value in transposed layout? I fell it introduces some redundant memory movement.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very sorry for losing track of this question. The reason is there is code in the LlamaModel (in transformers impl) that assumes a particular memory layout when looking up past kv length. Thats why we have to store it in this order. Fortunately this is a zero copy operation. The transpose operation just reorganizes the tensor metadata but does not result in any memory movement.

I was just recently testing another repo that depends on this and hit this error so it would be great to have this merged so that other people also don't run into this.

Again sorry about not answering this sooner.

@merrymercy merrymercy merged this pull request into lm-sys:main Sep 29, 2023
@merrymercy
Copy link
Member

@siddartha-RE Thanks! It is merged.

@arvindsun
Copy link

Looks like the merge commit never made it to the main branch?

491818f

@siddartha-RE
Copy link
Contributor Author

Not sure what happened here. Possibly main branch was rewritten and the commit dropped. I will try raising the PR again.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants