Skip to content

Commit 32ea3f5

Browse files
derdeljan-msftsnnn
authored andcommitted
[CPU] Optimize GQA attention bias application for FP16 (#25871)
### Description When using attention bias input for GQA op with FP16, on the platforms that don't natively support FP16 math a cast to fp32 needs to be performed, and thus a temporary buffer needs to be created to store the fp32 values. The issue is that this temporary buffer was being allocated / deallocated inside of a loop for every token being processed. Refactored the implementation so that the allocation takes place only once. Phi model throughput increased by 15%.
1 parent 1a743ae commit 32ea3f5

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,18 @@ class GQAAttentionBase {
280280
output, static_cast<int>(present_buffer_sequence_length), nullptr);
281281
}
282282

283+
// Pre-allocate buffer for attention mask to avoid allocating it for every processed token
284+
float* attention_bias_thread_fp32 = nullptr;
285+
if (attention_bias_thread != nullptr) {
286+
if constexpr (!std::is_same_v<U, T>) {
287+
static_assert(std::is_same_v<U, float> && std::is_same_v<T, MLFloat16>);
288+
289+
size_t bytes = attention_total_seqlen * sizeof(float);
290+
attention_bias_thread_fp32 = static_cast<float*>(allocator->Alloc(bytes));
291+
}
292+
}
293+
BufferUniquePtr scratch_buffer(attention_bias_thread_fp32, BufferDeleter(allocator));
294+
283295
// compute Softmax
284296
U* output_softmax = output;
285297
for (size_t seq = 0; seq < sequence_length; seq++) {
@@ -316,9 +328,6 @@ class GQAAttentionBase {
316328
static_cast<int>(window_size));
317329
} else {
318330
static_assert(std::is_same_v<U, float> && std::is_same_v<T, MLFloat16>);
319-
size_t bytes = window_size * sizeof(float);
320-
auto attention_bias_thread_fp32 = static_cast<float*>(allocator->Alloc(bytes));
321-
BufferUniquePtr scratch_buffer(attention_bias_thread_fp32, BufferDeleter(allocator));
322331

323332
MlasConvertHalfToFloatBuffer(attention_bias_thread + start_offset, attention_bias_thread_fp32, window_size);
324333
ApplyAttentionBias(output_softmax + start_offset, attention_bias_thread_fp32, static_cast<int>(window_size));

0 commit comments

Comments
 (0)