diff --git a/kernels/matmul/H100_mma_ABt/matmul.cu b/kernels/matmul/H100_mma_ABt/matmul.cu index f7f05d5ea..35f0059ee 100644 --- a/kernels/matmul/H100_mma_ABt/matmul.cu +++ b/kernels/matmul/H100_mma_ABt/matmul.cu @@ -68,13 +68,14 @@ struct matmul_template { zero(args.state.accum[n]); } __device__ static void compute(consumer_compute_args args) { - for(int n = 0; n < N_BLOCK; n++) { - warpgroup::mma_ABt( - args.state.accum[n], - args.input.a[warpgroup::groupid()], - args.input.b[n] - ); - } + using wide_rt = rt_fl<16, 64*N_BLOCK>; + using tall_st = st_bf<64*N_BLOCK, 64>; + // dispatch the largest possible tensor core instruction to maximize TFLOPS (64x16x256 on M_BLOCK=2, N_BLOCK=4) + warpgroup::mma_ABt( + reinterpret_cast(args.state.accum), + args.input.a[warpgroup::groupid()], + reinterpret_cast(args.input.b) + ); warpgroup::mma_async_wait(); if(laneid() == 0) arrive(args.inputs_finished); } @@ -128,7 +129,7 @@ void inner_run(bf16 *d_A, bf16 *d_B, bf16 *d_C, size_t M, size_t N, size_t K, di using globals = typename mmt::layout::globals; // printf("M: %d, N: %d, K: %d\n", M, N, K); global_layout Ag{d_A, nullptr, nullptr, M, K}; - global_layout Bg{d_B, nullptr, nullptr, K, N}; + global_layout Bg{d_B, nullptr, nullptr, N, K}; global_layout Cg{d_C, nullptr, nullptr, M, N}; globals G{Ag, Bg, Cg}; prototype::lcf::kernel<<>>(G); @@ -280,8 +281,7 @@ int run_benchmark(size_t M, size_t N, size_t K) { } int main() { - int N; - N = 4096; - run_benchmark>(N, N, N); + int M = 2048, N = 4096, K = 8192; + run_benchmark>(M, N, K); return 0; } \ No newline at end of file