From d15898481addf36c6ce06067130568f50e11eb22 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 5 Apr 2024 13:29:48 +0300 Subject: [PATCH 1/8] metal : add BS=1 kernel for flash attention (wip) --- ggml-metal.m | 117 +++++++++---- ggml-metal.metal | 416 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 501 insertions(+), 32 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index f4800e3b0e90a..bf6277d3893c5 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -179,6 +179,12 @@ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H80, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H112, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, GGML_METAL_KERNEL_TYPE_CPY_F32_F16, GGML_METAL_KERNEL_TYPE_CPY_F32_F32, GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, @@ -613,12 +619,18 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64, flash_attn_ext_vec_f16_h64, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H80, flash_attn_ext_vec_f16_h80, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96, flash_attn_ext_vec_f16_h96, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H112, flash_attn_ext_vec_f16_h112, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); @@ -2509,19 +2521,36 @@ static enum ggml_status ggml_metal_graph_compute( id pipeline = nil; - switch (ne00) { - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break; - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; - case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; - default: - { - GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_METAL_LOG_ERROR("add template specialization for this size\n"); - GGML_ASSERT(false && "add template specialization for this size"); - } + if (ne01 > 1) { + switch (ne00) { + case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break; + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; + case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; + default: + { + GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_METAL_LOG_ERROR("add template specialization for this size\n"); + GGML_ASSERT(false && "add template specialization for this size"); + } + } + } else { + switch (ne00) { + case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H112].pipeline; break; + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break; + case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break; + default: + { + GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_METAL_LOG_ERROR("add template specialization for this size\n"); + GGML_ASSERT(false && "add template specialization for this size"); + } + } } // TODO: extend if necessary @@ -2555,24 +2584,48 @@ static enum ggml_status ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! - const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! + // half8x8 kernel + if (ne01 > 1) { + const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! + const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! - GGML_ASSERT(nqptg <= 32); - GGML_ASSERT(nqptg % 8 == 0); - GGML_ASSERT(ncpsg % 32 == 0); + GGML_ASSERT(nqptg <= 32); + GGML_ASSERT(nqptg % 8 == 0); + GGML_ASSERT(ncpsg % 32 == 0); - // simdgroups per threadgroup (a.k.a. warps) - // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) - const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)) : 4; + // simdgroups per threadgroup (a.k.a. warps) + // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) + const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)) : 4; - const size_t smem = nqptg*(ne00 + nsg*(ncpsg + nqptg))*(sizeof(float)/2); + const size_t smem = nqptg*(ne00 + nsg*(ncpsg + nqptg))*(sizeof(float)/2); - //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); - GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); - [encoder setThreadgroupMemoryLength:smem atIndex:0]; + //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); + GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); + [encoder setThreadgroupMemoryLength:smem atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + } else { + // half1x4 kernel + const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !! + const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! + + GGML_ASSERT(nqptg <= 32); + GGML_ASSERT(nqptg % 1 == 0); + GGML_ASSERT(ncpsg % 32 == 0); + + // simdgroups per threadgroup (a.k.a. warps) + // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) + //const int64_t nsg = MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)); + const int64_t nsg = 1; + + const size_t smem = nqptg*(ne00 + nsg*(ncpsg + nqptg))*(sizeof(float)/2); + + //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); + GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); + [encoder setThreadgroupMemoryLength:smem atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + } } break; case GGML_OP_DUP: case GGML_OP_CPY: diff --git a/ggml-metal.metal b/ggml-metal.metal index 77004db85bfd7..e4be0f69e5c97 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2457,6 +2457,422 @@ template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 8, 32>; template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256, 8, 32>; +template // head size, queries per threadgroup, cache items per threadgroup +kernel void kernel_flash_attn_ext_vec_f16( + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne31, + constant uint64_t & nb31, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant float & scale, + threadgroup half * shared [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + const short nsg = ntg.y; // number of simdgroups + + const short iq3 = tgpig[2]; + const short iq2 = tgpig[1]; + const short iq1 = tgpig[0]*Q; + + const short D4 = D/4; + const short NW = N_SIMDWIDTH; + const short SH = (C + Q); // shared memory per simdgroup in (half) + + const short T = D + nsg*SH; // shared memory size per query in (half) + const short T4 = T/4; // shared memory size per query in (half4) + + threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data + threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 + threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix + threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*SH + 1*D); // same as above but in half4 + + // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) + half4 lo[Q][D4]; + + // load heads from Q to shared memory + for (short j = sgitg; j < Q; j += nsg) { + device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); + + for (short i = tiisg; i < D4; i += NW) { + if (iq1 + j < ne01) { + sq4[j*T4 + i] = (half4) q4[i]; + } else { + sq4[j*T4 + i] = 0.0h; + } + } + } + + // zero out lo + for (short j = 0; j < Q; ++j) { + for (short i = 0; i < D4; ++i) { + lo[j][i] = 0.0h; + } + } + + // zero out shared memory SH + for (short j = 0; j < Q; ++j) { + for (short i = tiisg; i < SH/4; i += NW) { + ss4[j*T4 + i] = 0.0h; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + { + half S[Q] = { [0 ... Q-1] = 0.0h }; + half M[Q] = { [0 ... Q-1] = -INFINITY }; + + // assume K and V are same shape + const short ne22 = ne12; + const short ne23 = ne13; + + const uint nb21 = nb11; + const uint nb22 = nb12; + const uint nb23 = nb13; + + // broadcast + const short rk2 = ne02/ne12; + const short rk3 = ne03/ne13; + + const short rv2 = ne02/ne22; + const short rv3 = ne03/ne23; + + // k indices + const short ik2 = iq2 / rk2; + const short ik3 = iq3 / rk3; + + // v indices + const short iv2 = iq2 / rv2; + const short iv3 = iq3 / rv3; + + // load the queries from shared memory into local memory + half4 mq[Q][D4]; + + for (short j = 0; j < Q; ++j) { + for (short i = tiisg; i < D4; i += NW) { + //simdgroup_load(mq[j][i], sq + 8*j*T + i*8, T); + mq[j][i] = sq4[j*T4 + i]; + } + } + + // pointer to the mask + device const half * mp = (device const half *) (mask + iq1*nb31); + + // prepare diagonal scale matrix + //simdgroup_half8x8 mscale(scale); + half mscale(scale); + + // loop over the KV cache + // each simdgroup handles blocks of Q rows and C columns + for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) { + const int ic = ic0 + C*sgitg; + if (ic >= ne11) { + break; + } + + // Q*K^T + { + for (short cc = 0; cc < C; ++cc) { + half mqk[Q]; + for (short j = 0; j < Q; ++j) { + mqk[j] = 0.0h; + } + + //device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); + device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + cc)*nb11 + ik2*nb12 + ik3*nb13)); + + for (short i = tiisg; i < D4; i += NW) { + //simdgroup_half8x8 mk; + half4 mk; + //simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose + mk = pk4[i]; + + for (short j = 0; j < Q; ++j) { + //simdgroup_multiply_accumulate(mqk[j], mq[j][i], mk, mqk[j]); + mqk[j] += dot(mq[j][i], mk); + } + } + + // reduce the results from the threads in the simdgroup + simdgroup_barrier(mem_flags::mem_none); + + for (short i = NW/2; i > 0; i /= 2) { + if (tiisg < i) { + for (short j = 0; j < Q; ++j) { + mqk[j] += simd_shuffle_down(mqk[j], i); + } + } + + simdgroup_barrier(mem_flags::mem_none); + } + + // mqk = mqk*scale + mask + if (tiisg == 0) { + for (short j = 0; j < Q; ++j) { + //simdgroup_half8x8 mm; + //simdgroup_load(mm, mp + 8*j*(nb31/sizeof(half)) + ic + 8*cc, nb31/sizeof(half), 0, false); + //simdgroup_multiply_accumulate(mqk[j], mqk[j], mscale, mm); + + //simdgroup_store(mqk[j], ss + 8*j*T + 8*cc, T, 0, false); + + half mm = mp[j*(nb31/sizeof(half)) + ic + cc]; + mqk[j] = mqk[j]*mscale + mm; + + ss[j*T + cc] = mqk[j]; + } + } + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // used to detect blocks full of -INF + half smax = -INFINITY; + + // online softmax + if (C == 32) { + half ms[Q]; + + for (short j = 0; j < Q; ++j) { + const short p = tiisg; + + const half m = M[j]; + const half s = ss[j*T + p]; + + smax = simd_max(max(smax, s)); + M[j] = simd_max(max(M[j], s)); + + ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]); + const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); + + S[j] = S[j]*ms[j] + simd_sum(vs); + + // the P matrix from the paper (Q rows, C columns) + ss[j*T + p] = vs; + } + + // create a QxQ diagonal matrix for rescaling the output + if (tiisg < Q) { + ss[tiisg*T + C + tiisg] = ms[tiisg]; + } + } else { + half ms[Q]; + + for (short j = 0; j < Q; ++j) { + const half m = M[j]; + + for (short p = tiisg; p < C; p += NW) { + const half s = ss[j*T + p]; + + smax = max(smax, s); + M[j] = max(M[j], s); + } + + smax = simd_max(smax); + M[j] = simd_max(M[j]); + + ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]); + + // local sum + half ls = 0.0h; + + for (short p = tiisg; p < C; p += NW) { + const half s = ss[j*T + p]; + + const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); + + ls += vs; + + // the P matrix from the paper (Q rows, C columns) + ss[j*T + p] = vs; + } + + S[j] = S[j]*ms[j] + simd_sum(ls); + } + + // create a QxQ diagonal matrix for rescaling the output + if (tiisg < Q) { + ss[tiisg*T + C + tiisg] = ms[tiisg]; + } + } + + // skip -INF blocks + if (smax == -INFINITY) { + continue; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // O = diag(ms)*O + for (short j = 0; j < Q; ++j) { + //simdgroup_half8x8 mm; + //simdgroup_load(mm, ss + 8*j*T + C + 8*j, T, 0, false); + half mm(ss[j*T + C + j]); + + for (short i = tiisg; i < D4; i += NW) { + //simdgroup_multiply(lo[j][i], mm, lo[j][i]); + lo[j][i] = lo[j][i]*mm; + } + } + + // O = O + (Q*K^T)*V + { + for (short cc = 0; cc < C; ++cc) { + //device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); + device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + cc)*nb21 + iv2*nb22 + iv3*nb23)); + + //for (short i = 0; i < D8; ++i) { + // simdgroup_half8x8 mk; + // simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false); + + // for (short j = 0; j < Q8; ++j) { + // simdgroup_half8x8 mv; + // simdgroup_load(mv, ss + 8*j*T + 8*cc, T, 0, false); + + // simdgroup_multiply_accumulate(lo[j][i], mv, mk, lo[j][i]); + // } + //} + + for (short i = tiisg; i < D4; i += NW) { + half4 mk = pv4[i]; + + for (short j = 0; j < Q; ++j) { + lo[j][i] += mk*ss[j*T + cc]; + } + } + } + } + } + + // these are needed for reducing the results from the simdgroups (reuse the ss buffer) + for (short j = 0; j < Q; ++j) { + if (tiisg == 0) { + ss[j*T + 0] = S[j]; + ss[j*T + 1] = M[j]; + } + } + } + + // reduce the warps sequentially + //for (short sg = 1; sg < nsg; ++sg) { + // half S = { 0.0h }; + // half M = { -INFINITY }; + + // threadgroup_barrier(mem_flags::mem_threadgroup); + + // // each simdgroup stores its output to shared memory, reusing sq + // if (sgitg == sg) { + // for (short j = 0; j < Q8; ++j) { + // for (short i = 0; i < D8; ++i) { + // simdgroup_store(lo[j][i], sq + 8*j*T + i*8, T, 0, false); + // } + // } + // } + + // threadgroup_barrier(mem_flags::mem_threadgroup); + + // // the first simdgroup accumulates the results from the other simdgroups + // if (sgitg == 0) { + // for (short j = 0; j < Q; ++j) { + // const half S0 = ss[j*T + 0]; + // const half S1 = ss[j*T + sg*SH + 0]; + + // const half M0 = ss[j*T + 1]; + // const half M1 = ss[j*T + sg*SH + 1]; + + // M = max(M0, M1); + + // const half ms0 = M0 == -INFINITY ? 0.0h : exp(M0 - M); + // const half ms1 = M1 == -INFINITY ? 0.0h : exp(M1 - M); + + // S = S0*ms0 + S1*ms1; + + // if (tiisg == 0) { + // ss[j*T + 0] = S; + // ss[j*T + 1] = M; + + // ss[j*T + C + j ] = ms0; + // ss[j*T + C + j + sg*SH] = ms1; + // } + // } + + // // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 + // for (short j = 0; j < Q8; ++j) { + // simdgroup_half8x8 t; + // simdgroup_half8x8 ms0; + // simdgroup_half8x8 ms1; + + // simdgroup_load(ms0, ss + 8*j*T + C + 8*j, T, 0, false); + // simdgroup_load(ms1, ss + 8*j*T + C + 8*j + sg*SH, T, 0, false); + + // for (short i = 0; i < D8; ++i) { + // simdgroup_load (t, sq + 8*j*T + i*8, T, 0, false); + // simdgroup_multiply(t, ms1, t); + + // simdgroup_multiply_accumulate(lo[j][i], ms0, lo[j][i], t); + // } + // } + // } + //} + + // store result to shared memory (reuse sq) + if (sgitg == 0) { + for (short j = 0; j < Q; ++j) { + for (short i = tiisg; i < D4; i += NW) { + //simdgroup_store(lo[j][i], sq + 8*j*T + i*8, T, 0, false); + sq4[j*T4 + i] = lo[j][i]; + } + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + device float4 * dst4 = (device float4 *) dst; + + // final rescale with 1/S and store to global memory + if (sgitg == 0) { + for (short j = 0; j < Q && iq1 + j < ne01; ++j) { + const half S = ss[j*T + 0]; + + for (short i = tiisg; i < D4; i += NW) { + dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S; + } + } + } +} + +template [[host_name("kernel_flash_attn_ext_vec_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<64, 1, 32>; +template [[host_name("kernel_flash_attn_ext_vec_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<80, 1, 32>; +template [[host_name("kernel_flash_attn_ext_vec_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<96, 1, 32>; +template [[host_name("kernel_flash_attn_ext_vec_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<112, 1, 32>; +template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 1, 32>; +template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256, 1, 32>; + kernel void kernel_cpy_f16_f16( device const half * src0, device half * dst, From 5eab7454dd272927bfdc2a5fe849e10535bb621e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 5 Apr 2024 13:41:00 +0300 Subject: [PATCH 2/8] metal : support more than 1 warps --- ggml-metal.m | 4 +-- ggml-metal.metal | 94 +++++++++++++++++++++++------------------------- 2 files changed, 46 insertions(+), 52 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index bf6277d3893c5..d942b673f6dd7 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2615,8 +2615,8 @@ static enum ggml_status ggml_metal_graph_compute( // simdgroups per threadgroup (a.k.a. warps) // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) - //const int64_t nsg = MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)); - const int64_t nsg = 1; + const int64_t nsg = MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)); + //const int64_t nsg = 1; const size_t smem = nqptg*(ne00 + nsg*(ncpsg + nqptg))*(sizeof(float)/2); diff --git a/ggml-metal.metal b/ggml-metal.metal index e4be0f69e5c97..d7ce102744dd6 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2608,9 +2608,8 @@ kernel void kernel_flash_attn_ext_vec_f16( for (short i = tiisg; i < D4; i += NW) { //simdgroup_half8x8 mk; - half4 mk; //simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose - mk = pk4[i]; + half4 mk = pk4[i]; for (short j = 0; j < Q; ++j) { //simdgroup_multiply_accumulate(mqk[j], mq[j][i], mk, mqk[j]); @@ -2779,66 +2778,61 @@ kernel void kernel_flash_attn_ext_vec_f16( } // reduce the warps sequentially - //for (short sg = 1; sg < nsg; ++sg) { - // half S = { 0.0h }; - // half M = { -INFINITY }; - - // threadgroup_barrier(mem_flags::mem_threadgroup); - - // // each simdgroup stores its output to shared memory, reusing sq - // if (sgitg == sg) { - // for (short j = 0; j < Q8; ++j) { - // for (short i = 0; i < D8; ++i) { - // simdgroup_store(lo[j][i], sq + 8*j*T + i*8, T, 0, false); - // } - // } - // } + for (short sg = 1; sg < nsg; ++sg) { + half S = { 0.0h }; + half M = { -INFINITY }; - // threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup_barrier(mem_flags::mem_threadgroup); - // // the first simdgroup accumulates the results from the other simdgroups - // if (sgitg == 0) { - // for (short j = 0; j < Q; ++j) { - // const half S0 = ss[j*T + 0]; - // const half S1 = ss[j*T + sg*SH + 0]; + // each simdgroup stores its output to shared memory, reusing sq + if (sgitg == sg) { + for (short j = 0; j < Q; ++j) { + for (short i = tiisg; i < D4; i += NW) { + //simdgroup_store(lo[j][i], sq + 8*j*T + i*8, T, 0, false); + sq4[j*T4 + i] = lo[j][i]; + } + } + } - // const half M0 = ss[j*T + 1]; - // const half M1 = ss[j*T + sg*SH + 1]; + threadgroup_barrier(mem_flags::mem_threadgroup); - // M = max(M0, M1); + // the first simdgroup accumulates the results from the other simdgroups + if (sgitg == 0) { + for (short j = 0; j < Q; ++j) { + const half S0 = ss[j*T + 0]; + const half S1 = ss[j*T + sg*SH + 0]; - // const half ms0 = M0 == -INFINITY ? 0.0h : exp(M0 - M); - // const half ms1 = M1 == -INFINITY ? 0.0h : exp(M1 - M); + const half M0 = ss[j*T + 1]; + const half M1 = ss[j*T + sg*SH + 1]; - // S = S0*ms0 + S1*ms1; + M = max(M0, M1); - // if (tiisg == 0) { - // ss[j*T + 0] = S; - // ss[j*T + 1] = M; + const half ms0 = M0 == -INFINITY ? 0.0h : exp(M0 - M); + const half ms1 = M1 == -INFINITY ? 0.0h : exp(M1 - M); - // ss[j*T + C + j ] = ms0; - // ss[j*T + C + j + sg*SH] = ms1; - // } - // } + S = S0*ms0 + S1*ms1; - // // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 - // for (short j = 0; j < Q8; ++j) { - // simdgroup_half8x8 t; - // simdgroup_half8x8 ms0; - // simdgroup_half8x8 ms1; + if (tiisg == 0) { + ss[j*T + 0] = S; + ss[j*T + 1] = M; - // simdgroup_load(ms0, ss + 8*j*T + C + 8*j, T, 0, false); - // simdgroup_load(ms1, ss + 8*j*T + C + 8*j + sg*SH, T, 0, false); + ss[j*T + C + j ] = ms0; + ss[j*T + C + j + sg*SH] = ms1; + } + } - // for (short i = 0; i < D8; ++i) { - // simdgroup_load (t, sq + 8*j*T + i*8, T, 0, false); - // simdgroup_multiply(t, ms1, t); + // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 + for (short j = 0; j < Q; ++j) { + for (short i = tiisg; i < D4; i += NW) { + half4 t = sq4[j*T4 + i]; + half ms0 = ss[j*T + C + j]; + half ms1 = ss[j*T + C + j + sg*SH]; - // simdgroup_multiply_accumulate(lo[j][i], ms0, lo[j][i], t); - // } - // } - // } - //} + lo[j][i] = lo[j][i]*ms0 + t*ms1; + } + } + } + } // store result to shared memory (reuse sq) if (sgitg == 0) { From 8d2a61f0686a2bdad824231a4c5c1335f08cdabf Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 5 Apr 2024 13:57:54 +0300 Subject: [PATCH 3/8] metal : opts --- ggml-metal.m | 1 - ggml-metal.metal | 35 +++++------------------------------ 2 files changed, 5 insertions(+), 31 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index d942b673f6dd7..0f405b1126e2b 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2616,7 +2616,6 @@ static enum ggml_status ggml_metal_graph_compute( // simdgroups per threadgroup (a.k.a. warps) // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) const int64_t nsg = MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)); - //const int64_t nsg = 1; const size_t smem = nqptg*(ne00 + nsg*(ncpsg + nqptg))*(sizeof(float)/2); diff --git a/ggml-metal.metal b/ggml-metal.metal index d7ce102744dd6..282ec3eb6df21 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2529,7 +2529,7 @@ kernel void kernel_flash_attn_ext_vec_f16( // zero out lo for (short j = 0; j < Q; ++j) { - for (short i = 0; i < D4; ++i) { + for (short i = tiisg; i < D4; i += NW) { lo[j][i] = 0.0h; } } @@ -2648,10 +2648,7 @@ kernel void kernel_flash_attn_ext_vec_f16( } } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // used to detect blocks full of -INF - half smax = -INFINITY; + //threadgroup_barrier(mem_flags::mem_threadgroup); // online softmax if (C == 32) { @@ -2663,7 +2660,6 @@ kernel void kernel_flash_attn_ext_vec_f16( const half m = M[j]; const half s = ss[j*T + p]; - smax = simd_max(max(smax, s)); M[j] = simd_max(max(M[j], s)); ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]); @@ -2688,11 +2684,9 @@ kernel void kernel_flash_attn_ext_vec_f16( for (short p = tiisg; p < C; p += NW) { const half s = ss[j*T + p]; - smax = max(smax, s); M[j] = max(M[j], s); } - smax = simd_max(smax); M[j] = simd_max(M[j]); ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]); @@ -2720,12 +2714,7 @@ kernel void kernel_flash_attn_ext_vec_f16( } } - // skip -INF blocks - if (smax == -INFINITY) { - continue; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); + //threadgroup_barrier(mem_flags::mem_threadgroup); // O = diag(ms)*O for (short j = 0; j < Q; ++j) { @@ -2742,26 +2731,12 @@ kernel void kernel_flash_attn_ext_vec_f16( // O = O + (Q*K^T)*V { for (short cc = 0; cc < C; ++cc) { - //device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + cc)*nb21 + iv2*nb22 + iv3*nb23)); - //for (short i = 0; i < D8; ++i) { - // simdgroup_half8x8 mk; - // simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false); - - // for (short j = 0; j < Q8; ++j) { - // simdgroup_half8x8 mv; - // simdgroup_load(mv, ss + 8*j*T + 8*cc, T, 0, false); - - // simdgroup_multiply_accumulate(lo[j][i], mv, mk, lo[j][i]); - // } - //} - + half vsum[Q]; for (short i = tiisg; i < D4; i += NW) { - half4 mk = pv4[i]; - for (short j = 0; j < Q; ++j) { - lo[j][i] += mk*ss[j*T + cc]; + lo[j][i] += pv4[i]*ss[j*T + cc]; } } } From 5733b00e5344734f9d21868c20cf8680db4debd3 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 5 Apr 2024 14:26:28 +0300 Subject: [PATCH 4/8] metal : opt --- ggml-metal.metal | 102 +++++++++++++---------------------------------- 1 file changed, 28 insertions(+), 74 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index 282ec3eb6df21..533b9fef60816 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2581,7 +2581,7 @@ kernel void kernel_flash_attn_ext_vec_f16( } // pointer to the mask - device const half * mp = (device const half *) (mask + iq1*nb31); + device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31); // prepare diagonal scale matrix //simdgroup_half8x8 mscale(scale); @@ -2597,23 +2597,23 @@ kernel void kernel_flash_attn_ext_vec_f16( // Q*K^T { - for (short cc = 0; cc < C; ++cc) { - half mqk[Q]; + for (short cc = 0; cc < C/4; ++cc) { + half4 mqk[Q]; for (short j = 0; j < Q; ++j) { mqk[j] = 0.0h; } - //device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); - device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + cc)*nb11 + ik2*nb12 + ik3*nb13)); + device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13)); for (short i = tiisg; i < D4; i += NW) { - //simdgroup_half8x8 mk; - //simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose - half4 mk = pk4[i]; + half4x4 mk; + mk[0] = pk4[i + 0*(nb11/8)]; + mk[1] = pk4[i + 1*(nb11/8)]; + mk[2] = pk4[i + 2*(nb11/8)]; + mk[3] = pk4[i + 3*(nb11/8)]; for (short j = 0; j < Q; ++j) { - //simdgroup_multiply_accumulate(mqk[j], mq[j][i], mk, mqk[j]); - mqk[j] += dot(mq[j][i], mk); + mqk[j] += mq[j][i] * mk; } } @@ -2633,85 +2633,40 @@ kernel void kernel_flash_attn_ext_vec_f16( // mqk = mqk*scale + mask if (tiisg == 0) { for (short j = 0; j < Q; ++j) { - //simdgroup_half8x8 mm; - //simdgroup_load(mm, mp + 8*j*(nb31/sizeof(half)) + ic + 8*cc, nb31/sizeof(half), 0, false); - //simdgroup_multiply_accumulate(mqk[j], mqk[j], mscale, mm); - - //simdgroup_store(mqk[j], ss + 8*j*T + 8*cc, T, 0, false); - - half mm = mp[j*(nb31/sizeof(half)) + ic + cc]; + half4 mm = mp4[(j*(nb31/sizeof(half)) + ic)/4 + cc]; mqk[j] = mqk[j]*mscale + mm; - ss[j*T + cc] = mqk[j]; + ss4[j*T4 + cc] = mqk[j]; } } } } - //threadgroup_barrier(mem_flags::mem_threadgroup); + simdgroup_barrier(mem_flags::mem_threadgroup); // online softmax - if (C == 32) { - half ms[Q]; - - for (short j = 0; j < Q; ++j) { - const short p = tiisg; - - const half m = M[j]; - const half s = ss[j*T + p]; - - M[j] = simd_max(max(M[j], s)); - - ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]); - const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); - - S[j] = S[j]*ms[j] + simd_sum(vs); - - // the P matrix from the paper (Q rows, C columns) - ss[j*T + p] = vs; - } - - // create a QxQ diagonal matrix for rescaling the output - if (tiisg < Q) { - ss[tiisg*T + C + tiisg] = ms[tiisg]; - } - } else { - half ms[Q]; + half ms[Q]; - for (short j = 0; j < Q; ++j) { - const half m = M[j]; - - for (short p = tiisg; p < C; p += NW) { - const half s = ss[j*T + p]; - - M[j] = max(M[j], s); - } - - M[j] = simd_max(M[j]); - - ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]); - - // local sum - half ls = 0.0h; + for (short j = 0; j < Q; ++j) { + const short p = tiisg; - for (short p = tiisg; p < C; p += NW) { - const half s = ss[j*T + p]; + const half m = M[j]; + const half s = ss[j*T + p]; - const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); + M[j] = simd_max(max(M[j], s)); - ls += vs; + ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]); + const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); - // the P matrix from the paper (Q rows, C columns) - ss[j*T + p] = vs; - } + S[j] = S[j]*ms[j] + simd_sum(vs); - S[j] = S[j]*ms[j] + simd_sum(ls); - } + // the P matrix from the paper (Q rows, C columns) + ss[j*T + p] = vs; + } - // create a QxQ diagonal matrix for rescaling the output - if (tiisg < Q) { - ss[tiisg*T + C + tiisg] = ms[tiisg]; - } + // create a QxQ diagonal matrix for rescaling the output + if (tiisg < Q) { + ss[tiisg*T + C + tiisg] = ms[tiisg]; } //threadgroup_barrier(mem_flags::mem_threadgroup); @@ -2733,7 +2688,6 @@ kernel void kernel_flash_attn_ext_vec_f16( for (short cc = 0; cc < C; ++cc) { device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + cc)*nb21 + iv2*nb22 + iv3*nb23)); - half vsum[Q]; for (short i = tiisg; i < D4; i += NW) { for (short j = 0; j < Q; ++j) { lo[j][i] += pv4[i]*ss[j*T + cc]; From e51778de5e67d867cc39802536cb995b245b73e2 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 5 Apr 2024 16:10:15 +0300 Subject: [PATCH 5/8] metal : switch to parallel reduce --- ggml-metal.m | 16 +++- ggml-metal.metal | 196 +++++++++++++++++++++++++---------------------- 2 files changed, 119 insertions(+), 93 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 0f405b1126e2b..6106bc7e360b2 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2615,13 +2615,23 @@ static enum ggml_status ggml_metal_graph_compute( // simdgroups per threadgroup (a.k.a. warps) // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) - const int64_t nsg = MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)); + //const int64_t nsg = MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)); + const int64_t nsg = 8; - const size_t smem = nqptg*(ne00 + nsg*(ncpsg + nqptg))*(sizeof(float)/2); + // require power of 2 + //{ + // int64_t nsgm = 1; + // while (nsgm < nsg) { + // nsgm *= 2; + // } + // GGML_ASSERT(nsg == nsgm); + //} + + const size_t smem = (nqptg*(ne00 + nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2); //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); - [encoder setThreadgroupMemoryLength:smem atIndex:0]; + [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; } diff --git a/ggml-metal.metal b/ggml-metal.metal index 533b9fef60816..404bd16e08d54 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2457,6 +2457,8 @@ template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 8, 32>; template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256, 8, 32>; +#define HALF_MAX_HALF half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction. + template // head size, queries per threadgroup, cache items per threadgroup kernel void kernel_flash_attn_ext_vec_f16( device const char * q, @@ -2500,6 +2502,7 @@ kernel void kernel_flash_attn_ext_vec_f16( const short iq1 = tgpig[0]*Q; const short D4 = D/4; + const short D8 = D/8; const short NW = N_SIMDWIDTH; const short SH = (C + Q); // shared memory per simdgroup in (half) @@ -2510,6 +2513,7 @@ kernel void kernel_flash_attn_ext_vec_f16( threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*SH + 1*D); // same as above but in half4 + threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + Q*T); // scratch buffer for the results // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) half4 lo[Q][D4]; @@ -2545,7 +2549,7 @@ kernel void kernel_flash_attn_ext_vec_f16( { half S[Q] = { [0 ... Q-1] = 0.0h }; - half M[Q] = { [0 ... Q-1] = -INFINITY }; + half M[Q] = { [0 ... Q-1] = -HALF_MAX_HALF }; // assume K and V are same shape const short ne22 = ne12; @@ -2571,21 +2575,21 @@ kernel void kernel_flash_attn_ext_vec_f16( const short iv3 = iq3 / rv3; // load the queries from shared memory into local memory - half4 mq[Q][D4]; + simdgroup_half8x8 mq[Q][D8]; for (short j = 0; j < Q; ++j) { - for (short i = tiisg; i < D4; i += NW) { - //simdgroup_load(mq[j][i], sq + 8*j*T + i*8, T); - mq[j][i] = sq4[j*T4 + i]; + for (short i = 0; i < D8; ++i) { + simdgroup_load(mq[j][i], sq + 8*j*T + i*8, T); } } // pointer to the mask - device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31); + //device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31); + device const half * mp = (device const half *) (mask + iq1*nb31); // prepare diagonal scale matrix - //simdgroup_half8x8 mscale(scale); - half mscale(scale); + simdgroup_half8x8 mscale(scale); + //half mscale(scale); // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns @@ -2595,55 +2599,83 @@ kernel void kernel_flash_attn_ext_vec_f16( break; } + // Q*K^T + //{ + // for (short cc = 0; cc < C/4; ++cc) { + // half4 mqk[Q]; + // for (short j = 0; j < Q; ++j) { + // mqk[j] = 0.0h; + // } + + // device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13)); + + // for (short i = tiisg; i < D4; i += NW) { + // half4x4 mk; + // mk[0] = pk4[i + 0*(nb11/8)]; + // mk[1] = pk4[i + 1*(nb11/8)]; + // mk[2] = pk4[i + 2*(nb11/8)]; + // mk[3] = pk4[i + 3*(nb11/8)]; + + // for (short j = 0; j < Q; ++j) { + // mqk[j] += mq[j][i] * mk; + // } + // } + + // // reduce the results from the threads in the simdgroup + // simdgroup_barrier(mem_flags::mem_none); + + // for (short i = NW/2; i > 0; i /= 2) { + // if (tiisg < i) { + // for (short j = 0; j < Q; ++j) { + // mqk[j] += simd_shuffle_down(mqk[j], i); + // } + // } + + // simdgroup_barrier(mem_flags::mem_none); + // } + + // // mqk = mqk*scale + mask + // if (tiisg == 0) { + // for (short j = 0; j < Q; ++j) { + // half4 mm = mp4[(j*(nb31/sizeof(half)) + ic)/4 + cc]; + // mqk[j] = mqk[j]*mscale + mm; + + // ss4[j*T4 + cc] = mqk[j]; + // } + // } + // } + //} + // Q*K^T { - for (short cc = 0; cc < C/4; ++cc) { - half4 mqk[Q]; + for (short cc = 0; cc < C/8; ++cc) { + simdgroup_half8x8 mqk[Q]; for (short j = 0; j < Q; ++j) { - mqk[j] = 0.0h; + mqk[j] = make_filled_simdgroup_matrix(0.h); } - device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13)); + device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); - for (short i = tiisg; i < D4; i += NW) { - half4x4 mk; - mk[0] = pk4[i + 0*(nb11/8)]; - mk[1] = pk4[i + 1*(nb11/8)]; - mk[2] = pk4[i + 2*(nb11/8)]; - mk[3] = pk4[i + 3*(nb11/8)]; + for (short i = 0; i < D8; ++i) { + simdgroup_half8x8 mk; + simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose for (short j = 0; j < Q; ++j) { - mqk[j] += mq[j][i] * mk; - } - } - - // reduce the results from the threads in the simdgroup - simdgroup_barrier(mem_flags::mem_none); - - for (short i = NW/2; i > 0; i /= 2) { - if (tiisg < i) { - for (short j = 0; j < Q; ++j) { - mqk[j] += simd_shuffle_down(mqk[j], i); - } + simdgroup_multiply_accumulate(mqk[j], mq[j][i], mk, mqk[j]); } - - simdgroup_barrier(mem_flags::mem_none); } // mqk = mqk*scale + mask - if (tiisg == 0) { - for (short j = 0; j < Q; ++j) { - half4 mm = mp4[(j*(nb31/sizeof(half)) + ic)/4 + cc]; - mqk[j] = mqk[j]*mscale + mm; + for (short j = 0; j < Q; ++j) { + simdgroup_half8x8 mm; + simdgroup_load(mm, mp + 8*j*(nb31/sizeof(half)) + ic + 8*cc, nb31/sizeof(half), 0, false); + simdgroup_multiply_accumulate(mqk[j], mqk[j], mscale, mm); - ss4[j*T4 + cc] = mqk[j]; - } + simdgroup_store(mqk[j], ss + 8*j*T + 8*cc, T, 0, false); } } } - simdgroup_barrier(mem_flags::mem_threadgroup); - // online softmax half ms[Q]; @@ -2655,8 +2687,8 @@ kernel void kernel_flash_attn_ext_vec_f16( M[j] = simd_max(max(M[j], s)); - ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]); - const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); + ms[j] = exp(m - M[j]); + const half vs = exp(s - M[j]); S[j] = S[j]*ms[j] + simd_sum(vs); @@ -2706,75 +2738,59 @@ kernel void kernel_flash_attn_ext_vec_f16( } } - // reduce the warps sequentially - for (short sg = 1; sg < nsg; ++sg) { - half S = { 0.0h }; - half M = { -INFINITY }; - - threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup_barrier(mem_flags::mem_threadgroup); - // each simdgroup stores its output to shared memory, reusing sq - if (sgitg == sg) { - for (short j = 0; j < Q; ++j) { - for (short i = tiisg; i < D4; i += NW) { - //simdgroup_store(lo[j][i], sq + 8*j*T + i*8, T, 0, false); - sq4[j*T4 + i] = lo[j][i]; - } - } + // store results to shared memory + for (short j = 0; j < Q; ++j) { + for (short i = tiisg; i < D4; i += NW) { + sr4[i] = lo[j][i]; } + } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // the first simdgroup accumulates the results from the other simdgroups - if (sgitg == 0) { - for (short j = 0; j < Q; ++j) { - const half S0 = ss[j*T + 0]; - const half S1 = ss[j*T + sg*SH + 0]; + // parallel reduce + for (short r = nsg/2; r > 0; r >>= 1) { + if (sgitg < r) { + if (tiisg == 0) { + for (short j = 0; j < Q; ++j) { + const half S0 = ss[j*T + 0]; + const half S1 = ss[j*T + r*SH + 0]; - const half M0 = ss[j*T + 1]; - const half M1 = ss[j*T + sg*SH + 1]; + const half M0 = ss[j*T + 1]; + const half M1 = ss[j*T + r*SH + 1]; - M = max(M0, M1); + const half M = max(M0, M1); - const half ms0 = M0 == -INFINITY ? 0.0h : exp(M0 - M); - const half ms1 = M1 == -INFINITY ? 0.0h : exp(M1 - M); + const half ms0 = exp(M0 - M); + const half ms1 = exp(M1 - M); - S = S0*ms0 + S1*ms1; + const half S = S0*ms0 + S1*ms1; - if (tiisg == 0) { ss[j*T + 0] = S; ss[j*T + 1] = M; - ss[j*T + C + j ] = ms0; - ss[j*T + C + j + sg*SH] = ms1; + ss[j*T + C + j ] = ms0; + ss[j*T + C + j + r*SH] = ms1; } } + } - // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (sgitg < r) { for (short j = 0; j < Q; ++j) { - for (short i = tiisg; i < D4; i += NW) { - half4 t = sq4[j*T4 + i]; - half ms0 = ss[j*T + C + j]; - half ms1 = ss[j*T + C + j + sg*SH]; + const half ms0 = ss[j*T + C + j]; + const half ms1 = ss[j*T + C + j + r*SH]; - lo[j][i] = lo[j][i]*ms0 + t*ms1; + // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 + for (short i = tiisg; i < D4; i += NW) { + sr4[i] = sr4[i]*ms0 + sr4[i + r*D4]*ms1; } } } - } - // store result to shared memory (reuse sq) - if (sgitg == 0) { - for (short j = 0; j < Q; ++j) { - for (short i = tiisg; i < D4; i += NW) { - //simdgroup_store(lo[j][i], sq + 8*j*T + i*8, T, 0, false); - sq4[j*T4 + i] = lo[j][i]; - } - } + threadgroup_barrier(mem_flags::mem_threadgroup); } - threadgroup_barrier(mem_flags::mem_threadgroup); - device float4 * dst4 = (device float4 *) dst; // final rescale with 1/S and store to global memory @@ -2783,7 +2799,7 @@ kernel void kernel_flash_attn_ext_vec_f16( const half S = ss[j*T + 0]; for (short i = tiisg; i < D4; i += NW) { - dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S; + dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sr4[i]/S; } } } From c4dff1ec910a2057a3c17b170028cb9c1d418865 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 5 Apr 2024 16:24:10 +0300 Subject: [PATCH 6/8] metal : reduce registers --- ggml-metal.m | 14 +------------- ggml-metal.metal | 18 +++++++++--------- 2 files changed, 10 insertions(+), 22 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 6106bc7e360b2..07535828dc70c 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -179,10 +179,6 @@ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H80, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H112, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, GGML_METAL_KERNEL_TYPE_CPY_F32_F16, @@ -625,10 +621,6 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64, flash_attn_ext_vec_f16_h64, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H80, flash_attn_ext_vec_f16_h80, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96, flash_attn_ext_vec_f16_h96, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H112, flash_attn_ext_vec_f16_h112, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); @@ -2521,7 +2513,7 @@ static enum ggml_status ggml_metal_graph_compute( id pipeline = nil; - if (ne01 > 1) { + if (ne01 > 1 || (ne00%128 != 0)) { switch (ne00) { case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; @@ -2538,10 +2530,6 @@ static enum ggml_status ggml_metal_graph_compute( } } else { switch (ne00) { - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H112].pipeline; break; case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break; case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break; default: diff --git a/ggml-metal.metal b/ggml-metal.metal index 404bd16e08d54..7709865c94d94 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2516,7 +2516,7 @@ kernel void kernel_flash_attn_ext_vec_f16( threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + Q*T); // scratch buffer for the results // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) - half4 lo[Q][D4]; + half4 lo[Q][D4/NW]; // load heads from Q to shared memory for (short j = sgitg; j < Q; j += nsg) { @@ -2534,7 +2534,7 @@ kernel void kernel_flash_attn_ext_vec_f16( // zero out lo for (short j = 0; j < Q; ++j) { for (short i = tiisg; i < D4; i += NW) { - lo[j][i] = 0.0h; + lo[j][i/NW] = 0.0h; } } @@ -2711,7 +2711,7 @@ kernel void kernel_flash_attn_ext_vec_f16( for (short i = tiisg; i < D4; i += NW) { //simdgroup_multiply(lo[j][i], mm, lo[j][i]); - lo[j][i] = lo[j][i]*mm; + lo[j][i/NW] = lo[j][i/NW]*mm; } } @@ -2722,7 +2722,7 @@ kernel void kernel_flash_attn_ext_vec_f16( for (short i = tiisg; i < D4; i += NW) { for (short j = 0; j < Q; ++j) { - lo[j][i] += pv4[i]*ss[j*T + cc]; + lo[j][i/NW] += pv4[i]*ss[j*T + cc]; } } } @@ -2743,7 +2743,7 @@ kernel void kernel_flash_attn_ext_vec_f16( // store results to shared memory for (short j = 0; j < Q; ++j) { for (short i = tiisg; i < D4; i += NW) { - sr4[i] = lo[j][i]; + sr4[i] = lo[j][i/NW]; } } @@ -2805,10 +2805,10 @@ kernel void kernel_flash_attn_ext_vec_f16( } } -template [[host_name("kernel_flash_attn_ext_vec_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<64, 1, 32>; -template [[host_name("kernel_flash_attn_ext_vec_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<80, 1, 32>; -template [[host_name("kernel_flash_attn_ext_vec_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<96, 1, 32>; -template [[host_name("kernel_flash_attn_ext_vec_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<112, 1, 32>; +template [[host_name("kernel_flash_attn_ext_vec_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 2, 32>; +template [[host_name("kernel_flash_attn_ext_vec_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 3, 32>; +template [[host_name("kernel_flash_attn_ext_vec_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 4, 32>; +template [[host_name("kernel_flash_attn_ext_vec_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 5, 32>; template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 1, 32>; template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256, 1, 32>; From f8d709f01a91752f7c04cd8fd0906e0e69d78c7c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 5 Apr 2024 16:29:29 +0300 Subject: [PATCH 7/8] metal : simplify --- ggml-metal.m | 11 ++-- ggml-metal.metal | 130 +++++++++++++++++------------------------------ 2 files changed, 54 insertions(+), 87 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 07535828dc70c..204ccea1b2795 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2573,7 +2573,7 @@ static enum ggml_status ggml_metal_graph_compute( [encoder setBytes:&scale length:sizeof( float) atIndex:27]; // half8x8 kernel - if (ne01 > 1) { + if (ne01 > 1 || (ne00%128 != 0)) { const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! @@ -2603,8 +2603,13 @@ static enum ggml_status ggml_metal_graph_compute( // simdgroups per threadgroup (a.k.a. warps) // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) - //const int64_t nsg = MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)); - const int64_t nsg = 8; + const int64_t nsgt = MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)); + + int64_t nsg = 1; + while (nsg <= nsgt) { + nsg *= 2; + } + nsg /= 2; // require power of 2 //{ diff --git a/ggml-metal.metal b/ggml-metal.metal index 7709865c94d94..63a5a175d446e 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2575,21 +2575,20 @@ kernel void kernel_flash_attn_ext_vec_f16( const short iv3 = iq3 / rv3; // load the queries from shared memory into local memory - simdgroup_half8x8 mq[Q][D8]; + half4 mq[Q][D4]; for (short j = 0; j < Q; ++j) { - for (short i = 0; i < D8; ++i) { - simdgroup_load(mq[j][i], sq + 8*j*T + i*8, T); + for (short ii = 0; ii < D4; ii += NW) { + short i = ii + tiisg; + mq[j][i] = sq4[j*T4 + i]; } } // pointer to the mask - //device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31); - device const half * mp = (device const half *) (mask + iq1*nb31); + device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31); // prepare diagonal scale matrix - simdgroup_half8x8 mscale(scale); - //half mscale(scale); + half mscale(scale); // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns @@ -2599,79 +2598,45 @@ kernel void kernel_flash_attn_ext_vec_f16( break; } - // Q*K^T - //{ - // for (short cc = 0; cc < C/4; ++cc) { - // half4 mqk[Q]; - // for (short j = 0; j < Q; ++j) { - // mqk[j] = 0.0h; - // } - - // device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13)); - - // for (short i = tiisg; i < D4; i += NW) { - // half4x4 mk; - // mk[0] = pk4[i + 0*(nb11/8)]; - // mk[1] = pk4[i + 1*(nb11/8)]; - // mk[2] = pk4[i + 2*(nb11/8)]; - // mk[3] = pk4[i + 3*(nb11/8)]; - - // for (short j = 0; j < Q; ++j) { - // mqk[j] += mq[j][i] * mk; - // } - // } - - // // reduce the results from the threads in the simdgroup - // simdgroup_barrier(mem_flags::mem_none); - - // for (short i = NW/2; i > 0; i /= 2) { - // if (tiisg < i) { - // for (short j = 0; j < Q; ++j) { - // mqk[j] += simd_shuffle_down(mqk[j], i); - // } - // } - - // simdgroup_barrier(mem_flags::mem_none); - // } - - // // mqk = mqk*scale + mask - // if (tiisg == 0) { - // for (short j = 0; j < Q; ++j) { - // half4 mm = mp4[(j*(nb31/sizeof(half)) + ic)/4 + cc]; - // mqk[j] = mqk[j]*mscale + mm; - - // ss4[j*T4 + cc] = mqk[j]; - // } - // } - // } - //} - // Q*K^T { - for (short cc = 0; cc < C/8; ++cc) { - simdgroup_half8x8 mqk[Q]; - for (short j = 0; j < Q; ++j) { - mqk[j] = make_filled_simdgroup_matrix(0.h); - } + for (short cc = 0; cc < C/4; ++cc) { + half4 mqk[Q] = { [0 ... Q-1] = 0.0h }; - device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); + device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13)); - for (short i = 0; i < D8; ++i) { - simdgroup_half8x8 mk; - simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose +#pragma unroll + for (short ii = 0; ii < D4; ii += NW) { + const short i = ii + tiisg; + + half4x4 mk; + mk[0] = pk4[i + 0*(nb11/8)]; + mk[1] = pk4[i + 1*(nb11/8)]; + mk[2] = pk4[i + 2*(nb11/8)]; + mk[3] = pk4[i + 3*(nb11/8)]; for (short j = 0; j < Q; ++j) { - simdgroup_multiply_accumulate(mqk[j], mq[j][i], mk, mqk[j]); + mqk[j] += mq[j][i] * mk; } } - // mqk = mqk*scale + mask + // reduce the results from the threads in the simdgroup for (short j = 0; j < Q; ++j) { - simdgroup_half8x8 mm; - simdgroup_load(mm, mp + 8*j*(nb31/sizeof(half)) + ic + 8*cc, nb31/sizeof(half), 0, false); - simdgroup_multiply_accumulate(mqk[j], mqk[j], mscale, mm); + mqk[j] += simd_shuffle_down(mqk[j], 16); + mqk[j] += simd_shuffle_down(mqk[j], 8); + mqk[j] += simd_shuffle_down(mqk[j], 4); + mqk[j] += simd_shuffle_down(mqk[j], 2); + mqk[j] += simd_shuffle_down(mqk[j], 1); + } - simdgroup_store(mqk[j], ss + 8*j*T + 8*cc, T, 0, false); + // mqk = mqk*scale + mask + if (tiisg == 0) { + for (short j = 0; j < Q; ++j) { + half4 mm = mp4[(j*(nb31/sizeof(half)) + ic)/4 + cc]; + mqk[j] = mqk[j]*mscale + mm; + + ss4[j*T4 + cc] = mqk[j]; + } } } } @@ -2701,26 +2666,26 @@ kernel void kernel_flash_attn_ext_vec_f16( ss[tiisg*T + C + tiisg] = ms[tiisg]; } - //threadgroup_barrier(mem_flags::mem_threadgroup); - // O = diag(ms)*O for (short j = 0; j < Q; ++j) { - //simdgroup_half8x8 mm; - //simdgroup_load(mm, ss + 8*j*T + C + 8*j, T, 0, false); half mm(ss[j*T + C + j]); - for (short i = tiisg; i < D4; i += NW) { - //simdgroup_multiply(lo[j][i], mm, lo[j][i]); +#pragma unroll + for (short ii = 0; ii < D4; ii += NW) { + const short i = ii + tiisg; lo[j][i/NW] = lo[j][i/NW]*mm; } } // O = O + (Q*K^T)*V { +#pragma unroll for (short cc = 0; cc < C; ++cc) { device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + cc)*nb21 + iv2*nb22 + iv3*nb23)); - for (short i = tiisg; i < D4; i += NW) { +#pragma unroll + for (short ii = 0; ii < D4; ii += NW) { + short i = ii + tiisg; for (short j = 0; j < Q; ++j) { lo[j][i/NW] += pv4[i]*ss[j*T + cc]; } @@ -2738,15 +2703,16 @@ kernel void kernel_flash_attn_ext_vec_f16( } } - threadgroup_barrier(mem_flags::mem_threadgroup); - // store results to shared memory for (short j = 0; j < Q; ++j) { - for (short i = tiisg; i < D4; i += NW) { - sr4[i] = lo[j][i/NW]; + for (short ii = 0; ii < D4; ii += NW) { + short i = ii + tiisg; + sr4[i] = lo[j][ii/NW]; } } + threadgroup_barrier(mem_flags::mem_threadgroup); + // parallel reduce for (short r = nsg/2; r > 0; r >>= 1) { if (sgitg < r) { @@ -2805,10 +2771,6 @@ kernel void kernel_flash_attn_ext_vec_f16( } } -template [[host_name("kernel_flash_attn_ext_vec_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 2, 32>; -template [[host_name("kernel_flash_attn_ext_vec_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 3, 32>; -template [[host_name("kernel_flash_attn_ext_vec_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 4, 32>; -template [[host_name("kernel_flash_attn_ext_vec_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 5, 32>; template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 1, 32>; template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256, 1, 32>; From b57af0c9dd7e92c2767b165a035fb434c5465ca3 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 5 Apr 2024 17:47:01 +0300 Subject: [PATCH 8/8] metal : initial FA vec kernel --- ggml-metal.m | 2 +- ggml-metal.metal | 197 ++++++++++++++++++----------------------------- 2 files changed, 78 insertions(+), 121 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 204ccea1b2795..2680cf21cc42c 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2603,7 +2603,7 @@ static enum ggml_status ggml_metal_graph_compute( // simdgroups per threadgroup (a.k.a. warps) // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) - const int64_t nsgt = MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)); + const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)); int64_t nsg = 1; while (nsg <= nsgt) { diff --git a/ggml-metal.metal b/ggml-metal.metal index 63a5a175d446e..ca0f57a965f39 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2459,7 +2459,7 @@ template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f #define HALF_MAX_HALF half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction. -template // head size, queries per threadgroup, cache items per threadgroup +template // head size, queries per threadgroup, cache items per threadgroup kernel void kernel_flash_attn_ext_vec_f16( device const char * q, device const char * k, @@ -2499,12 +2499,12 @@ kernel void kernel_flash_attn_ext_vec_f16( const short iq3 = tgpig[2]; const short iq2 = tgpig[1]; - const short iq1 = tgpig[0]*Q; + const short iq1 = tgpig[0]; const short D4 = D/4; const short D8 = D/8; const short NW = N_SIMDWIDTH; - const short SH = (C + Q); // shared memory per simdgroup in (half) + const short SH = (C + 1); // shared memory per simdgroup in (half) const short T = D + nsg*SH; // shared memory size per query in (half) const short T4 = T/4; // shared memory size per query in (half4) @@ -2513,43 +2513,37 @@ kernel void kernel_flash_attn_ext_vec_f16( threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*SH + 1*D); // same as above but in half4 - threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + Q*T); // scratch buffer for the results + threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + 1*T); // scratch buffer for the results // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) - half4 lo[Q][D4/NW]; + half4 lo[D4/NW]; // load heads from Q to shared memory - for (short j = sgitg; j < Q; j += nsg) { - device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); + device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)); - for (short i = tiisg; i < D4; i += NW) { - if (iq1 + j < ne01) { - sq4[j*T4 + i] = (half4) q4[i]; - } else { - sq4[j*T4 + i] = 0.0h; - } + for (short i = tiisg; i < D4; i += NW) { + if (iq1 < ne01) { + sq4[i] = (half4) q4[i]; + } else { + sq4[i] = 0.0h; } } // zero out lo - for (short j = 0; j < Q; ++j) { - for (short i = tiisg; i < D4; i += NW) { - lo[j][i/NW] = 0.0h; - } + for (short i = tiisg; i < D4; i += NW) { + lo[i/NW] = 0.0h; } // zero out shared memory SH - for (short j = 0; j < Q; ++j) { - for (short i = tiisg; i < SH/4; i += NW) { - ss4[j*T4 + i] = 0.0h; - } + for (short i = tiisg; i < SH/4; i += NW) { + ss4[i] = 0.0h; } threadgroup_barrier(mem_flags::mem_threadgroup); { - half S[Q] = { [0 ... Q-1] = 0.0h }; - half M[Q] = { [0 ... Q-1] = -HALF_MAX_HALF }; + half S = { 0.0h }; + half M = { -HALF_MAX_HALF }; // assume K and V are same shape const short ne22 = ne12; @@ -2575,21 +2569,16 @@ kernel void kernel_flash_attn_ext_vec_f16( const short iv3 = iq3 / rv3; // load the queries from shared memory into local memory - half4 mq[Q][D4]; + half4 mq[D4]; - for (short j = 0; j < Q; ++j) { - for (short ii = 0; ii < D4; ii += NW) { - short i = ii + tiisg; - mq[j][i] = sq4[j*T4 + i]; - } + for (short ii = 0; ii < D4; ii += NW) { + short i = ii + tiisg; + mq[i] = sq4[i]; } // pointer to the mask device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31); - // prepare diagonal scale matrix - half mscale(scale); - // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) { @@ -2600,8 +2589,9 @@ kernel void kernel_flash_attn_ext_vec_f16( // Q*K^T { +#pragma unroll for (short cc = 0; cc < C/4; ++cc) { - half4 mqk[Q] = { [0 ... Q-1] = 0.0h }; + half4 mqk = { 0.0h }; device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13)); @@ -2615,100 +2605,81 @@ kernel void kernel_flash_attn_ext_vec_f16( mk[2] = pk4[i + 2*(nb11/8)]; mk[3] = pk4[i + 3*(nb11/8)]; - for (short j = 0; j < Q; ++j) { - mqk[j] += mq[j][i] * mk; - } + mqk += mq[i] * mk; } // reduce the results from the threads in the simdgroup - for (short j = 0; j < Q; ++j) { - mqk[j] += simd_shuffle_down(mqk[j], 16); - mqk[j] += simd_shuffle_down(mqk[j], 8); - mqk[j] += simd_shuffle_down(mqk[j], 4); - mqk[j] += simd_shuffle_down(mqk[j], 2); - mqk[j] += simd_shuffle_down(mqk[j], 1); - } + mqk += simd_shuffle_down(mqk, 16); + mqk += simd_shuffle_down(mqk, 8); + mqk += simd_shuffle_down(mqk, 4); + mqk += simd_shuffle_down(mqk, 2); + mqk += simd_shuffle_down(mqk, 1); // mqk = mqk*scale + mask if (tiisg == 0) { - for (short j = 0; j < Q; ++j) { - half4 mm = mp4[(j*(nb31/sizeof(half)) + ic)/4 + cc]; - mqk[j] = mqk[j]*mscale + mm; + half4 mm = mp4[ic/4 + cc]; + mqk = mqk*scale + mm; - ss4[j*T4 + cc] = mqk[j]; - } + ss4[cc] = mqk; } } } // online softmax - half ms[Q]; - - for (short j = 0; j < Q; ++j) { + { const short p = tiisg; - const half m = M[j]; - const half s = ss[j*T + p]; + const half m = M; + const half s = ss[p]; - M[j] = simd_max(max(M[j], s)); + M = simd_max(max(M, s)); - ms[j] = exp(m - M[j]); - const half vs = exp(s - M[j]); + const half ms = exp(m - M); + const half vs = exp(s - M); - S[j] = S[j]*ms[j] + simd_sum(vs); + S = S*ms + simd_sum(vs); // the P matrix from the paper (Q rows, C columns) - ss[j*T + p] = vs; - } - - // create a QxQ diagonal matrix for rescaling the output - if (tiisg < Q) { - ss[tiisg*T + C + tiisg] = ms[tiisg]; - } - - // O = diag(ms)*O - for (short j = 0; j < Q; ++j) { - half mm(ss[j*T + C + j]); + ss[p] = vs; + // O = diag(ms)*O #pragma unroll for (short ii = 0; ii < D4; ii += NW) { const short i = ii + tiisg; - lo[j][i/NW] = lo[j][i/NW]*mm; + lo[i/NW] *= ms; } } // O = O + (Q*K^T)*V { #pragma unroll - for (short cc = 0; cc < C; ++cc) { - device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + cc)*nb21 + iv2*nb22 + iv3*nb23)); + for (short cc = 0; cc < C/4; ++cc) { + device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + 4*cc)*nb21 + iv2*nb22 + iv3*nb23)); #pragma unroll for (short ii = 0; ii < D4; ii += NW) { - short i = ii + tiisg; - for (short j = 0; j < Q; ++j) { - lo[j][i/NW] += pv4[i]*ss[j*T + cc]; - } + const short i = ii + tiisg; + lo[i/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0]; + lo[i/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1]; + lo[i/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2]; + lo[i/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3]; } } } + } // these are needed for reducing the results from the simdgroups (reuse the ss buffer) - for (short j = 0; j < Q; ++j) { - if (tiisg == 0) { - ss[j*T + 0] = S[j]; - ss[j*T + 1] = M[j]; - } + if (tiisg == 0) { + ss[0] = S; + ss[1] = M; } } // store results to shared memory - for (short j = 0; j < Q; ++j) { - for (short ii = 0; ii < D4; ii += NW) { - short i = ii + tiisg; - sr4[i] = lo[j][ii/NW]; - } + for (short ii = 0; ii < D4; ii += NW) { + short i = ii + tiisg; + sr4[i] = lo[ii/NW]; } threadgroup_barrier(mem_flags::mem_threadgroup); @@ -2716,41 +2687,28 @@ kernel void kernel_flash_attn_ext_vec_f16( // parallel reduce for (short r = nsg/2; r > 0; r >>= 1) { if (sgitg < r) { - if (tiisg == 0) { - for (short j = 0; j < Q; ++j) { - const half S0 = ss[j*T + 0]; - const half S1 = ss[j*T + r*SH + 0]; + const half S0 = ss[ 0]; + const half S1 = ss[r*SH + 0]; - const half M0 = ss[j*T + 1]; - const half M1 = ss[j*T + r*SH + 1]; + const half M0 = ss[ 1]; + const half M1 = ss[r*SH + 1]; - const half M = max(M0, M1); + const half M = max(M0, M1); - const half ms0 = exp(M0 - M); - const half ms1 = exp(M1 - M); + const half ms0 = exp(M0 - M); + const half ms1 = exp(M1 - M); - const half S = S0*ms0 + S1*ms1; - - ss[j*T + 0] = S; - ss[j*T + 1] = M; + const half S = S0*ms0 + S1*ms1; - ss[j*T + C + j ] = ms0; - ss[j*T + C + j + r*SH] = ms1; - } + if (tiisg == 0) { + ss[0] = S; + ss[1] = M; } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - if (sgitg < r) { - for (short j = 0; j < Q; ++j) { - const half ms0 = ss[j*T + C + j]; - const half ms1 = ss[j*T + C + j + r*SH]; - - // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 - for (short i = tiisg; i < D4; i += NW) { - sr4[i] = sr4[i]*ms0 + sr4[i + r*D4]*ms1; - } + // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 + for (short ii = 0; ii < D4; ii += NW) { + short i = ii + tiisg; + sr4[i] = sr4[i]*ms0 + sr4[i + r*D4]*ms1; } } @@ -2761,18 +2719,17 @@ kernel void kernel_flash_attn_ext_vec_f16( // final rescale with 1/S and store to global memory if (sgitg == 0) { - for (short j = 0; j < Q && iq1 + j < ne01; ++j) { - const half S = ss[j*T + 0]; + const half S = ss[0]; - for (short i = tiisg; i < D4; i += NW) { - dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sr4[i]/S; - } + for (short ii = 0; ii < D4; ii += NW) { + short i = ii + tiisg; + dst4[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D4 + i] = (float4) sr4[i]/S; } } } -template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 1, 32>; -template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256, 1, 32>; +template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 32>; +template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256, 32>; kernel void kernel_cpy_f16_f16( device const half * src0,