diff --git a/cpp/kernels/fmha_v2/fmha_test.py b/cpp/kernels/fmha_v2/fmha_test.py index 21486b00ea6..bd743d829a0 100644 --- a/cpp/kernels/fmha_v2/fmha_test.py +++ b/cpp/kernels/fmha_v2/fmha_test.py @@ -155,50 +155,41 @@ def test_trtllm_sage_attention_fmha(d, s): @pytest.mark.parametrize('dtype', ["-bf16", "-e4m3", "-e4m3 -bf16-output"], ids=["bf16", "e4m3", "e4m3-bf16"]) @pytest.mark.parametrize('s', [1024, 4096], ids=["seqlen-1024", "seqlen-4096"]) -@pytest.mark.parametrize( - 'input_layout', ["", "-paged-kv", "-contiguous-q-kv", "-separate-q-k-v"], - ids=["packed-qkv", "paged-kv", "q-contiguous-kv", "separate-q-k-v"]) -def test_trtllm_context_mla_attention_fmha(dtype, s, input_layout): +def test_trtllm_context_mla_attention_fmha(dtype, s): + sm_version = getSMVersion() + if sm_version < 90: + pytest.skip("MLA kernels are only tested on sm90 and above currently.") + # use higher error tolerance for bf16 and s = 4096. epsilon = '' if dtype == "-bf16" and s == 4096: epsilon += ' -epsilon 0.03' - sm_version = getSMVersion() - if dtype in ["-e4m3", "-e4m3 -bf16-output"] and sm_version != 89: - pytest.skip("FP8 MLAs only supported on sm89 currently.") + if dtype in ["-e4m3", "-e4m3 -bf16-output"] and sm_version != 120: + pytest.skip("FP8 MLAs are only supported on sm120 currently.") - # Context phase kernels. + # Context phase kernels, always use separate-q-k-v layout. subprocess.run( - f"bin/fmha.exe -v 0 -runs 1 -min-s 1024 -s {s} -b 8 -h 8 -d 192 -dv 128 {dtype} \ - -force-non-warp-specialization -causal-mask {epsilon}", + f"bin/fmha.exe -v 0 -runs 1 -min-s 1024 -s {s} -b 8 -h 8 -d 192 -dv 128 {dtype} " + f"-causal-mask {epsilon} -separate-q-k-v", shell=True, check=True) - if sm_version == 90: - # Now only hopper-style supports separate-q-k-v + # For chunked prefill, we need to enable -save-softmax (dtype: bf16, layout: separate-q-k-v). + # Currently fp8 kernel doesn't support saving softmax. + if dtype == "-bf16": + # padding mask subprocess.run( - f"bin/fmha.exe -v 0 -runs 1 -min-s 1024 -s {s} -b 8 -h 8 -d 192 -dv 128 {dtype} \ - -causal-mask {epsilon} {input_layout}", + f"bin/fmha.exe -v 0 -runs 1 -min-s 1024 -s {s} -b 8 -h 8 -d 192 -dv 128 {dtype} " + f"{epsilon} -separate-q-k-v -save-softmax", + shell=True, + check=True) + # causal mask + subprocess.run( + f"bin/fmha.exe -v 0 -runs 1 -min-s 1024 -s {s} -b 8 -h 8 -d 192 -dv 128 {dtype} " + f"-causal-mask {epsilon} -separate-q-k-v -save-softmax", shell=True, check=True) - - # For chunked prefill, we need to enable -save-softmax (dtype: bf16, sm90, layout: paged-kv or separate-q-k-v). - if dtype == "-bf16" and input_layout in [ - "-paged-kv", "-separate-q-k-v" - ]: - # padding mask - subprocess.run( - f"bin/fmha.exe -v 0 -runs 1 -min-s 1024 -s {s} -b 8 -h 8 -d 192 -dv 128 {dtype} \ - {epsilon} {input_layout} -save-softmax", - shell=True, - check=True) - # causal mask - subprocess.run( - f"bin/fmha.exe -v 0 -runs 1 -min-s 1024 -s {s} -b 8 -h 8 -d 192 -dv 128 {dtype} \ - -causal-mask {epsilon} {input_layout} -save-softmax", - shell=True, - check=True) @pytest.mark.parametrize('dtype', ["-bf16", "-e4m3", "-e4m3 -bf16-output"], @@ -210,14 +201,17 @@ def test_trtllm_context_mla_attention_fmha(dtype, s, input_layout): "num-grouped-heads-64", "num-grouped-heads-128" ]) def test_trtllm_gen_mla_attention_fmha(dtype, s, num_grouped_heads): + sm_version = getSMVersion() + if sm_version < 90: + pytest.skip("MLA kernels are only tested on sm90 and above currently.") + # use higher error tolerance for bf16 and s = 4096. epsilon = '' if dtype == "-bf16" and s == 4096: epsilon += ' -epsilon 0.03' - sm_version = getSMVersion() - if dtype in ["-e4m3", "-e4m3 -bf16-output"] and sm_version != 89: - pytest.skip("FP8 MLAs only supported on sm89 currently.") + if dtype in ["-e4m3", "-e4m3 -bf16-output"] and sm_version != 120: + pytest.skip("FP8 MLAs are only supported on sm120 currently.") # Generation phase kernels. subprocess.run( diff --git a/cpp/kernels/fmha_v2/setup.py b/cpp/kernels/fmha_v2/setup.py index 8434d4225df..220b7898a98 100644 --- a/cpp/kernels/fmha_v2/setup.py +++ b/cpp/kernels/fmha_v2/setup.py @@ -2075,6 +2075,8 @@ def get_kernel_code(kspec, kname, lname): kernel_traits += '_paged_kv_cache' elif kspec.input_layout == InputLayout.CONTIGUOUS_Q_KV: kernel_traits += '_contiguous_kv_cache' + elif kspec.input_layout == InputLayout.SEPARATE_Q_K_V: + kernel_traits += '_q_k_v' flags = 0 if kspec.ldgsts_q: @@ -3183,7 +3185,7 @@ def get_cubin_header(kernel_traits, specs_names): attention_mask_type_value = attention_mask_type.value # Attention input layout: - # packed_qkv (0), contiguous_q_kv (1), q_paged_kv (2). + # packed_qkv (0), contiguous_q_kv (1), q_paged_kv (2), separate_q_k_v (3). attention_input_layout = InputLayout.PACKED_QKV if '_q_kv' in kname: attention_input_layout = InputLayout.CONTIGUOUS_Q_KV @@ -3652,12 +3654,9 @@ def enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='fp16'): if alibi and enable_attn_logit_softcapping: continue # for normal attention, we only need contiguous kv as input layout when returning softmax. - skip_combination = return_softmax and (input_layout - != InputLayout.CONTIGUOUS_Q_KV) - # for context mla, we need paged kv or separate qkv as input layout when returning softmax. - skip_mla_combination = return_softmax and ( - input_layout != InputLayout.Q_PAGED_KV - and input_layout != InputLayout.SEPARATE_Q_K_V) + skip_combination = return_softmax and input_layout != InputLayout.CONTIGUOUS_Q_KV + # for context mla, we need separate qkv as input layout when returning softmax. + skip_mla_combination = return_softmax and input_layout != InputLayout.SEPARATE_Q_K_V if not skip_combination: # only specify specs.append( @@ -4702,9 +4701,16 @@ def enumerate_hmma_paged_kv_flash_kernels(specs, sm=80, dtype='fp16'): def enumerate_hmma_flash_kernels(specs, sm=80, dtype='fp16', head_size_v=0): - for (input_layout, enable_attn_logit_softcapping) in \ - product([InputLayout.PACKED_QKV, InputLayout.CONTIGUOUS_Q_KV, InputLayout.Q_PAGED_KV], \ - [False, True]): + input_layouts = [ + InputLayout.PACKED_QKV, InputLayout.CONTIGUOUS_Q_KV, + InputLayout.Q_PAGED_KV + ] + # Deepseek MLA (context 192/128 separate-q-k-v) + if head_size_v == 128: + input_layouts.append(InputLayout.SEPARATE_Q_K_V) + for (input_layout, + enable_attn_logit_softcapping) in product(input_layouts, + [False, True]): enumerate_hmma_flash_kernels_base(specs, sm, dtype, input_layout, enable_attn_logit_softcapping, head_size_v) @@ -5080,7 +5086,7 @@ def enumerate_qmma_flash_kernels(specs, ] input_layouts = [ InputLayout.PACKED_QKV, InputLayout.CONTIGUOUS_Q_KV, - InputLayout.Q_PAGED_KV + InputLayout.Q_PAGED_KV, InputLayout.SEPARATE_Q_K_V ] for (head_size_params, (q_loop_step, kv_loop_step), tiled), input_layout in \ product(params_q_kv_step, input_layouts): @@ -5094,6 +5100,9 @@ def enumerate_qmma_flash_kernels(specs, # skip if head_size is not in head_sizes if head_sizes is not None and head_size not in head_sizes: continue + # skip if head_size_v is not 128 for separate-q-k-v + if input_layout == InputLayout.SEPARATE_Q_K_V and head_size_v != 128: + continue specs.append( kernel_spec(sm=sm, sm_mma=89, @@ -6354,28 +6363,30 @@ def enumerate_kernels(): and kspec.version == 2 and kspec.cross_mha == False and kspec.flash_attention == False) - # Deepseek MLA (192/128 packed + 576/512 paged) - or (kspec.sm in [80, 86, 89, 90, 100, 120] + # Deepseek MLA (generation 576/512 paged) + or (kspec.sm in [90, 100, 120] and kspec.dtype in ['bf16', 'e4m3_fp32'] - and (((kspec.head_size, kspec.head_size_v) == (192, 128) and kspec.input_layout in [InputLayout.PACKED_QKV, InputLayout.Q_PAGED_KV]) - or ((kspec.head_size, kspec.head_size_v) == (576, 512) and kspec.input_layout == InputLayout.Q_PAGED_KV)) + and kspec.head_size == 576 + and kspec.head_size_v == 512 + and kspec.input_layout == InputLayout.Q_PAGED_KV and kspec.sage_block_sizes is None and kspec.version == 2 and kspec.cross_mha == False and kspec.flash_attention == True and kspec.warp_specialization == False and kspec.tiled == True) - # Deepseek MLA (hopper-style context 192/128) - or (kspec.sm == 90 - and kspec.dtype == 'bf16' + # Deepseek MLA (context 192/128 separate-q-k-v) + or (kspec.sm in [90, 100, 120] + and kspec.dtype in ['bf16', 'e4m3_fp32'] and kspec.head_size == 192 and kspec.head_size_v == 128 + and kspec.input_layout == InputLayout.SEPARATE_Q_K_V and kspec.sage_block_sizes is None and kspec.version == 2 and kspec.cross_mha == False and kspec.flash_attention == True - and kspec.warp_specialization == True - and kspec.alibi == False + and ((kspec.warp_specialization == True and kspec.alibi == False) # sm90 + or (kspec.warp_specialization == False and kspec.tiled == True)) # non-sm90 and kspec.enable_attn_logit_softcapping == False) # SageAttention (warp_spec, head_size in (80, 128), packed QKV, padding mask) or (kspec.sm == 90 diff --git a/cpp/kernels/fmha_v2/src/fmha/gmem_tile_qkv_packed.h b/cpp/kernels/fmha_v2/src/fmha/gmem_tile_qkv_packed.h index 8f54c52b0ba..172131a22fd 100644 --- a/cpp/kernels/fmha_v2/src/fmha/gmem_tile_qkv_packed.h +++ b/cpp/kernels/fmha_v2/src/fmha/gmem_tile_qkv_packed.h @@ -418,7 +418,7 @@ struct Gmem_tile_qkv //////////////////////////////////////////////////////////////////////////////////////////////////// -// We expect the Q layout to be [B, S, H, D] with variable sequence length support. +// We expect the Q/K/V layout to be [B, S, H, D] with variable sequence length support. template < // The instruction traits. typename Traits, @@ -440,7 +440,7 @@ template < int NUM_MATS = 1, // Is sliding window attention used ? bool SLIDING_WINDOW_ATTENTION = false> -struct Gmem_tile_q +struct Gmem_tile_q_k_v { // The size of each LDG. @@ -523,22 +523,38 @@ struct Gmem_tile_q USE_LDGSTS = USE_LDGSTS_ }; - // Ctor (keep qkv_offset for compatibility) + // Ctor + // qkv_offset: 0 for Q, 1 for K, 2 for V template - inline __device__ Gmem_tile_q(bert::Fused_multihead_attention_params_v2 const& params, int qkv_offset, + inline __device__ Gmem_tile_q_k_v(bert::Fused_multihead_attention_params_v2 const& params, int qkv_offset, Block_info const& binfo, int tidx, int cta_row_offset = 0, int cta_col_offset_in_bytes = 0) - : Gmem_tile_q(params, binfo, tidx, cta_row_offset, cta_col_offset_in_bytes) { - } - // Ctor. - template - inline __device__ Gmem_tile_q(bert::Fused_multihead_attention_params_v2 const& params, Block_info const& binfo, - int tidx, int cta_row_offset = 0, int cta_col_offset_in_bytes = 0) - : params_q_stride_in_bytes_(params.q_stride_in_bytes) - , actual_seqlen_(binfo.actual_q_seqlen) - , q_ptr_(reinterpret_cast(params.q_ptr)) - { + int seq_offset = 0; + if (qkv_offset == 0) + { + // Q tensor + params_q_k_v_stride_in_bytes_ = params.q_stride_in_bytes; + q_k_v_ptr_ = reinterpret_cast(params.q_ptr); + actual_seqlen_ = binfo.actual_q_seqlen; + seq_offset = binfo.sum_s; + } + else if (qkv_offset == 1) + { + // K tensor + params_q_k_v_stride_in_bytes_ = params.k_stride_in_bytes; + q_k_v_ptr_ = reinterpret_cast(params.k_ptr); + actual_seqlen_ = binfo.actual_kv_seqlen; + seq_offset = binfo.sum_s_kv; + } + else if (qkv_offset == 2) + { + // V tensor + params_q_k_v_stride_in_bytes_ = params.v_stride_in_bytes; + q_k_v_ptr_ = reinterpret_cast(params.v_ptr); + actual_seqlen_ = binfo.actual_kv_seqlen; + seq_offset = binfo.sum_s_kv; + } // Compute the position in the sequence (within the CTA for the moment). int row = tidx / THREADS_PER_ROW; @@ -550,17 +566,20 @@ struct Gmem_tile_q // Do not load/store if the thread is in the padded area col_in_bytes_ = cta_col_offset_in_bytes + col * BYTES_PER_LDG; - // The row offset in the batched GEMM. For each seq element, we store QKV in that order. - // We won't consider past_q_length when loading from gmem_q. - int64_t row_offset = (int64_t) (row + cta_row_offset) * params_q_stride_in_bytes_; - // Add the block index. (sum_s * h + hidx). - int64_t idx = binfo.bidx; + // The row offset in the batched GEMM, including the sequence offset. + int64_t row_offset = (int64_t) (row + cta_row_offset + seq_offset) * params_q_k_v_stride_in_bytes_; + // Add the head index. + int64_t idx = binfo.bidh; // Assemble the final pointer. - q_ptr_ += row_offset + idx * VALID_BYTES_PER_ROW + col_in_bytes_; + q_k_v_ptr_ += row_offset + idx * VALID_BYTES_PER_ROW + col_in_bytes_; // Take the CTA offset to modify the sequence length. actual_seqlen_ -= cta_row_offset; + + // Set the initial seq_len and qkv_offset in case of reinterating + actual_seqlen_init_ = actual_seqlen_; + q_k_v_ptr_init_ = q_k_v_ptr_; } // Store data to shared memory. @@ -590,7 +609,7 @@ struct Gmem_tile_q #pragma unroll for (int ii = 0; ii < LDGS; ++ii) { - ptrs[ii] = q_ptr_ + (int64_t) ii * ROWS_PER_LDG * params_q_stride_in_bytes_; + ptrs[ii] = q_k_v_ptr_ + (int64_t) ii * ROWS_PER_LDG * params_q_k_v_stride_in_bytes_; } // Trigger LDGSTS or the LDGs. @@ -598,10 +617,24 @@ struct Gmem_tile_q Ldgsts_helper::load(this, smem_tile, ptrs, preds); } + // Move the pointer to the next row location. + inline __device__ void move(int const steps = 1) + { + q_k_v_ptr_ += (int64_t) ROWS * params_q_k_v_stride_in_bytes_ * steps; + actual_seqlen_ -= (int) ROWS * steps; + } + + // Move the pointer to the next row location by the offset (not step). + inline __device__ void move_by_offset(int const offset) + { + q_k_v_ptr_ = q_k_v_ptr_init_ + (int64_t) offset * params_q_k_v_stride_in_bytes_; + actual_seqlen_ = actual_seqlen_init_ - (int) offset; + } + // Move the pointer to the next column location inline __device__ void move_col() { - q_ptr_ += (int64_t) COLS * (BITS_PER_ELEMENT / 8); + q_k_v_ptr_ += (int64_t) COLS * (BITS_PER_ELEMENT / 8); // Update col_in_bytes_ to ensure load predicates work col_in_bytes_ += THREADS_PER_ROW * BYTES_PER_LDG; } @@ -609,15 +642,29 @@ struct Gmem_tile_q // Rewind the pointer back to previous column location inline __device__ void rewind_col(int const steps) { - q_ptr_ -= COLS * (BITS_PER_ELEMENT / 8) * steps; + q_k_v_ptr_ -= COLS * (BITS_PER_ELEMENT / 8) * steps; // Update col_in_bytes_ to ensure load predicates work col_in_bytes_ -= THREADS_PER_ROW * BYTES_PER_LDG * steps; } - // The stride between rows for the QKV matrice. - int64_t params_q_stride_in_bytes_; + // Move the pointer to the specified step. + inline __device__ void move_to(int const step) + { + q_k_v_ptr_ = q_k_v_ptr_init_ + (int64_t) ROWS * params_q_k_v_stride_in_bytes_ * step; + actual_seqlen_ = actual_seqlen_init_ - (int) ROWS * step; + } + + inline __device__ void reset() + { + q_k_v_ptr_ = q_k_v_ptr_init_; + actual_seqlen_ = actual_seqlen_init_; + } + + // The stride between rows for the Q/K/V matrice. + int64_t params_q_k_v_stride_in_bytes_; // The pointer. - char* q_ptr_; + char* q_k_v_ptr_; + char* q_k_v_ptr_init_; // The register to store predicates. uint32_t preds_[PRED_REGS]; // The fetch registers. @@ -627,6 +674,7 @@ struct Gmem_tile_q int64_t col_in_bytes_; // The sequence length. int actual_seqlen_; + int actual_seqlen_init_; }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/kernels/fmha_v2/src/fmha/kernel_traits.h b/cpp/kernels/fmha_v2/src/fmha/kernel_traits.h index 872036ca3d2..065ce36869c 100644 --- a/cpp/kernels/fmha_v2/src/fmha/kernel_traits.h +++ b/cpp/kernels/fmha_v2/src/fmha/kernel_traits.h @@ -988,6 +988,40 @@ using Kernel_traits_v2 = Kernel_traits_