diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index 40d46cc3fba59..8b7b257dd2852 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -198,9 +198,12 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& Tensor* present_value = context.Output(2, present_kv_shape); parameters.past_present_share_buffer_ = present_key != nullptr && present_value != nullptr && past_key != nullptr && past_value != nullptr && past_key->DataRaw() == present_key->DataRaw() && past_value->DataRaw() == present_value->DataRaw(); + ORT_ENFORCE(parameters.total_sequence_length_ <= parameters.seqlen_present_kv_cache_, "Total sequence length cannot be greater than the existing KV cache length."); + // Use a sliding window if the total sequence exceeds the window's length. + bool use_sliding_window = (local_window_size_ != -1 && local_window_size_ < parameters.total_sequence_length_); if (!do_rotary_ && head_sink == nullptr && !use_smooth_softmax_ && - local_window_size_ == -1 && + !use_sliding_window && CanApplyFlashAttention(attention_bias, present_key, present_value, parameters, context)) { return ApplyFlashAttention(query, key, value, attention_bias, output, past_key, present_key, past_value, present_value, parameters, context);