Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 28 additions & 34 deletions cpp/kernels/fmha_v2/fmha_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,50 +155,41 @@ def test_trtllm_sage_attention_fmha(d, s):
@pytest.mark.parametrize('dtype', ["-bf16", "-e4m3", "-e4m3 -bf16-output"],
ids=["bf16", "e4m3", "e4m3-bf16"])
@pytest.mark.parametrize('s', [1024, 4096], ids=["seqlen-1024", "seqlen-4096"])
@pytest.mark.parametrize(
'input_layout', ["", "-paged-kv", "-contiguous-q-kv", "-separate-q-k-v"],
ids=["packed-qkv", "paged-kv", "q-contiguous-kv", "separate-q-k-v"])
def test_trtllm_context_mla_attention_fmha(dtype, s, input_layout):
def test_trtllm_context_mla_attention_fmha(dtype, s):
sm_version = getSMVersion()
if sm_version < 90:
pytest.skip("MLA kernels are only tested on sm90 and above currently.")

# use higher error tolerance for bf16 and s = 4096.
epsilon = ''
if dtype == "-bf16" and s == 4096:
epsilon += ' -epsilon 0.03'

sm_version = getSMVersion()
if dtype in ["-e4m3", "-e4m3 -bf16-output"] and sm_version != 89:
pytest.skip("FP8 MLAs only supported on sm89 currently.")
if dtype in ["-e4m3", "-e4m3 -bf16-output"] and sm_version != 120:
pytest.skip("FP8 MLAs are only supported on sm120 currently.")

# Context phase kernels.
# Context phase kernels, always use separate-q-k-v layout.
subprocess.run(
f"bin/fmha.exe -v 0 -runs 1 -min-s 1024 -s {s} -b 8 -h 8 -d 192 -dv 128 {dtype} \
-force-non-warp-specialization -causal-mask {epsilon}",
f"bin/fmha.exe -v 0 -runs 1 -min-s 1024 -s {s} -b 8 -h 8 -d 192 -dv 128 {dtype} "
f"-causal-mask {epsilon} -separate-q-k-v",
shell=True,
check=True)

if sm_version == 90:
# Now only hopper-style supports separate-q-k-v
# For chunked prefill, we need to enable -save-softmax (dtype: bf16, layout: separate-q-k-v).
# Currently fp8 kernel doesn't support saving softmax.
if dtype == "-bf16":
# padding mask
subprocess.run(
f"bin/fmha.exe -v 0 -runs 1 -min-s 1024 -s {s} -b 8 -h 8 -d 192 -dv 128 {dtype} \
-causal-mask {epsilon} {input_layout}",
f"bin/fmha.exe -v 0 -runs 1 -min-s 1024 -s {s} -b 8 -h 8 -d 192 -dv 128 {dtype} "
f"{epsilon} -separate-q-k-v -save-softmax",
shell=True,
check=True)
# causal mask
subprocess.run(
f"bin/fmha.exe -v 0 -runs 1 -min-s 1024 -s {s} -b 8 -h 8 -d 192 -dv 128 {dtype} "
f"-causal-mask {epsilon} -separate-q-k-v -save-softmax",
shell=True,
check=True)

# For chunked prefill, we need to enable -save-softmax (dtype: bf16, sm90, layout: paged-kv or separate-q-k-v).
if dtype == "-bf16" and input_layout in [
"-paged-kv", "-separate-q-k-v"
]:
# padding mask
subprocess.run(
f"bin/fmha.exe -v 0 -runs 1 -min-s 1024 -s {s} -b 8 -h 8 -d 192 -dv 128 {dtype} \
{epsilon} {input_layout} -save-softmax",
shell=True,
check=True)
# causal mask
subprocess.run(
f"bin/fmha.exe -v 0 -runs 1 -min-s 1024 -s {s} -b 8 -h 8 -d 192 -dv 128 {dtype} \
-causal-mask {epsilon} {input_layout} -save-softmax",
shell=True,
check=True)


@pytest.mark.parametrize('dtype', ["-bf16", "-e4m3", "-e4m3 -bf16-output"],
Expand All @@ -210,14 +201,17 @@ def test_trtllm_context_mla_attention_fmha(dtype, s, input_layout):
"num-grouped-heads-64", "num-grouped-heads-128"
])
def test_trtllm_gen_mla_attention_fmha(dtype, s, num_grouped_heads):
sm_version = getSMVersion()
if sm_version < 90:
pytest.skip("MLA kernels are only tested on sm90 and above currently.")

# use higher error tolerance for bf16 and s = 4096.
epsilon = ''
if dtype == "-bf16" and s == 4096:
epsilon += ' -epsilon 0.03'

