@@ -62,14 +62,14 @@ void make_copy<MLFloat16, MLFloat16>(MLFloat16* mask_data, const MLFloat16* mask
6262template <>
6363void make_copy<float , bool >(float * mask_data, const bool * mask_index, size_t size) {
6464 for (size_t i = 0 ; i < size; ++i) {
65- mask_data[i] = mask_index[i] ? 0 .0f : std::numeric_limits <float >:: lowest ();
65+ mask_data[i] = mask_index[i] ? 0 .0f : negative_infinity <float >();
6666 }
6767}
6868
6969template <>
7070void make_copy<MLFloat16, bool >(MLFloat16* mask_data, const bool * mask_index, size_t size) {
7171 for (size_t i = 0 ; i < size; ++i) {
72- mask_data[i] = mask_index[i] ? MLFloat16 (0 .f ) : std::numeric_limits <MLFloat16>:: lowest ();
72+ mask_data[i] = mask_index[i] ? MLFloat16 (0 .f ) : negative_infinity <MLFloat16>();
7373 }
7474}
7575
@@ -251,7 +251,7 @@ void AttentionBase<T>::ComputeAttentionProbs(T* attention_probs,
251251 mask_data = static_cast <T*>(allocated_ptr);
252252 for (int s_i = 0 ; s_i < parameters.q_sequence_length ; s_i++) {
253253 for (int m_i = parameters.past_sequence_length + s_i + 1 ; m_i < parameters.total_sequence_length ; m_i++) {
254- mask_data[s_i * parameters.total_sequence_length + m_i] = std::numeric_limits <T>:: lowest ();
254+ mask_data[s_i * parameters.total_sequence_length + m_i] = negative_infinity <T>();
255255 }
256256 }
257257 delete_mask_data = true ;
@@ -277,7 +277,7 @@ void AttentionBase<T>::ComputeAttentionProbs(T* attention_probs,
277277 for (int i = 0 ; i < n_iter; ++i) {
278278 for (int s_i = 0 ; s_i < parameters.q_sequence_length ; s_i++) {
279279 for (int m_i = parameters.past_sequence_length + s_i + 1 ; m_i < parameters.total_sequence_length ; m_i++) {
280- mask_data[s_i * parameters.total_sequence_length + m_i + probs_matrix_size * i] = std::numeric_limits <T>:: lowest ();
280+ mask_data[s_i * parameters.total_sequence_length + m_i + probs_matrix_size * i] = negative_infinity <T>();
281281 }
282282 }
283283 }
@@ -332,7 +332,8 @@ void AttentionBase<T>::ComputeAttentionProbs(T* attention_probs,
332332 }
333333
334334 // handling GQA
335- std::ptrdiff_t ki = batch_i * parameters.kv_num_heads + head_i % parameters.kv_num_heads ;
335+ std::ptrdiff_t head_ki = head_i * parameters.kv_num_heads / parameters.q_num_heads ;
336+ std::ptrdiff_t ki = batch_i * parameters.kv_num_heads + head_ki;
336337 const T* k = K + k_input_chunk_length * ki;
337338
338339 if (nullptr != present_key) {
@@ -362,7 +363,7 @@ void AttentionBase<T>::ComputeAttentionProbs(T* attention_probs,
362363 alpha,
363364 Q + q_input_chunk_length * parameters.q_num_heads * batch_i + head_i * parameters.head_size ,
364365 parameters.head_size * parameters.q_num_heads , // lda
365- transposed_k ? K + k_input_chunk_length * parameters.kv_num_heads * batch_i + (head_i % parameters. kv_num_heads ) * parameters.head_size : k,
366+ transposed_k ? K + k_input_chunk_length * parameters.kv_num_heads * batch_i + head_ki * parameters.head_size : k,
366367 transposed_k ? parameters.head_size * parameters.kv_num_heads : parameters.head_size , // ldb
367368 beta,
368369 output,
@@ -568,7 +569,8 @@ void AttentionBase<T>::ComputeVxAttentionScore(T* output, // bu
568569 // handling GQA
569570 std::ptrdiff_t batch_i = i / num_heads;
570571 std::ptrdiff_t head_i = i % num_heads;
571- std::ptrdiff_t vi = batch_i * kv_num_heads + head_i % kv_num_heads;
572+ std::ptrdiff_t head_vi = head_i * kv_num_heads / num_heads;
573+ std::ptrdiff_t vi = batch_i * kv_num_heads + head_vi;
572574 const T* v = V + v_input_chunk_length * vi;
573575
574576 if (nullptr != present_value) {
@@ -592,16 +594,15 @@ void AttentionBase<T>::ComputeVxAttentionScore(T* output, // bu
592594 // V is transposed but not QK. We use GemmEx with a different value for ldb.
593595 math::GemmEx<T, ThreadPool>(CblasNoTrans,
594596 CblasNoTrans,
595- sequence_length, // M
596- v_head_size, // N
597- total_sequence_length, // K
598- 1 .f , // alpha
599- attention_probs + attention_probs_offset, // QK
600- total_sequence_length, // lda
601- transposed_v ? V + (head_i % kv_num_heads) * v_head_size + v_input_chunk_length * kv_num_heads * batch_i
602- : v,
603- transposed_v ? v_head_size * kv_num_heads : v_head_size, // ldb
604- 0 .f , // beta
597+ sequence_length, // M
598+ v_head_size, // N
599+ total_sequence_length, // K
600+ 1 .f , // alpha
601+ attention_probs + attention_probs_offset, // QK
602+ total_sequence_length, // lda
603+ transposed_v ? V + head_vi * v_head_size + v_input_chunk_length * kv_num_heads * batch_i : v, // V
604+ transposed_v ? v_head_size * kv_num_heads : v_head_size, // ldb
605+ 0 .f , // beta
605606 output + ((batch_i * sequence_length * num_heads + head_i) * v_head_size),
606607 v_head_size * num_heads, // ldc
607608 nullptr );
0 commit comments