diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 26b701fea6fbb..4f7dd8c11e655 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -937,6 +937,7 @@ Do not modify directly.* |FusedConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*in* Z:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |FusedMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |GatedRelativePositionBias|*in* query_layer:**T**
*in* query_bias:**T**
*in* rel_pos:**T**
*in* weight:**T**
*in* bias:**T**
*in* eco_a:**T**
*in* token_offset:**M**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| +|GatherBlockQuantized|*in* data:**T1**
*in* indices:**Tind**
*in* scales:**T2**
*in* zero_points:**T1**
*out* output:**T2**|1+|**T1** = tensor(int4), tensor(uint4), tensor(uint8)
**T2** = tensor(bfloat16), tensor(float), tensor(float16)
**Tind** = tensor(int32), tensor(int64)| |Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |GemmFloat8|*in* A:**TA**
*in* B:**TB**
*in* C:**TC**
*in* scaleA:**TS**
*in* scaleB:**TS**
*in* scaleY:**TS**
*out* Y:**TR**|1+|**TA** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)
**TB** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)
**TR** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)
**TS** = tensor(float)| |GemmaRotaryEmbedding|*in* emb:**U**
*in* q:**T**
*in* q_rot:**T**
*in* k:**T**
*in* k_rot:**T**
*out* output1:**T**
*out* output2:**T**|1+|**T** = tensor(float16)
**U** = tensor(float)| diff --git a/include/onnxruntime/core/framework/op_kernel.h b/include/onnxruntime/core/framework/op_kernel.h index 375f0a4dc8dd2..e59a803d97629 100644 --- a/include/onnxruntime/core/framework/op_kernel.h +++ b/include/onnxruntime/core/framework/op_kernel.h @@ -305,6 +305,24 @@ using BuildKernelCreateInfoFn = KernelCreateInfo (*)(); static_cast([](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { out = std::make_unique<__VA_ARGS__>(info); return Status::OK(); })); \ } +#define ONNX_OPERATOR_THREE_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type1, type2, type3, name) \ + provider##_##name##_##domain##_ver##ver##_##type1##_##type2##_##type3 + +#define ONNX_OPERATOR_THREE_TYPED_KERNEL_EX(name, domain, ver, type1, type2, type3, provider, builder, ...) \ + class ONNX_OPERATOR_THREE_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type1, type2, type3, name); \ + template <> \ + KernelCreateInfo \ + BuildKernelCreateInfo() { \ + return KernelCreateInfo( \ + builder.SetName(#name) \ + .SetDomain(domain) \ + .SinceVersion(ver) \ + .Provider(provider) \ + .Build(), \ + static_cast([](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { \ + out = std::make_unique<__VA_ARGS__>(info); return Status::OK(); })); \ + } + #define ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(provider, domain, startver, endver, type, name) \ provider##_##name##_##domain##_ver##startver##_##endver##_##type diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index a50ee907c302b..36d6fc378d45e 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -15,6 +15,10 @@ using namespace onnxruntime::common; ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, start_ver, end_ver, type, name) #define CUDA_MS_OP_VERSIONED_CLASS_NAME(start_ver, end_ver, name) \ ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, start_ver, end_ver, name) +#define CUDA_MS_OP_TWO_TYPED_CLASS_NAME(ver, type1, type2, name) \ + ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, ver, type1, type2, name) +#define CUDA_MS_OP_THREE_TYPED_CLASS_NAME(ver, type1, type2, type3, name) \ + ONNX_OPERATOR_THREE_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, ver, type1, type2, type3, name) #define CUDA_ONNX_OP_TYPED_CLASS_NAME(ver, type, name) \ ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, ver, type, name) @@ -186,6 +190,25 @@ class CUDA_MS_OP_CLASS_NAME(1, GemmFloat8); class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, SparseAttention); class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, SparseAttention); +class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, uint8_t, float, int32_t, GatherBlockQuantized); +class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, uint8_t, MLFloat16, int32_t, GatherBlockQuantized); +class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, uint8_t, BFloat16, int32_t, GatherBlockQuantized); +class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, uint8_t, float, int64_t, GatherBlockQuantized); +class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, uint8_t, MLFloat16, int64_t, GatherBlockQuantized); +class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, uint8_t, BFloat16, int64_t, GatherBlockQuantized); +class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, UInt4x2, float, int32_t, GatherBlockQuantized); +class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, UInt4x2, MLFloat16, int32_t, GatherBlockQuantized); +class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, UInt4x2, BFloat16, int32_t, GatherBlockQuantized); +class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, UInt4x2, float, int64_t, GatherBlockQuantized); +class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, UInt4x2, MLFloat16, int64_t, GatherBlockQuantized); +class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, UInt4x2, BFloat16, int64_t, GatherBlockQuantized); +class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, Int4x2, float, int32_t, GatherBlockQuantized); +class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, Int4x2, MLFloat16, int32_t, GatherBlockQuantized); +class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, Int4x2, BFloat16, int32_t, GatherBlockQuantized); +class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, Int4x2, float, int64_t, GatherBlockQuantized); +class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, Int4x2, MLFloat16, int64_t, GatherBlockQuantized); +class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, Int4x2, BFloat16, int64_t, GatherBlockQuantized); + #ifdef ENABLE_ATEN class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kPytorchAtenDomain, 1, ATen); #endif @@ -408,6 +431,24 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, #ifdef ENABLE_ATEN BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cuda/quantization/gather_block_quantized.cc b/onnxruntime/contrib_ops/cuda/quantization/gather_block_quantized.cc new file mode 100644 index 0000000000000..bad44b260b7b2 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/gather_block_quantized.cc @@ -0,0 +1,152 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cuda/cuda_common.h" +#include "contrib_ops/cuda/quantization/gather_block_quantized.h" +#include "contrib_ops/cuda/quantization/gather_block_quantized.cuh" + +namespace onnxruntime { +namespace contrib { +namespace cuda { +using namespace onnxruntime::cuda; + +#define REGISTER_GATHERBLOCKQUANTIZED(T1, T2, Tind) \ + ONNX_OPERATOR_THREE_TYPED_KERNEL_EX( \ + GatherBlockQuantized, \ + kMSDomain, 1, \ + T1, T2, Tind, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T2", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("Tind", DataTypeImpl::GetTensorType()), \ + GatherBlockQuantized); + +REGISTER_GATHERBLOCKQUANTIZED(uint8_t, float, int32_t); +REGISTER_GATHERBLOCKQUANTIZED(uint8_t, float, int64_t); +REGISTER_GATHERBLOCKQUANTIZED(UInt4x2, float, int32_t); +REGISTER_GATHERBLOCKQUANTIZED(UInt4x2, float, int64_t); +REGISTER_GATHERBLOCKQUANTIZED(Int4x2, float, int32_t); +REGISTER_GATHERBLOCKQUANTIZED(Int4x2, float, int64_t); + +REGISTER_GATHERBLOCKQUANTIZED(uint8_t, MLFloat16, int64_t); +REGISTER_GATHERBLOCKQUANTIZED(uint8_t, MLFloat16, int32_t); +REGISTER_GATHERBLOCKQUANTIZED(UInt4x2, MLFloat16, int32_t); +REGISTER_GATHERBLOCKQUANTIZED(UInt4x2, MLFloat16, int64_t); +REGISTER_GATHERBLOCKQUANTIZED(Int4x2, MLFloat16, int32_t); +REGISTER_GATHERBLOCKQUANTIZED(Int4x2, MLFloat16, int64_t); + +REGISTER_GATHERBLOCKQUANTIZED(UInt4x2, BFloat16, int32_t); +REGISTER_GATHERBLOCKQUANTIZED(UInt4x2, BFloat16, int64_t); +REGISTER_GATHERBLOCKQUANTIZED(uint8_t, BFloat16, int32_t); +REGISTER_GATHERBLOCKQUANTIZED(uint8_t, BFloat16, int64_t); +REGISTER_GATHERBLOCKQUANTIZED(Int4x2, BFloat16, int32_t); +REGISTER_GATHERBLOCKQUANTIZED(Int4x2, BFloat16, int64_t); + +template +GatherBlockQuantized::GatherBlockQuantized(const OpKernelInfo& info) : CudaKernel(info) { + ORT_ENFORCE(info.GetAttr("bits", &bits_).IsOK()); + + block_size_ = info.GetAttrOrDefault("block_size", 0); + gather_axis_ = info.GetAttrOrDefault("gather_axis", 0); + quantize_axis_ = info.GetAttrOrDefault("quantize_axis", 0); + + // If block size is set, it has to be no smaller than 16 and must be power of 2 + // block_size_ & (block_size_ - 1) == 0 checks if block_size_ only has 1 bit set + ORT_ENFORCE(block_size_ == 0 || (block_size_ >= 16 && ((block_size_ & (block_size_ - 1)) == 0))); +} + +template +Status GatherBlockQuantized::ComputeInternal(OpKernelContext* ctx) const { + const Tensor* data = ctx->Input(0); + const Tensor* indices = ctx->Input(1); + const Tensor* scales = ctx->Input(2); + const Tensor* zero_points = ctx->Input(3); + + auto data_shape = data->Shape().GetDims(); + auto data_rank = data->Shape().NumDimensions(); + + auto indices_shape = indices->Shape().GetDims(); + auto indices_rank = indices->Shape().NumDimensions(); + + ORT_ENFORCE(quantize_axis_ == data_rank - 1); + + TensorShapeVector output_shape; + output_shape.reserve(data_rank - 1 + indices_rank); + + // Dimension after gather axis + int64_t after_gather_dim = 1; + + // Dimension of indices + int64_t ind_dim = 1; + + // 1) dims before gather_axis + for (int64_t i = 0; i < gather_axis_; ++i) { + output_shape.push_back(data_shape[i]); + } + + // 2) all of indices.shape + for (auto dim : indices_shape) { + output_shape.push_back(dim); + ind_dim *= dim; + } + + // 3) dims after gather_axis + for (int64_t i = gather_axis_ + 1; i < data_rank; ++i) { + output_shape.push_back(data_shape[i]); + after_gather_dim *= data_shape[i]; + } + + // Special int4‐in‐uint8 packing tweak: expand the last dim by components + if constexpr (std::is_same_v) { + uint32_t components = 8 / static_cast(bits_); + if (components > 1) { + output_shape.back() *= components; + } + } + + Tensor* output = ctx->Output(0, TensorShape(output_shape)); + + int64_t N = 1; + for (auto dim : output_shape) { + N *= dim; + } + + const auto* data_ptr = data->Data(); + const auto* indices_ptr = indices->Data(); + const T1* zero_points_ptr = nullptr; + if (zero_points != nullptr) { + zero_points_ptr = zero_points->Data(); + } + + GatherBlockQuantizedParam param; + param.stream = Stream(ctx); + param.after_gather_dim = after_gather_dim; + param.gather_axis_dim = data_shape[gather_axis_]; + param.ind_dim = ind_dim; + param.bits = bits_; + param.block_size = block_size_; + param.gather_axis = gather_axis_; + param.N = N; + + const auto dequantized_type = scales->GetElementType(); + if (dequantized_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + const auto* scales_ptr = static_cast(scales->DataRaw()); + auto* output_ptr = static_cast(output->MutableDataRaw()); + LaunchGatherBlockQuantizedKernel(data_ptr, indices_ptr, scales_ptr, zero_points_ptr, output_ptr, param); + } else if (dequantized_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { + const auto* scales_ptr = static_cast(scales->DataRaw()); + auto* output_ptr = static_cast(output->MutableDataRaw()); + LaunchGatherBlockQuantizedKernel(data_ptr, indices_ptr, scales_ptr, zero_points_ptr, output_ptr, param); + } else if (dequantized_type == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) { + const auto* scales_ptr = static_cast(scales->DataRaw()); + auto* output_ptr = static_cast(output->MutableDataRaw()); + LaunchGatherBlockQuantizedKernel(data_ptr, indices_ptr, scales_ptr, zero_points_ptr, output_ptr, param); + } + + return Status::OK(); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/gather_block_quantized.cu b/onnxruntime/contrib_ops/cuda/quantization/gather_block_quantized.cu new file mode 100644 index 0000000000000..39286c63e9a08 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/gather_block_quantized.cu @@ -0,0 +1,111 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/cu_inc/common.cuh" +#include "gather_block_quantized.cuh" + +using namespace onnxruntime::cuda; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +__device__ inline int64_t get_val(const T1* data, int64_t idx, int64_t bits, bool sign) { + const uint32_t mask = (1U << bits) - 1; + const int64_t elems_per_byte = 8 / bits; + const int64_t byte_idx = idx / elems_per_byte; + const int64_t bit_offset = (idx % elems_per_byte) * bits; + const uint8_t byte = reinterpret_cast(data)[byte_idx]; + int64_t val = (byte >> bit_offset) & mask; + + // Sign-extend based on bit width + if (sign) { + if (val & (1 << (bits - 1))) { + val |= -1LL << bits; + } + } + + return val; +} + +template +__global__ void GatherBlockQuantizedKernel( + const T1* data, // packed 4-bit codes, one code per element + const Tind* indices, + const T2* scales, // one float scale per block + const T1* zero_points, // packed 4-bit zero-points, one per block + T2* output, + int64_t after_gather_dim, + int64_t gather_axis_dim, + int64_t ind_dim, + int64_t bits, + int64_t block_size, + int64_t gather_axis, + int64_t N, + bool sign) { + int64_t out_idx = blockDim.x * blockIdx.x + threadIdx.x; + if (out_idx >= N) return; + + // compute which input element this thread corresponds to: + int64_t idx_before = out_idx / (after_gather_dim * ind_dim); + int64_t idx_after = out_idx % after_gather_dim; + int64_t idx = (out_idx % (after_gather_dim * ind_dim)) / after_gather_dim; + int64_t idx_at_g = indices[idx]; + int64_t in_idx = idx_before * gather_axis_dim * after_gather_dim + idx_at_g * after_gather_dim + idx_after; + + int64_t block_id = in_idx / block_size; + + // unpack zero_point for this block: + int64_t offset = 0; + if (zero_points) { + offset = get_val(zero_points, block_id, bits, sign); + } + + // unpack the raw quantized code for this element: + int64_t weight = get_val(data, in_idx, bits, sign); + + // apply dequantization: + output[out_idx] = static_cast(weight - offset) * scales[block_id]; +} + +template +void LaunchGatherBlockQuantizedKernel(const T1* data, + const Tind* indices, + const T2* scales, + const T1* zero_points, + T2* output, + GatherBlockQuantizedParam param) { + // Require quant_axis is last dim + int blocksPerGrid = (int)(ceil(static_cast(param.N) / GridDim::maxThreadsPerBlock)); + bool sign = std::is_same::value; + + GatherBlockQuantizedKernel<<>>(data, indices, scales, zero_points, output, + param.after_gather_dim, param.gather_axis_dim, param.ind_dim, param.bits, param.block_size, param.gather_axis, param.N, sign); +} + +template void LaunchGatherBlockQuantizedKernel(const uint8_t*, const int32_t*, const float*, const uint8_t*, float*, GatherBlockQuantizedParam); +template void LaunchGatherBlockQuantizedKernel(const uint8_t*, const int64_t*, const float*, const uint8_t*, float*, GatherBlockQuantizedParam); +template void LaunchGatherBlockQuantizedKernel(const UInt4x2*, const int32_t*, const float*, const UInt4x2*, float*, GatherBlockQuantizedParam); +template void LaunchGatherBlockQuantizedKernel(const UInt4x2*, const int64_t*, const float*, const UInt4x2*, float*, GatherBlockQuantizedParam); +template void LaunchGatherBlockQuantizedKernel(const Int4x2*, const int32_t*, const float*, const Int4x2*, float*, GatherBlockQuantizedParam); +template void LaunchGatherBlockQuantizedKernel(const Int4x2*, const int64_t*, const float*, const Int4x2*, float*, GatherBlockQuantizedParam); + +template void LaunchGatherBlockQuantizedKernel(const uint8_t*, const int32_t*, const half*, const uint8_t*, half*, GatherBlockQuantizedParam); +template void LaunchGatherBlockQuantizedKernel(const uint8_t*, const int64_t*, const half*, const uint8_t*, half*, GatherBlockQuantizedParam); +template void LaunchGatherBlockQuantizedKernel(const UInt4x2*, const int32_t*, const half*, const UInt4x2*, half*, GatherBlockQuantizedParam); +template void LaunchGatherBlockQuantizedKernel(const UInt4x2*, const int64_t*, const half*, const UInt4x2*, half*, GatherBlockQuantizedParam); +template void LaunchGatherBlockQuantizedKernel(const Int4x2*, const int32_t*, const half*, const Int4x2*, half*, GatherBlockQuantizedParam); +template void LaunchGatherBlockQuantizedKernel(const Int4x2*, const int64_t*, const half*, const Int4x2*, half*, GatherBlockQuantizedParam); + +template void LaunchGatherBlockQuantizedKernel(const uint8_t*, const int32_t*, const BFloat16*, const uint8_t*, BFloat16*, GatherBlockQuantizedParam); +template void LaunchGatherBlockQuantizedKernel(const uint8_t*, const int64_t*, const BFloat16*, const uint8_t*, BFloat16*, GatherBlockQuantizedParam); +template void LaunchGatherBlockQuantizedKernel(const UInt4x2*, const int32_t*, const BFloat16*, const UInt4x2*, BFloat16*, GatherBlockQuantizedParam); +template void LaunchGatherBlockQuantizedKernel(const UInt4x2*, const int64_t*, const BFloat16*, const UInt4x2*, BFloat16*, GatherBlockQuantizedParam); +template void LaunchGatherBlockQuantizedKernel(const Int4x2*, const int32_t*, const BFloat16*, const Int4x2*, BFloat16*, GatherBlockQuantizedParam); +template void LaunchGatherBlockQuantizedKernel(const Int4x2*, const int64_t*, const BFloat16*, const Int4x2*, BFloat16*, GatherBlockQuantizedParam); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/gather_block_quantized.cuh b/onnxruntime/contrib_ops/cuda/quantization/gather_block_quantized.cuh new file mode 100644 index 0000000000000..f5dea3b1f2d9d --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/gather_block_quantized.cuh @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/cuda/cuda_kernel.h" +#include "core/providers/cuda/shared_inc/cuda_utils.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +struct GatherBlockQuantizedParam { + cudaStream_t stream; + int64_t after_gather_dim; + int64_t gather_axis_dim; + int64_t ind_dim; + int64_t bits; + int64_t block_size; + int64_t gather_axis; + int64_t N; +}; + +template +void LaunchGatherBlockQuantizedKernel(const T1* data, + const Tind* indices, + const T2* scales, + const T1* zero_points, + T2* output, + GatherBlockQuantizedParam param); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/gather_block_quantized.h b/onnxruntime/contrib_ops/cuda/quantization/gather_block_quantized.h new file mode 100644 index 0000000000000..7718b6dd06765 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/gather_block_quantized.h @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/cuda/cuda_kernel.h" + +#include + +using namespace onnxruntime::cuda; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +using namespace onnxruntime::cuda; + +template +class GatherBlockQuantized final : public CudaKernel { + public: + GatherBlockQuantized(const OpKernelInfo& info); + Status ComputeInternal(OpKernelContext* context) const override; + + private: + int64_t bits_; + int64_t block_size_; + int64_t gather_axis_; + int64_t quantize_axis_; +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc b/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc old mode 100755 new mode 100644 index 334be3e03b483..4b586e24c9bd3 --- a/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc +++ b/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc @@ -10,6 +10,7 @@ #include "core/common/common.h" #include "core/framework/execution_provider.h" +#include "test/common/cuda_op_test_utils.h" #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" #include "test/util/include/default_providers.h" @@ -102,6 +103,7 @@ void RunGatherBlockQuantized(const std::vector& data, const std::vector& output_shape, OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess, bool touch_on_device_data = false) { + (void)touch_on_device_data; CheckDataAndShape(data, data_shape, "data in RunGatherBlockQuantized"); CheckDataAndShape(indices, indices_shape, "indices in RunGatherBlockQuantized"); CheckDataAndShape(scales, scales_shape, "scales in RunGatherBlockQuantized"); @@ -127,12 +129,15 @@ void RunGatherBlockQuantized(const std::vector& data, test.AddOutput("output", output_shape, output); - if (touch_on_device_data) { - // test would need to see data on device - test.Run(expect_result, "", {kWebGpuExecutionProvider}, nullptr); + bool enable_cuda = HasCudaEnvironment(0); + std::vector> eps; + if (enable_cuda) { + eps.push_back(DefaultCudaExecutionProvider()); } else { - test.Run(expect_result, ""); + eps.push_back(DefaultCpuExecutionProvider()); } + + test.Run(expect_result, "", {}, nullptr, &eps); }; run_test(false); @@ -275,6 +280,7 @@ void Test_Fail_WithZeroPoints(int64_t gather_axis, gather_axis, quantize_axis, block_size, bits, output, output_shape, false); } +#ifndef USE_CUDA TEST(GatherBlockQuantizedOpTest, UnsupportedTypes) { Test_Fail_WithZeroPoints(0, 2, 16); Test_Fail_WithZeroPoints(0, 2, 16); @@ -289,6 +295,7 @@ TEST(GatherBlockQuantizedOpTest, UnsupportedTypes) { Test_Fail_WithZeroPoints(0, 2, 16); Test_Fail_WithZeroPoints(0, 2, 16); } +#endif template void Test_Fail_WithoutZeroPoints(int64_t gather_axis, @@ -317,6 +324,7 @@ void Test_Fail_WithoutZeroPoints(int64_t gather_axis, gather_axis, quantize_axis, block_size, bits, output, output_shape, false); } +#ifndef USE_CUDA TEST(GatherBlockQuantizedOpTest, UnsupportedUInt8DataType) { // Gather on axis other than 0 is not supported with uint8_t Test_Fail_WithoutZeroPoints(1, 2, 16); @@ -349,6 +357,7 @@ TEST(GatherBlockQuantizedOpTest, NotSupportedBits) { Test_Fail_WithZeroPoints(0, 2, 16, 6); Test_Fail_WithZeroPoints(0, 2, 16, 7); } +#endif template void Test_ShapeMismatch_WithZeroPoints() { @@ -377,11 +386,13 @@ void Test_ShapeMismatch_WithZeroPoints() { gather_axis, quantize_axis, block_size, bits, output, output_shape, false); } +#ifndef USE_CUDA TEST(GatherBlockQuantizedOpTest, ShapeMismatch) { Test_ShapeMismatch_WithZeroPoints(); Test_ShapeMismatch_WithZeroPoints(); Test_ShapeMismatch_WithZeroPoints(); } +#endif template void Test_InvalidIndices_WithZeroPoints() { @@ -410,11 +421,13 @@ void Test_InvalidIndices_WithZeroPoints() { gather_axis, quantize_axis, block_size, bits, output, output_shape, false, true); } +#ifndef USE_CUDA TEST(GatherBlockQuantizedOpTest, InvalidIndices) { Test_InvalidIndices_WithZeroPoints(); Test_InvalidIndices_WithZeroPoints(); Test_InvalidIndices_WithZeroPoints(); } +#endif template void Test_GatherAxis0_WithZeroPoints(int bits = 4) { @@ -447,6 +460,7 @@ void Test_GatherAxis0_WithZeroPoints(int bits = 4) { -3, -1, block_size, bits, output, output_shape, true); } +#ifndef USE_CUDA TEST(GatherBlockQuantizedOpTest, GatherAxis0WithZeroPoints) { Test_GatherAxis0_WithZeroPoints(); Test_GatherAxis0_WithZeroPoints(); @@ -457,6 +471,7 @@ TEST(GatherBlockQuantizedOpTest, GatherAxis0WithZeroPoints) { Test_GatherAxis0_WithZeroPoints(); Test_GatherAxis0_WithZeroPoints(); } +#endif template void Test_GatherAxis0_WithZeroPoints_Uint8(int bits = 4) { @@ -490,6 +505,7 @@ void Test_GatherAxis0_WithZeroPoints_Uint8(int bits = 4) { -3, -1, block_size, bits, output, output_shape, true); } +#ifndef USE_CUDA TEST(GatherBlockQuantizedOpTest, GatherAxis0WithZeroPoints_4Bits) { Test_GatherAxis0_WithZeroPoints_Uint8(); Test_GatherAxis0_WithZeroPoints_Uint8(); @@ -499,6 +515,7 @@ TEST(GatherBlockQuantizedOpTest, GatherAxis0WithZeroPoints_8Bits) { Test_GatherAxis0_WithZeroPoints_Uint8(8); Test_GatherAxis0_WithZeroPoints_Uint8(8); } +#endif template void Test_GatherAxis0_NoZeroPoints(int bits = 4) { @@ -533,6 +550,7 @@ void Test_GatherAxis0_NoZeroPoints(int bits = 4) { -3, -1, block_size, bits, output, output_shape, true); } +#ifndef USE_CUDA TEST(GatherBlockQuantizedOpTest, GatherAxis0NoZeroPoints) { Test_GatherAxis0_NoZeroPoints(); Test_GatherAxis0_NoZeroPoints(); @@ -551,6 +569,7 @@ TEST(GatherBlockQuantizedOpTest, GatherAxis0NoZeroPoints_8Bits) { Test_GatherAxis0_NoZeroPoints(8); Test_GatherAxis0_NoZeroPoints(8); } +#endif template void Test_GatherAxis1_WithZeroPoints() { @@ -585,6 +604,7 @@ void Test_GatherAxis1_WithZeroPoints() { -2, -2, block_size, bits, output, output_shape, true); } +#ifndef USE_CUDA TEST(GatherBlockQuantizedOpTest, GatherAxis1) { Test_GatherAxis1_WithZeroPoints(); Test_GatherAxis1_WithZeroPoints(); @@ -595,6 +615,7 @@ TEST(GatherBlockQuantizedOpTest, GatherAxis1) { Test_GatherAxis1_WithZeroPoints(); Test_GatherAxis1_WithZeroPoints(); } +#endif template void Test_GatherAxis2_WithZeroPoints() { @@ -629,6 +650,7 @@ void Test_GatherAxis2_WithZeroPoints() { -1, -3, block_size, bits, output, output_shape, true); } +#ifndef USE_CUDA TEST(GatherBlockQuantizedOpTest, GatherAxis2) { Test_GatherAxis2_WithZeroPoints(); Test_GatherAxis2_WithZeroPoints(); @@ -639,6 +661,7 @@ TEST(GatherBlockQuantizedOpTest, GatherAxis2) { Test_GatherAxis2_WithZeroPoints(); Test_GatherAxis2_WithZeroPoints(); } +#endif } // namespace test } // namespace onnxruntime