sm_version = getSMVersion()
if dtype in ["-e4m3", "-e4m3 -bf16-output"] and sm_version != 89:
pytest.skip("FP8 MLAs only supported on sm89 currently.")
if dtype in ["-e4m3", "-e4m3 -bf16-output"] and sm_version != 120:
pytest.skip("FP8 MLAs are only supported on sm120 currently.")

# Generation phase kernels.
subprocess.run(
Expand Down
51 changes: 31 additions & 20 deletions cpp/kernels/fmha_v2/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2075,6 +2075,8 @@ def get_kernel_code(kspec, kname, lname):
kernel_traits += '_paged_kv_cache'
elif kspec.input_layout == InputLayout.CONTIGUOUS_Q_KV:
kernel_traits += '_contiguous_kv_cache'
elif kspec.input_layout == InputLayout.SEPARATE_Q_K_V:
kernel_traits += '_q_k_v'

flags = 0
if kspec.ldgsts_q:
Expand Down Expand Up @@ -3183,7 +3185,7 @@ def get_cubin_header(kernel_traits, specs_names):
attention_mask_type_value = attention_mask_type.value

# Attention input layout:
# packed_qkv (0), contiguous_q_kv (1), q_paged_kv (2).
# packed_qkv (0), contiguous_q_kv (1), q_paged_kv (2), separate_q_k_v (3).
attention_input_layout = InputLayout.PACKED_QKV
if '_q_kv' in kname:
attention_input_layout = InputLayout.CONTIGUOUS_Q_KV
Expand Down Expand Up @@ -3652,12 +3654,9 @@ def enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='fp16'):
if alibi and enable_attn_logit_softcapping:
continue
# for normal attention, we only need contiguous kv as input layout when returning softmax.
skip_combination = return_softmax and (input_layout
!= InputLayout.CONTIGUOUS_Q_KV)
# for context mla, we need paged kv or separate qkv as input layout when returning softmax.
skip_mla_combination = return_softmax and (
input_layout != InputLayout.Q_PAGED_KV
and input_layout != InputLayout.SEPARATE_Q_K_V)
skip_combination = return_softmax and input_layout != InputLayout.CONTIGUOUS_Q_KV
# for context mla, we need separate qkv as input layout when returning softmax.
skip_mla_combination = return_softmax and input_layout != InputLayout.SEPARATE_Q_K_V
if not skip_combination:
# only specify
specs.append(
Expand Down Expand Up @@ -4702,9 +4701,16 @@ def enumerate_hmma_paged_kv_flash_kernels(specs, sm=80, dtype='fp16'):


def enumerate_hmma_flash_kernels(specs, sm=80, dtype='fp16', head_size_v=0):
for (input_layout, enable_attn_logit_softcapping) in \
product([InputLayout.PACKED_QKV, InputLayout.CONTIGUOUS_Q_KV, InputLayout.Q_PAGED_KV], \
[False, True]):
input_layouts = [
InputLayout.PACKED_QKV, InputLayout.CONTIGUOUS_Q_KV,
InputLayout.Q_PAGED_KV
]
# Deepseek MLA (context 192/128 separate-q-k-v)
if head_size_v == 128:
input_layouts.append(InputLayout.SEPARATE_Q_K_V)
for (input_layout,
enable_attn_logit_softcapping) in product(input_layouts,
[False, True]):
enumerate_hmma_flash_kernels_base(specs, sm, dtype, input_layout,
enable_attn_logit_softcapping,
head_size_v)
Expand Down Expand Up @@ -5080,7 +5086,7 @@ def enumerate_qmma_flash_kernels(specs,
]
input_layouts = [
InputLayout.PACKED_QKV, InputLayout.CONTIGUOUS_Q_KV,
InputLayout.Q_PAGED_KV
InputLayout.Q_PAGED_KV, InputLayout.SEPARATE_Q_K_V
]
for (head_size_params, (q_loop_step, kv_loop_step), tiled), input_layout in \
product(params_q_kv_step, input_layouts):
Expand All @@ -5094,6 +5100,9 @@ def enumerate_qmma_flash_kernels(specs,
# skip if head_size is not in head_sizes
if head_sizes is not None and head_size not in head_sizes:
continue
# skip if head_size_v is not 128 for separate-q-k-v
if input_layout == InputLayout.SEPARATE_Q_K_V and head_size_v != 128:
continue
specs.append(
kernel_spec(sm=sm,
sm_mma=89,
Expand Down Expand Up @@ -6354,28 +6363,30 @@ def enumerate_kernels():
and kspec.version == 2
and kspec.cross_mha == False
and kspec.flash_attention == False)
# Deepseek MLA (192/128 packed + 576/512 paged)
or (kspec.sm in [80, 86, 89, 90, 100, 120]
# Deepseek MLA (generation 576/512 paged)
or (kspec.sm in [90, 100, 120]
and kspec.dtype in ['bf16', 'e4m3_fp32']
and (((kspec.head_size, kspec.head_size_v) == (192, 128) and kspec.input_layout in [InputLayout.PACKED_QKV, InputLayout.Q_PAGED_KV])
or ((kspec.head_size, kspec.head_size_v) == (576, 512) and kspec.input_layout == InputLayout.Q_PAGED_KV))
and kspec.head_size == 576
and kspec.head_size_v == 512
and kspec.input_layout == InputLayout.Q_PAGED_KV
and kspec.sage_block_sizes is None
and kspec.version == 2
and kspec.cross_mha == False
and kspec.flash_attention == True
and kspec.warp_specialization == False
and kspec.tiled == True)
# Deepseek MLA (hopper-style context 192/128)
or (kspec.sm == 90
and kspec.dtype == 'bf16'
# Deepseek MLA (context 192/128 separate-q-k-v)
or (kspec.sm in [90, 100, 120]
and kspec.dtype in ['bf16', 'e4m3_fp32']
and kspec.head_size == 192
and kspec.head_size_v == 128
and kspec.input_layout == InputLayout.SEPARATE_Q_K_V
and kspec.sage_block_sizes is None
and kspec.version == 2
and kspec.cross_mha == False
and kspec.flash_attention == True
and kspec.warp_specialization == True
and kspec.alibi == False
and ((kspec.warp_specialization == True and kspec.alibi == False) # sm90
or (kspec.warp_specialization == False and kspec.tiled == True)) # non-sm90
and kspec.enable_attn_logit_softcapping == False)
# SageAttention (warp_spec, head_size in (80, 128), packed QKV, padding mask)
or (kspec.sm == 90
Expand Down
100 changes: 74 additions & 26 deletions cpp/kernels/fmha_v2/src/fmha/gmem_tile_qkv_packed.h
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ struct Gmem_tile_qkv

////////////////////////////////////////////////////////////////////////////////////////////////////

// We expect the Q layout to be [B, S, H, D] with variable sequence length support.
// We expect the Q/K/V layout to be [B, S, H, D] with variable sequence length support.
template <
// The instruction traits.
typename Traits,
Expand All @@ -440,7 +440,7 @@ template <
int NUM_MATS = 1,
// Is sliding window attention used ?
bool SLIDING_WINDOW_ATTENTION = false>
struct Gmem_tile_q
struct Gmem_tile_q_k_v
{

// The size of each LDG.
Expand Down Expand Up @@ -523,22 +523,38 @@ struct Gmem_tile_q
USE_LDGSTS = USE_LDGSTS_
};

// Ctor (keep qkv_offset for compatibility)
// Ctor
// qkv_offset: 0 for Q, 1 for K, 2 for V
template <typename Block_info>
inline __device__ Gmem_tile_q(bert::Fused_multihead_attention_params_v2 const& params, int qkv_offset,
inline __device__ Gmem_tile_q_k_v(bert::Fused_multihead_attention_params_v2 const& params, int qkv_offset,
Block_info const& binfo, int tidx, int cta_row_offset = 0, int cta_col_offset_in_bytes = 0)
: Gmem_tile_q(params, binfo, tidx, cta_row_offset, cta_col_offset_in_bytes)
{
}

// Ctor.
template <typename Block_info>
inline __device__ Gmem_tile_q(bert::Fused_multihead_attention_params_v2 const& params, Block_info const& binfo,
int tidx, int cta_row_offset = 0, int cta_col_offset_in_bytes = 0)
: params_q_stride_in_bytes_(params.q_stride_in_bytes)
, actual_seqlen_(binfo.actual_q_seqlen)
, q_ptr_(reinterpret_cast<char*>(params.q_ptr))
{
int seq_offset = 0;
if (qkv_offset == 0)
{
// Q tensor
params_q_k_v_stride_in_bytes_ = params.q_stride_in_bytes;
q_k_v_ptr_ = reinterpret_cast<char*>(params.q_ptr);
actual_seqlen_ = binfo.actual_q_seqlen;
seq_offset = binfo.sum_s;
}
else if (qkv_offset == 1)
{
// K tensor
params_q_k_v_stride_in_bytes_ = params.k_stride_in_bytes;
q_k_v_ptr_ = reinterpret_cast<char*>(params.k_ptr);
actual_seqlen_ = binfo.actual_kv_seqlen;
seq_offset = binfo.sum_s_kv;
}
else if (qkv_offset == 2)
{
// V tensor
params_q_k_v_stride_in_bytes_ = params.v_stride_in_bytes;
q_k_v_ptr_ = reinterpret_cast<char*>(params.v_ptr);
actual_seqlen_ = binfo.actual_kv_seqlen;
seq_offset = binfo.sum_s_kv;
}

// Compute the position in the sequence (within the CTA for the moment).
int row = tidx / THREADS_PER_ROW;
Expand All @@ -550,17 +566,20 @@ struct Gmem_tile_q
// Do not load/store if the thread is in the padded area
col_in_bytes_ = cta_col_offset_in_bytes + col * BYTES_PER_LDG;

// The row offset in the batched GEMM. For each seq element, we store QKV in that order.
// We won't consider past_q_length when loading from gmem_q.
int64_t row_offset = (int64_t) (row + cta_row_offset) * params_q_stride_in_bytes_;
// Add the block index. (sum_s * h + hidx).
int64_t idx = binfo.bidx;
// The row offset in the batched GEMM, including the sequence offset.
int64_t row_offset = (int64_t) (row + cta_row_offset + seq_offset) * params_q_k_v_stride_in_bytes_;
// Add the head index.
int64_t idx = binfo.bidh;

// Assemble the final pointer.
q_ptr_ += row_offset + idx * VALID_BYTES_PER_ROW + col_in_bytes_;
q_k_v_ptr_ += row_offset + idx * VALID_BYTES_PER_ROW + col_in_bytes_;

// Take the CTA offset to modify the sequence length.
actual_seqlen_ -= cta_row_offset;

// Set the initial seq_len and qkv_offset in case of reinterating
actual_seqlen_init_ = actual_seqlen_;
q_k_v_ptr_init_ = q_k_v_ptr_;
}

// Store data to shared memory.
Expand Down Expand Up @@ -590,34 +609,62 @@ struct Gmem_tile_q
#pragma unroll
for (int ii = 0; ii < LDGS; ++ii)
{
ptrs[ii] = q_ptr_ + (int64_t) ii * ROWS_PER_LDG * params_q_stride_in_bytes_;
ptrs[ii] = q_k_v_ptr_ + (int64_t) ii * ROWS_PER_LDG * params_q_k_v_stride_in_bytes_;
}

// Trigger LDGSTS or the LDGs.
// The predicates protect against out-of-bound access in rows and cols
Ldgsts_helper<USE_LDGSTS>::load(this, smem_tile, ptrs, preds);
}

// Move the pointer to the next row location.
inline __device__ void move(int const steps = 1)
{
q_k_v_ptr_ += (int64_t) ROWS * params_q_k_v_stride_in_bytes_ * steps;
actual_seqlen_ -= (int) ROWS * steps;
}

// Move the pointer to the next row location by the offset (not step).
inline __device__ void move_by_offset(int const offset)
{
q_k_v_ptr_ = q_k_v_ptr_init_ + (int64_t) offset * params_q_k_v_stride_in_bytes_;
actual_seqlen_ = actual_seqlen_init_ - (int) offset;
}

// Move the pointer to the next column location
inline __device__ void move_col()
{
q_ptr_ += (int64_t) COLS * (BITS_PER_ELEMENT / 8);
q_k_v_ptr_ += (int64_t) COLS * (BITS_PER_ELEMENT / 8);
// Update col_in_bytes_ to ensure load predicates work
col_in_bytes_ += THREADS_PER_ROW * BYTES_PER_LDG;
}

// Rewind the pointer back to previous column location
inline __device__ void rewind_col(int const steps)
{
q_ptr_ -= COLS * (BITS_PER_ELEMENT / 8) * steps;
q_k_v_ptr_ -= COLS * (BITS_PER_ELEMENT / 8) * steps;
// Update col_in_bytes_ to ensure load predicates work
col_in_bytes_ -= THREADS_PER_ROW * BYTES_PER_LDG * steps;
}

// The stride between rows for the QKV matrice.
int64_t params_q_stride_in_bytes_;
// Move the pointer to the specified step.
inline __device__ void move_to(int const step)
{
q_k_v_ptr_ = q_k_v_ptr_init_ + (int64_t) ROWS * params_q_k_v_stride_in_bytes_ * step;
actual_seqlen_ = actual_seqlen_init_ - (int) ROWS * step;
}

inline __device__ void reset()
{
q_k_v_ptr_ = q_k_v_ptr_init_;
actual_seqlen_ = actual_seqlen_init_;
}

// The stride between rows for the Q/K/V matrice.
int64_t params_q_k_v_stride_in_bytes_;
// The pointer.
char* q_ptr_;
char* q_k_v_ptr_;
char* q_k_v_ptr_init_;
// The register to store predicates.
uint32_t preds_[PRED_REGS];
// The fetch registers.
Expand All @@ -627,6 +674,7 @@ struct Gmem_tile_q
int64_t col_in_bytes_;
// The sequence length.
int actual_seqlen_;
int actual_seqlen_init_;
};

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
Loading