Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
acef6e8
test pass
zhou-yuxin Aug 21, 2025
27ed25a
pre-commit format
zhou-yuxin Aug 21, 2025
cdd456f
test pass
zhou-yuxin Aug 21, 2025
1fe8273
pre-commit format
zhou-yuxin Aug 21, 2025
4338e08
[None][chore] Update namelist in blossom-ci (#7015)
karljang Aug 20, 2025
d271157
[None][ci] move unittests to sub-directories (#6635)
Funatiq Aug 20, 2025
89747a8
[None][infra] Waive failed tests on main branch 8/20 (#7092)
EmmaQiaoCh Aug 20, 2025
8d445a4
[None][fix] Fix W4A8 MoE kernel issue (#7072)
yuhyao Aug 20, 2025
d36ba89
[TRTLLM-7348] [feat] Enable Cross-Attention to use XQA kernels for Wh…
DomBrown Aug 20, 2025
5920e6e
[None][chore] Only check the bindings lib for current build (#7026)
liji-nv Aug 20, 2025
54bc8fd
[None][ci] move some tests of b200 to post merge (#7093)
QiJune Aug 20, 2025
29aee2a
[https://nvbugs/5457489][fix] unwaive some tests (#6991)
byshiue Aug 21, 2025
890dda6
[TRTLLM-6771][feat] Support MMMU for multimodal models (#6828)
yechank-nvidia Aug 21, 2025
ab3153c
[None][fix] Fix llama4 multimodal by skipping request validation (#6957)
chang-l Aug 21, 2025
1b3709e
[None][infra] Upgrade UCX to v1.19.x and NIXL to 0.5.0 (#7024)
BatshevaBlack Aug 21, 2025
d84e1c7
[None][fix] update accelerate dependency to 1.7+ for AutoDeploy (#7077)
Fridah-nv Aug 21, 2025
c923ba7
[None][fix] Fix const modifier inconsistency in log function declarat…
Fan-Yunfan Aug 21, 2025
8233dda
[None][chore] waive failed cases on H100 (#7084)
xinhe-nv Aug 21, 2025
176f367
[fix]: use safeInitRowMax instead of fp32_lowest to avoid NaN (#7087)
lowsfer Aug 21, 2025
cc35ba2
[https://nvbugs/5443039][fix] Fix AutoDeploy pattern matcher for torc…
Fridah-nv Aug 21, 2025
9631242
[https://nvbugs/5437405][fix] qwen3 235b eagle3 ci (#7000)
byshiue Aug 21, 2025
810beb2
[None][doc] Update gpt-oss deployment guide to latest release image (…
farshadghodsian Aug 21, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
test pass
Signed-off-by: Yuxin <yuxinz@nvidia.com>
  • Loading branch information
zhou-yuxin committed Aug 21, 2025
commit acef6e808a7bd795eb919dbd2ed71b9dfc2108d8
4 changes: 2 additions & 2 deletions cpp/kernels/fmha_v2/fmha_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,8 @@ def test_trtllm_context_mla_attention_fmha(dtype, s):
if dtype == "-bf16" and s == 4096:
epsilon += ' -epsilon 0.03'

if dtype in ["-e4m3", "-e4m3 -bf16-output"] and sm_version != 120:
pytest.skip("FP8 MLAs are only supported on sm120 currently.")
if dtype in ["-e4m3", "-e4m3 -bf16-output"] and sm_version not in [90, 120]:
pytest.skip("FP8 MLAs are only supported on sm90 and sm120 currently.")

# Context phase kernels, always use separate-q-k-v layout.
subprocess.run(
Expand Down
51 changes: 47 additions & 4 deletions cpp/kernels/fmha_v2/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -1914,8 +1914,9 @@ def enable_mutex(kspec):


def enable_tma_store(kspec):
output_dtype = kspec.output_dtype if kspec.output_dtype is not None else kspec.dtype
# TMA copies data in the 16B granularity.
return 'true' if (kspec.dtype in ['e4m3', 'e4m3_fp32']
return 'true' if (output_dtype in ['e4m3', 'e4m3_fp32']
and kspec.head_size % 16 == 0) else 'false'


Expand Down Expand Up @@ -3812,7 +3813,9 @@ def enumerate_qgmma_flash_warpspec_kernels(specs,
# use specialized kernels for cases without alibi scales.
# there is a numeric issues when applying the exp2f scale optimization and alibi scale at the same time.
combinations = product([False, True], \
[InputLayout.PACKED_QKV, InputLayout.CONTIGUOUS_Q_KV, InputLayout.Q_PAGED_KV], [False, True])
[InputLayout.PACKED_QKV, InputLayout.CONTIGUOUS_Q_KV,
InputLayout.Q_PAGED_KV, InputLayout.SEPARATE_Q_K_V],
[False, True])
for (alibi, input_layout, enable_attn_logit_softcapping) in combinations:
# alibi and bmm1_tanh_scale shouldn't be used together.
if alibi and enable_attn_logit_softcapping:
Expand Down Expand Up @@ -3911,7 +3914,7 @@ def enumerate_qgmma_flash_warpspec_kernels(specs,
has_noloop=0,
noloop_step=64,
kv_loop_step=
128, # use 64 kv step size to avoid register spilling
128, # use 128 kv step size to avoid register spilling
kv_tile_buffers=2, # only used by warp specialized kernels
unroll_threshold=1,
has_scale_max=False,
Expand All @@ -3926,6 +3929,46 @@ def enumerate_qgmma_flash_warpspec_kernels(specs,
sage_block_sizes=sage_block_sizes,
output_dtype=output_dtype))

# context MLA (192x128)
# we could use param 'output_dtype' of enumerate_qgmma_flash_warpspec_kernels(),
# but it will generate many unnecessary kernels and they are not easy to filter out.
for output_type in [None, 'bf16']:
specs.append(
kernel_spec(
sm=sm,
sm_mma=90,
dtype=dtype,
seq_len=0, # support any sequence length
head_size=192,
head_size_v=128,
warps_m=4, #4x1 warpgroups
warps_n=1,
version=2,
interleaved=False,
ldgsts_q=
False, # for Hopper kernels, ldgsts = False signals TMA usage.
ldgsts_k=False,
ldgsts_v=False,
share_smem_k_v=False,
loop_step=64,
q_tile_buffers=1, # only used by warp specialized kernels
has_noloop=0,
noloop_step=64,
kv_loop_step=128,
kv_tile_buffers=2, # only used by warp specialized kernels
unroll_threshold=1,
has_scale_max=False,
flash_attention=True,
warp_specialization=True,
alibi=alibi,
enable_attn_logit_softcapping=enable_attn_logit_softcapping,
return_softmax_stats=
False, # return softmax stats is not supported for fp8 now
scheduling_mode=scheduling_mode,
input_layout=input_layout,
sage_block_sizes=sage_block_sizes,
output_dtype=output_type))


def enumerate_igmma_kernels(specs, sm=90):
specs.append(
Expand Down Expand Up @@ -6377,7 +6420,7 @@ def enumerate_kernels():
and kspec.tiled == True)
# Deepseek MLA (context 192/128 separate-q-k-v)
or (kspec.sm in [90, 100, 120]
and kspec.dtype in ['bf16', 'e4m3_fp32']
and kspec.dtype in ['bf16', 'e4m3', 'e4m3_fp32']
and kspec.head_size == 192
and kspec.head_size_v == 128
and kspec.input_layout == InputLayout.SEPARATE_Q_K_V
Expand Down
44 changes: 29 additions & 15 deletions cpp/kernels/fmha_v2/src/fmha/hopper/gmem_tile_o_packed.h
Original file line number Diff line number Diff line change
Expand Up @@ -1222,6 +1222,14 @@ struct Gmem_tile_o_qgmma_fp32_16bits
inline __device__ Gmem_tile_o_qgmma_fp32_16bits(
Params const& params, Block_info const& block_info, Shared&&, int tidx, int cta_row_offset = 0)
: params_o_stride_in_bytes_(params.o_stride_in_bytes)
, params_scale_bmm2_(
#ifdef GENERATE_CUBIN
// Specialized for trt-llm generated cubins only.
params.scale_bmm2_d ? *params.scale_bmm2_d : params.scale_bmm2
#else
params.scale_bmm2
#endif
)
, actual_seqlen_(block_info.actual_seqlen)
, o_ptr_(reinterpret_cast<char*>(params.o_ptr))
{
Expand Down Expand Up @@ -1251,21 +1259,25 @@ struct Gmem_tile_o_qgmma_fp32_16bits
inline __device__ void store(Accumulators const (&acc)[M][N])
{
int64_t const step_m = 8 * params_o_stride_in_bytes_;
// we assume M = 1. some shortcuts.
static_assert(M == 1);

#define STORE_COLUMN(idx) \
{ \
float _reg0 = acc[0][mma_ni].elt(((ci + 0) * ROWS_PER_THREAD + ri) * 2 + idx); \
float _reg1 = acc[0][mma_ni].elt(((ci + 1) * ROWS_PER_THREAD + ri) * 2 + idx); \
static_assert(std::is_same_v<Output_type, bf16_t> || std::is_same_v<Output_type, fp16_t>); \
uint32_t _out = fmha::float2_to_16bit_2<Output_type>(_reg0, _reg1); \
int64_t _offset = (int64_t) ri * step_m + (int64_t) (ci + mma_ni * COLS_PER_THREAD) * STEP_N; \
fmha::stg(o_ptr_ + _offset + 4 * idx, _out); \
}
#define STORE_COLUMNS() \
{ \
STORE_COLUMN(0) STORE_COLUMN(1) \
#ifdef UNIFIED_EPILOGUE_SCALE
constexpr bool Scale = false;
#else
constexpr bool Scale = true;
#endif
#define STORE_COLUMNS() \
{ \
/* we assume M = 1. some shortcuts. */ \
static_assert(M == 1); \
uint4 _src = { \
.x = acc[0][mma_ni].reg(((ci + 0) * ROWS_PER_THREAD + ri) * 2), \
.y = acc[0][mma_ni].reg(((ci + 1) * ROWS_PER_THREAD + ri) * 2), \
.z = acc[0][mma_ni].reg(((ci + 0) * ROWS_PER_THREAD + ri) * 2 + 1), \
.w = acc[0][mma_ni].reg(((ci + 1) * ROWS_PER_THREAD + ri) * 2 + 1), \
}; \
uint2 _dst = Acc_packer<float, Output_type, Scale>::run(this, _src); \
int64_t _offset = \
(int64_t)ri * step_m + (int64_t)(ci + mma_ni * COLS_PER_THREAD) * STEP_N; \
fmha::stg(o_ptr_ + _offset, _dst); \
}

#pragma unroll
Expand Down Expand Up @@ -1303,6 +1315,8 @@ struct Gmem_tile_o_qgmma_fp32_16bits

// The stride between rows for the QKV matrice.
int64_t params_o_stride_in_bytes_;
// Scaling factor; this usually means QKV descale factor in actuality
uint32_t params_scale_bmm2_;
// The pointer.
char* o_ptr_;
// The row loaded by this thread.
Expand Down
4 changes: 2 additions & 2 deletions cpp/kernels/fmha_v2/src/fmha/warpspec/dma.h
Original file line number Diff line number Diff line change
Expand Up @@ -755,7 +755,7 @@ struct DMA
for (int kgroup_idx = 0; kgroup_idx < Kernel_traits::BMM2_K_GROUPS; kgroup_idx++)
{
#pragma unroll
for (int dgroup_idx = 0; dgroup_idx < Kernel_traits::D_GROUPS; dgroup_idx++)
for (int dgroup_idx = 0; dgroup_idx < Kernel_traits::DV_GROUPS; dgroup_idx++)
{
// Src smem block is k first then d
uint32_t src_offset = (kgroup_idx * Kernel_traits::BMM2_K_PER_GROUP * Kernel_traits::D_PER_GROUP
Expand All @@ -764,7 +764,7 @@ struct DMA

// Dst smem block is d first then k
uint32_t dst_offset = (dgroup_idx * Kernel_traits::BMM2_K_PER_GROUP * Kernel_traits::D_PER_GROUP
+ kgroup_idx * Kernel_traits::BMM2_K_PER_GROUP * Kernel_traits::D)
+ kgroup_idx * Kernel_traits::BMM2_K_PER_GROUP * Kernel_traits::DV)
* Kernel_traits::ELEMENT_BYTES;

transposer.template transpose_<false>(smem_v_src + src_offset, smem_v_dst + dst_offset);
Expand Down
3 changes: 2 additions & 1 deletion cpp/kernels/fmha_v2/src/fmha/warpspec/kernel_traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,8 @@ struct Kernel_traits_Hopper_qgmma_e4m3_fp32
// Base class.
using Base = Kernel_traits<Hopper_qgmma_e4m3_fp32_traits, STEP_Q_, STEP_KV_, D_, DV_, Q_BUFFERS_, KV_BUFFERS_,
NUM_COMPUTE_GROUPS_, DMA2COMPUTE_DEPTH_, ATTENTION_MASK_TYPE_, HEADS_INTERLEAVED_, APPLY_ALIBI_, ENABLE_MUTEX_,
SCHEDULING_MODE_, INPUT_LAYOUT_, USE_TMA_STORE_, ENABLE_BMM1_SOFTCAPPING_SCALE_>;
SCHEDULING_MODE_, INPUT_LAYOUT_, USE_TMA_STORE_, ENABLE_BMM1_SOFTCAPPING_SCALE_,
RETURN_SOFTMAX_STATS_, OutputType, SAGE_BLOCK_SIZE_Q_, SAGE_BLOCK_SIZE_K_, SAGE_BLOCK_SIZE_V_>;

enum
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,8 @@ extern void run_fmha_v2_flash_attention_e4m3_64_128_S_q_paged_kv_160_tma_ws_sm90
extern void run_fmha_v2_flash_attention_e4m3_64_128_S_q_paged_kv_192_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_e4m3_64_128_S_q_paged_kv_256_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_e4m3_64_128_S_q_paged_kv_256_softcapping_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_output_bf16_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
Comment on lines +263 to +264
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

New externs for 192x128 S_q_k_v (and BF16 output) look consistent; please verify out-of-line definitions exist.

Signatures follow existing conventions and appear correct. To avoid ODR/link issues, confirm these have exactly one non-extern definition and that parameter names match between declaration and definition (per guidelines).

Run this script from repo root to verify definitions and parameter-name consistency:

Optionally, if these are SM90-only implementations, consider guarding the declarations/uses with the same compile-time condition used for other SM90-only symbols for consistency.


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Declarations (should match these two)
decls=(
  'run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_tma_ws_sm90'
  'run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_output_bf16_tma_ws_sm90'
)

for sym in "${decls[@]}"; do
  echo "== Checking: $sym =="
  # Find declaration sites
  rg -nP --type=cpp --type=cu $'extern\\s+void\\s+'"$sym"$'\\s*\\(\\s*Fused_multihead_attention_params_v2\\s*&\\s*params\\s*,\\s*const\\s*Launch_params\\s*&\\s*launch_params\\s*,\\s*cudaStream_t\\s*stream\\s*\\)' -C1

  # Find definitions (non-extern)
  rg -nP --type=cpp --type=cu $'(?<!extern)\\s+void\\s+'"$sym"$'\\s*\\(\\s*Fused_multihead_attention_params_v2\\s*&\\s*params\\s*,\\s*const\\s*Launch_params\\s*&\\s*launch_params\\s*,\\s*cudaStream_t\\s*stream\\s*\\)' -C3
done

Length of output: 669


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Ensure we search all relevant file extensions instead of relying on rg file types.
decls=(
  'run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_tma_ws_sm90'
  'run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_output_bf16_tma_ws_sm90'
)

for sym in "${decls[@]}"; do
  echo "== Checking: $sym =="
  echo "-- extern declarations in headers --"
  rg -nP --glob '*.h' "extern\s+void\s+$sym\s*\([^)]*\)" -C1 || true

  echo "-- definitions in .cpp/.cu sources (non-extern) --"
  rg -nP --glob '*.cpp' --glob '*.cu' "(?<!extern)\s+void\s+$sym\s*\([^)]*\)" -C2 || true

  echo
done

Length of output: 3201


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Search for definitions/occurrences of new symbols and a nearby existing symbol for reference
symbols=(
  run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_tma_ws_sm90
  run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_output_bf16_tma_ws_sm90
  run_fmha_v2_flash_attention_e4m3_64_256_S_qkv_32_alibi_tma_ws_sm90
)

for sym in "${symbols[@]}"; do
  echo "=== Occurrences of $sym ==="
  rg -n "$sym" -C2 || echo "  (no matches found)"
  echo
done

Length of output: 12014


Missing out-of-line definitions for new SM90 externs

Our scan of all .cpp and .cu files did not locate non-extern definitions for the two new functions declared in
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.h at lines 263–264:

  • run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_tma_ws_sm90
  • run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_output_bf16_tma_ws_sm90

Without matching definitions, the build will fail at link time. Please add implementations—e.g. in the existing fmha_cubin.cu (or a new .cpp/.cu)—ensuring:

  • Exactly one non-extern definition per symbol.
  • The parameter names in the definitions match the declarations (params, launch_params, stream).
  • Optionally, guard both declaration and definition with the same SM90 compile-time macro used elsewhere for consistency.

– File to update:
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.h (lines 263–264)
• Corresponding .cu or .cpp where definitions belong.

🤖 Prompt for AI Agents
In cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.h
around lines 263–264, two functions were declared as extern for SM90 but no
non-extern (definition) implementations exist:
run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_tma_ws_sm90 and
run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_output_bf16_tma_ws_sm90;
add exactly one non-extern definition for each (preferably in
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.cu),
ensure the function signatures use the same parameter names (params,
launch_params, stream), and wrap both declaration and definition with the same
SM90 compile-time macro guard used elsewhere so the linker finds the
implementations only when SM90 is enabled.

extern void run_fmha_v2_flash_attention_e4m3_64_256_S_qkv_32_alibi_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_e4m3_64_256_S_qkv_40_alibi_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_e4m3_64_256_S_qkv_48_alibi_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
Expand Down Expand Up @@ -1969,6 +1971,10 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 256, 128, 128, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_q_paged_kv_128_softcapping_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_q_paged_kv_128_softcapping_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_q_paged_kv_128_sliding_or_chunked_causal_softcapping_tma_ws_sm90_kernel", 180480, 384, 64, 2, 2, false, true, true, true, false, false, true, false, nullptr},
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 128, 256, 256, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_e4m3_64_128_S_q_paged_kv_256_causal_softcapping_tma_ws_sm90_kernel", 229632, 384, 64, 1, 2, false, true, true, true, false, false, true, false, run_fmha_v2_flash_attention_e4m3_64_128_S_q_paged_kv_256_softcapping_tma_ws_sm90},
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 128, 256, 256, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_e4m3_64_128_S_q_paged_kv_256_sliding_or_chunked_causal_softcapping_tma_ws_sm90_kernel", 196864, 384, 64, 2, 2, false, true, true, true, false, false, true, false, run_fmha_v2_flash_attention_e4m3_64_128_S_q_paged_kv_256_softcapping_tma_ws_sm90},
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_tma_ws_sm90_kernel", 164096, 384, 64, 0, 3, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_tma_ws_sm90},
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_causal_tma_ws_sm90_kernel", 164096, 384, 64, 1, 3, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_tma_ws_sm90},
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_output_bf16_tma_ws_sm90_kernel", 164096, 384, 64, 0, 3, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_output_bf16_tma_ws_sm90},
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_causal_output_bf16_tma_ws_sm90_kernel", 164096, 384, 64, 1, 3, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_output_bf16_tma_ws_sm90},
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 256, 32, 32, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_e4m3_64_256_S_qkv_32_causal_alibi_tma_ws_sm90_kernel", 82304, 384, 64, 1, 0, false, true, true, true, true, false, false, false, run_fmha_v2_flash_attention_e4m3_64_256_S_qkv_32_alibi_tma_ws_sm90},
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 256, 40, 40, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_e4m3_64_256_S_qkv_40_causal_alibi_tma_ws_sm90_kernel", 164224, 384, 64, 1, 0, false, true, true, true, true, false, false, false, run_fmha_v2_flash_attention_e4m3_64_256_S_qkv_40_alibi_tma_ws_sm90},
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 256, 48, 48, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_e4m3_64_256_S_qkv_48_causal_alibi_tma_ws_sm90_kernel", 164224, 384, 64, 1, 0, false, true, true, true, true, false, false, false, run_fmha_v2_flash_attention_e4m3_64_256_S_qkv_48_alibi_tma_ws_sm90},
Expand Down