diff --git a/config.py b/config.py index f071b3a03..c77e9b133 100644 --- a/config.py +++ b/config.py @@ -63,12 +63,19 @@ 'h100': 'kernels/torch_scaled/scaled_matmul.cu' } }, + 'group_gemm': { + 'source_files': { + 'h100': 'kernels/group_gemm/group_gemm.cu' + } + }, + } ### WHICH KERNELS DO WE WANT TO BUILD? # (oftentimes during development work you don't need to redefine them all.) # kernels = ['attn', 'mamba2', 'hedgehog', 'fftconv', 'fused_rotary', 'based', 'fused_layernorm'] -kernels = ['fp8_gemm'] +# kernels = ['fp8_gemm'] +kernels = ['group_gemm'] ### WHICH GPU TARGET DO WE WANT TO BUILD FOR? target = 'h100' diff --git a/include/common/util.cuh b/include/common/util.cuh index b7d2a25d4..10105aefc 100644 --- a/include/common/util.cuh +++ b/include/common/util.cuh @@ -285,6 +285,21 @@ struct shared_allocator { */ using tma_allocator = shared_allocator<1024>; using tma_swizzle_allocator = tma_allocator; // swizzled TMA modes require up to 1024 byte alignments :/ + +/* Get CTA ID within a cluster */ +__device__ static inline int3 clusterIdx() { + int3 cluster_idx; + asm volatile("mov.u32 %0, %clusterid.x;\n" : "=r"(cluster_idx.x)); + asm volatile("mov.u32 %0, %clusterid.y;\n" : "=r"(cluster_idx.y)); + asm volatile("mov.u32 %0, %clusterid.z;\n" : "=r"(cluster_idx.z)); + return cluster_idx; +} +__device__ static inline int cluster_ctarank() { + uint32_t ctarank; + asm volatile("mov.u32 %0, %cluster_ctarank;\n" : "=r"(ctarank)); + return ctarank; +} + #endif } // namespace kittens \ No newline at end of file diff --git a/kernels/group_gemm/gemm_test.py b/kernels/group_gemm/gemm_test.py new file mode 100644 index 000000000..74482a59d --- /dev/null +++ b/kernels/group_gemm/gemm_test.py @@ -0,0 +1,173 @@ +import torch + +import random +import torch +from typing import List, Tuple + +import thunderkittens +import deep_gemm +from deep_gemm import bench_kineto, calc_diff, ceil_div, get_col_major_tma_aligned_tensor +from deep_gemm.jit_kernels.utils import get_m_alignment_for_contiguous_layout + +torch.manual_seed(42) + + +def calc_diff(x, y): + x, y = x.double(), y.double() + denominator = (x * x + y * y).sum() + sim = 2 * (x * y).sum() / denominator + return 1 - sim + +def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + pad_size = (128 - (n % 128)) % 128 + x = torch.nn.functional.pad(x, (0, pad_size), value=0) if pad_size > 0 else x + x_view = x.view(m, -1, 128) + x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) + fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn) + return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1) + + +def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + x_padded = torch.zeros((ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=x.dtype, device=x.device) + x_padded[:m, :n] = x + x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) + x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) + x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) + return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2)) + + +def construct_contiguous_grouped(num_groups: int, expected_m_per_group: int, k: int, n: int, bf16_input=False) -> \ + Tuple[int, Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]: + alignment = get_m_alignment_for_contiguous_layout() + # group_ms = [int(expected_m_per_group * random.uniform(0.7, 1.3)) for _ in range(num_groups)] + group_ms = [expected_m_per_group for _ in range(num_groups)] + m = sum([ceil_div(x, alignment) * alignment for x in group_ms]) + + x = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) + y = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16) + m_indices = torch.empty(m, device='cuda', dtype=torch.int32) + out = torch.empty((m, n), device='cuda', dtype=torch.bfloat16) + ref_out = torch.randn((m, n), device='cuda', dtype=torch.bfloat16) + + start = 0 + for i, group_m in enumerate(group_ms): + actual_end = start + group_m + aligned_end = start + ceil_div(group_m, alignment) * alignment + m_indices[start:actual_end] = i + m_indices[actual_end:aligned_end] = i + ref_out[start:aligned_end] = x[start:aligned_end] @ y[i].t() + start = aligned_end + ref_out = torch.where((m_indices == -1).unsqueeze(1), torch.zeros_like(ref_out), ref_out) + + assert m % 4 == 0, f'TMA alignment error: {m}' + if bf16_input: + return m, x, y, m_indices, out, ref_out + x_fp8 = per_token_cast_to_fp8(x) + y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, ceil_div(n, 128), k // 128), device='cuda', dtype=torch.float)) + for i in range(num_groups): + y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i]) + + return m, x_fp8, y_fp8, m_indices, out, ref_out + + +def test_m_grouped_gemm_contiguous_deepgemm() -> None: + print('Testing grouped contiguous GEMM for deepgemm:') + + for num_groups, expected_m_per_group, k, n in ((4, 8192, 7168, 4096), (4, 8192, 2048, 7168), + (8, 4096, 7168, 4096), (8, 4096, 2048, 7168), + (32, 256, 7168, 4096), (32, 256, 2048, 7168)): + # NOTES: we should mask the unfilled part before calculating difference + m, x_fp8, y_fp8, m_indices, out, ref_out = construct_contiguous_grouped(num_groups, expected_m_per_group, k, n) + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices) + out = torch.where((m_indices == -1).unsqueeze(1), torch.zeros_like(out), out) + diff = calc_diff(out, ref_out) + assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}' + + # noinspection PyShadowingNames + def test_func(): + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices) + + t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) + valid_m = (m_indices != -1).sum().item() + print(f' > Perf ({num_groups=:2}, {expected_m_per_group=:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | ' + f'throughput: {2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS, ' + f'{(valid_m * k + num_groups * k * n + valid_m * n * 2) / 1e9 / t:4.0f} GB/s') + print() + + +def test_m_grouped_gemm_masked_deepgemm() -> None: + print('Testing grouped masked GEMM:') + + for num_groups, expected_m_per_group in ((1, 1024), (2, 512), (4, 256)): + for k, n in ((7168, 4096), (2048, 7168), ): + # Test correctness + for i in range(10): + x_fp8, y_fp8, masked_m, out, ref_out = construct_masked_grouped(num_groups, 4096, expected_m_per_group, k, n) + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, expected_m_per_group) + for j in range(num_groups): + diff = calc_diff(out[j, :masked_m[j].item()], ref_out[j, :masked_m[j].item()]) + assert diff < 0.001, f'{expected_m_per_group=}, {k=}, {n=}, {j=}, masked_m={masked_m[j]}, {num_groups=}, {diff:.5f}' + + # noinspection PyShadowingNames + def test_func(): + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, expected_m_per_group) + + # Test performance with fixed shapes + valid_m = masked_m.sum().item() + t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) + print(f' > Perf ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | ' + f'throughput: {2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS, ' + f'{(valid_m * k + num_groups * k * n + valid_m * n * 2) / 1e9 / t:4.0f} GB/s') + print() + +def test_m_grouped_gemm_contiguous_tk() -> None: + print('Testing grouped contiguous GEMM for tk:') + + for num_groups, expected_m_per_group, k, n in ((4, 8192, 7168, 4096), (4, 8192, 2048, 7168), + (8, 4096, 7168, 4096), (8, 4096, 2048, 7168), + (32, 256, 7168, 4096), (32, 256, 2048, 7168)): + # NOTES: we should mask the unfilled part before calculating difference + m, x_fp8, y_fp8, m_indices, out, ref_out = construct_contiguous_grouped(num_groups, expected_m_per_group, k, n) + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices) + out = torch.where((m_indices == -1).unsqueeze(1), torch.zeros_like(out), out) + + ht_out = torch.empty_like(out) + thunderkittens.group_gemm( + x_fp8[0], # x_fp8 is a tuple (x, x_inv_s) + y_fp8[0], # y_fp8 is a tuple (y, y_inv_s) + x_fp8[1].transpose(0, 1).contiguous(), # x_inv_s + y_fp8[1], + m_indices, + ht_out) + ht_out = torch.where((m_indices == -1).unsqueeze(1), torch.zeros_like(out), ht_out) + diff = calc_diff(ht_out, ref_out) + assert diff < 0.001, f'TK has big difference with float: {m=}, {k=}, {n=}, {diff:.5f}' + assert torch.allclose(out, ht_out.bfloat16(), atol=1e-3, rtol=1e-2), f'TK has big difference with deepgemm: {m=}, {k=}, {n=}' + + # tk scale shape need to b (k, m) + x_fp8 = (x_fp8[0], x_fp8[1].transpose(0, 1).contiguous()) + def test_func(): + thunderkittens.group_gemm( + x_fp8[0], # x_fp8 is a tuple (x, x_inv_s) + y_fp8[0], # y_fp8 is a tuple (y, y_inv_s) + x_fp8[1], # x_inv_s + y_fp8[1], + m_indices, + ht_out) + + t = bench_kineto(test_func, 'matmul_template', suppress_kineto_output=True) + valid_m = (m_indices != -1).sum().item() + print(f' > Perf ({num_groups=:2}, {expected_m_per_group=:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | ' + f'throughput: {2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS, ' + f'{(valid_m * k + num_groups * k * n + valid_m * n * 2) / 1e9 / t:4.0f} GB/s') + + print() + + +if __name__ == '__main__': + test_m_grouped_gemm_contiguous_deepgemm() + test_m_grouped_gemm_contiguous_tk() diff --git a/kernels/group_gemm/group_gemm.cu b/kernels/group_gemm/group_gemm.cu new file mode 100644 index 000000000..9e7f2971a --- /dev/null +++ b/kernels/group_gemm/group_gemm.cu @@ -0,0 +1,264 @@ +#include "kittens.cuh" +#include "prototype.cuh" +#include "scheduler.cuh" +#include "utils.cuh" + +#include + +#ifdef TORCH_COMPILE +#define TK_COMPILE_GROUP_GEMM +#endif + +using namespace kittens; +using namespace kittens::prototype; +using namespace kittens::prototype::lcf; + +using scale_dtype = float; + + +template +__device__ static inline void mul_add(T &dst, const T &src, const T &other, const V &row_values) { + row_map(dst, src, other, row_values); +} + +template +struct matmul_layout { + // tiles for the quantized inputs + using a_tile = st_fl8_e4m3<64, 128>; + using b_tile = st_fl8_e4m3<128, 128>; + using c_tile = st; + using a_layout = gl; + using b_layout = gl; + using c_layout = gl; + + using index_layout = gl; + // tiles for the dequantized inputs + using a_vec = sv_fl<64>; // scale_a + using scale_a_layout = gl; + using scale_b_layout = gl; + + template using accum_tile = rt; + + struct globals { + a_layout A; b_layout B; c_layout C; index_layout index; + scale_a_layout scale_a; scale_b_layout scale_b; + }; + + struct input_block { + a_tile a[M_BLOCK]; b_tile b; + a_vec scale_a_sv[M_BLOCK]; + }; + struct finish_block { + c_tile c[M_BLOCK]; + }; + struct scratch_block { + }; + struct common_state { + uint32_t block_m_idx, block_n_idx; + uint32_t group_idx; + scale_dtype scale_b; // scale_b is a single value for a block tile 128x128 + bool is_tma_multicast_valid; + }; + struct consumer_state { + accum_tile accum;// Changed to single tall accumulator + }; +}; + + +template +struct matmul_template { + static_assert(_M_BLOCK <= 2, "only support _M_BLOCK<=2"); + static constexpr int M_BLOCK = _M_BLOCK, SUPER_M = _SUPER_M, CLUSTER_BLOCKS = 2; + using layout = matmul_layout; + static constexpr int NUM_CONSUMER_WARPS=M_BLOCK*4, INPUT_PIPE_STAGES=4, PRODUCER_BARRIER_ARRIVALS=1; + // Helper functions + template __host__ static inline dim3 grid(int M, int N, int K) { + return dim3(PERISISTENT_GRID ? 132 : M*N/(M_BLOCK*layout::c_tile::num_elements)); + } + // ThunderKittens template functions + __device__ static inline void common_setup(common_setup_args args, bool is_prepared = false) { + if (is_prepared) { + return; + } + + auto my_scheduler = deep_gemm::Scheduler(M_BLOCK*layout::c_tile::rows), + static_cast(layout::c_tile::cols), 2, false>( + static_cast(args.globals.C.rows()), + static_cast(args.globals.C.cols()), + static_cast(args.globals.B.depth()), + args.globals.index.raw_ptr); + bool is_valid = my_scheduler.get_next_block( + args.common.block_m_idx, args.common.block_n_idx, args.task_iter + ); + if (is_valid) { + args.num_iters = args.globals.A.cols() / layout::a_tile::cols; + int id = warpgroup::groupid() == NUM_CONSUMER_WARPS/4 ? 0 : warpgroup::groupid(); + args.common.is_tma_multicast_valid = my_scheduler.is_tma_multicast_valid(args.common.block_m_idx); + args.common.block_m_idx = args.common.block_m_idx * M_BLOCK + id; + args.common.group_idx = args.globals.index.raw_ptr[args.common.block_m_idx * layout::c_tile::rows]; + } else { + args.num_iters = -1; // No more work to do + return; + } + } + + struct producer { + __device__ static void setup(producer_setup_args args) { + warpgroup::decrease_registers<40>(); // decrease registers for producers + } + __device__ static void load(producer_load_cluster_args args) { + if(warpgroup::warpid() == 0) { + tma::expect(args.inputs_arrived, args.input.a, args.input.scale_a_sv); + #pragma unroll + for(int i = 0; i < M_BLOCK; i++) { + tma::load_async(args.input.a[i], args.globals.A, + {args.common.block_m_idx+i, args.iter}, args.inputs_arrived); + tma::load_async(args.input.scale_a_sv[i], args.globals.scale_a, {args.iter, args.common.block_m_idx+i}, args.inputs_arrived); + } + + if (args.common.is_tma_multicast_valid) { + if (cluster_ctarank() == 0) { + tma::cluster::expect(args.inputs_cluster_arrived, 0, args.input.b); + tma::cluster::expect(args.inputs_cluster_arrived, 1, args.input.b); + tma::cluster::load_async(args.input.b, args.globals.B, + {args.common.group_idx, args.common.block_n_idx, args.iter}, args.inputs_cluster_arrived, 0b0011); + } + } else { + tma::expect(args.inputs_cluster_arrived, args.input.b); + tma::load_async(args.input.b, args.globals.B, + {args.common.group_idx, args.common.block_n_idx, args.iter}, args.inputs_cluster_arrived); + } + } + } + }; + + + struct consumer { + __device__ static void setup(consumer_setup_args args, bool is_prepared = true, int iter = 0) { + if (is_prepared) { + warpgroup::increase_registers<232>(); // increase registers for consumers + zero(args.state.accum); + } else { + args.common.scale_b = args.globals.scale_b[{ + args.common.group_idx, + args.common.block_n_idx, + iter + }]; + } + } + __device__ static void compute(consumer_compute_cluster_args args) { + rt_fl<16, layout::c_tile::cols> accum_tmp; + warpgroup::mm_ABt( + accum_tmp, + args.input.a[warpgroup::groupid()], + args.input.b + ); + col_vec> scale_a_rv; + warpgroup::load(scale_a_rv, args.input.scale_a_sv[warpgroup::groupid()]); + mul(scale_a_rv, scale_a_rv, args.common.scale_b); + warpgroup::mma_async_wait(); + if(laneid() == 0) { + arrive(args.inputs_finished); + tma::cluster::arrive(args.inputs_used, 0, 1); + } + mul_add(args.state.accum, accum_tmp, args.state.accum, scale_a_rv); + + } + __device__ static void finish(consumer_finish_args args) { + warpgroup::store(args.finish.c[warpgroup::groupid()], args.state.accum); + warpgroup::sync(warpgroup::groupid()+4); + if(warpgroup::warpid() == 0) { + tma::store_async(args.globals.C, args.finish.c[warpgroup::groupid()], + {args.common.block_m_idx, args.common.block_n_idx}); + tma::store_async_read_wait(); + } + if(laneid() == 0) arrive(args.finish_finished); + } + }; +}; +template +void inner_run( + fp8e4m3 *d_A, fp8e4m3 *d_B, bf16 *d_C, int* index, + scale_dtype *d_scale_a, scale_dtype *d_scale_b, + size_t M, size_t N, size_t K, size_t groups, + dim3 grid, dim3 block +) { + using a_layout = typename mmt::layout::a_layout; + using b_layout = typename mmt::layout::b_layout; + using c_layout = typename mmt::layout::c_layout; + using index_layout = typename mmt::layout::index_layout; + using globals = typename mmt::layout::globals; + a_layout Ag{d_A, nullptr, nullptr, M, K}; + b_layout Bg{d_B, nullptr, groups, N, K}; + c_layout Cg{d_C, nullptr, nullptr, M, N}; + index_layout index_g{index, nullptr, nullptr, nullptr, M}; + + // scales + using scale_a_layout = typename mmt::layout::scale_a_layout; + using scale_b_layout = typename mmt::layout::scale_b_layout; + scale_a_layout scale_a{d_scale_a, nullptr, nullptr, K/128, M}; + scale_b_layout scale_b{d_scale_b, nullptr, groups, N/128, K/128}; + + globals G{Ag, Bg, Cg, index_g, scale_a, scale_b}; + prototype::lcf::cluster_kernel<<>>(G); +} + + +#ifdef TK_COMPILE_GROUP_GEMM +#include +#include "pyutils/torch_helpers.cuh" + +// A: M x K +// B: GROUP x N x K +// scale_a: K/128 x M +// scale_b: GROUP x N/128 x K/128 +// index: M x 1, M[i] is the index of the group that A[i] belongs to B[M[i]], C[i]= A[i] @ B[M[i]] +// Returns: C: M x N +torch::Tensor& group_gemm( + const torch::Tensor& A, const torch::Tensor& B, const torch::Tensor& scale_a, + const torch::Tensor& scale_b, const torch::Tensor& index, torch::Tensor& C) { + CHECK_INPUT(A); + CHECK_INPUT(B); + CHECK_INPUT(scale_a); + CHECK_INPUT(scale_b); + + auto M = A.size(0); + auto N = B.size(1); + auto K = A.size(1); + auto groups = B.size(0); + TORCH_CHECK(scale_a.size(1) == M, "scale_a must have the row of A.size(0)"); + TORCH_CHECK(scale_a.size(0) == K/128, "scale_a must have the row of A.size(1)/128"); + + TORCH_CHECK(scale_b.size(1) == N/128, "scale_b must have the row of B.size(0)/128"); + TORCH_CHECK(scale_b.size(2) == K/128, "scale_b must have the row of B.size(1)/128"); + TORCH_CHECK(B.size(2) == K, "B must have the same number of columns as A"); + TORCH_CHECK(M % 128 == 0, "M must be divisible by 128"); + TORCH_CHECK(N % 128 == 0, "N must be divisible by 128"); + TORCH_CHECK(K % 128 == 0, "K must be divisible by 128"); + TORCH_CHECK(A.dtype() == c10::ScalarType::Float8_e4m3fn, "A must have the same dtype as Float8_e4m3fn"); + TORCH_CHECK(B.dtype() == c10::ScalarType::Float8_e4m3fn, "B must have the same dtype as Float8_e4m3fn"); + TORCH_CHECK(scale_a.dtype() == c10::ScalarType::Float, "scale_a must have the same dtype as A"); + TORCH_CHECK(scale_b.dtype() == c10::ScalarType::Float, "scale_b must have the same dtype as B"); + TORCH_CHECK(index.dtype() == c10::ScalarType::Int, "index must have the same dtype as Int"); + TORCH_CHECK(C.dtype() == c10::ScalarType::BFloat16, "C must have the same dtype as BFloat16"); + + c10::Float8_e4m3fn *A_fp8 = A.data_ptr(); + c10::Float8_e4m3fn *B_fp8 = B.data_ptr(); + + fp8e4m3 *d_A = reinterpret_cast(A_fp8); + fp8e4m3 *d_B = reinterpret_cast(B_fp8); + int *index_ptr = index.data_ptr(); + scale_dtype *d_scale_a = scale_a.data_ptr(); + scale_dtype *d_scale_b = scale_b.data_ptr(); + bf16 *d_C = reinterpret_cast(C.data_ptr()); + using mnt = matmul_template<2, 8>; + dim3 grid(mnt::grid(M, N, K)); + dim3 block(kittens::prototype::detail::NUM_THREADS_v); + unsigned long mem_size = MAX_SHARED_MEMORY - 1024; + cudaFuncSetAttribute(prototype::lcf::cluster_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, mem_size); + inner_run(d_A, d_B, d_C, index_ptr, d_scale_a, d_scale_b, M, N, K, groups, grid, block); + CHECK_CUDA_ERROR(cudaGetLastError()); + return C; +} +#else +#endif diff --git a/kernels/group_gemm/scheduler.cuh b/kernels/group_gemm/scheduler.cuh new file mode 100644 index 000000000..b43450904 --- /dev/null +++ b/kernels/group_gemm/scheduler.cuh @@ -0,0 +1,167 @@ +// Copy from https://github.com/deepseek-ai/DeepGEMM/blob/0c88cd01392c1073c7049a97d6328c7bba9b3947/deep_gemm/include/deep_gemm/scheduler.cuh + +#pragma once + +#include "utils.cuh" + +namespace deep_gemm { + +enum class GemmType { + Normal, + GroupedContiguous, + GroupedMasked +}; + +#pragma clang diagnostic push +#pragma ide diagnostic ignored "cppcoreguidelines-pro-type-member-init" +template +struct Scheduler { + // int current_iter = -1; + // uint32_t SHAPE_N; + uint32_t kNumNBlocks; + uint32_t kNumGroups; + + uint32_t num_aligned_m_blocks; + + // For normal GEMM + // Maybe not used in the masked grouped GEMM + uint32_t num_blocks; + uint32_t num_blocks_in_group; + bool is_peer_cta_alive = true; + + // For grouped GEMM + int* grouped_layout; + + // Only used for masked layout + uint32_t curr_group_idx, curr_cumsum; + + __device__ __forceinline__ explicit Scheduler(uint32_t shape_m, uint32_t shape_n, uint32_t group_count = 1, + int* grouped_layout = nullptr) { + kNumNBlocks = ceil_div(shape_n, BLOCK_N); + kNumGroups = group_count; + num_aligned_m_blocks = ceil_div(shape_m, BLOCK_M); + if constexpr (kGemmType == GemmType::Normal) { + num_blocks = num_aligned_m_blocks * kNumNBlocks; + } else if (kGemmType == GemmType::GroupedContiguous) { + num_blocks = num_aligned_m_blocks * kNumNBlocks; + this->grouped_layout = grouped_layout; + } else if (kGemmType == GemmType::GroupedMasked) { + curr_group_idx = curr_cumsum = 0; + this->grouped_layout = grouped_layout; + } + } + + // ReSharper disable once CppNotAllPathsReturnValue + __device__ __forceinline__ bool is_computation_valid(const uint32_t& m_block_idx, const uint32_t& m_offset) const { + if constexpr (kGemmType == GemmType::Normal) { + return true; + } else if constexpr (kGemmType == GemmType::GroupedContiguous) { + return __ldg(grouped_layout + m_offset + m_block_idx * BLOCK_M) >= 0; + } else if constexpr (kGemmType == GemmType::GroupedMasked) { + return m_offset + m_block_idx * BLOCK_M < __ldg(grouped_layout + curr_group_idx); + } + } + + __device__ __forceinline__ bool is_tma_multicast_valid(const uint32_t& m_block_idx) const { + if (num_blocks_in_group == 1) + return false; + if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::GroupedMasked) { + return true; + } else { + DG_STATIC_ASSERT(kGemmType == GemmType::GroupedContiguous, "Invalid Gemm type"); + if constexpr (kIsTMAMulticastOnA) { + return true; + } else { + auto group_idx = __ldg(grouped_layout + m_block_idx * BLOCK_M); + auto peer_group_idx = __ldg(grouped_layout + (m_block_idx ^ 1) * BLOCK_M); + return group_idx == peer_group_idx; + } + } + } + + __device__ __forceinline__ void get_swizzled_block_idx(const uint32_t& num_m_blocks, const uint32_t& block_idx, + uint32_t& m_block_idx, uint32_t& n_block_idx) { + DG_STATIC_ASSERT(kNum1DBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size"); + + // Swizzle for better L2 usages + auto primary_num_blocks = kIsTMAMulticastOnA ? kNumNBlocks : num_m_blocks; + auto secondary_num_blocks = kIsTMAMulticastOnA ? num_m_blocks : kNumNBlocks; + auto num_blocks_per_group = secondary_num_blocks * kNum1DBlocksPerGroup; + auto group_idx = block_idx / num_blocks_per_group; + auto first_block_idx = group_idx * kNum1DBlocksPerGroup; + auto in_group_idx = block_idx % num_blocks_per_group; + num_blocks_in_group = min(kNum1DBlocksPerGroup, primary_num_blocks - first_block_idx); + + // Fix unaligned TMA multicast + if (kNumTMAMulticast > 1 and num_blocks_in_group % 2 != 0) { + if (in_group_idx < (num_blocks_in_group ^ 1) * secondary_num_blocks) { + num_blocks_in_group = num_blocks_in_group ^ 1; + } else { + in_group_idx = in_group_idx - (num_blocks_in_group ^ 1) * secondary_num_blocks; + first_block_idx += num_blocks_in_group ^ 1; + num_blocks_in_group = 1; + } + } + + // Convert to final M/N block indices + if constexpr (kIsTMAMulticastOnA) { + m_block_idx = in_group_idx / num_blocks_in_group; + n_block_idx = first_block_idx + in_group_idx % num_blocks_in_group; + } else { + m_block_idx = first_block_idx + in_group_idx % num_blocks_in_group; + n_block_idx = in_group_idx / num_blocks_in_group; + } + } + + template + __device__ __forceinline__ uint32_t get_global_idx(const uint32_t& shape_dim, const uint32_t& block_size, + const uint32_t& block_idx, const uint32_t& m_block_idx=0) { + if constexpr (kGemmType == GemmType::Normal) { + return block_idx * block_size; + } else if constexpr (kGemmType == GemmType::GroupedContiguous) { + auto offset = kIgnoreGroupedForGroupedContiguous ? 0 : max(0, __ldg(grouped_layout + m_block_idx * BLOCK_M)); + return offset * shape_dim + block_idx * block_size; + } else if constexpr (kGemmType == GemmType::GroupedMasked) { + return curr_group_idx * shape_dim + block_idx * block_size; + } + } + + __device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx, int current_iter) { + const auto next_block_idx = current_iter * gridDim.x + blockIdx.x; + + if constexpr (kGemmType == GemmType::GroupedMasked) { + uint32_t num_m_blocks; + while (true) { + // End of the task + if (curr_group_idx == kNumGroups) + return false; + // Within the current group + num_m_blocks = ceil_div(static_cast(__ldg(grouped_layout + curr_group_idx)), BLOCK_M); + auto current_m_block_cumsum = curr_cumsum + num_m_blocks; + if (next_block_idx < current_m_block_cumsum * kNumNBlocks) + break; + + // Move to check the next group + curr_group_idx ++, curr_cumsum = current_m_block_cumsum; + } + + get_swizzled_block_idx(num_m_blocks, next_block_idx - curr_cumsum * kNumNBlocks, m_block_idx, n_block_idx); + } else { + if (next_block_idx >= num_blocks) + return false; + // NOTES: we don't have to set `is_peer_cta_alive` for masked grouped GEMM, as it must be aligned + is_peer_cta_alive = kNumNBlocks % kNumTMAMulticast == 0 or // Always aligned on N (constant bypass) + num_aligned_m_blocks % kNumTMAMulticast == 0 or // Always aligned on M (constant bypass) + (next_block_idx ^ 1) < num_blocks; // Peer CTA in bound + get_swizzled_block_idx(num_aligned_m_blocks, next_block_idx, m_block_idx, n_block_idx); + } + return true; + } +}; + +#pragma clang diagnostic pop + +} // namespace deep_gemm diff --git a/kernels/group_gemm/utils.cuh b/kernels/group_gemm/utils.cuh new file mode 100644 index 000000000..4af097400 --- /dev/null +++ b/kernels/group_gemm/utils.cuh @@ -0,0 +1,36 @@ +// Copy from https://github.com/deepseek-ai/DeepGEMM/blob/0c88cd01392c1073c7049a97d6328c7bba9b3947/deep_gemm/include/deep_gemm/utils.cuh + +#pragma once + +#ifdef __CLION_IDE__ + +__host__ __device__ __forceinline__ void host_device_printf(const char* format, ...) { + asm volatile("trap;"); +} + +#define printf host_device_printf +#endif + +#ifndef DG_DEVICE_ASSERT +#define DG_DEVICE_ASSERT(cond) \ +do { \ + if (not (cond)) { \ + printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \ + asm("trap;"); \ + } \ +} while (0) +#endif + +#ifndef DG_STATIC_ASSERT +#define DG_STATIC_ASSERT(cond, reason) static_assert(cond, reason) +#endif + +template +__device__ __host__ constexpr T ceil_div(T a, T b) { + return (a + b - 1) / b; +} + +template +__device__ __host__ constexpr T constexpr_gcd(T a, T b) { + return b == 0 ? a : constexpr_gcd(b, a % b); +} diff --git a/prototype/lcf/lcf_cluster.cuh b/prototype/lcf/lcf_cluster.cuh new file mode 100644 index 000000000..7cda0e554 --- /dev/null +++ b/prototype/lcf/lcf_cluster.cuh @@ -0,0 +1,181 @@ +#pragma once + +#include "../include/kittens.cuh" +#include "../common/common.cuh" +#include "templates.cuh" + +namespace kittens { +namespace prototype { +namespace lcf { + +template concept cluster_kernel_template = requires { + typename lcft::layout; + typename lcft::producer; + typename lcft::consumer; + lcft::common_setup; + lcft::producer::setup; + lcft::producer::load; + lcft::consumer::setup; + lcft::consumer::compute; + lcft::consumer::finish; +} && kittens_layout; + +template // load-compute-store-finish template +__global__ __launch_bounds__(detail::NUM_THREADS_v, detail::NUM_BLOCKS_v) +__cluster_dims__(detail::CLUSTER_BLOCKS_v) +void cluster_kernel(const __grid_constant__ typename lcft::layout::globals globals) { + static_assert(cluster_kernel_template, "lcf kernel template parameter does not satisfy concept requirements"); + using L = typename lcft::layout; + using CKL = complete_kittens_layout; // complete the layout by filling in the optional types with empty + using common_state = typename CKL::common_state_t; + using producer_state = typename CKL::producer_state_t; + using consumer_state = typename CKL::consumer_state_t; + using input_block = typename CKL::input_block_t; + using scratch_block = typename CKL::scratch_block_t; + using finish_block = typename CKL::finish_block_t; + using input_alloc_block = typename CKL::input_alloc_block_t; + using scratch_alloc_block = typename CKL::scratch_alloc_block_t; + constexpr int MAX_SHARED_MEMORY = detail::MAX_SHARED_MEMORY_v; + constexpr int INPUT_PIPE_STAGES = detail::INPUT_PIPE_STAGES_v; + static_assert(INPUT_PIPE_STAGES >= 1 && INPUT_PIPE_STAGES <= 16, "Invalid number of input pipe stages"); + static_assert( + INPUT_PIPE_STAGES*sizeof(input_alloc_block) + sizeof(scratch_alloc_block) + <= MAX_SHARED_MEMORY-1024, "Shared memory usage exceeds limits" + ); + constexpr int NUM_CONSUMER_WARPS = detail::NUM_CONSUMER_WARPS_v; + constexpr int NUM_PRODUCER_WARPS = detail::NUM_PRODUCER_WARPS_v; + + using everyone = group>; + + extern __shared__ int __shm[]; + shared_allocator alloc(&__shm[0]); // allocate shared memory + scratch_alloc_block (&scratch_smem) = alloc.allocate(); + input_alloc_block (&input_smem) [INPUT_PIPE_STAGES] = alloc.allocate(); + + // figure out where we're going to put the finish block + constexpr int FINISH_BLOCK_OFFSET = (MAX_SHARED_MEMORY-1024)/detail::NUM_BLOCKS_v - sizeof(finish_block); + static_assert(FINISH_BLOCK_OFFSET >= 0, "Finish block is too large for shared memory."); + constexpr int NON_FINISH_BLOCK_SPACE = FINISH_BLOCK_OFFSET - 1024 - sizeof(scratch_alloc_block); // including the losses from alignment + constexpr int SAFE_STAGES_BETWEEN_BLOCKS = NON_FINISH_BLOCK_SPACE/sizeof(input_alloc_block)((((uint64_t)&__shm[0] + FINISH_BLOCK_OFFSET)/1024)*1024); // alignment + + if constexpr (detail::DEBUG_v) { + if(threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { + printf("DEBUG REPORT FOR PRODUCER TEMPLATE KERNEL:\n"); + printf(" BLOCK INFORMATION\n"); + printf(" gridDim.x: %d\n", gridDim.x); + printf(" gridDim.y: %d\n", gridDim.y); + printf(" gridDim.z: %d\n", gridDim.z); + printf(" blockDim.x: %d\n", blockDim.x); + printf(" blockDim.y: %d\n", blockDim.y); + printf(" blockDim.z: %d\n", blockDim.z); + printf(" num_blocks per SM: %d\n", detail::NUM_BLOCKS_v); + printf(" num_threads per SM: %d\n", detail::NUM_THREADS_v); + printf(" num_warps per SM: %d\n", detail::NUM_WARPS_v); + printf(" num_consumer_warpgroups: %d\n", detail::NUM_CONSUMER_WARPGROUPS_v); + printf(" num_consumer_warps: %d\n", detail::NUM_CONSUMER_WARPS_v); + printf(" num_producer_warps: %d\n", detail::NUM_PRODUCER_WARPS_v); + printf(" PIPELINE INFORMATION\n"); + printf(" input_pipe_stages: %d\n", INPUT_PIPE_STAGES); + printf(" safe_stages_between_blocks: %d\n", SAFE_STAGES_BETWEEN_BLOCKS); + printf(" SHARED MEMORY INFORMATION\n"); + printf(" input_smem block size: %llu\n", sizeof(input_block)); + printf(" input_smem block size (aligned): %llu\n", sizeof(input_alloc_block)); + printf(" input_smem: %p\n", (void*)&input_smem); + printf(" input_smem size: %llu\n", INPUT_PIPE_STAGES*sizeof(input_alloc_block)); + printf(" scratch_smem block size: %llu\n", sizeof(scratch_block)); + printf(" scratch_smem block size (aligned): %llu\n", sizeof(scratch_alloc_block)); + printf(" scratch_smem: %p\n", (void*)&scratch_smem); + printf(" finish_smem: %p\n", (void*)finish_smem); + printf(" finish_smem size: %llu\n", sizeof(finish_block)); + printf(" dynamic shared memory usage: %llu\n", sizeof(scratch_alloc_block) + uint64_t(&scratch_smem) - uint64_t(&__shm[0])); + } + everyone::sync(15); + } + + // Initialize semaphores. This is constant for all two-stage producer-consumer kernels. + __shared__ kittens::semaphore inputs_arrived[INPUT_PIPE_STAGES], inputs_finished[INPUT_PIPE_STAGES], inputs_cluster_arrived[INPUT_PIPE_STAGES], inputs_used[INPUT_PIPE_STAGES]; + __shared__ kittens::semaphore finish_finished; + uint32_t semaphore_bitfield = 0xFFFF0000; // ***_finished phase bits start as 1s, ***_arrived phase bits start as 0s + common_state common; + int num_iters = -1; // number of iterations for the current task + int task_iter = -1; // which task are we on? + common_setup_args unif{common, task_iter, num_iters, globals, *scratch_smem}; + lcft::common_setup(unif, true); + + if(warpid() >= NUM_CONSUMER_WARPS) { // code path for producer warps + using producers = group; + if (warpid() == NUM_CONSUMER_WARPS) { // a single warp (in fact a single thread) does these. + for(int i = 0; i < INPUT_PIPE_STAGES; i++) { + init_semaphore(inputs_arrived[i], detail::PRODUCER_BARRIER_ARRIVALS_v, 0); // needs to wait on each producer warp + init_semaphore(inputs_cluster_arrived[i], detail::PRODUCER_BARRIER_ARRIVALS_v, 0); // needs to wait on each producer warp + init_semaphore(inputs_finished[i], detail::CONSUMER_BARRIER_ARRIVALS_v, 0); // needs to wait on one thread from each consumer warp + init_semaphore(inputs_used[i], 2*detail::CONSUMER_BARRIER_ARRIVALS_v, 0); // needs to wait on one thread from each consumer warp + } + init_semaphore(finish_finished, detail::CONSUMER_BARRIER_ARRIVALS_v, 0); // consumer warps must say they are done with the finish block + } + // all warps must arrive here, confirming semaphore initialization is visible to all threads. + tma::cluster::sync(); + for(task_iter = 0; true; task_iter++) { + num_iters = -1; + common_setup_args unif{common, task_iter, num_iters, globals, *scratch_smem}; + lcft::common_setup(unif); + if(num_iters <= 0) return; // no work to do + int input_ring = 0; // tracking which input block is being loaded + int load_iter; + producer_state p_state; + lcft::producer::setup({p_state, unif}); + for(load_iter = 0; load_iter < SAFE_STAGES_BETWEEN_BLOCKS && load_iter(semaphore_bitfield, input_ring)); + if (cluster_ctarank() == 0) + wait(inputs_used[input_ring], get_phasebit<1>(semaphore_bitfield, input_ring)); + update_phasebit<1>(semaphore_bitfield, input_ring); + lcft::producer::load({p_state, *input_smem[input_ring], inputs_arrived[input_ring], inputs_cluster_arrived[input_ring], load_iter, unif}); + input_ring=ring_advance(input_ring); + } + wait(finish_finished, (task_iter%2)^1); // wait for consumer to finish their finish stage before we can do the rest. + for(; load_iter(semaphore_bitfield, input_ring)); + if (cluster_ctarank() == 0) { + wait(inputs_used[input_ring], get_phasebit<1>(semaphore_bitfield, input_ring)); + } + update_phasebit<1>(semaphore_bitfield, input_ring); + lcft::producer::load({p_state, *input_smem[input_ring], inputs_arrived[input_ring], inputs_cluster_arrived[input_ring], load_iter, unif}); + input_ring=ring_advance(input_ring); + } + producers::sync(13); // producer warps must finish before consumer warps can proceed + } // task iter loop + } // producer warpgroup + else { // code path for consumer warps + using consumers = group; + tma::cluster::sync(); + for(task_iter = 0; true; task_iter++) { + num_iters = -1; + common_setup_args unif{common, task_iter, num_iters, globals, *scratch_smem}; + lcft::common_setup(unif); + if(num_iters <= 0) return; // no work to do + int input_ring = 0; // tracking which input block is being loaded + consumer_state c_state; + lcft::consumer::setup({c_state, unif}); // setup consumer state, true means we are not prepared, so we need to do the setup +#ifdef CONSUMER_UNROLL + #pragma unroll CONSUMER_UNROLL_VALUE +#endif + for(int it = 0; it < num_iters; it++) { + lcft::consumer::setup({c_state, unif}, false, it); // setup consumer state for the current iteration + wait(inputs_arrived[input_ring], get_phasebit<0>(semaphore_bitfield, input_ring)); // wait for memory to arrive, phase changes at half the rate of the ring + wait(inputs_cluster_arrived[input_ring], get_phasebit<0>(semaphore_bitfield, input_ring)); // wait for memory to arrive, phase changes at half the rate of the ring + update_phasebit<0>(semaphore_bitfield, input_ring); + lcft::consumer::compute({c_state, *input_smem[input_ring], inputs_finished[input_ring], inputs_used[input_ring], it, unif}); + input_ring=ring_advance(input_ring); + } // work loop + consumers::sync(14); // cannot overwrite finish block until all consumer warps are done. + lcft::consumer::finish({c_state, *finish_smem, finish_finished, unif}); + consumers::sync(14); // cannot overwrite finish block until all consumer warps are done. + } // task iter loop + } // consumer warpgroup + tma::cluster::sync(); +} + +} // namespace lcf +} // namespace prototype +} // namespace kittens \ No newline at end of file diff --git a/prototype/lcf/templates.cuh b/prototype/lcf/templates.cuh index 61a396944..5e07c888a 100644 --- a/prototype/lcf/templates.cuh +++ b/prototype/lcf/templates.cuh @@ -62,6 +62,25 @@ template struct producer_load_args : uniform_args { ) : uniform_args(_args), input(_input), state(_state), inputs_arrived(_inputs_arrived), iter(_iter) {} }; + +// Producer load cluster args +template struct producer_load_cluster_args : uniform_args { + using CKL = complete_kittens_layout; + typename CKL::producer_state_t & state; + typename CKL::input_block_t & input; + kittens::semaphore & inputs_arrived; + kittens::semaphore & inputs_cluster_arrived; + int iter; + __device__ producer_load_cluster_args( + typename CKL::producer_state_t& _state, + typename CKL::input_block_t& _input, + semaphore& _inputs_arrived, + semaphore& _inputs_cluster_arrived, + int _iter, + uniform_args &_args + ) : uniform_args(_args), input(_input), state(_state), inputs_arrived(_inputs_arrived), inputs_cluster_arrived(_inputs_cluster_arrived), iter(_iter) {} +}; + // Consumer init args template struct consumer_setup_args : uniform_args { using CKL = complete_kittens_layout; @@ -88,6 +107,24 @@ template struct consumer_compute_args : uniform_args { ) : uniform_args(_args), input(_input), state(_state), inputs_finished(_inputs_finished), iter(_iter) {} }; +// Consumer compute cluster args +template struct consumer_compute_cluster_args : uniform_args { + using CKL = complete_kittens_layout; + typename CKL::consumer_state_t & state; + typename CKL::input_block_t & input; + kittens::semaphore & inputs_finished; + kittens::semaphore & inputs_used; + int iter; + __device__ consumer_compute_cluster_args( + typename CKL::consumer_state_t& _state, + typename CKL::input_block_t& _input, + semaphore& _inputs_finished, + semaphore& _inputs_used, + int _iter, + uniform_args &_args + ) : uniform_args(_args), input(_input), state(_state), inputs_finished(_inputs_finished), inputs_used(_inputs_used), iter(_iter) {} +}; + // Consumer finish args template struct consumer_finish_args : uniform_args { using CKL = complete_kittens_layout; diff --git a/prototype/prototype.cuh b/prototype/prototype.cuh index 47ac67dc6..bb4b50fa1 100644 --- a/prototype/prototype.cuh +++ b/prototype/prototype.cuh @@ -9,5 +9,6 @@ #include "common/common.cuh" #include "lcf/lcf.cuh" +#include "lcf/lcf_cluster.cuh" #include "lcsf/lcsf.cuh" #include "interpreter/interpreter.cuh" \ No newline at end of file diff --git a/thunderkittens.cpp b/thunderkittens.cpp index 9517e5938..525465a6a 100644 --- a/thunderkittens.cpp +++ b/thunderkittens.cpp @@ -142,6 +142,17 @@ extern torch::Tensor scaled_matmul( ); #endif +#ifdef TK_COMPILE_GROUP_GEMM +extern torch::Tensor& group_gemm( + const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& scale_a, + const torch::Tensor& scale_b, + const torch::Tensor& index, + torch::Tensor& c +); +#endif + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.doc() = "ThunderKittens Kernels"; // optional module docstring @@ -192,4 +203,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("scaled_matmul", scaled_matmul, "Scaled Matmul TK. Takes tensors (a, b, scale_a, scale_b). a, b are fp8e4m3, scale_a, scale_b are float. Returns (M, N) in float."); #endif +#ifdef TK_COMPILE_GROUP_GEMM + m.def("group_gemm", group_gemm, "Group GEMM TK. Takes tensors (a, b, scale_a, scale_b, index). a, b are fp8, scale_a, scale_b are float, index is int64. Returns (M, N) in bf16."); +#endif + } \ No newline at end of file