From 236f71ea0506cce8de2ed99adaff29ed38718782 Mon Sep 17 00:00:00 2001 From: xinhe-nv <200704525+xinhe-nv@users.noreply.github.com> Date: Thu, 18 Sep 2025 14:48:16 +0800 Subject: [PATCH 01/76] [None][chore] Add failed cases into waives.txt (#7801) Signed-off-by: xinhe-nv <200704525+xinhe-nv@users.noreply.github.com> --- tests/integration/test_lists/waives.txt | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index 4ce04721f7c..7b730758453 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -346,3 +346,7 @@ accuracy/test_llm_api_pytorch.py::TestLlama4ScoutInstruct::test_fp8_chunked_pref accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[ep4-mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5522746) accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-tp2pp2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5522746) test_e2e.py::test_ptp_quickstart_multimodal[NVILA-8B-FP16-vila/NVILA-8B-image-False] SKIP (https://nvbugs/5523925) +test_e2e.py::test_ptp_quickstart_multimodal[NVILA-8B-FP16-vila/NVILA-8B-image-False] SKIP (https://nvbugs/5509024) +test_e2e.py::test_ptp_quickstart_multimodal[qwen2.5-vl-7b-instruct-Qwen2.5-VL-7B-Instruct-video-False] SKIP (https://nvbugs/5509024) +test_e2e.py::test_ptp_quickstart_multimodal[qwen2.5-vl-7b-instruct-Qwen2.5-VL-7B-Instruct-video-True] SKIP (https://nvbugs/5509024) +test_e2e.py::test_trtllm_multimodal_benchmark_serving SKIP (https://nvbugs/5523315) From 2ae08bd1b846acb17e6be667639f7fcf4427f86c Mon Sep 17 00:00:00 2001 From: dongfengy <99041270+dongfengy@users.noreply.github.com> Date: Thu, 18 Sep 2025 01:01:53 -0700 Subject: [PATCH 02/76] [https://nvbugs/5519530][fix] Fix gptoss 2-gpu test (#7819) Signed-off-by: Dongfeng Yu --- tests/integration/defs/accuracy/test_llm_api_pytorch.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index eb5f6e98b0a..3e5a84012a8 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -3270,6 +3270,8 @@ def test_w4_2gpus(self, kv_cache_dtype, moe_backend, tp_size, pp_size, model_name = "GPT-OSS/MXFP4" task = GSM8K(model_name) mocker.patch.object(GSM8K, "MAX_OUTPUT_LEN", 8192) + mocker.patch.dict(GSM8K.EVALUATE_KWARGS, + {"scores_filter": "exact_match,flexible-extract"}) task.evaluate(llm, extra_evaluator_kwargs=self.extra_evaluator_kwargs) From a7ca0fff54e171f56a919265abc3238c0f506e63 Mon Sep 17 00:00:00 2001 From: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com> Date: Thu, 18 Sep 2025 16:26:20 +0800 Subject: [PATCH 03/76] [TRTLLM-6577][feat] Support nano_v2_vlm in pytorch backend (#7207) Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com> --- .github/CODEOWNERS | 2 + cpp/kernels/fmha_v2/setup.py | 4 +- .../cubin/fmha_cubin.h | 53 +- docs/source/models/supported-models.md | 1 + tensorrt_llm/_torch/models/__init__.py | 2 + .../_torch/models/modeling_nanov2vlm.py | 458 +++++++++ tensorrt_llm/_torch/models/modeling_radio.py | 903 ++++++++++++++++++ tensorrt_llm/commands/eval.py | 9 +- .../defs/accuracy/references/mmmu.yaml | 2 + .../defs/accuracy/test_llm_api_pytorch.py | 23 + tests/integration/defs/test_e2e.py | 28 + .../test_lists/qa/llm_function_core.txt | 1 + 12 files changed, 1474 insertions(+), 12 deletions(-) create mode 100644 tensorrt_llm/_torch/models/modeling_nanov2vlm.py create mode 100644 tensorrt_llm/_torch/models/modeling_radio.py diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 28f29c3b152..4c619e62b32 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -100,6 +100,8 @@ /tests/unittest/_torch/modeling/test_modeling_pixtral.py @NVIDIA/trt-llm-torch-models-vlm-devs @NVIDIA/trt-llm-torch-models-devs ### TensorRT-LLM Pytorch - Models - Nemotron +/tensorrt_llm/_torch/models/modeling_nanov2vlm.py @NVIDIA/trt-llm-torch-models-nemotron-devs @NVIDIA/trt-llm-torch-models-vlm-devs @NVIDIA/trt-llm-torch-models-devs +/tensorrt_llm/_torch/models/modeling_radio.py @NVIDIA/trt-llm-torch-models-nemotron-devs @NVIDIA/trt-llm-torch-models-vlm-devs @NVIDIA/trt-llm-torch-models-devs /tensorrt_llm/_torch/models/modeling_nemotron_nas.py @NVIDIA/trt-llm-torch-models-nemotron-devs @NVIDIA/trt-llm-torch-models-devs /tensorrt_llm/_torch/models/modeling_nemotron_h.py @NVIDIA/trt-llm-torch-models-nemotron-devs @NVIDIA/trt-llm-torch-models-devs /tensorrt_llm/_torch/models/modeling_nemotron_nas.py @NVIDIA/trt-llm-torch-models-nemotron-devs @NVIDIA/trt-llm-torch-models-devs diff --git a/cpp/kernels/fmha_v2/setup.py b/cpp/kernels/fmha_v2/setup.py index 24a80b8d713..ced8c67764b 100644 --- a/cpp/kernels/fmha_v2/setup.py +++ b/cpp/kernels/fmha_v2/setup.py @@ -1982,8 +1982,8 @@ def selected_mask_types(kspec): custom_mask = '0' # encoder models (head_size = 32 / 64 / 128) need packed_qkv input layout + padding mask. elif kspec.input_layout == InputLayout.PACKED_QKV: - # NOTE: 72 is added for vision transformer - if kspec.head_size not in [32, 64, 72, 128]: + # NOTE: 72/80 are added for vision transformer + if kspec.head_size not in [32, 64, 72, 80, 128]: padding_mask = '0' # only cross attention (head_size = 32/64/128) needs contiguous_q_kv input layout + padding mask / custom_mask. elif kspec.input_layout == InputLayout.CONTIGUOUS_Q_KV: diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.h b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.h index 0c2f3aed72b..4d0731ce426 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.h +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.h @@ -49,7 +49,6 @@ extern unsigned char cubin_fmha_v2_flash_attention_e4m3_64_256_S_q_paged_kv_128_ extern unsigned char cubin_fmha_v2_flash_attention_e4m3_64_256_S_q_paged_kv_128_softcapping_tma_ws_sm90_cu_cubin[]; extern unsigned char cubin_fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_alibi_tma_ws_sm90_cu_cubin[]; extern unsigned char cubin_fmha_v2_flash_attention_e4m3_64_256_S_q_paged_kv_128_alibi_tma_ws_sm90_cu_cubin[]; -extern unsigned char cubin_fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_sage_64_64_256_output_bf16_tma_ws_sm90_cu_cubin[]; extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_tma_ws_sm90_cu_cubin[]; extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_softcapping_tma_ws_sm90_cu_cubin[]; extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_q_kv_128_tma_ws_sm90_cu_cubin[]; @@ -261,7 +260,7 @@ extern void run_fmha_v2_flash_attention_e4m3_64_128_S_q_paged_kv_192_tma_ws_sm90 extern void run_fmha_v2_flash_attention_e4m3_64_128_S_q_paged_kv_256_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream); extern void run_fmha_v2_flash_attention_e4m3_64_128_S_q_paged_kv_256_softcapping_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream); extern void run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream); -extern void run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_output_bf16_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream); +extern void run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_softmax_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream); extern void run_fmha_v2_flash_attention_e4m3_64_256_S_qkv_32_alibi_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream); extern void run_fmha_v2_flash_attention_e4m3_64_256_S_qkv_40_alibi_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream); extern void run_fmha_v2_flash_attention_e4m3_64_256_S_qkv_48_alibi_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream); @@ -282,7 +281,8 @@ extern void run_fmha_v2_flash_attention_e4m3_64_256_S_q_paged_kv_104_alibi_tma_w extern void run_fmha_v2_flash_attention_e4m3_64_128_S_q_paged_kv_160_alibi_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream); extern void run_fmha_v2_flash_attention_e4m3_64_128_S_q_paged_kv_192_alibi_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream); extern void run_fmha_v2_flash_attention_e4m3_64_128_S_q_paged_kv_256_alibi_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream); -extern void run_fmha_v2_flash_attention_e4m3_64_256_S_qkv_80_sage_64_64_256_output_bf16_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream); +extern void run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_output_bf16_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream); +extern void run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_softmax_output_bf16_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream); extern void run_fmha_v2_flash_attention_fp16_fp32_64_256_S_qkv_32_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream); extern void run_fmha_v2_flash_attention_fp16_fp32_64_256_S_qkv_40_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream); extern void run_fmha_v2_flash_attention_fp16_fp32_64_256_S_qkv_48_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream); @@ -1354,7 +1354,6 @@ extern uint32_t cubin_fmha_v2_flash_attention_e4m3_64_256_S_q_paged_kv_128_tma_w extern uint32_t cubin_fmha_v2_flash_attention_e4m3_64_256_S_q_paged_kv_128_softcapping_tma_ws_sm90_cu_cubin_len; extern uint32_t cubin_fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_alibi_tma_ws_sm90_cu_cubin_len; extern uint32_t cubin_fmha_v2_flash_attention_e4m3_64_256_S_q_paged_kv_128_alibi_tma_ws_sm90_cu_cubin_len; -extern uint32_t cubin_fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_sage_64_64_256_output_bf16_tma_ws_sm90_cu_cubin_len; extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_tma_ws_sm90_cu_cubin_len; extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_softcapping_tma_ws_sm90_cu_cubin_len; extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_q_kv_128_tma_ws_sm90_cu_cubin_len; @@ -1645,6 +1644,7 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2 { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_fp16_64_128_S_qkv_72_causal_tma_ws_sm90_kernel", 164096, 384, 64, 1, 0, false, true, true, false, false, false, false, false, run_fmha_v2_flash_attention_fp16_64_128_S_qkv_72_tma_ws_sm90}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_fp16_64_128_S_qkv_72_sliding_or_chunked_causal_tma_ws_sm90_kernel", 164096, 384, 64, 2, 0, false, true, true, false, false, false, false, false, run_fmha_v2_flash_attention_fp16_64_128_S_qkv_72_tma_ws_sm90}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_fp16_64_128_S_qkv_72_custom_mask_tma_ws_sm90_kernel", 164096, 384, 64, 3, 0, false, true, true, false, false, false, false, false, run_fmha_v2_flash_attention_fp16_64_128_S_qkv_72_tma_ws_sm90}, +{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_fp16_64_128_S_qkv_80_tma_ws_sm90_kernel", 164096, 384, 64, 0, 0, false, true, true, false, false, false, false, false, run_fmha_v2_flash_attention_fp16_64_128_S_qkv_80_tma_ws_sm90}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_fp16_64_128_S_qkv_80_causal_tma_ws_sm90_kernel", 164096, 384, 64, 1, 0, false, true, true, false, false, false, false, false, run_fmha_v2_flash_attention_fp16_64_128_S_qkv_80_tma_ws_sm90}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_fp16_64_128_S_qkv_80_sliding_or_chunked_causal_tma_ws_sm90_kernel", 164096, 384, 64, 2, 0, false, true, true, false, false, false, false, false, run_fmha_v2_flash_attention_fp16_64_128_S_qkv_80_tma_ws_sm90}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_fp16_64_128_S_qkv_80_custom_mask_tma_ws_sm90_kernel", 164096, 384, 64, 3, 0, false, true, true, false, false, false, false, false, run_fmha_v2_flash_attention_fp16_64_128_S_qkv_80_tma_ws_sm90}, @@ -1772,6 +1772,7 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2 { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_72_causal_tma_ws_sm90_kernel", 164096, 384, 64, 1, 0, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_72_tma_ws_sm90}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_72_sliding_or_chunked_causal_tma_ws_sm90_kernel", 164096, 384, 64, 2, 0, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_72_tma_ws_sm90}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_72_custom_mask_tma_ws_sm90_kernel", 164096, 384, 64, 3, 0, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_72_tma_ws_sm90}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_80_tma_ws_sm90_kernel", 164096, 384, 64, 0, 0, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_80_tma_ws_sm90}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_80_causal_tma_ws_sm90_kernel", 164096, 384, 64, 1, 0, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_80_tma_ws_sm90}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_80_sliding_or_chunked_causal_tma_ws_sm90_kernel", 164096, 384, 64, 2, 0, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_80_tma_ws_sm90}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_80_custom_mask_tma_ws_sm90_kernel", 164096, 384, 64, 3, 0, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_80_tma_ws_sm90}, @@ -1899,6 +1900,7 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2 { DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 256, 64, 64, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_e4m3_64_256_S_qkv_64_causal_tma_ws_sm90_kernel", 164224, 384, 64, 1, 0, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_e4m3_64_256_S_qkv_64_tma_ws_sm90}, { DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 256, 64, 64, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_e4m3_64_256_S_qkv_64_sliding_or_chunked_causal_tma_ws_sm90_kernel", 156032, 384, 64, 2, 0, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_e4m3_64_256_S_qkv_64_tma_ws_sm90}, { DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 256, 64, 64, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_e4m3_64_256_S_qkv_64_custom_mask_tma_ws_sm90_kernel", 164224, 384, 64, 3, 0, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_e4m3_64_256_S_qkv_64_tma_ws_sm90}, +{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 256, 80, 80, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_e4m3_64_256_S_qkv_80_tma_ws_sm90_kernel", 196864, 384, 64, 0, 0, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_e4m3_64_256_S_qkv_80_tma_ws_sm90}, { DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 256, 80, 80, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_e4m3_64_256_S_qkv_80_causal_tma_ws_sm90_kernel", 196864, 384, 64, 1, 0, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_e4m3_64_256_S_qkv_80_tma_ws_sm90}, { DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 256, 80, 80, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_e4m3_64_256_S_qkv_80_sliding_or_chunked_causal_tma_ws_sm90_kernel", 180480, 384, 64, 2, 0, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_e4m3_64_256_S_qkv_80_tma_ws_sm90}, { DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 256, 80, 80, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_e4m3_64_256_S_qkv_80_custom_mask_tma_ws_sm90_kernel", 196864, 384, 64, 3, 0, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_e4m3_64_256_S_qkv_80_tma_ws_sm90}, @@ -1973,8 +1975,8 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2 { DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 128, 256, 256, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_e4m3_64_128_S_q_paged_kv_256_sliding_or_chunked_causal_softcapping_tma_ws_sm90_kernel", 196864, 384, 64, 2, 2, false, true, true, true, false, false, true, false, run_fmha_v2_flash_attention_e4m3_64_128_S_q_paged_kv_256_softcapping_tma_ws_sm90}, { DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_tma_ws_sm90_kernel", 164096, 384, 64, 0, 3, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_tma_ws_sm90}, { DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_causal_tma_ws_sm90_kernel", 164096, 384, 64, 1, 3, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_tma_ws_sm90}, -{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_output_bf16_tma_ws_sm90_kernel", 164096, 384, 64, 0, 3, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_output_bf16_tma_ws_sm90}, -{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_causal_output_bf16_tma_ws_sm90_kernel", 164096, 384, 64, 1, 3, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_output_bf16_tma_ws_sm90}, +{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_softmax_tma_ws_sm90_kernel", 164096, 384, 64, 0, 3, false, true, true, true, false, false, false, true, run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_softmax_tma_ws_sm90}, +{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_causal_softmax_tma_ws_sm90_kernel", 164096, 384, 64, 1, 3, false, true, true, true, false, false, false, true, run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_softmax_tma_ws_sm90}, { DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 256, 32, 32, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_e4m3_64_256_S_qkv_32_causal_alibi_tma_ws_sm90_kernel", 82304, 384, 64, 1, 0, false, true, true, true, true, false, false, false, run_fmha_v2_flash_attention_e4m3_64_256_S_qkv_32_alibi_tma_ws_sm90}, { DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 256, 40, 40, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_e4m3_64_256_S_qkv_40_causal_alibi_tma_ws_sm90_kernel", 164224, 384, 64, 1, 0, false, true, true, true, true, false, false, false, run_fmha_v2_flash_attention_e4m3_64_256_S_qkv_40_alibi_tma_ws_sm90}, { DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 256, 48, 48, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_e4m3_64_256_S_qkv_48_causal_alibi_tma_ws_sm90_kernel", 164224, 384, 64, 1, 0, false, true, true, true, true, false, false, false, run_fmha_v2_flash_attention_e4m3_64_256_S_qkv_48_alibi_tma_ws_sm90}, @@ -1997,8 +1999,10 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2 { DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 128, 160, 160, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_e4m3_64_128_S_q_paged_kv_160_causal_alibi_tma_ws_sm90_kernel", 229632, 384, 64, 1, 2, false, true, true, true, true, false, false, false, run_fmha_v2_flash_attention_e4m3_64_128_S_q_paged_kv_160_alibi_tma_ws_sm90}, { DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 128, 192, 192, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_e4m3_64_128_S_q_paged_kv_192_causal_alibi_tma_ws_sm90_kernel", 229632, 384, 64, 1, 2, false, true, true, true, true, false, false, false, run_fmha_v2_flash_attention_e4m3_64_128_S_q_paged_kv_192_alibi_tma_ws_sm90}, { DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 128, 256, 256, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_e4m3_64_128_S_q_paged_kv_256_causal_alibi_tma_ws_sm90_kernel", 229632, 384, 64, 1, 2, false, true, true, true, true, false, false, false, run_fmha_v2_flash_attention_e4m3_64_128_S_q_paged_kv_256_alibi_tma_ws_sm90}, -{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 256, 80, 80, 64, 64, 256, kSM_90, nullptr, 0, "fmha_v2_flash_attention_e4m3_64_256_S_qkv_80_sage_64_64_256_output_bf16_tma_ws_sm90_kernel", 196864, 384, 64, 0, 0, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_e4m3_64_256_S_qkv_80_sage_64_64_256_output_bf16_tma_ws_sm90}, -{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 256, 128, 128, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_sage_64_64_256_output_bf16_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_sage_64_64_256_output_bf16_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_sage_64_64_256_output_bf16_tma_ws_sm90_kernel", 196864, 384, 64, 0, 0, false, true, true, true, false, false, false, false, nullptr}, +{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_output_bf16_tma_ws_sm90_kernel", 164096, 384, 64, 0, 3, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_output_bf16_tma_ws_sm90}, +{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_causal_output_bf16_tma_ws_sm90_kernel", 164096, 384, 64, 1, 3, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_output_bf16_tma_ws_sm90}, +{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_softmax_output_bf16_tma_ws_sm90_kernel", 164096, 384, 64, 0, 3, false, true, true, true, false, false, false, true, run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_softmax_output_bf16_tma_ws_sm90}, +{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_causal_softmax_output_bf16_tma_ws_sm90_kernel", 164096, 384, 64, 1, 3, false, true, true, true, false, false, false, true, run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_softmax_output_bf16_tma_ws_sm90}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 256, 32, 32, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_256_S_qkv_32_tma_ws_sm90_kernel", 73984, 384, 64, 0, 0, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_fp16_fp32_64_256_S_qkv_32_tma_ws_sm90}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 256, 32, 32, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_256_S_qkv_32_causal_tma_ws_sm90_kernel", 73984, 384, 64, 1, 0, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_fp16_fp32_64_256_S_qkv_32_tma_ws_sm90}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 256, 32, 32, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_256_S_qkv_32_sliding_or_chunked_causal_tma_ws_sm90_kernel", 73984, 384, 64, 2, 0, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_fp16_fp32_64_256_S_qkv_32_tma_ws_sm90}, @@ -2017,6 +2021,7 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2 { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_72_causal_tma_ws_sm90_kernel", 164096, 384, 64, 1, 0, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_72_tma_ws_sm90}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_72_sliding_or_chunked_causal_tma_ws_sm90_kernel", 164096, 384, 64, 2, 0, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_72_tma_ws_sm90}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_72_custom_mask_tma_ws_sm90_kernel", 164096, 384, 64, 3, 0, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_72_tma_ws_sm90}, +{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_80_tma_ws_sm90_kernel", 164096, 384, 64, 0, 0, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_80_tma_ws_sm90}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_80_causal_tma_ws_sm90_kernel", 164096, 384, 64, 1, 0, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_80_tma_ws_sm90}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_80_sliding_or_chunked_causal_tma_ws_sm90_kernel", 164096, 384, 64, 2, 0, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_80_tma_ws_sm90}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_80_custom_mask_tma_ws_sm90_kernel", 164096, 384, 64, 3, 0, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_80_tma_ws_sm90}, @@ -2147,6 +2152,7 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2 { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_fp16_64_128_S_qkv_72_causal_sm90_kernel_nl_tiled", 81920, 128, 64, 1, 0, false, true, false, false, true, true, false, false, run_fmha_v2_flash_attention_fp16_64_128_S_qkv_72_sm90_nl_tiled}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_fp16_64_128_S_qkv_72_sliding_or_chunked_causal_sm90_kernel_nl_tiled", 81920, 128, 64, 2, 0, false, true, false, false, true, true, false, false, run_fmha_v2_flash_attention_fp16_64_128_S_qkv_72_sm90_nl_tiled}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_fp16_64_128_S_qkv_72_custom_mask_sm90_kernel_nl_tiled", 81920, 128, 64, 3, 0, false, true, false, false, true, true, false, false, run_fmha_v2_flash_attention_fp16_64_128_S_qkv_72_sm90_nl_tiled}, +{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_fp16_64_128_S_qkv_80_sm90_kernel_nl_tiled", 81920, 128, 64, 0, 0, false, true, false, false, true, true, false, false, run_fmha_v2_flash_attention_fp16_64_128_S_qkv_80_sm90_nl_tiled}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_fp16_64_128_S_qkv_80_causal_sm90_kernel_nl_tiled", 81920, 128, 64, 1, 0, false, true, false, false, true, true, false, false, run_fmha_v2_flash_attention_fp16_64_128_S_qkv_80_sm90_nl_tiled}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_fp16_64_128_S_qkv_80_sliding_or_chunked_causal_sm90_kernel_nl_tiled", 81920, 128, 64, 2, 0, false, true, false, false, true, true, false, false, run_fmha_v2_flash_attention_fp16_64_128_S_qkv_80_sm90_nl_tiled}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_fp16_64_128_S_qkv_80_custom_mask_sm90_kernel_nl_tiled", 81920, 128, 64, 3, 0, false, true, false, false, true, true, false, false, run_fmha_v2_flash_attention_fp16_64_128_S_qkv_80_sm90_nl_tiled}, @@ -2190,6 +2196,7 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2 { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 72, 72, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_fp16_64_32_S_qkv_72_causal_sm90_kernel_nl", 32768, 128, 64, 1, 0, false, true, false, false, true, false, false, false, run_fmha_v2_flash_attention_fp16_64_32_S_qkv_72_sm90_nl}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 72, 72, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_fp16_64_32_S_qkv_72_sliding_or_chunked_causal_sm90_kernel_nl", 32768, 128, 64, 2, 0, false, true, false, false, true, false, false, false, run_fmha_v2_flash_attention_fp16_64_32_S_qkv_72_sm90_nl}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 72, 72, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_fp16_64_32_S_qkv_72_custom_mask_sm90_kernel_nl", 32768, 128, 64, 3, 0, false, true, false, false, true, false, false, false, run_fmha_v2_flash_attention_fp16_64_32_S_qkv_72_sm90_nl}, +{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 80, 80, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_fp16_64_32_S_qkv_80_sm90_kernel_nl", 32768, 128, 64, 0, 0, false, true, false, false, true, false, false, false, run_fmha_v2_flash_attention_fp16_64_32_S_qkv_80_sm90_nl}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 80, 80, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_fp16_64_32_S_qkv_80_causal_sm90_kernel_nl", 32768, 128, 64, 1, 0, false, true, false, false, true, false, false, false, run_fmha_v2_flash_attention_fp16_64_32_S_qkv_80_sm90_nl}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 80, 80, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_fp16_64_32_S_qkv_80_sliding_or_chunked_causal_sm90_kernel_nl", 32768, 128, 64, 2, 0, false, true, false, false, true, false, false, false, run_fmha_v2_flash_attention_fp16_64_32_S_qkv_80_sm90_nl}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 80, 80, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_fp16_64_32_S_qkv_80_custom_mask_sm90_kernel_nl", 32768, 128, 64, 3, 0, false, true, false, false, true, false, false, false, run_fmha_v2_flash_attention_fp16_64_32_S_qkv_80_sm90_nl}, @@ -2241,6 +2248,7 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2 { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_72_causal_sm90_kernel_nl_tiled", 81920, 128, 64, 1, 0, false, true, false, true, true, true, false, false, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_72_sm90_nl_tiled}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_72_sliding_or_chunked_causal_sm90_kernel_nl_tiled", 81920, 128, 64, 2, 0, false, true, false, true, true, true, false, false, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_72_sm90_nl_tiled}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_72_custom_mask_sm90_kernel_nl_tiled", 81920, 128, 64, 3, 0, false, true, false, true, true, true, false, false, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_72_sm90_nl_tiled}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_80_sm90_kernel_nl_tiled", 81920, 128, 64, 0, 0, false, true, false, true, true, true, false, false, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_80_sm90_nl_tiled}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_80_causal_sm90_kernel_nl_tiled", 81920, 128, 64, 1, 0, false, true, false, true, true, true, false, false, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_80_sm90_nl_tiled}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_80_sliding_or_chunked_causal_sm90_kernel_nl_tiled", 81920, 128, 64, 2, 0, false, true, false, true, true, true, false, false, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_80_sm90_nl_tiled}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_80_custom_mask_sm90_kernel_nl_tiled", 81920, 128, 64, 3, 0, false, true, false, true, true, true, false, false, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_80_sm90_nl_tiled}, @@ -2284,6 +2292,7 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2 { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 32, 72, 72, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_32_S_qkv_72_causal_sm90_kernel_nl", 32768, 128, 64, 1, 0, false, true, false, true, true, false, false, false, run_fmha_v2_flash_attention_bf16_64_32_S_qkv_72_sm90_nl}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 32, 72, 72, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_32_S_qkv_72_sliding_or_chunked_causal_sm90_kernel_nl", 32768, 128, 64, 2, 0, false, true, false, true, true, false, false, false, run_fmha_v2_flash_attention_bf16_64_32_S_qkv_72_sm90_nl}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 32, 72, 72, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_32_S_qkv_72_custom_mask_sm90_kernel_nl", 32768, 128, 64, 3, 0, false, true, false, true, true, false, false, false, run_fmha_v2_flash_attention_bf16_64_32_S_qkv_72_sm90_nl}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 32, 80, 80, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_32_S_qkv_80_sm90_kernel_nl", 32768, 128, 64, 0, 0, false, true, false, true, true, false, false, false, run_fmha_v2_flash_attention_bf16_64_32_S_qkv_80_sm90_nl}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 32, 80, 80, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_32_S_qkv_80_causal_sm90_kernel_nl", 32768, 128, 64, 1, 0, false, true, false, true, true, false, false, false, run_fmha_v2_flash_attention_bf16_64_32_S_qkv_80_sm90_nl}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 32, 80, 80, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_32_S_qkv_80_sliding_or_chunked_causal_sm90_kernel_nl", 32768, 128, 64, 2, 0, false, true, false, true, true, false, false, false, run_fmha_v2_flash_attention_bf16_64_32_S_qkv_80_sm90_nl}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 32, 80, 80, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_32_S_qkv_80_custom_mask_sm90_kernel_nl", 32768, 128, 64, 3, 0, false, true, false, true, true, false, false, false, run_fmha_v2_flash_attention_bf16_64_32_S_qkv_80_sm90_nl}, @@ -2335,6 +2344,7 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2 { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_72_causal_sm90_kernel_nl_tiled", 81920, 128, 64, 1, 0, false, true, false, true, true, true, false, false, run_fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_72_sm90_nl_tiled}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_72_sliding_or_chunked_causal_sm90_kernel_nl_tiled", 81920, 128, 64, 2, 0, false, true, false, true, true, true, false, false, run_fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_72_sm90_nl_tiled}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_72_custom_mask_sm90_kernel_nl_tiled", 81920, 128, 64, 3, 0, false, true, false, true, true, true, false, false, run_fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_72_sm90_nl_tiled}, +{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_80_sm90_kernel_nl_tiled", 81920, 128, 64, 0, 0, false, true, false, true, true, true, false, false, run_fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_80_sm90_nl_tiled}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_80_causal_sm90_kernel_nl_tiled", 81920, 128, 64, 1, 0, false, true, false, true, true, true, false, false, run_fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_80_sm90_nl_tiled}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_80_sliding_or_chunked_causal_sm90_kernel_nl_tiled", 81920, 128, 64, 2, 0, false, true, false, true, true, true, false, false, run_fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_80_sm90_nl_tiled}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_80_custom_mask_sm90_kernel_nl_tiled", 81920, 128, 64, 3, 0, false, true, false, true, true, true, false, false, run_fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_80_sm90_nl_tiled}, @@ -2378,6 +2388,7 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2 { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 72, 72, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_72_causal_sm90_kernel_nl", 32768, 128, 64, 1, 0, false, true, false, true, true, false, false, false, run_fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_72_sm90_nl}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 72, 72, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_72_sliding_or_chunked_causal_sm90_kernel_nl", 32768, 128, 64, 2, 0, false, true, false, true, true, false, false, false, run_fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_72_sm90_nl}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 72, 72, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_72_custom_mask_sm90_kernel_nl", 32768, 128, 64, 3, 0, false, true, false, true, true, false, false, false, run_fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_72_sm90_nl}, +{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 80, 80, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_80_sm90_kernel_nl", 32768, 128, 64, 0, 0, false, true, false, true, true, false, false, false, run_fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_80_sm90_nl}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 80, 80, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_80_causal_sm90_kernel_nl", 32768, 128, 64, 1, 0, false, true, false, true, true, false, false, false, run_fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_80_sm90_nl}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 80, 80, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_80_sliding_or_chunked_causal_sm90_kernel_nl", 32768, 128, 64, 2, 0, false, true, false, true, true, false, false, false, run_fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_80_sm90_nl}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 80, 80, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_80_custom_mask_sm90_kernel_nl", 32768, 128, 64, 3, 0, false, true, false, true, true, false, false, false, run_fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_80_sm90_nl}, @@ -2453,6 +2464,7 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2 { DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 32, 72, 72, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_72_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_72_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_72_causal_sm89_kernel_nl", 32768, 128, 64, 1, 2, false, true, false, true, true, false, false, true, nullptr}, { DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 32, 72, 72, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_72_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_72_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_72_sliding_or_chunked_causal_sm89_kernel_nl", 32768, 128, 64, 2, 2, false, true, false, true, true, false, false, true, nullptr}, { DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 32, 72, 72, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_72_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_72_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_72_custom_mask_sm89_kernel_nl", 32768, 128, 64, 3, 2, false, true, false, true, true, false, false, true, nullptr}, +{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 32, 80, 80, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sm89_kernel_nl", 32768, 128, 64, 0, 0, false, true, false, true, true, false, false, true, nullptr}, { DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 32, 80, 80, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_causal_sm89_kernel_nl", 32768, 128, 64, 1, 0, false, true, false, true, true, false, false, true, nullptr}, { DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 32, 80, 80, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sliding_or_chunked_causal_sm89_kernel_nl", 32768, 128, 64, 2, 0, false, true, false, true, true, false, false, true, nullptr}, { DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 32, 80, 80, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_custom_mask_sm89_kernel_nl", 32768, 128, 64, 3, 0, false, true, false, true, true, false, false, true, nullptr}, @@ -2530,6 +2542,7 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2 { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_fp16_64_128_S_qkv_72_causal_sm89_kernel_nl_tiled", 81920, 128, 64, 1, 0, false, true, false, false, true, true, false, true, run_fmha_v2_flash_attention_fp16_64_128_S_qkv_72_sm89_nl_tiled}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_fp16_64_128_S_qkv_72_sliding_or_chunked_causal_sm89_kernel_nl_tiled", 81920, 128, 64, 2, 0, false, true, false, false, true, true, false, true, run_fmha_v2_flash_attention_fp16_64_128_S_qkv_72_sm89_nl_tiled}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_fp16_64_128_S_qkv_72_custom_mask_sm89_kernel_nl_tiled", 81920, 128, 64, 3, 0, false, true, false, false, true, true, false, true, run_fmha_v2_flash_attention_fp16_64_128_S_qkv_72_sm89_nl_tiled}, +{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_fp16_64_128_S_qkv_80_sm89_kernel_nl_tiled", 81920, 128, 64, 0, 0, false, true, false, false, true, true, false, true, run_fmha_v2_flash_attention_fp16_64_128_S_qkv_80_sm89_nl_tiled}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_fp16_64_128_S_qkv_80_causal_sm89_kernel_nl_tiled", 81920, 128, 64, 1, 0, false, true, false, false, true, true, false, true, run_fmha_v2_flash_attention_fp16_64_128_S_qkv_80_sm89_nl_tiled}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_fp16_64_128_S_qkv_80_sliding_or_chunked_causal_sm89_kernel_nl_tiled", 81920, 128, 64, 2, 0, false, true, false, false, true, true, false, true, run_fmha_v2_flash_attention_fp16_64_128_S_qkv_80_sm89_nl_tiled}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_fp16_64_128_S_qkv_80_custom_mask_sm89_kernel_nl_tiled", 81920, 128, 64, 3, 0, false, true, false, false, true, true, false, true, run_fmha_v2_flash_attention_fp16_64_128_S_qkv_80_sm89_nl_tiled}, @@ -2573,6 +2586,7 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2 { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 72, 72, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_fp16_64_32_S_qkv_72_causal_sm89_kernel_nl", 32768, 128, 64, 1, 0, false, true, false, false, true, false, false, true, run_fmha_v2_flash_attention_fp16_64_32_S_qkv_72_sm89_nl}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 72, 72, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_fp16_64_32_S_qkv_72_sliding_or_chunked_causal_sm89_kernel_nl", 32768, 128, 64, 2, 0, false, true, false, false, true, false, false, true, run_fmha_v2_flash_attention_fp16_64_32_S_qkv_72_sm89_nl}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 72, 72, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_fp16_64_32_S_qkv_72_custom_mask_sm89_kernel_nl", 32768, 128, 64, 3, 0, false, true, false, false, true, false, false, true, run_fmha_v2_flash_attention_fp16_64_32_S_qkv_72_sm89_nl}, +{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 80, 80, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_fp16_64_32_S_qkv_80_sm89_kernel_nl", 32768, 128, 64, 0, 0, false, true, false, false, true, false, false, true, run_fmha_v2_flash_attention_fp16_64_32_S_qkv_80_sm89_nl}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 80, 80, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_fp16_64_32_S_qkv_80_causal_sm89_kernel_nl", 32768, 128, 64, 1, 0, false, true, false, false, true, false, false, true, run_fmha_v2_flash_attention_fp16_64_32_S_qkv_80_sm89_nl}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 80, 80, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_fp16_64_32_S_qkv_80_sliding_or_chunked_causal_sm89_kernel_nl", 32768, 128, 64, 2, 0, false, true, false, false, true, false, false, true, run_fmha_v2_flash_attention_fp16_64_32_S_qkv_80_sm89_nl}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 80, 80, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_fp16_64_32_S_qkv_80_custom_mask_sm89_kernel_nl", 32768, 128, 64, 3, 0, false, true, false, false, true, false, false, true, run_fmha_v2_flash_attention_fp16_64_32_S_qkv_80_sm89_nl}, @@ -2732,6 +2746,7 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2 { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_72_causal_sm89_kernel_nl_tiled", 81920, 128, 64, 1, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_72_sm89_nl_tiled}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_72_sliding_or_chunked_causal_sm89_kernel_nl_tiled", 81920, 128, 64, 2, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_72_sm89_nl_tiled}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_72_custom_mask_sm89_kernel_nl_tiled", 81920, 128, 64, 3, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_72_sm89_nl_tiled}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_80_sm89_kernel_nl_tiled", 81920, 128, 64, 0, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_80_sm89_nl_tiled}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_80_causal_sm89_kernel_nl_tiled", 81920, 128, 64, 1, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_80_sm89_nl_tiled}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_80_sliding_or_chunked_causal_sm89_kernel_nl_tiled", 81920, 128, 64, 2, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_80_sm89_nl_tiled}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_80_custom_mask_sm89_kernel_nl_tiled", 81920, 128, 64, 3, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_80_sm89_nl_tiled}, @@ -2775,6 +2790,7 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2 { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 32, 72, 72, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_bf16_64_32_S_qkv_72_causal_sm89_kernel_nl", 32768, 128, 64, 1, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_bf16_64_32_S_qkv_72_sm89_nl}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 32, 72, 72, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_bf16_64_32_S_qkv_72_sliding_or_chunked_causal_sm89_kernel_nl", 32768, 128, 64, 2, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_bf16_64_32_S_qkv_72_sm89_nl}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 32, 72, 72, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_bf16_64_32_S_qkv_72_custom_mask_sm89_kernel_nl", 32768, 128, 64, 3, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_bf16_64_32_S_qkv_72_sm89_nl}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 32, 80, 80, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_bf16_64_32_S_qkv_80_sm89_kernel_nl", 32768, 128, 64, 0, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_bf16_64_32_S_qkv_80_sm89_nl}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 32, 80, 80, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_bf16_64_32_S_qkv_80_causal_sm89_kernel_nl", 32768, 128, 64, 1, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_bf16_64_32_S_qkv_80_sm89_nl}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 32, 80, 80, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_bf16_64_32_S_qkv_80_sliding_or_chunked_causal_sm89_kernel_nl", 32768, 128, 64, 2, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_bf16_64_32_S_qkv_80_sm89_nl}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 32, 80, 80, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_bf16_64_32_S_qkv_80_custom_mask_sm89_kernel_nl", 32768, 128, 64, 3, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_bf16_64_32_S_qkv_80_sm89_nl}, @@ -2934,6 +2950,7 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2 { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_72_causal_sm89_kernel_nl_tiled", 81920, 128, 64, 1, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_72_sm89_nl_tiled}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_72_sliding_or_chunked_causal_sm89_kernel_nl_tiled", 81920, 128, 64, 2, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_72_sm89_nl_tiled}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_72_custom_mask_sm89_kernel_nl_tiled", 81920, 128, 64, 3, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_72_sm89_nl_tiled}, +{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_80_sm89_kernel_nl_tiled", 81920, 128, 64, 0, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_80_sm89_nl_tiled}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_80_causal_sm89_kernel_nl_tiled", 81920, 128, 64, 1, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_80_sm89_nl_tiled}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_80_sliding_or_chunked_causal_sm89_kernel_nl_tiled", 81920, 128, 64, 2, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_80_sm89_nl_tiled}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_80_custom_mask_sm89_kernel_nl_tiled", 81920, 128, 64, 3, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_80_sm89_nl_tiled}, @@ -2977,6 +2994,7 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2 { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 72, 72, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_72_causal_sm89_kernel_nl", 32768, 128, 64, 1, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_72_sm89_nl}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 72, 72, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_72_sliding_or_chunked_causal_sm89_kernel_nl", 32768, 128, 64, 2, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_72_sm89_nl}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 72, 72, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_72_custom_mask_sm89_kernel_nl", 32768, 128, 64, 3, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_72_sm89_nl}, +{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 80, 80, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_80_sm89_kernel_nl", 32768, 128, 64, 0, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_80_sm89_nl}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 80, 80, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_80_causal_sm89_kernel_nl", 32768, 128, 64, 1, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_80_sm89_nl}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 80, 80, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_80_sliding_or_chunked_causal_sm89_kernel_nl", 32768, 128, 64, 2, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_80_sm89_nl}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 80, 80, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_80_custom_mask_sm89_kernel_nl", 32768, 128, 64, 3, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_80_sm89_nl}, @@ -3139,6 +3157,7 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2 { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_fp16_64_128_S_qkv_72_causal_sm80_kernel_nl_tiled", 81920, 128, 64, 1, 0, false, true, false, false, true, true, false, true, run_fmha_v2_flash_attention_fp16_64_128_S_qkv_72_sm80_nl_tiled}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_fp16_64_128_S_qkv_72_sliding_or_chunked_causal_sm80_kernel_nl_tiled", 81920, 128, 64, 2, 0, false, true, false, false, true, true, false, true, run_fmha_v2_flash_attention_fp16_64_128_S_qkv_72_sm80_nl_tiled}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_fp16_64_128_S_qkv_72_custom_mask_sm80_kernel_nl_tiled", 81920, 128, 64, 3, 0, false, true, false, false, true, true, false, true, run_fmha_v2_flash_attention_fp16_64_128_S_qkv_72_sm80_nl_tiled}, +{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_fp16_64_128_S_qkv_80_sm80_kernel_nl_tiled", 81920, 128, 64, 0, 0, false, true, false, false, true, true, false, true, run_fmha_v2_flash_attention_fp16_64_128_S_qkv_80_sm80_nl_tiled}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_fp16_64_128_S_qkv_80_causal_sm80_kernel_nl_tiled", 81920, 128, 64, 1, 0, false, true, false, false, true, true, false, true, run_fmha_v2_flash_attention_fp16_64_128_S_qkv_80_sm80_nl_tiled}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_fp16_64_128_S_qkv_80_sliding_or_chunked_causal_sm80_kernel_nl_tiled", 81920, 128, 64, 2, 0, false, true, false, false, true, true, false, true, run_fmha_v2_flash_attention_fp16_64_128_S_qkv_80_sm80_nl_tiled}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_fp16_64_128_S_qkv_80_custom_mask_sm80_kernel_nl_tiled", 81920, 128, 64, 3, 0, false, true, false, false, true, true, false, true, run_fmha_v2_flash_attention_fp16_64_128_S_qkv_80_sm80_nl_tiled}, @@ -3182,6 +3201,7 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2 { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 72, 72, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_fp16_64_32_S_qkv_72_causal_sm80_kernel_nl", 32768, 128, 64, 1, 0, false, true, false, false, true, false, false, true, run_fmha_v2_flash_attention_fp16_64_32_S_qkv_72_sm80_nl}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 72, 72, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_fp16_64_32_S_qkv_72_sliding_or_chunked_causal_sm80_kernel_nl", 32768, 128, 64, 2, 0, false, true, false, false, true, false, false, true, run_fmha_v2_flash_attention_fp16_64_32_S_qkv_72_sm80_nl}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 72, 72, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_fp16_64_32_S_qkv_72_custom_mask_sm80_kernel_nl", 32768, 128, 64, 3, 0, false, true, false, false, true, false, false, true, run_fmha_v2_flash_attention_fp16_64_32_S_qkv_72_sm80_nl}, +{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 80, 80, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_fp16_64_32_S_qkv_80_sm80_kernel_nl", 32768, 128, 64, 0, 0, false, true, false, false, true, false, false, true, run_fmha_v2_flash_attention_fp16_64_32_S_qkv_80_sm80_nl}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 80, 80, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_fp16_64_32_S_qkv_80_causal_sm80_kernel_nl", 32768, 128, 64, 1, 0, false, true, false, false, true, false, false, true, run_fmha_v2_flash_attention_fp16_64_32_S_qkv_80_sm80_nl}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 80, 80, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_fp16_64_32_S_qkv_80_sliding_or_chunked_causal_sm80_kernel_nl", 32768, 128, 64, 2, 0, false, true, false, false, true, false, false, true, run_fmha_v2_flash_attention_fp16_64_32_S_qkv_80_sm80_nl}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 80, 80, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_fp16_64_32_S_qkv_80_custom_mask_sm80_kernel_nl", 32768, 128, 64, 3, 0, false, true, false, false, true, false, false, true, run_fmha_v2_flash_attention_fp16_64_32_S_qkv_80_sm80_nl}, @@ -3341,6 +3361,7 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2 { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_72_causal_sm80_kernel_nl_tiled", 81920, 128, 64, 1, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_72_sm80_nl_tiled}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_72_sliding_or_chunked_causal_sm80_kernel_nl_tiled", 81920, 128, 64, 2, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_72_sm80_nl_tiled}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_72_custom_mask_sm80_kernel_nl_tiled", 81920, 128, 64, 3, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_72_sm80_nl_tiled}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_80_sm80_kernel_nl_tiled", 81920, 128, 64, 0, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_80_sm80_nl_tiled}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_80_causal_sm80_kernel_nl_tiled", 81920, 128, 64, 1, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_80_sm80_nl_tiled}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_80_sliding_or_chunked_causal_sm80_kernel_nl_tiled", 81920, 128, 64, 2, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_80_sm80_nl_tiled}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_80_custom_mask_sm80_kernel_nl_tiled", 81920, 128, 64, 3, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_80_sm80_nl_tiled}, @@ -3384,6 +3405,7 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2 { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 32, 72, 72, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_bf16_64_32_S_qkv_72_causal_sm80_kernel_nl", 32768, 128, 64, 1, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_bf16_64_32_S_qkv_72_sm80_nl}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 32, 72, 72, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_bf16_64_32_S_qkv_72_sliding_or_chunked_causal_sm80_kernel_nl", 32768, 128, 64, 2, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_bf16_64_32_S_qkv_72_sm80_nl}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 32, 72, 72, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_bf16_64_32_S_qkv_72_custom_mask_sm80_kernel_nl", 32768, 128, 64, 3, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_bf16_64_32_S_qkv_72_sm80_nl}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 32, 80, 80, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_bf16_64_32_S_qkv_80_sm80_kernel_nl", 32768, 128, 64, 0, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_bf16_64_32_S_qkv_80_sm80_nl}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 32, 80, 80, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_bf16_64_32_S_qkv_80_causal_sm80_kernel_nl", 32768, 128, 64, 1, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_bf16_64_32_S_qkv_80_sm80_nl}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 32, 80, 80, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_bf16_64_32_S_qkv_80_sliding_or_chunked_causal_sm80_kernel_nl", 32768, 128, 64, 2, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_bf16_64_32_S_qkv_80_sm80_nl}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 32, 80, 80, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_bf16_64_32_S_qkv_80_custom_mask_sm80_kernel_nl", 32768, 128, 64, 3, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_bf16_64_32_S_qkv_80_sm80_nl}, @@ -3543,6 +3565,7 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2 { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_72_causal_sm80_kernel_nl_tiled", 81920, 128, 64, 1, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_72_sm80_nl_tiled}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_72_sliding_or_chunked_causal_sm80_kernel_nl_tiled", 81920, 128, 64, 2, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_72_sm80_nl_tiled}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_72_custom_mask_sm80_kernel_nl_tiled", 81920, 128, 64, 3, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_72_sm80_nl_tiled}, +{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_80_sm80_kernel_nl_tiled", 81920, 128, 64, 0, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_80_sm80_nl_tiled}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_80_causal_sm80_kernel_nl_tiled", 81920, 128, 64, 1, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_80_sm80_nl_tiled}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_80_sliding_or_chunked_causal_sm80_kernel_nl_tiled", 81920, 128, 64, 2, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_80_sm80_nl_tiled}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_80_custom_mask_sm80_kernel_nl_tiled", 81920, 128, 64, 3, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_80_sm80_nl_tiled}, @@ -3586,6 +3609,7 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2 { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 72, 72, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_72_causal_sm80_kernel_nl", 32768, 128, 64, 1, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_72_sm80_nl}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 72, 72, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_72_sliding_or_chunked_causal_sm80_kernel_nl", 32768, 128, 64, 2, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_72_sm80_nl}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 72, 72, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_72_custom_mask_sm80_kernel_nl", 32768, 128, 64, 3, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_72_sm80_nl}, +{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 80, 80, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_80_sm80_kernel_nl", 32768, 128, 64, 0, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_80_sm80_nl}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 80, 80, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_80_causal_sm80_kernel_nl", 32768, 128, 64, 1, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_80_sm80_nl}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 80, 80, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_80_sliding_or_chunked_causal_sm80_kernel_nl", 32768, 128, 64, 2, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_80_sm80_nl}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 80, 80, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_80_custom_mask_sm80_kernel_nl", 32768, 128, 64, 3, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_80_sm80_nl}, @@ -3748,6 +3772,7 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2 { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_fp16_64_128_S_qkv_72_causal_sm86_kernel_nl_tiled", 81920, 128, 64, 1, 0, false, true, false, false, true, true, false, true, run_fmha_v2_flash_attention_fp16_64_128_S_qkv_72_sm86_nl_tiled}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_fp16_64_128_S_qkv_72_sliding_or_chunked_causal_sm86_kernel_nl_tiled", 81920, 128, 64, 2, 0, false, true, false, false, true, true, false, true, run_fmha_v2_flash_attention_fp16_64_128_S_qkv_72_sm86_nl_tiled}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_fp16_64_128_S_qkv_72_custom_mask_sm86_kernel_nl_tiled", 81920, 128, 64, 3, 0, false, true, false, false, true, true, false, true, run_fmha_v2_flash_attention_fp16_64_128_S_qkv_72_sm86_nl_tiled}, +{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_fp16_64_128_S_qkv_80_sm86_kernel_nl_tiled", 81920, 128, 64, 0, 0, false, true, false, false, true, true, false, true, run_fmha_v2_flash_attention_fp16_64_128_S_qkv_80_sm86_nl_tiled}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_fp16_64_128_S_qkv_80_causal_sm86_kernel_nl_tiled", 81920, 128, 64, 1, 0, false, true, false, false, true, true, false, true, run_fmha_v2_flash_attention_fp16_64_128_S_qkv_80_sm86_nl_tiled}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_fp16_64_128_S_qkv_80_sliding_or_chunked_causal_sm86_kernel_nl_tiled", 81920, 128, 64, 2, 0, false, true, false, false, true, true, false, true, run_fmha_v2_flash_attention_fp16_64_128_S_qkv_80_sm86_nl_tiled}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_fp16_64_128_S_qkv_80_custom_mask_sm86_kernel_nl_tiled", 81920, 128, 64, 3, 0, false, true, false, false, true, true, false, true, run_fmha_v2_flash_attention_fp16_64_128_S_qkv_80_sm86_nl_tiled}, @@ -3791,6 +3816,7 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2 { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 72, 72, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_fp16_64_32_S_qkv_72_causal_sm86_kernel_nl", 32768, 128, 64, 1, 0, false, true, false, false, true, false, false, true, run_fmha_v2_flash_attention_fp16_64_32_S_qkv_72_sm86_nl}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 72, 72, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_fp16_64_32_S_qkv_72_sliding_or_chunked_causal_sm86_kernel_nl", 32768, 128, 64, 2, 0, false, true, false, false, true, false, false, true, run_fmha_v2_flash_attention_fp16_64_32_S_qkv_72_sm86_nl}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 72, 72, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_fp16_64_32_S_qkv_72_custom_mask_sm86_kernel_nl", 32768, 128, 64, 3, 0, false, true, false, false, true, false, false, true, run_fmha_v2_flash_attention_fp16_64_32_S_qkv_72_sm86_nl}, +{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 80, 80, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_fp16_64_32_S_qkv_80_sm86_kernel_nl", 32768, 128, 64, 0, 0, false, true, false, false, true, false, false, true, run_fmha_v2_flash_attention_fp16_64_32_S_qkv_80_sm86_nl}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 80, 80, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_fp16_64_32_S_qkv_80_causal_sm86_kernel_nl", 32768, 128, 64, 1, 0, false, true, false, false, true, false, false, true, run_fmha_v2_flash_attention_fp16_64_32_S_qkv_80_sm86_nl}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 80, 80, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_fp16_64_32_S_qkv_80_sliding_or_chunked_causal_sm86_kernel_nl", 32768, 128, 64, 2, 0, false, true, false, false, true, false, false, true, run_fmha_v2_flash_attention_fp16_64_32_S_qkv_80_sm86_nl}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 80, 80, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_fp16_64_32_S_qkv_80_custom_mask_sm86_kernel_nl", 32768, 128, 64, 3, 0, false, true, false, false, true, false, false, true, run_fmha_v2_flash_attention_fp16_64_32_S_qkv_80_sm86_nl}, @@ -3950,6 +3976,7 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2 { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_72_causal_sm86_kernel_nl_tiled", 81920, 128, 64, 1, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_72_sm86_nl_tiled}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_72_sliding_or_chunked_causal_sm86_kernel_nl_tiled", 81920, 128, 64, 2, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_72_sm86_nl_tiled}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_72_custom_mask_sm86_kernel_nl_tiled", 81920, 128, 64, 3, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_72_sm86_nl_tiled}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_80_sm86_kernel_nl_tiled", 81920, 128, 64, 0, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_80_sm86_nl_tiled}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_80_causal_sm86_kernel_nl_tiled", 81920, 128, 64, 1, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_80_sm86_nl_tiled}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_80_sliding_or_chunked_causal_sm86_kernel_nl_tiled", 81920, 128, 64, 2, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_80_sm86_nl_tiled}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_80_custom_mask_sm86_kernel_nl_tiled", 81920, 128, 64, 3, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_80_sm86_nl_tiled}, @@ -3993,6 +4020,7 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2 { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 32, 72, 72, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_bf16_64_32_S_qkv_72_causal_sm86_kernel_nl", 32768, 128, 64, 1, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_bf16_64_32_S_qkv_72_sm86_nl}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 32, 72, 72, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_bf16_64_32_S_qkv_72_sliding_or_chunked_causal_sm86_kernel_nl", 32768, 128, 64, 2, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_bf16_64_32_S_qkv_72_sm86_nl}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 32, 72, 72, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_bf16_64_32_S_qkv_72_custom_mask_sm86_kernel_nl", 32768, 128, 64, 3, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_bf16_64_32_S_qkv_72_sm86_nl}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 32, 80, 80, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_bf16_64_32_S_qkv_80_sm86_kernel_nl", 32768, 128, 64, 0, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_bf16_64_32_S_qkv_80_sm86_nl}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 32, 80, 80, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_bf16_64_32_S_qkv_80_causal_sm86_kernel_nl", 32768, 128, 64, 1, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_bf16_64_32_S_qkv_80_sm86_nl}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 32, 80, 80, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_bf16_64_32_S_qkv_80_sliding_or_chunked_causal_sm86_kernel_nl", 32768, 128, 64, 2, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_bf16_64_32_S_qkv_80_sm86_nl}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 32, 80, 80, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_bf16_64_32_S_qkv_80_custom_mask_sm86_kernel_nl", 32768, 128, 64, 3, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_bf16_64_32_S_qkv_80_sm86_nl}, @@ -4152,6 +4180,7 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2 { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_72_causal_sm86_kernel_nl_tiled", 81920, 128, 64, 1, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_72_sm86_nl_tiled}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_72_sliding_or_chunked_causal_sm86_kernel_nl_tiled", 81920, 128, 64, 2, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_72_sm86_nl_tiled}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_72_custom_mask_sm86_kernel_nl_tiled", 81920, 128, 64, 3, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_72_sm86_nl_tiled}, +{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_80_sm86_kernel_nl_tiled", 81920, 128, 64, 0, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_80_sm86_nl_tiled}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_80_causal_sm86_kernel_nl_tiled", 81920, 128, 64, 1, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_80_sm86_nl_tiled}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_80_sliding_or_chunked_causal_sm86_kernel_nl_tiled", 81920, 128, 64, 2, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_80_sm86_nl_tiled}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_80_custom_mask_sm86_kernel_nl_tiled", 81920, 128, 64, 3, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_80_sm86_nl_tiled}, @@ -4195,6 +4224,7 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2 { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 72, 72, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_72_causal_sm86_kernel_nl", 32768, 128, 64, 1, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_72_sm86_nl}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 72, 72, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_72_sliding_or_chunked_causal_sm86_kernel_nl", 32768, 128, 64, 2, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_72_sm86_nl}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 72, 72, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_72_custom_mask_sm86_kernel_nl", 32768, 128, 64, 3, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_72_sm86_nl}, +{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 80, 80, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_80_sm86_kernel_nl", 32768, 128, 64, 0, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_80_sm86_nl}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 80, 80, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_80_causal_sm86_kernel_nl", 32768, 128, 64, 1, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_80_sm86_nl}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 80, 80, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_80_sliding_or_chunked_causal_sm86_kernel_nl", 32768, 128, 64, 2, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_80_sm86_nl}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 80, 80, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_80_custom_mask_sm86_kernel_nl", 32768, 128, 64, 3, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_80_sm86_nl}, @@ -4363,6 +4393,7 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2 { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_fp16_64_128_S_qkv_72_causal_sm120_kernel_nl_tiled", 81920, 128, 64, 1, 0, false, true, false, false, true, true, false, true, run_fmha_v2_flash_attention_fp16_64_128_S_qkv_72_sm120_nl_tiled}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_fp16_64_128_S_qkv_72_sliding_or_chunked_causal_sm120_kernel_nl_tiled", 81920, 128, 64, 2, 0, false, true, false, false, true, true, false, true, run_fmha_v2_flash_attention_fp16_64_128_S_qkv_72_sm120_nl_tiled}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_fp16_64_128_S_qkv_72_custom_mask_sm120_kernel_nl_tiled", 81920, 128, 64, 3, 0, false, true, false, false, true, true, false, true, run_fmha_v2_flash_attention_fp16_64_128_S_qkv_72_sm120_nl_tiled}, +{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_fp16_64_128_S_qkv_80_sm120_kernel_nl_tiled", 81920, 128, 64, 0, 0, false, true, false, false, true, true, false, true, run_fmha_v2_flash_attention_fp16_64_128_S_qkv_80_sm120_nl_tiled}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_fp16_64_128_S_qkv_80_causal_sm120_kernel_nl_tiled", 81920, 128, 64, 1, 0, false, true, false, false, true, true, false, true, run_fmha_v2_flash_attention_fp16_64_128_S_qkv_80_sm120_nl_tiled}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_fp16_64_128_S_qkv_80_sliding_or_chunked_causal_sm120_kernel_nl_tiled", 81920, 128, 64, 2, 0, false, true, false, false, true, true, false, true, run_fmha_v2_flash_attention_fp16_64_128_S_qkv_80_sm120_nl_tiled}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_fp16_64_128_S_qkv_80_custom_mask_sm120_kernel_nl_tiled", 81920, 128, 64, 3, 0, false, true, false, false, true, true, false, true, run_fmha_v2_flash_attention_fp16_64_128_S_qkv_80_sm120_nl_tiled}, @@ -4406,6 +4437,7 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2 { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 72, 72, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_fp16_64_32_S_qkv_72_causal_sm120_kernel_nl", 32768, 128, 64, 1, 0, false, true, false, false, true, false, false, true, run_fmha_v2_flash_attention_fp16_64_32_S_qkv_72_sm120_nl}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 72, 72, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_fp16_64_32_S_qkv_72_sliding_or_chunked_causal_sm120_kernel_nl", 32768, 128, 64, 2, 0, false, true, false, false, true, false, false, true, run_fmha_v2_flash_attention_fp16_64_32_S_qkv_72_sm120_nl}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 72, 72, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_fp16_64_32_S_qkv_72_custom_mask_sm120_kernel_nl", 32768, 128, 64, 3, 0, false, true, false, false, true, false, false, true, run_fmha_v2_flash_attention_fp16_64_32_S_qkv_72_sm120_nl}, +{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 80, 80, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_fp16_64_32_S_qkv_80_sm120_kernel_nl", 32768, 128, 64, 0, 0, false, true, false, false, true, false, false, true, run_fmha_v2_flash_attention_fp16_64_32_S_qkv_80_sm120_nl}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 80, 80, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_fp16_64_32_S_qkv_80_causal_sm120_kernel_nl", 32768, 128, 64, 1, 0, false, true, false, false, true, false, false, true, run_fmha_v2_flash_attention_fp16_64_32_S_qkv_80_sm120_nl}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 80, 80, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_fp16_64_32_S_qkv_80_sliding_or_chunked_causal_sm120_kernel_nl", 32768, 128, 64, 2, 0, false, true, false, false, true, false, false, true, run_fmha_v2_flash_attention_fp16_64_32_S_qkv_80_sm120_nl}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 80, 80, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_fp16_64_32_S_qkv_80_custom_mask_sm120_kernel_nl", 32768, 128, 64, 3, 0, false, true, false, false, true, false, false, true, run_fmha_v2_flash_attention_fp16_64_32_S_qkv_80_sm120_nl}, @@ -4565,6 +4597,7 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2 { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_72_causal_sm120_kernel_nl_tiled", 81920, 128, 64, 1, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_72_sm120_nl_tiled}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_72_sliding_or_chunked_causal_sm120_kernel_nl_tiled", 81920, 128, 64, 2, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_72_sm120_nl_tiled}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_72_custom_mask_sm120_kernel_nl_tiled", 81920, 128, 64, 3, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_72_sm120_nl_tiled}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_80_sm120_kernel_nl_tiled", 81920, 128, 64, 0, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_80_sm120_nl_tiled}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_80_causal_sm120_kernel_nl_tiled", 81920, 128, 64, 1, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_80_sm120_nl_tiled}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_80_sliding_or_chunked_causal_sm120_kernel_nl_tiled", 81920, 128, 64, 2, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_80_sm120_nl_tiled}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_80_custom_mask_sm120_kernel_nl_tiled", 81920, 128, 64, 3, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_80_sm120_nl_tiled}, @@ -4608,6 +4641,7 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2 { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 32, 72, 72, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_bf16_64_32_S_qkv_72_causal_sm120_kernel_nl", 32768, 128, 64, 1, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_bf16_64_32_S_qkv_72_sm120_nl}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 32, 72, 72, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_bf16_64_32_S_qkv_72_sliding_or_chunked_causal_sm120_kernel_nl", 32768, 128, 64, 2, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_bf16_64_32_S_qkv_72_sm120_nl}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 32, 72, 72, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_bf16_64_32_S_qkv_72_custom_mask_sm120_kernel_nl", 32768, 128, 64, 3, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_bf16_64_32_S_qkv_72_sm120_nl}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 32, 80, 80, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_bf16_64_32_S_qkv_80_sm120_kernel_nl", 32768, 128, 64, 0, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_bf16_64_32_S_qkv_80_sm120_nl}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 32, 80, 80, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_bf16_64_32_S_qkv_80_causal_sm120_kernel_nl", 32768, 128, 64, 1, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_bf16_64_32_S_qkv_80_sm120_nl}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 32, 80, 80, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_bf16_64_32_S_qkv_80_sliding_or_chunked_causal_sm120_kernel_nl", 32768, 128, 64, 2, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_bf16_64_32_S_qkv_80_sm120_nl}, { DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 32, 80, 80, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_bf16_64_32_S_qkv_80_custom_mask_sm120_kernel_nl", 32768, 128, 64, 3, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_bf16_64_32_S_qkv_80_sm120_nl}, @@ -4790,6 +4824,7 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2 { DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 32, 72, 72, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_72_causal_sm120_kernel_nl", 32768, 128, 64, 1, 2, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_72_sm120_nl}, { DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 32, 72, 72, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_72_sliding_or_chunked_causal_sm120_kernel_nl", 32768, 128, 64, 2, 2, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_72_sm120_nl}, { DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 32, 72, 72, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_72_custom_mask_sm120_kernel_nl", 32768, 128, 64, 3, 2, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_72_sm120_nl}, +{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 32, 80, 80, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sm120_kernel_nl", 32768, 128, 64, 0, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sm120_nl}, { DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 32, 80, 80, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_causal_sm120_kernel_nl", 32768, 128, 64, 1, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sm120_nl}, { DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 32, 80, 80, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sliding_or_chunked_causal_sm120_kernel_nl", 32768, 128, 64, 2, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sm120_nl}, { DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 32, 80, 80, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_custom_mask_sm120_kernel_nl", 32768, 128, 64, 3, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sm120_nl}, @@ -4869,6 +4904,7 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2 { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_72_causal_sm120_kernel_nl_tiled", 81920, 128, 64, 1, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_72_sm120_nl_tiled}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_72_sliding_or_chunked_causal_sm120_kernel_nl_tiled", 81920, 128, 64, 2, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_72_sm120_nl_tiled}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_72_custom_mask_sm120_kernel_nl_tiled", 81920, 128, 64, 3, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_72_sm120_nl_tiled}, +{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_80_sm120_kernel_nl_tiled", 81920, 128, 64, 0, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_80_sm120_nl_tiled}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_80_causal_sm120_kernel_nl_tiled", 81920, 128, 64, 1, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_80_sm120_nl_tiled}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_80_sliding_or_chunked_causal_sm120_kernel_nl_tiled", 81920, 128, 64, 2, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_80_sm120_nl_tiled}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 80, 80, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_80_custom_mask_sm120_kernel_nl_tiled", 81920, 128, 64, 3, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_80_sm120_nl_tiled}, @@ -4912,6 +4948,7 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2 { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 72, 72, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_72_causal_sm120_kernel_nl", 32768, 128, 64, 1, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_72_sm120_nl}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 72, 72, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_72_sliding_or_chunked_causal_sm120_kernel_nl", 32768, 128, 64, 2, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_72_sm120_nl}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 72, 72, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_72_custom_mask_sm120_kernel_nl", 32768, 128, 64, 3, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_72_sm120_nl}, +{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 80, 80, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_80_sm120_kernel_nl", 32768, 128, 64, 0, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_80_sm120_nl}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 80, 80, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_80_causal_sm120_kernel_nl", 32768, 128, 64, 1, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_80_sm120_nl}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 80, 80, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_80_sliding_or_chunked_causal_sm120_kernel_nl", 32768, 128, 64, 2, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_80_sm120_nl}, { DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 80, 80, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_80_custom_mask_sm120_kernel_nl", 32768, 128, 64, 3, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_80_sm120_nl}, diff --git a/docs/source/models/supported-models.md b/docs/source/models/supported-models.md index aac729a9545..c3979aa61e1 100644 --- a/docs/source/models/supported-models.md +++ b/docs/source/models/supported-models.md @@ -51,6 +51,7 @@ Note: Support for other models may vary. Features marked "N/A" are not applicabl | LlavaNextForConditionalGeneration | Yes | Yes | No | Yes | Yes | No | Yes | No | L + I | | Llama4ForConditionalGeneration | Yes | Yes | No | Yes | Yes | No | Yes | No | L + I | | Mistral3ForConditionalGeneration | Yes | Yes | Yes | Yes | Yes | Yes | Yes | No | L + I | +| NemotronH_Nano_VL_V2 | Yes | Yes | Yes | Yes | Yes | No | Yes | No | L + I + V | | Phi4MMForCausalLM | Yes | Yes | No | Yes | Yes | No | Yes | No | L + I + A | | Qwen2VLForConditionalGeneration | Yes | Yes | No | Yes | Yes | Yes | Yes | No | L + I + V | | Qwen2_5_VLForConditionalGeneration | Yes | Yes | No | Yes | Yes | Yes | Yes | No | L + I + V | diff --git a/tensorrt_llm/_torch/models/__init__.py b/tensorrt_llm/_torch/models/__init__.py index 4f7aa39330e..ae413741f0f 100644 --- a/tensorrt_llm/_torch/models/__init__.py +++ b/tensorrt_llm/_torch/models/__init__.py @@ -14,6 +14,7 @@ from .modeling_llava_next import LlavaNextModel from .modeling_mistral import Mistral3VLM, MistralForCausalLM from .modeling_mixtral import MixtralForCausalLM +from .modeling_nanov2vlm import NemotronH_Nano_VL_V2 from .modeling_nemotron import NemotronForCausalLM from .modeling_nemotron_h import NemotronHForCausalLM from .modeling_nemotron_nas import NemotronNASForCausalLM @@ -45,6 +46,7 @@ "Mistral3VLM", "MistralForCausalLM", "MixtralForCausalLM", + "NemotronH_Nano_VL_V2", "NemotronForCausalLM", "NemotronHForCausalLM", "NemotronNASForCausalLM", diff --git a/tensorrt_llm/_torch/models/modeling_nanov2vlm.py b/tensorrt_llm/_torch/models/modeling_nanov2vlm.py new file mode 100644 index 00000000000..c3fc27da4c0 --- /dev/null +++ b/tensorrt_llm/_torch/models/modeling_nanov2vlm.py @@ -0,0 +1,458 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +import copy +import os +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +import transformers +from PIL import Image + +from tensorrt_llm._torch.models.checkpoints import NemotronHHfWeightMapper +from tensorrt_llm.inputs.multimodal import MultimodalParams + +from ...inputs import (BaseMultimodalInputProcessor, ExtraProcessedInputs, + InputProcessor, MultimodalPlaceholderMetadata, + MultimodalPlaceholderPlacement, TextPrompt, + register_input_processor) +from ...logger import logger +from ...sampling_params import SamplingParams +from ..attention_backend import AttentionMetadata +from ..model_config import ModelConfig +from .modeling_auto import AutoModelForCausalLM +from .modeling_multimodal_utils import find_input_mm_embeds, fuse_input_embeds +from .modeling_radio import RADIOVisionModel +from .modeling_utils import register_auto_model + + +# Make this a runtime lookup rather than a module-wide constant for easier unit testing. +def _is_disagg() -> bool: + return os.getenv("TLLM_MULTIMODAL_DISAGGREGATED", "0") == "1" + + +class SquaredReLU(nn.Module): + + def forward(self, x): + return torch.pow(torch.nn.functional.relu(x), 2) + + +# Source codes are from NemotronH_Nano_VL_V2 modeling.py. +class NanoV2VLVisionEncoder(transformers.PreTrainedModel): + + def __init__(self, + model_config: ModelConfig[transformers.PretrainedConfig]): + config = model_config.pretrained_config + super().__init__(config) + self.image_size = config.force_image_size + self.patch_size = config.patch_size + self.num_image_token = int((self.image_size // self.patch_size)**2 * + (config.downsample_ratio**2)) + self.downsample_ratio = config.downsample_ratio + self.ps_version = config.ps_version # Pixel shuffle version. + + # Construct the vision projection. + self.vit_hidden_size = config.vit_hidden_size + self.vision_projection_hidden_size = config.projector_hidden_size + self.llm_hidden_size = config.llm_config.hidden_size + self.mlp1 = nn.Sequential( + nn.RMSNorm(self.vit_hidden_size * int(1 / self.downsample_ratio)**2, + eps=config.llm_config.rms_norm_eps, + dtype=config.torch_dtype), + nn.Linear(self.vit_hidden_size * int(1 / self.downsample_ratio)**2, + self.vision_projection_hidden_size, + bias=False, + dtype=config.torch_dtype), SquaredReLU(), + nn.Linear(self.vision_projection_hidden_size, + self.llm_hidden_size, + bias=False, + dtype=config.torch_dtype)) + + # Construct the vision encoder. + vision_model_config = copy.deepcopy(model_config) + vision_model_config.pretrained_config = vision_model_config.pretrained_config.vision_config + self.vision_model = RADIOVisionModel(vision_model_config) + + def load_weights(self, weights): + # Load mlp1 weights. + mlp1_weights = { + k.replace('mlp1.', ''): v + for k, v in weights.items() if k.startswith('mlp1.') + } + self.mlp1.load_state_dict(mlp1_weights, strict=True) + + # Load vision encoder weights. + vision_encoder_weights = { + k.replace('vision_model.', ''): v + for k, v in weights.items() if k.startswith('vision_model.') + } + self.vision_model.load_weights(vision_encoder_weights) + + @torch.compile + def pixel_shuffle(self, x, scale_factor=0.5): + n, w, h, c = x.size() + # N, W, H, C --> N, W, H * scale, C // scale + x = x.view(n, w, int(h * scale_factor), int(c / scale_factor)) + # N, W, H * scale, C // scale --> N, H * scale, W, C // scale + x = x.permute(0, 2, 1, 3).contiguous() + # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2) + x = x.view(n, int(h * scale_factor), int(w * scale_factor), + int(c / (scale_factor * scale_factor))) + if self.ps_version == 'v1': + logger.warning( + "In ps_version 'v1', the height and width have not been swapped back, " + 'which results in a transposed image.') + else: + x = x.permute(0, 2, 1, 3).contiguous() + return x + + def extract_feature(self, pixel_values): + vit_embeds = self.vision_model(pixel_values) + # Down-sampling and projection. + h = w = int(vit_embeds.shape[1]**0.5) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) + vit_embeds = self.pixel_shuffle(vit_embeds, + scale_factor=self.downsample_ratio) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, + vit_embeds.shape[-1]) + vit_embeds = self.mlp1(vit_embeds) + return vit_embeds + + def forward(self, multimodal_params: List[MultimodalParams]): + mm_embedding = [] + # Batch data. + pixel_values = [ + multimodal_param.multimodal_data["pixel_values"] + for multimodal_param in multimodal_params + ] + batched_pixel_values = torch.cat(pixel_values, dim=0) + # -> [num_patches, channel, height, width] + patch_list = [ + multimodal_param.multimodal_data["num_patches"] + for multimodal_param in multimodal_params + ] + batched_num_patches = torch.cat(patch_list, dim=0).tolist() + # -> list of[num_patches1, num_patches2, ...] + batched_image_embeds = self.extract_feature(batched_pixel_values) + # -> [num_patches, num_image_token, hidden_size] + mm_embedding = torch.split(batched_image_embeds, + batched_num_patches, + dim=0) + mm_embedding = [ + m.reshape(-1, self.llm_hidden_size) for m in mm_embedding + ] + # -> list of [num_patches*num_image_token, hidden_size] + return mm_embedding + + +class NanoV2VLInputProcessor(BaseMultimodalInputProcessor, InputProcessor): + + def __init__(self, + model_path: str, + model_config: transformers.PretrainedConfig, + tokenizer: transformers.AutoTokenizer, + trust_remote_code: bool = True): + if not trust_remote_code: + raise ValueError("trust_remote_code must be True for NanoV2VL") + + self.model_config = model_config + self.image_size = model_config.force_image_size + self.patch_size = model_config.patch_size + self.downsample_ratio = model_config.downsample_ratio + self.img_context_token_id = model_config.img_context_token_id + self.num_image_token = int((self.image_size // self.patch_size)**2 * + (self.downsample_ratio**2)) + + self.device = 'cpu' + + self.tokenizer = tokenizer + self.use_fast = True + if self.tokenizer is None: + self.tokenizer = transformers.AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True, use_fast=self.use_fast) + + self.processor = transformers.AutoImageProcessor.from_pretrained( + model_path, trust_remote_code=True, use_fast=self.use_fast) + + self.img_context_token = model_config.img_context_token + self.video_context_token = model_config.video_context_token + self.img_start_token = model_config.img_start_token + self.img_end_token = model_config.img_end_token + self.dtype = model_config.torch_dtype + + def get_vocab_size(self): + return self.model_config.llm_config.vocab_size + + def get_mm_token_ids(self): + return torch.tensor([self.img_context_token_id], dtype=torch.int32) + + def get_num_tokens_per_image( + self, + *, + image: Image.Image, + **kwargs, + ): + + def _get_internvl_target_ratios( + min_num: int, + max_num: int, + ) -> list[tuple[int, int]]: + target_ratios = {(i, j) + for n in range(min_num, max_num + 1) + for i in range(1, n + 1) + for j in range(1, n + 1) + if min_num <= i * j <= max_num} + return sorted(target_ratios, key=lambda x: x[0] * x[1]) + + def _find_closest_aspect_ratio(aspect_ratio, target_ratios, width, + height, image_size): + best_factor = float('-inf') + best_ratio = (1, 1) + area = width * height + for ratio in target_ratios: + target_aspect_ratio = ratio[0] / ratio[1] + factor_based_on_area_n_ratio = min( + (ratio[0] * ratio[1] * image_size * image_size) / area, + 0.6) * min(target_aspect_ratio / aspect_ratio, + aspect_ratio / target_aspect_ratio) + if factor_based_on_area_n_ratio > best_factor: + best_factor = factor_based_on_area_n_ratio + best_ratio = ratio + return best_ratio + + def _calculate_targets( + orig_width: int, + orig_height: int, + target_ratios: list[tuple[int, int]], + image_size: int, + ) -> int: + aspect_ratio = orig_width / orig_height + + # find the closest aspect ratio to the target + target_aspect_ratio = _find_closest_aspect_ratio( + aspect_ratio, + target_ratios, + width=orig_width, + height=orig_height, + image_size=image_size, + ) + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] + + return blocks + + image_height = image.height + image_width = image.width + target_ratios = _get_internvl_target_ratios( + 1, self.processor.max_num_tiles) + blocks = _calculate_targets(image_width, image_height, target_ratios, + self.image_size) + if self.processor.use_thumbnail and blocks != 1: + blocks += 1 + num_image_tokens = self.num_image_token * blocks + return num_image_tokens + + @torch.inference_mode() + def __call__( + self, inputs: TextPrompt, sampling_params: SamplingParams + ) -> Tuple[List[int], Optional[ExtraProcessedInputs]]: + text_prompt, mm_data = inputs.get("prompt"), inputs.get( + "multi_modal_data", {}) + images = mm_data.get("image", None) + videos = mm_data.get("video", None) + if images is not None and videos is not None: + raise ValueError( + "NanoV2VL does not support both images and videos in the same prompt yet." + ) + + if images is None and videos is None: + input_ids = self.tokenizer.encode(text_prompt, + add_special_tokens=False, + return_tensors="pt") + return input_ids[0].to(torch.int32).tolist(), {} + + if images is not None: + if isinstance(images[0], torch.Tensor): + # NanoV2VL can only support PIL images. Convert normalized tensors (0-1) to PIL images (0-255). + images = [ + Image.fromarray((image.permute(1, 2, 0) * 255).to( + torch.uint8).cpu().numpy()) for image in images + ] + # Processing for multimodal data. + processed_images = self.processor(images=images, + return_tensors='pt').to( + self.device) + # Insert enough special tokens for image embedding. + parts = text_prompt.split(self.img_context_token) + if len(parts) - 1 != len(processed_images['num_patches']): + raise ValueError( + f"Number of {self.img_context_token} tokens ({len(parts) - 1}) doesn't match num_patches_list length ({len(processed_images['num_patches'])})" + ) + processed_query = parts[0] + for num_patches, part in zip(processed_images['num_patches'], + parts[1:]): + feature_size = num_patches * self.num_image_token + image_repl = self.img_start_token + self.img_context_token * feature_size + self.img_end_token + processed_query += image_repl + part + elif videos is not None: + num_videos = len(videos) + num_patches_list = [] + pixel_values_list = [] + parts = text_prompt.split(self.video_context_token) + if len(parts) - 1 != num_videos: + raise ValueError( + f"Number of {self.video_context_token} tokens ({len(parts) - 1}) doesn't match number of videos ({num_videos})" + ) + # Process videos one by one to get correct processed_query. + processed_query = "" + for video_index, video in enumerate(videos): + if isinstance(video[0], torch.Tensor): + # NanoV2VL can only support PIL images. Convert normalized tensors (0-1) to PIL images (0-255). + images = [ + Image.fromarray((image.permute(1, 2, 0) * 255).to( + torch.uint8).cpu().numpy()) for image in video + ] + else: + images = video + # Processing for multimodal data. + processed_images = self.processor(images=images, + return_tensors='pt').to( + self.device) + num_patches_list.append(processed_images['num_patches']) + pixel_values_list.append(processed_images['pixel_values']) + + # Processing the text prompt. + processed_query += parts[video_index] + for num_patches in processed_images['num_patches']: + feature_size = num_patches * self.num_image_token + image_repl = self.img_start_token + self.img_context_token * feature_size + self.img_end_token + processed_query += image_repl + processed_query += parts[num_videos] + processed_images['num_patches'] = torch.tensor( + [sum(num_patches) for num_patches in num_patches_list]) + processed_images['pixel_values'] = torch.cat(pixel_values_list, + dim=0) + + input_ids = self.tokenizer.encode(processed_query, + add_special_tokens=False, + return_tensors="pt") + + # Will package inputs for language model forward in AGGREGATE mode. + multimodal_data = {} + multimodal_data['pixel_values'] = processed_images['pixel_values'].to( + self.dtype) + multimodal_data['num_patches'] = processed_images['num_patches'].sum( + dim=0, keepdim=True) + return input_ids[0].to(torch.int32).tolist(), { + "multimodal_data": multimodal_data, + } + + +@register_auto_model("NemotronH_Nano_VL_V2") +@register_input_processor( + NanoV2VLInputProcessor, + model_type="NemotronH_Nano_VL_V2", + placeholder_metadata=MultimodalPlaceholderMetadata( + placeholder_map={ + "image": "", + "video": "