Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -910,7 +910,9 @@ Status ComputeFlashAttentionDecodeVxReduce(onnxruntime::webgpu::ComputeContext&

Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias,
Tensor* output, const Tensor* past_key, Tensor* present_key, const Tensor* past_value, Tensor* present_value,
const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) {
const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, int local_window_size) {
ORT_ENFORCE(local_window_size == -1, "Sliding window is not supported yet in FlashAttention.");

ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value));
const int present_sequence_length = parameters.is_gqa_ ? parameters.seqlen_present_kv_cache_ : parameters.total_sequence_length_;
if (parameters.sequence_length_ > 1) {
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/webgpu/bert/flash_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class FlashAttentionDecodeVxReduceProgram final : public Program<FlashAttentionD

Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias,
Tensor* output, const Tensor* past_key, Tensor* present_key, const Tensor* past_value, Tensor* present_value,
const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context);
const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, int local_window_size = -1);

bool CanApplyFlashAttention(const Tensor* bias, const Tensor* present_key, const Tensor* present_value,
const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context);
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,10 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
Tensor* present_value = context.Output(2, present_kv_shape);
parameters.past_present_share_buffer_ = present_key != nullptr && present_value != nullptr && past_key != nullptr && past_value != nullptr && past_key->DataRaw() == present_key->DataRaw() && past_value->DataRaw() == present_value->DataRaw();

bool use_sliding_window = (local_window_size_ != -1 && local_window_size_ < parameters.seqlen_present_kv_cache_ && local_window_size_ < parameters.total_sequence_length_);
if (!do_rotary_ &&
head_sink == nullptr && !use_smooth_softmax_ &&
local_window_size_ == -1 &&
!use_sliding_window &&
CanApplyFlashAttention(attention_bias, present_key, present_value, parameters, context)) {
return ApplyFlashAttention(query, key, value, attention_bias, output, past_key, present_key, past_value,
present_value, parameters, context);
Expand Down
Loading