forked from kohya-ss/sd-scripts
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathattention.py
More file actions
260 lines (225 loc) · 10.9 KB
/
attention.py
File metadata and controls
260 lines (225 loc) · 10.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
# Unified attention function supporting various implementations
from dataclasses import dataclass
import torch
from typing import Optional, Union
try:
import flash_attn
from flash_attn.flash_attn_interface import _flash_attn_forward
from flash_attn.flash_attn_interface import flash_attn_varlen_func
from flash_attn.flash_attn_interface import flash_attn_func
except ImportError:
flash_attn = None
flash_attn_varlen_func = None
_flash_attn_forward = None
flash_attn_func = None
try:
from sageattention import sageattn_varlen, sageattn
except ImportError:
sageattn_varlen = None
sageattn = None
try:
import xformers.ops as xops
except ImportError:
xops = None
@dataclass
class AttentionParams:
attn_mode: Optional[str] = None
split_attn: bool = False
img_len: Optional[int] = None
attention_mask: Optional[torch.Tensor] = None
seqlens: Optional[torch.Tensor] = None
cu_seqlens: Optional[torch.Tensor] = None
max_seqlen: Optional[int] = None
@staticmethod
def create_attention_params(attn_mode: Optional[str], split_attn: bool) -> "AttentionParams":
return AttentionParams(attn_mode, split_attn)
@staticmethod
def create_attention_params_from_mask(
attn_mode: Optional[str], split_attn: bool, img_len: Optional[int], attention_mask: Optional[torch.Tensor]
) -> "AttentionParams":
if attention_mask is None:
# No attention mask provided: assume all tokens are valid
return AttentionParams(attn_mode, split_attn, None, None, None, None, None)
else:
# Note: attention_mask is only for text tokens, not including image tokens
seqlens = attention_mask.sum(dim=1).to(torch.int32) + img_len # [B]
max_seqlen = attention_mask.shape[1] + img_len
if split_attn:
# cu_seqlens is not needed for split attention
return AttentionParams(attn_mode, split_attn, img_len, attention_mask, seqlens, None, max_seqlen)
# Convert attention mask to cumulative sequence lengths for flash attention
batch_size = attention_mask.shape[0]
cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device=attention_mask.device)
for i in range(batch_size):
cu_seqlens[2 * i + 1] = i * max_seqlen + seqlens[i] # end of valid tokens for query
cu_seqlens[2 * i + 2] = (i + 1) * max_seqlen # end of all tokens for query
# Expand attention mask to include image tokens
attention_mask = torch.nn.functional.pad(attention_mask, (img_len, 0), value=1) # [B, img_len + L]
if attn_mode == "xformers":
seqlens_list = seqlens.cpu().tolist()
attention_mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens(
seqlens_list, seqlens_list, device=attention_mask.device
)
elif attn_mode == "torch":
attention_mask = attention_mask[:, None, None, :].to(torch.bool) # [B, 1, 1, img_len + L]
return AttentionParams(attn_mode, split_attn, img_len, attention_mask, seqlens, cu_seqlens, max_seqlen)
def attention(
qkv_or_q: Union[torch.Tensor, list],
k: Optional[torch.Tensor] = None,
v: Optional[torch.Tensor] = None,
attn_params: Optional[AttentionParams] = None,
drop_rate: float = 0.0,
) -> torch.Tensor:
"""
Compute scaled dot-product attention with variable sequence lengths.
Handles batches with different sequence lengths by splitting and
processing each sequence individually.
Args:
qkv_or_q: Query tensor [B, L, H, D]. or list of such tensors.
k: Key tensor [B, L, H, D].
v: Value tensor [B, L, H, D].
attn_param: Attention parameters including mask and sequence lengths.
drop_rate: Attention dropout rate.
Returns:
Attention output tensor [B, L, H*D].
"""
if isinstance(qkv_or_q, list):
q, k, v = qkv_or_q
q: torch.Tensor = q
qkv_or_q.clear()
del qkv_or_q
else:
q: torch.Tensor = qkv_or_q
del qkv_or_q
assert k is not None and v is not None, "k and v must be provided if qkv_or_q is a tensor"
if attn_params is None:
attn_params = AttentionParams.create_attention_params("torch", False)
# If split attn is False, attention mask is provided and all sequence lengths are same, we can trim the sequence
seqlen_trimmed = False
if not attn_params.split_attn and attn_params.attention_mask is not None and attn_params.seqlens is not None:
if torch.all(attn_params.seqlens == attn_params.seqlens[0]):
seqlen = attn_params.seqlens[0].item()
q = q[:, :seqlen]
k = k[:, :seqlen]
v = v[:, :seqlen]
max_seqlen = attn_params.max_seqlen
attn_params = AttentionParams.create_attention_params(attn_params.attn_mode, False) # do not in-place modify
attn_params.max_seqlen = max_seqlen # keep max_seqlen for padding
seqlen_trimmed = True
# Determine tensor layout based on attention implementation
if attn_params.attn_mode == "torch" or (
attn_params.attn_mode == "sageattn" and (attn_params.split_attn or attn_params.cu_seqlens is None)
):
transpose_fn = lambda x: x.transpose(1, 2) # [B, H, L, D] for SDPA and sageattn with fixed length
# pad on sequence length dimension
pad_fn = lambda x, pad_to: torch.nn.functional.pad(x, (0, 0, 0, pad_to - x.shape[-2]), value=0)
else:
transpose_fn = lambda x: x # [B, L, H, D] for other implementations
# pad on sequence length dimension
pad_fn = lambda x, pad_to: torch.nn.functional.pad(x, (0, 0, 0, 0, 0, pad_to - x.shape[-3]), value=0)
# Process each batch element with its valid sequence lengths
if attn_params.split_attn:
if attn_params.seqlens is None:
# If no seqlens provided, assume all tokens are valid
attn_params = AttentionParams.create_attention_params(attn_params.attn_mode, True) # do not in-place modify
attn_params.seqlens = torch.tensor([q.shape[1]] * q.shape[0], device=q.device)
attn_params.max_seqlen = q.shape[1]
q = [transpose_fn(q[i : i + 1, : attn_params.seqlens[i]]) for i in range(len(q))]
k = [transpose_fn(k[i : i + 1, : attn_params.seqlens[i]]) for i in range(len(k))]
v = [transpose_fn(v[i : i + 1, : attn_params.seqlens[i]]) for i in range(len(v))]
else:
q = transpose_fn(q)
k = transpose_fn(k)
v = transpose_fn(v)
if attn_params.attn_mode == "torch":
if attn_params.split_attn:
x = []
for i in range(len(q)):
x_i = torch.nn.functional.scaled_dot_product_attention(q[i], k[i], v[i], dropout_p=drop_rate)
q[i] = None
k[i] = None
v[i] = None
x.append(pad_fn(x_i, attn_params.max_seqlen)) # B, H, L, D
x = torch.cat(x, dim=0)
del q, k, v
else:
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_params.attention_mask, dropout_p=drop_rate)
del q, k, v
elif attn_params.attn_mode == "xformers":
if attn_params.split_attn:
x = []
for i in range(len(q)):
x_i = xops.memory_efficient_attention(q[i], k[i], v[i], p=drop_rate)
q[i] = None
k[i] = None
v[i] = None
x.append(pad_fn(x_i, attn_params.max_seqlen)) # B, L, H, D
x = torch.cat(x, dim=0)
del q, k, v
else:
x = xops.memory_efficient_attention(q, k, v, attn_bias=attn_params.attention_mask, p=drop_rate)
del q, k, v
elif attn_params.attn_mode == "sageattn":
if attn_params.split_attn:
x = []
for i in range(len(q)):
# HND seems to cause an error
x_i = sageattn(q[i], k[i], v[i]) # B, H, L, D. No dropout support
q[i] = None
k[i] = None
v[i] = None
x.append(pad_fn(x_i, attn_params.max_seqlen)) # B, H, L, D
x = torch.cat(x, dim=0)
del q, k, v
elif attn_params.cu_seqlens is None: # all tokens are valid
x = sageattn(q, k, v) # B, L, H, D. No dropout support
del q, k, v
else:
# Reshape to [(bxs), a, d]
batch_size, seqlen = q.shape[0], q.shape[1]
q = q.view(q.shape[0] * q.shape[1], *q.shape[2:]) # [B*L, H, D]
k = k.view(k.shape[0] * k.shape[1], *k.shape[2:]) # [B*L, H, D]
v = v.view(v.shape[0] * v.shape[1], *v.shape[2:]) # [B*L, H, D]
# Assume cu_seqlens_q == cu_seqlens_kv and max_seqlen_q == max_seqlen_kv. No dropout support
x = sageattn_varlen(
q, k, v, attn_params.cu_seqlens, attn_params.cu_seqlens, attn_params.max_seqlen, attn_params.max_seqlen
)
del q, k, v
# Reshape x with shape [(bxs), a, d] to [b, s, a, d]
x = x.view(batch_size, seqlen, x.shape[-2], x.shape[-1]) # B, L, H, D
elif attn_params.attn_mode == "flash":
if attn_params.split_attn:
x = []
for i in range(len(q)):
# HND seems to cause an error
x_i = flash_attn_func(q[i], k[i], v[i], drop_rate) # B, L, H, D
q[i] = None
k[i] = None
v[i] = None
x.append(pad_fn(x_i, attn_params.max_seqlen)) # B, L, H, D
x = torch.cat(x, dim=0)
del q, k, v
elif attn_params.cu_seqlens is None: # all tokens are valid
x = flash_attn_func(q, k, v, drop_rate) # B, L, H, D
del q, k, v
else:
# Reshape to [(bxs), a, d]
batch_size, seqlen = q.shape[0], q.shape[1]
q = q.view(q.shape[0] * q.shape[1], *q.shape[2:]) # [B*L, H, D]
k = k.view(k.shape[0] * k.shape[1], *k.shape[2:]) # [B*L, H, D]
v = v.view(v.shape[0] * v.shape[1], *v.shape[2:]) # [B*L, H, D]
# Assume cu_seqlens_q == cu_seqlens_kv and max_seqlen_q == max_seqlen_kv
x = flash_attn_varlen_func(
q, k, v, attn_params.cu_seqlens, attn_params.cu_seqlens, attn_params.max_seqlen, attn_params.max_seqlen, drop_rate
)
del q, k, v
# Reshape x with shape [(bxs), a, d] to [b, s, a, d]
x = x.view(batch_size, seqlen, x.shape[-2], x.shape[-1]) # B, L, H, D
else:
# Currently only PyTorch SDPA and xformers are implemented
raise ValueError(f"Unsupported attention mode: {attn_params.attn_mode}")
x = transpose_fn(x) # [B, L, H, D]
x = x.reshape(x.shape[0], x.shape[1], -1) # [B, L, H*D]
if seqlen_trimmed:
x = torch.nn.functional.pad(x, (0, 0, 0, attn_params.max_seqlen - x.shape[1]), value=0) # pad back to max_seqlen
return x