Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions csrc/fused/fused.cu
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,7 @@ void quant_per_block_int8_cuda(
}

auto input_dtype = input.scalar_type();
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input_dtype, c_type, {
DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, {
Expand All @@ -492,7 +493,7 @@ void quant_per_block_int8_cuda(

dim3 block(BLOCK_SIZE * (HEAD_DIM / 8) / num_pack_per_thread);

QuantInt8Kernel<HEAD_DIM, BLOCK_SIZE, num_pack_per_thread, true, false, c_type><<<grid, block>>>(
QuantInt8Kernel<HEAD_DIM, BLOCK_SIZE, num_pack_per_thread, true, false, c_type><<<grid, block, 0, stream>>>(
reinterpret_cast<c_type*>(input.data_ptr()),
nullptr,
output.data_ptr<int8_t>(),
Expand Down Expand Up @@ -560,6 +561,7 @@ void quant_per_block_int8_cuda(
}

auto input_dtype = input.scalar_type();
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input_dtype, c_type, {
DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, {
Expand All @@ -574,7 +576,7 @@ void quant_per_block_int8_cuda(

dim3 block(BLOCK_SIZE * (HEAD_DIM / 8) / num_pack_per_thread);

QuantInt8Kernel<HEAD_DIM, BLOCK_SIZE, num_pack_per_thread, false, false, c_type><<<grid, block>>>(
QuantInt8Kernel<HEAD_DIM, BLOCK_SIZE, num_pack_per_thread, false, false, c_type><<<grid, block, 0, stream>>>(
reinterpret_cast<c_type*>(input.data_ptr()),
nullptr,
output.data_ptr<int8_t>(),
Expand Down Expand Up @@ -647,6 +649,7 @@ void quant_per_block_int8_fuse_sub_mean_cuda(

auto input_dtype = input.scalar_type();
auto mean_dtype = mean.scalar_type();
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

TORCH_CHECK(input_dtype == mean_dtype, "Input and mean must have the same data type");

Expand All @@ -664,7 +667,7 @@ void quant_per_block_int8_fuse_sub_mean_cuda(

dim3 block(BLOCK_SIZE * (HEAD_DIM / 8) / num_pack_per_thread);

QuantInt8Kernel<HEAD_DIM, BLOCK_SIZE, num_pack_per_thread, false, true, c_type><<<grid, block>>>(
QuantInt8Kernel<HEAD_DIM, BLOCK_SIZE, num_pack_per_thread, false, true, c_type><<<grid, block, 0, stream>>>(
reinterpret_cast<c_type*>(input.data_ptr()),
reinterpret_cast<c_type*>(mean.data_ptr()),
output.data_ptr<int8_t>(),
Expand Down Expand Up @@ -734,6 +737,7 @@ void quant_per_warp_int8_cuda(
}

auto input_dtype = input.scalar_type();
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input_dtype, c_type, {
DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, {
Expand All @@ -749,7 +753,7 @@ void quant_per_warp_int8_cuda(

dim3 block(WARP_BLOCK_SIZE * (HEAD_DIM / 8) / num_pack_per_thread);

QuantInt8Kernel<HEAD_DIM, WARP_BLOCK_SIZE, num_pack_per_thread, false, false, c_type><<<grid, block>>>(
QuantInt8Kernel<HEAD_DIM, WARP_BLOCK_SIZE, num_pack_per_thread, false, false, c_type><<<grid, block, 0, stream>>>(
reinterpret_cast<c_type*>(input.data_ptr()),
nullptr,
output.data_ptr<int8_t>(),
Expand Down Expand Up @@ -817,6 +821,7 @@ void sub_mean_cuda(

auto input_dtype = input.scalar_type();
auto mean_dtype = mean.scalar_type();
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

TORCH_CHECK(input_dtype == mean_dtype, "Input and mean must have the same data type");

Expand All @@ -834,7 +839,7 @@ void sub_mean_cuda(

dim3 block(BLOCK_SIZE * (HEAD_DIM / 8) / num_pack_per_thread);

SubMeanKernel<HEAD_DIM, BLOCK_SIZE, num_pack_per_thread><<<grid, block>>>(
SubMeanKernel<HEAD_DIM, BLOCK_SIZE, num_pack_per_thread><<<grid, block, 0, stream>>>(
reinterpret_cast<c_type*>(input.data_ptr()),
reinterpret_cast<c_type*>(mean.data_ptr()),
reinterpret_cast<half*>(output.data_ptr()),
Expand Down Expand Up @@ -900,6 +905,7 @@ void transpose_pad_permute_cuda(

auto input_dtype = input.scalar_type();
auto output_dtype = output.scalar_type();
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

TORCH_CHECK(input_dtype == output_dtype, "Input and output must have the same data type");

Expand All @@ -911,7 +917,7 @@ void transpose_pad_permute_cuda(

dim3 block(CTA_SIZE * (HEAD_DIM / 8));

TransposePadPermuteKernel<HEAD_DIM, CTA_SIZE, true, c_type><<<grid, block>>>(
TransposePadPermuteKernel<HEAD_DIM, CTA_SIZE, true, c_type><<<grid, block, 0, stream>>>(
reinterpret_cast<c_type*>(input.data_ptr()),
reinterpret_cast<c_type*>(output.data_ptr()),
num_tokens,
Expand Down Expand Up @@ -982,9 +988,10 @@ void scale_fuse_quant_cuda(
dim3 block(CTA_SIZE);

auto input_dtype = input.scalar_type();
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input_dtype, c_type, {
MeanScaleKernel<64, false, c_type><<<grid, block>>>(
MeanScaleKernel<64, false, c_type><<<grid, block, 0, stream>>>(
reinterpret_cast<c_type*>(input.data_ptr()),
reinterpret_cast<int8_t*>(output.data_ptr()),
nullptr,
Expand Down Expand Up @@ -1065,9 +1072,10 @@ void mean_scale_fuse_quant_cuda(
dim3 block(CTA_SIZE);

auto input_dtype = input.scalar_type();
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input_dtype, c_type, {
MeanScaleKernel<64, true, c_type><<<grid, block>>>(
MeanScaleKernel<64, true, c_type><<<grid, block, 0, stream>>>(
reinterpret_cast<c_type*>(input.data_ptr()),
reinterpret_cast<int8_t*>(output.data_ptr()),
reinterpret_cast<float*>(mean.data_ptr()),
Expand Down
13 changes: 9 additions & 4 deletions csrc/qattn/qk_int_sv_f16_cuda_sm80.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "../utils.cuh"
#include <cuda_fp16.h>
#include <cuda_pipeline_primitives.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>

#include "../cp_async.cuh"
Expand Down Expand Up @@ -718,6 +719,7 @@ torch::Tensor qk_int8_sv_f16_accum_f32_attn(torch::Tensor query,
int stride_bz_v = value.stride(0);
int stride_bz_o = output.stride(0);

cudaStream_t stream = at::cuda::getCurrentCUDAStream();
int qo_len, kv_len, num_qo_heads, num_kv_heads;
int stride_seq_q, stride_seq_k, stride_seq_v, stride_seq_o;
int stride_h_q, stride_h_k, stride_h_v, stride_h_o;
Expand Down Expand Up @@ -819,7 +821,7 @@ torch::Tensor qk_int8_sv_f16_accum_f32_attn(torch::Tensor query,
dim3 grid(div_ceil(qo_len, CTA_Q), num_qo_heads, batch_size);
dim3 block(32, (CTA_Q / WARP_Q) * (CTA_K / WARP_K));

kernel_func<<<grid, block, smem_max>>>(
kernel_func<<<grid, block, smem_max, stream>>>(
query.data_ptr<int8_t>(),
key.data_ptr<int8_t>(),
reinterpret_cast<half*>(value.data_ptr()),
Expand Down Expand Up @@ -892,6 +894,7 @@ torch::Tensor qk_int8_sv_f16_accum_f16_attn(torch::Tensor query,
int stride_bz_v = value.stride(0);
int stride_bz_o = output.stride(0);

cudaStream_t stream = at::cuda::getCurrentCUDAStream();
int qo_len, kv_len, num_qo_heads, num_kv_heads;
int stride_seq_q, stride_seq_k, stride_seq_v, stride_seq_o;
int stride_h_q, stride_h_k, stride_h_v, stride_h_o;
Expand Down Expand Up @@ -994,7 +997,7 @@ torch::Tensor qk_int8_sv_f16_accum_f16_attn(torch::Tensor query,
dim3 grid(div_ceil(qo_len, CTA_Q), num_qo_heads, batch_size);
dim3 block(32, (CTA_Q / WARP_Q) * (CTA_K / WARP_K));

kernel_func<<<grid, block, smem_max>>>(
kernel_func<<<grid, block, smem_max, stream>>>(
query.data_ptr<int8_t>(),
key.data_ptr<int8_t>(),
reinterpret_cast<half*>(value.data_ptr()),
Expand Down Expand Up @@ -1067,6 +1070,7 @@ torch::Tensor qk_int8_sv_f16_accum_f16_attn_inst_buf(torch::Tensor query,
int stride_bz_v = value.stride(0);
int stride_bz_o = output.stride(0);

cudaStream_t stream = at::cuda::getCurrentCUDAStream();
int qo_len, kv_len, num_qo_heads, num_kv_heads;
int stride_seq_q, stride_seq_k, stride_seq_v, stride_seq_o;
int stride_h_q, stride_h_k, stride_h_v, stride_h_o;
Expand Down Expand Up @@ -1169,7 +1173,7 @@ torch::Tensor qk_int8_sv_f16_accum_f16_attn_inst_buf(torch::Tensor query,
dim3 grid(div_ceil(qo_len, CTA_Q), num_qo_heads, batch_size);
dim3 block(32, (CTA_Q / WARP_Q) * (CTA_K / WARP_K));

kernel_func<<<grid, block, smem_max>>>(
kernel_func<<<grid, block, smem_max, stream>>>(
query.data_ptr<int8_t>(),
key.data_ptr<int8_t>(),
reinterpret_cast<half*>(value.data_ptr()),
Expand Down Expand Up @@ -1246,6 +1250,7 @@ torch::Tensor qk_int8_sv_f16_accum_f16_fuse_v_mean_attn(torch::Tensor query,
int stride_bz_v = value.stride(0);
int stride_bz_o = output.stride(0);

cudaStream_t stream = at::cuda::getCurrentCUDAStream();
int qo_len, kv_len, num_qo_heads, num_kv_heads;
int stride_seq_q, stride_seq_k, stride_seq_v, stride_seq_o;
int stride_h_q, stride_h_k, stride_h_v, stride_h_o;
Expand Down Expand Up @@ -1353,7 +1358,7 @@ torch::Tensor qk_int8_sv_f16_accum_f16_fuse_v_mean_attn(torch::Tensor query,
dim3 grid(div_ceil(qo_len, CTA_Q), num_qo_heads, batch_size);
dim3 block(32, (CTA_Q / WARP_Q) * (CTA_K / WARP_K));

kernel_func<<<grid, block, smem_max>>>(
kernel_func<<<grid, block, smem_max, stream>>>(
query.data_ptr<int8_t>(),
key.data_ptr<int8_t>(),
reinterpret_cast<half*>(value.data_ptr()),
Expand Down
16 changes: 11 additions & 5 deletions csrc/qattn/qk_int_sv_f8_cuda_sm89.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "../utils.cuh"
#include <cuda_fp16.h>
#include <cuda_pipeline_primitives.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>

#include "../cp_async.cuh"
Expand Down Expand Up @@ -733,6 +734,7 @@ torch::Tensor qk_int8_sv_f8_accum_f32_attn(torch::Tensor query,
int stride_bz_v = value.stride(0);
int stride_bz_o = output.stride(0);

cudaStream_t stream = at::cuda::getCurrentCUDAStream();
int qo_len, kv_len, num_qo_heads, num_kv_heads;
int stride_seq_q, stride_h_q, stride_seq_k, stride_h_k, stride_h_v, stride_d_v, stride_seq_o, stride_h_o;

Expand Down Expand Up @@ -836,7 +838,7 @@ torch::Tensor qk_int8_sv_f8_accum_f32_attn(torch::Tensor query,
dim3 grid(div_ceil(qo_len, CTA_Q), num_qo_heads, batch_size);
dim3 block(32, (CTA_Q / WARP_Q) * (CTA_K / WARP_K));

kernel_func<<<grid, block, smem_max>>>(
kernel_func<<<grid, block, smem_max, stream>>>(
query.data_ptr<int8_t>(),
key.data_ptr<int8_t>(),
reinterpret_cast<int8_t*>(value.data_ptr()),
Expand Down Expand Up @@ -911,6 +913,7 @@ torch::Tensor qk_int8_sv_f8_accum_f32_attn_inst_buf(torch::Tensor query,
int stride_bz_v = value.stride(0);
int stride_bz_o = output.stride(0);

cudaStream_t stream = at::cuda::getCurrentCUDAStream();
int qo_len, kv_len, num_qo_heads, num_kv_heads;
int stride_seq_q, stride_h_q, stride_seq_k, stride_h_k, stride_h_v, stride_d_v, stride_seq_o, stride_h_o;

Expand Down Expand Up @@ -1014,7 +1017,7 @@ torch::Tensor qk_int8_sv_f8_accum_f32_attn_inst_buf(torch::Tensor query,
dim3 grid(div_ceil(qo_len, CTA_Q), num_qo_heads, batch_size);
dim3 block(32, (CTA_Q / WARP_Q) * (CTA_K / WARP_K));

kernel_func<<<grid, block, smem_max>>>(
kernel_func<<<grid, block, smem_max, stream>>>(
query.data_ptr<int8_t>(),
key.data_ptr<int8_t>(),
reinterpret_cast<int8_t*>(value.data_ptr()),
Expand Down Expand Up @@ -1099,6 +1102,7 @@ torch::Tensor qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn(torch::Tenso
int stride_bz_v = value.stride(0);
int stride_bz_o = output.stride(0);

cudaStream_t stream = at::cuda::getCurrentCUDAStream();
int qo_len, kv_len, num_qo_heads, num_kv_heads;
int stride_seq_q, stride_h_q, stride_seq_k, stride_h_k, stride_h_v, stride_d_v, stride_seq_o, stride_h_o;

Expand Down Expand Up @@ -1205,7 +1209,7 @@ torch::Tensor qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn(torch::Tenso
dim3 grid(div_ceil(qo_len, CTA_Q), num_qo_heads, batch_size);
dim3 block(32, (CTA_Q / WARP_Q) * (CTA_K / WARP_K));

kernel_func<<<grid, block, smem_max>>>(
kernel_func<<<grid, block, smem_max, stream>>>(
query.data_ptr<int8_t>(),
key.data_ptr<int8_t>(),
reinterpret_cast<int8_t*>(value.data_ptr()),
Expand Down Expand Up @@ -1285,6 +1289,7 @@ torch::Tensor qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(torch::Tensor query,
int stride_bz_v = value.stride(0);
int stride_bz_o = output.stride(0);

cudaStream_t stream = at::cuda::getCurrentCUDAStream();
int qo_len, kv_len, num_qo_heads, num_kv_heads;
int stride_seq_q, stride_h_q, stride_seq_k, stride_h_k, stride_h_v, stride_d_v, stride_seq_o, stride_h_o;

Expand Down Expand Up @@ -1391,7 +1396,7 @@ torch::Tensor qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(torch::Tensor query,
dim3 grid(div_ceil(qo_len, CTA_Q), num_qo_heads, batch_size);
dim3 block(32, (CTA_Q / WARP_Q) * (CTA_K / WARP_K));

kernel_func<<<grid, block, smem_max>>>(
kernel_func<<<grid, block, smem_max, stream>>>(
query.data_ptr<int8_t>(),
key.data_ptr<int8_t>(),
reinterpret_cast<int8_t*>(value.data_ptr()),
Expand Down Expand Up @@ -1471,6 +1476,7 @@ torch::Tensor qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(torch::Tensor q
int stride_bz_v = value.stride(0);
int stride_bz_o = output.stride(0);

cudaStream_t stream = at::cuda::getCurrentCUDAStream();
int qo_len, kv_len, num_qo_heads, num_kv_heads;
int stride_seq_q, stride_h_q, stride_seq_k, stride_h_k, stride_h_v, stride_d_v, stride_seq_o, stride_h_o;

Expand Down Expand Up @@ -1577,7 +1583,7 @@ torch::Tensor qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(torch::Tensor q
dim3 grid(div_ceil(qo_len, CTA_Q), num_qo_heads, batch_size);
dim3 block(32, (CTA_Q / WARP_Q) * (CTA_K / WARP_K));

kernel_func<<<grid, block, smem_max>>>(
kernel_func<<<grid, block, smem_max, stream>>>(
query.data_ptr<int8_t>(),
key.data_ptr<int8_t>(),
reinterpret_cast<int8_t*>(value.data_ptr()),
Expand Down
9 changes: 6 additions & 3 deletions csrc/qattn/qk_int_sv_f8_cuda_sm90.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <cuda.h>
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>

#include "../wgmma.cuh"
Expand Down Expand Up @@ -614,6 +615,7 @@ torch::Tensor qk_int8_sv_f8_accum_f32_attn_inst_buf(
int stride_bz_v = value.stride(0);
int stride_bz_o = output.stride(0);

cudaStream_t stream = at::cuda::getCurrentCUDAStream();
int qo_len, kv_len, padded_kv_len, num_qo_heads, num_kv_heads;
int stride_seq_q, stride_h_q, stride_seq_k, stride_h_k, stride_h_v, stride_d_v, stride_seq_o, stride_h_o;

Expand Down Expand Up @@ -717,7 +719,7 @@ torch::Tensor qk_int8_sv_f8_accum_f32_attn_inst_buf(
cudaFuncAttributeMaxDynamicSharedMemorySize, sMemSize);

dim3 grid(div_ceil(qo_len, CTA_Q), num_qo_heads, batch_size);
kernel<<<grid, NUM_THREADS, sMemSize>>>(
kernel<<<grid, NUM_THREADS, sMemSize, stream>>>(
tma_map_Q,
tma_map_K,
tma_map_V,
Expand Down Expand Up @@ -790,6 +792,7 @@ torch::Tensor qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(
int stride_bz_v = value.stride(0);
int stride_bz_o = output.stride(0);

cudaStream_t stream = at::cuda::getCurrentCUDAStream();
int qo_len, kv_len, padded_kv_len, num_qo_heads, num_kv_heads;
int stride_seq_q, stride_h_q, stride_seq_k, stride_h_k, stride_h_v, stride_d_v, stride_seq_o, stride_h_o;

Expand Down Expand Up @@ -895,7 +898,7 @@ torch::Tensor qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(
cudaFuncAttributeMaxDynamicSharedMemorySize, sMemSize);

dim3 grid(div_ceil(qo_len, CTA_Q), num_qo_heads, batch_size);
kernel<<<grid, NUM_THREADS, sMemSize>>>(
kernel<<<grid, NUM_THREADS, sMemSize, stream>>>(
tma_map_Q,
tma_map_K,
tma_map_V,
Expand All @@ -913,4 +916,4 @@ torch::Tensor qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(
});

return lse;
}
}