Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 16 additions & 37 deletions fastchat/train/llama2_flash_attn_monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Optional, Tuple

import torch
from flash_attn import __version__ as flash_attn_version
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.flash_attn_interface import (
flash_attn_func,
Expand Down Expand Up @@ -36,6 +37,7 @@ def forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
warnings.warn(
Expand All @@ -58,50 +60,31 @@ def forward(
kv_seq_len = k.shape[1]
past_kv_len = 0
if past_key_value is not None:
past_kv_len = past_key_value[0].shape[1]
past_kv_len = past_key_value[0].shape[2]
kv_seq_len += past_kv_len

cos_sin = self.rotary_emb(v, seq_len=kv_seq_len)
q, k = apply_rotary_pos_emb(q, k, cos_sin, position_ids)

if past_key_value is not None:
assert (
flash_attn_version >= "2.1.0"
), "past_key_value support requires flash-attn >= 2.1.0"
# reuse k, v
k = torch.cat([past_key_value[0], k], dim=1)
v = torch.cat([past_key_value[1], v], dim=1)

past_key_value = (k, v) if use_cache else None

key_padding_mask = attention_mask
# Ideally we could just do this:
# q, indices, cu_q_lens, max_s = unpad_input(q, key_padding_mask[:, -q_len:])
# but this does not work as Flash attention treats the q seq and kv seq as starting at index 0
# which then breaks the causality logic. Probably if q_len >> past_kv_len we should
# just skip flash attention. Leaving this in for now to demonstrate correctness of
# flash attention information even when q needs padding.
# TODO(siddartha): delegate back to original implementation on this condition.
if past_kv_len > 0:
q = torch.cat(
(
torch.full(
(bsz, past_kv_len, self.num_heads, self.head_dim),
0.0,
dtype=q.dtype,
device=q.device,
),
q,
),
dim=1,
)
k = torch.cat([past_key_value[0].transpose(1, 2), k], dim=1)
v = torch.cat([past_key_value[1].transpose(1, 2), v], dim=1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you store past_key_value in transposed layout? I fell it introduces some redundant memory movement.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very sorry for losing track of this question. The reason is there is code in the LlamaModel (in transformers impl) that assumes a particular memory layout when looking up past kv length. Thats why we have to store it in this order. Fortunately this is a zero copy operation. The transpose operation just reorganizes the tensor metadata but does not result in any memory movement.

I was just recently testing another repo that depends on this and hit this error so it would be great to have this merged so that other people also don't run into this.

Again sorry about not answering this sooner.


past_key_value = (k.transpose(1, 2), v.transpose(1, 2)) if use_cache else None

if key_padding_mask is None:
if attention_mask is None:
output = flash_attn_func(q, k, v, 0.0, softmax_scale=None, causal=True).view(
bsz, q_len + past_kv_len, -1
bsz, q_len, -1
)
else:
q, indices, cu_q_lens, max_s = unpad_input(q, key_padding_mask)
q, indices, cu_q_lens, max_s = unpad_input(q, attention_mask[:, -q_len:])
# We can skip concat and call unpad twice but seems better to call unpad only once.
kv, _, cu_k_lens, max_k = unpad_input(
torch.stack((k, v), dim=2), key_padding_mask
torch.stack((k, v), dim=2), attention_mask
)
output_unpad = flash_attn_varlen_kvpacked_func(
q,
Expand All @@ -115,11 +98,7 @@ def forward(
causal=True,
)
output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
output = pad_input(output_unpad, indices, bsz, q_len + past_kv_len)

# Need to strip off the zero query outputs.
if past_kv_len > 0:
output = output[:, past_kv_len:, ...]
output = pad_input(output_unpad, indices, bsz, q_len)

return self.o_proj(output), None, past_key_value

Expand Down Expand Up @@ -245,7 +224,7 @@ def test():
use_cache=True,
)
parts.append(part)
past_kv_len = past_kv[0].shape[1]
past_kv_len = past_kv[0].shape[2]

print(
f"allclose(oneshot[:, 0], parts[0]) = {torch.allclose(oneshot[:, :part_len], parts[0])}"
Expand Down