From 5d8dddff8943a2aa4ea700e4ad9ae7cac2a94a96 Mon Sep 17 00:00:00 2001 From: Soham Govande Date: Wed, 5 Mar 2025 10:08:39 -0800 Subject: [PATCH 1/6] Performance on MMA_ABt +60 TFLOPS by using larger wgmma instructions --- kernels/matmul/H100_mma_ABt/matmul.cu | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/kernels/matmul/H100_mma_ABt/matmul.cu b/kernels/matmul/H100_mma_ABt/matmul.cu index 8d8293e5a..cf2377a75 100644 --- a/kernels/matmul/H100_mma_ABt/matmul.cu +++ b/kernels/matmul/H100_mma_ABt/matmul.cu @@ -68,13 +68,14 @@ struct matmul_template { zero(args.state.accum[n]); } __device__ static void compute(consumer_compute_args args) { - for(int n = 0; n < N_BLOCK; n++) { - warpgroup::mma_ABt( - args.state.accum[n], - args.input.a[warpgroup::groupid()], - args.input.b[n] - ); - } + using wide_rt = rt_fl<16, 64*N_BLOCK>; + using tall_st = st_bf<64*N_BLOCK, 64>; + // dispatch the largest possible tensor core instruction to maximize TFLOPS (64x16x256 on M_BLOCK=2, N_BLOCK=4) + warpgroup::mma_ABt( + reinterpret_cast(args.state.accum), + args.input.a[warpgroup::groupid()], + reinterpret_cast(args.input.b) + ); warpgroup::mma_async_wait(); if(laneid() == 0) arrive(args.inputs_finished); } @@ -128,7 +129,7 @@ void inner_run(bf16 *d_A, bf16 *d_B, bf16 *d_C, size_t M, size_t N, size_t K, di using globals = typename mmt::layout::globals; // printf("M: %d, N: %d, K: %d\n", M, N, K); global_layout Ag{d_A, nullptr, nullptr, M, K}; - global_layout Bg{d_B, nullptr, nullptr, K, N}; + global_layout Bg{d_B, nullptr, nullptr, N, K}; global_layout Cg{d_C, nullptr, nullptr, M, N}; globals G{Ag, Bg, Cg}; prototype::lcf::kernel<<>>(G); From ace5ba257dea252b977556690c1ec29c28156225 Mon Sep 17 00:00:00 2001 From: Soham Govande Date: Wed, 5 Mar 2025 10:10:41 -0800 Subject: [PATCH 2/6] MMA_ABt works on non-uniform sizes --- kernels/matmul/H100_mma_ABt/matmul.cu | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/kernels/matmul/H100_mma_ABt/matmul.cu b/kernels/matmul/H100_mma_ABt/matmul.cu index cf2377a75..6747ddced 100644 --- a/kernels/matmul/H100_mma_ABt/matmul.cu +++ b/kernels/matmul/H100_mma_ABt/matmul.cu @@ -281,8 +281,7 @@ int run_benchmark(size_t M, size_t N, size_t K) { } int main() { - int N; - N = 4096; - run_benchmark>(N, N, N); + int M = 2048, N = 4096, K = 8192; + run_benchmark>(M, N, K); return 0; } \ No newline at end of file From f2b081a9f3b467c860fe8b4e66ff77ec381b6286 Mon Sep 17 00:00:00 2001 From: Soham Govande Date: Sun, 15 Jun 2025 10:09:15 -0700 Subject: [PATCH 3/6] . --- kernels/matmul/H100_mma_ABt/matmul.cu | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/kernels/matmul/H100_mma_ABt/matmul.cu b/kernels/matmul/H100_mma_ABt/matmul.cu index 35f0059ee..8dbb94706 100644 --- a/kernels/matmul/H100_mma_ABt/matmul.cu +++ b/kernels/matmul/H100_mma_ABt/matmul.cu @@ -281,7 +281,8 @@ int run_benchmark(size_t M, size_t N, size_t K) { } int main() { - int M = 2048, N = 4096, K = 8192; - run_benchmark>(M, N, K); + int N; + N = 4096; + run_benchmark>(N, N, N); return 0; } \ No newline at end of file From 3fc1a350c2f3dd636cacf2a713901a7b0c846519 Mon Sep 17 00:00:00 2001 From: Soham Govande Date: Sun, 15 Jun 2025 10:09:37 -0700 Subject: [PATCH 4/6] . --- kernels/matmul/H100_mma_ABt/matmul.cu | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/kernels/matmul/H100_mma_ABt/matmul.cu b/kernels/matmul/H100_mma_ABt/matmul.cu index 8dbb94706..f7f05d5ea 100644 --- a/kernels/matmul/H100_mma_ABt/matmul.cu +++ b/kernels/matmul/H100_mma_ABt/matmul.cu @@ -68,14 +68,13 @@ struct matmul_template { zero(args.state.accum[n]); } __device__ static void compute(consumer_compute_args args) { - using wide_rt = rt_fl<16, 64*N_BLOCK>; - using tall_st = st_bf<64*N_BLOCK, 64>; - // dispatch the largest possible tensor core instruction to maximize TFLOPS (64x16x256 on M_BLOCK=2, N_BLOCK=4) - warpgroup::mma_ABt( - reinterpret_cast(args.state.accum), - args.input.a[warpgroup::groupid()], - reinterpret_cast(args.input.b) - ); + for(int n = 0; n < N_BLOCK; n++) { + warpgroup::mma_ABt( + args.state.accum[n], + args.input.a[warpgroup::groupid()], + args.input.b[n] + ); + } warpgroup::mma_async_wait(); if(laneid() == 0) arrive(args.inputs_finished); } @@ -129,7 +128,7 @@ void inner_run(bf16 *d_A, bf16 *d_B, bf16 *d_C, size_t M, size_t N, size_t K, di using globals = typename mmt::layout::globals; // printf("M: %d, N: %d, K: %d\n", M, N, K); global_layout Ag{d_A, nullptr, nullptr, M, K}; - global_layout Bg{d_B, nullptr, nullptr, N, K}; + global_layout Bg{d_B, nullptr, nullptr, K, N}; global_layout Cg{d_C, nullptr, nullptr, M, N}; globals G{Ag, Bg, Cg}; prototype::lcf::kernel<<>>(G); From d8c6b60367ea58b670532fdbdbdc7291937a1b31 Mon Sep 17 00:00:00 2001 From: Soham Govande Date: Sun, 15 Jun 2025 10:10:04 -0700 Subject: [PATCH 5/6] . --- kernels/matmul/H100_mma_ABt/matmul.cu | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/kernels/matmul/H100_mma_ABt/matmul.cu b/kernels/matmul/H100_mma_ABt/matmul.cu index f7f05d5ea..8dbb94706 100644 --- a/kernels/matmul/H100_mma_ABt/matmul.cu +++ b/kernels/matmul/H100_mma_ABt/matmul.cu @@ -68,13 +68,14 @@ struct matmul_template { zero(args.state.accum[n]); } __device__ static void compute(consumer_compute_args args) { - for(int n = 0; n < N_BLOCK; n++) { - warpgroup::mma_ABt( - args.state.accum[n], - args.input.a[warpgroup::groupid()], - args.input.b[n] - ); - } + using wide_rt = rt_fl<16, 64*N_BLOCK>; + using tall_st = st_bf<64*N_BLOCK, 64>; + // dispatch the largest possible tensor core instruction to maximize TFLOPS (64x16x256 on M_BLOCK=2, N_BLOCK=4) + warpgroup::mma_ABt( + reinterpret_cast(args.state.accum), + args.input.a[warpgroup::groupid()], + reinterpret_cast(args.input.b) + ); warpgroup::mma_async_wait(); if(laneid() == 0) arrive(args.inputs_finished); } @@ -128,7 +129,7 @@ void inner_run(bf16 *d_A, bf16 *d_B, bf16 *d_C, size_t M, size_t N, size_t K, di using globals = typename mmt::layout::globals; // printf("M: %d, N: %d, K: %d\n", M, N, K); global_layout Ag{d_A, nullptr, nullptr, M, K}; - global_layout Bg{d_B, nullptr, nullptr, K, N}; + global_layout Bg{d_B, nullptr, nullptr, N, K}; global_layout Cg{d_C, nullptr, nullptr, M, N}; globals G{Ag, Bg, Cg}; prototype::lcf::kernel<<>>(G); From 138f687e95ff82f03b51cb9ace782f6c87142f63 Mon Sep 17 00:00:00 2001 From: Soham Govande Date: Sun, 15 Jun 2025 10:10:15 -0700 Subject: [PATCH 6/6] . --- kernels/matmul/H100_mma_ABt/matmul.cu | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/kernels/matmul/H100_mma_ABt/matmul.cu b/kernels/matmul/H100_mma_ABt/matmul.cu index 8dbb94706..35f0059ee 100644 --- a/kernels/matmul/H100_mma_ABt/matmul.cu +++ b/kernels/matmul/H100_mma_ABt/matmul.cu @@ -281,8 +281,7 @@ int run_benchmark(size_t M, size_t N, size_t K) { } int main() { - int N; - N = 4096; - run_benchmark>(N, N, N); + int M = 2048, N = 4096, K = 8192; + run_benchmark>(M, N, K); return 0; } \ No newline at end of file