diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 26b701fea6fbb..4f7dd8c11e655 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -937,6 +937,7 @@ Do not modify directly.*
|FusedConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*in* Z:**T**
*out* Y:**T**|1+|**T** = tensor(float)|
|FusedMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|GatedRelativePositionBias|*in* query_layer:**T**
*in* query_bias:**T**
*in* rel_pos:**T**
*in* weight:**T**
*in* bias:**T**
*in* eco_a:**T**
*in* token_offset:**M**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
+|GatherBlockQuantized|*in* data:**T1**
*in* indices:**Tind**
*in* scales:**T2**
*in* zero_points:**T1**
*out* output:**T2**|1+|**T1** = tensor(int4), tensor(uint4), tensor(uint8)
**T2** = tensor(bfloat16), tensor(float), tensor(float16)
**Tind** = tensor(int32), tensor(int64)|
|Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|GemmFloat8|*in* A:**TA**
*in* B:**TB**
*in* C:**TC**
*in* scaleA:**TS**
*in* scaleB:**TS**
*in* scaleY:**TS**
*out* Y:**TR**|1+|**TA** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)
**TB** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)
**TR** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)
**TS** = tensor(float)|
|GemmaRotaryEmbedding|*in* emb:**U**
*in* q:**T**
*in* q_rot:**T**
*in* k:**T**
*in* k_rot:**T**
*out* output1:**T**
*out* output2:**T**|1+|**T** = tensor(float16)
**U** = tensor(float)|
diff --git a/include/onnxruntime/core/framework/op_kernel.h b/include/onnxruntime/core/framework/op_kernel.h
index 375f0a4dc8dd2..e59a803d97629 100644
--- a/include/onnxruntime/core/framework/op_kernel.h
+++ b/include/onnxruntime/core/framework/op_kernel.h
@@ -305,6 +305,24 @@ using BuildKernelCreateInfoFn = KernelCreateInfo (*)();
static_cast([](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { out = std::make_unique<__VA_ARGS__>(info); return Status::OK(); })); \
}
+#define ONNX_OPERATOR_THREE_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type1, type2, type3, name) \
+ provider##_##name##_##domain##_ver##ver##_##type1##_##type2##_##type3
+
+#define ONNX_OPERATOR_THREE_TYPED_KERNEL_EX(name, domain, ver, type1, type2, type3, provider, builder, ...) \
+ class ONNX_OPERATOR_THREE_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type1, type2, type3, name); \
+ template <> \
+ KernelCreateInfo \
+ BuildKernelCreateInfo() { \
+ return KernelCreateInfo( \
+ builder.SetName(#name) \
+ .SetDomain(domain) \
+ .SinceVersion(ver) \
+ .Provider(provider) \
+ .Build(), \
+ static_cast([](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { \
+ out = std::make_unique<__VA_ARGS__>(info); return Status::OK(); })); \
+ }
+
#define ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(provider, domain, startver, endver, type, name) \
provider##_##name##_##domain##_ver##startver##_##endver##_##type
diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
index a50ee907c302b..36d6fc378d45e 100644
--- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
+++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
@@ -15,6 +15,10 @@ using namespace onnxruntime::common;
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, start_ver, end_ver, type, name)
#define CUDA_MS_OP_VERSIONED_CLASS_NAME(start_ver, end_ver, name) \
ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, start_ver, end_ver, name)
+#define CUDA_MS_OP_TWO_TYPED_CLASS_NAME(ver, type1, type2, name) \
+ ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, ver, type1, type2, name)
+#define CUDA_MS_OP_THREE_TYPED_CLASS_NAME(ver, type1, type2, type3, name) \
+ ONNX_OPERATOR_THREE_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, ver, type1, type2, type3, name)
#define CUDA_ONNX_OP_TYPED_CLASS_NAME(ver, type, name) \
ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, ver, type, name)
@@ -186,6 +190,25 @@ class CUDA_MS_OP_CLASS_NAME(1, GemmFloat8);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, SparseAttention);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, SparseAttention);
+class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, uint8_t, float, int32_t, GatherBlockQuantized);
+class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, uint8_t, MLFloat16, int32_t, GatherBlockQuantized);
+class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, uint8_t, BFloat16, int32_t, GatherBlockQuantized);
+class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, uint8_t, float, int64_t, GatherBlockQuantized);
+class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, uint8_t, MLFloat16, int64_t, GatherBlockQuantized);
+class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, uint8_t, BFloat16, int64_t, GatherBlockQuantized);
+class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, UInt4x2, float, int32_t, GatherBlockQuantized);
+class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, UInt4x2, MLFloat16, int32_t, GatherBlockQuantized);
+class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, UInt4x2, BFloat16, int32_t, GatherBlockQuantized);
+class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, UInt4x2, float, int64_t, GatherBlockQuantized);
+class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, UInt4x2, MLFloat16, int64_t, GatherBlockQuantized);
+class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, UInt4x2, BFloat16, int64_t, GatherBlockQuantized);
+class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, Int4x2, float, int32_t, GatherBlockQuantized);
+class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, Int4x2, MLFloat16, int32_t, GatherBlockQuantized);
+class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, Int4x2, BFloat16, int32_t, GatherBlockQuantized);
+class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, Int4x2, float, int64_t, GatherBlockQuantized);
+class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, Int4x2, MLFloat16, int64_t, GatherBlockQuantized);
+class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, Int4x2, BFloat16, int64_t, GatherBlockQuantized);
+
#ifdef ENABLE_ATEN
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kPytorchAtenDomain, 1, ATen);
#endif
@@ -408,6 +431,24 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
#ifdef ENABLE_ATEN
BuildKernelCreateInfo,
diff --git a/onnxruntime/contrib_ops/cuda/quantization/gather_block_quantized.cc b/onnxruntime/contrib_ops/cuda/quantization/gather_block_quantized.cc
new file mode 100644
index 0000000000000..bad44b260b7b2
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/quantization/gather_block_quantized.cc
@@ -0,0 +1,152 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "core/providers/cuda/cuda_common.h"
+#include "contrib_ops/cuda/quantization/gather_block_quantized.h"
+#include "contrib_ops/cuda/quantization/gather_block_quantized.cuh"
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+using namespace onnxruntime::cuda;
+
+#define REGISTER_GATHERBLOCKQUANTIZED(T1, T2, Tind) \
+ ONNX_OPERATOR_THREE_TYPED_KERNEL_EX( \
+ GatherBlockQuantized, \
+ kMSDomain, 1, \
+ T1, T2, Tind, \
+ kCudaExecutionProvider, \
+ (*KernelDefBuilder::Create()) \
+ .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \
+ .TypeConstraint("T2", DataTypeImpl::GetTensorType()) \
+ .TypeConstraint("Tind", DataTypeImpl::GetTensorType()), \
+ GatherBlockQuantized);
+
+REGISTER_GATHERBLOCKQUANTIZED(uint8_t, float, int32_t);
+REGISTER_GATHERBLOCKQUANTIZED(uint8_t, float, int64_t);
+REGISTER_GATHERBLOCKQUANTIZED(UInt4x2, float, int32_t);
+REGISTER_GATHERBLOCKQUANTIZED(UInt4x2, float, int64_t);
+REGISTER_GATHERBLOCKQUANTIZED(Int4x2, float, int32_t);
+REGISTER_GATHERBLOCKQUANTIZED(Int4x2, float, int64_t);
+
+REGISTER_GATHERBLOCKQUANTIZED(uint8_t, MLFloat16, int64_t);
+REGISTER_GATHERBLOCKQUANTIZED(uint8_t, MLFloat16, int32_t);
+REGISTER_GATHERBLOCKQUANTIZED(UInt4x2, MLFloat16, int32_t);
+REGISTER_GATHERBLOCKQUANTIZED(UInt4x2, MLFloat16, int64_t);
+REGISTER_GATHERBLOCKQUANTIZED(Int4x2, MLFloat16, int32_t);
+REGISTER_GATHERBLOCKQUANTIZED(Int4x2, MLFloat16, int64_t);
+
+REGISTER_GATHERBLOCKQUANTIZED(UInt4x2, BFloat16, int32_t);
+REGISTER_GATHERBLOCKQUANTIZED(UInt4x2, BFloat16, int64_t);
+REGISTER_GATHERBLOCKQUANTIZED(uint8_t, BFloat16, int32_t);
+REGISTER_GATHERBLOCKQUANTIZED(uint8_t, BFloat16, int64_t);
+REGISTER_GATHERBLOCKQUANTIZED(Int4x2, BFloat16, int32_t);
+REGISTER_GATHERBLOCKQUANTIZED(Int4x2, BFloat16, int64_t);
+
+template
+GatherBlockQuantized::GatherBlockQuantized(const OpKernelInfo& info) : CudaKernel(info) {
+ ORT_ENFORCE(info.GetAttr("bits", &bits_).IsOK());
+
+ block_size_ = info.GetAttrOrDefault("block_size", 0);
+ gather_axis_ = info.GetAttrOrDefault("gather_axis", 0);
+ quantize_axis_ = info.GetAttrOrDefault("quantize_axis", 0);
+
+ // If block size is set, it has to be no smaller than 16 and must be power of 2
+ // block_size_ & (block_size_ - 1) == 0 checks if block_size_ only has 1 bit set
+ ORT_ENFORCE(block_size_ == 0 || (block_size_ >= 16 && ((block_size_ & (block_size_ - 1)) == 0)));
+}
+
+template
+Status GatherBlockQuantized::ComputeInternal(OpKernelContext* ctx) const {
+ const Tensor* data = ctx->Input(0);
+ const Tensor* indices = ctx->Input(1);
+ const Tensor* scales = ctx->Input(2);
+ const Tensor* zero_points = ctx->Input(3);
+
+ auto data_shape = data->Shape().GetDims();
+ auto data_rank = data->Shape().NumDimensions();
+
+ auto indices_shape = indices->Shape().GetDims();
+ auto indices_rank = indices->Shape().NumDimensions();
+
+ ORT_ENFORCE(quantize_axis_ == data_rank - 1);
+
+ TensorShapeVector output_shape;
+ output_shape.reserve(data_rank - 1 + indices_rank);
+
+ // Dimension after gather axis
+ int64_t after_gather_dim = 1;
+
+ // Dimension of indices
+ int64_t ind_dim = 1;
+
+ // 1) dims before gather_axis
+ for (int64_t i = 0; i < gather_axis_; ++i) {
+ output_shape.push_back(data_shape[i]);
+ }
+
+ // 2) all of indices.shape
+ for (auto dim : indices_shape) {
+ output_shape.push_back(dim);
+ ind_dim *= dim;
+ }
+
+ // 3) dims after gather_axis
+ for (int64_t i = gather_axis_ + 1; i < data_rank; ++i) {
+ output_shape.push_back(data_shape[i]);
+ after_gather_dim *= data_shape[i];
+ }
+
+ // Special int4‐in‐uint8 packing tweak: expand the last dim by components
+ if constexpr (std::is_same_v) {
+ uint32_t components = 8 / static_cast(bits_);
+ if (components > 1) {
+ output_shape.back() *= components;
+ }
+ }
+
+ Tensor* output = ctx->Output(0, TensorShape(output_shape));
+
+ int64_t N = 1;
+ for (auto dim : output_shape) {
+ N *= dim;
+ }
+
+ const auto* data_ptr = data->Data();
+ const auto* indices_ptr = indices->Data();
+ const T1* zero_points_ptr = nullptr;
+ if (zero_points != nullptr) {
+ zero_points_ptr = zero_points->Data();
+ }
+
+ GatherBlockQuantizedParam param;
+ param.stream = Stream(ctx);
+ param.after_gather_dim = after_gather_dim;
+ param.gather_axis_dim = data_shape[gather_axis_];
+ param.ind_dim = ind_dim;
+ param.bits = bits_;
+ param.block_size = block_size_;
+ param.gather_axis = gather_axis_;
+ param.N = N;
+
+ const auto dequantized_type = scales->GetElementType();
+ if (dequantized_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
+ const auto* scales_ptr = static_cast(scales->DataRaw());
+ auto* output_ptr = static_cast(output->MutableDataRaw());
+ LaunchGatherBlockQuantizedKernel(data_ptr, indices_ptr, scales_ptr, zero_points_ptr, output_ptr, param);
+ } else if (dequantized_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
+ const auto* scales_ptr = static_cast(scales->DataRaw());
+ auto* output_ptr = static_cast(output->MutableDataRaw());
+ LaunchGatherBlockQuantizedKernel(data_ptr, indices_ptr, scales_ptr, zero_points_ptr, output_ptr, param);
+ } else if (dequantized_type == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) {
+ const auto* scales_ptr = static_cast(scales->DataRaw());
+ auto* output_ptr = static_cast(output->MutableDataRaw());
+ LaunchGatherBlockQuantizedKernel(data_ptr, indices_ptr, scales_ptr, zero_points_ptr, output_ptr, param);
+ }
+
+ return Status::OK();
+}
+
+} // namespace cuda
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/quantization/gather_block_quantized.cu b/onnxruntime/contrib_ops/cuda/quantization/gather_block_quantized.cu
new file mode 100644
index 0000000000000..39286c63e9a08
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/quantization/gather_block_quantized.cu
@@ -0,0 +1,111 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "core/providers/cuda/cuda_common.h"
+#include "core/providers/cuda/cu_inc/common.cuh"
+#include "gather_block_quantized.cuh"
+
+using namespace onnxruntime::cuda;
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+
+template
+__device__ inline int64_t get_val(const T1* data, int64_t idx, int64_t bits, bool sign) {
+ const uint32_t mask = (1U << bits) - 1;
+ const int64_t elems_per_byte = 8 / bits;
+ const int64_t byte_idx = idx / elems_per_byte;
+ const int64_t bit_offset = (idx % elems_per_byte) * bits;
+ const uint8_t byte = reinterpret_cast(data)[byte_idx];
+ int64_t val = (byte >> bit_offset) & mask;
+
+ // Sign-extend based on bit width
+ if (sign) {
+ if (val & (1 << (bits - 1))) {
+ val |= -1LL << bits;
+ }
+ }
+
+ return val;
+}
+
+template
+__global__ void GatherBlockQuantizedKernel(
+ const T1* data, // packed 4-bit codes, one code per element
+ const Tind* indices,
+ const T2* scales, // one float scale per block
+ const T1* zero_points, // packed 4-bit zero-points, one per block
+ T2* output,
+ int64_t after_gather_dim,
+ int64_t gather_axis_dim,
+ int64_t ind_dim,
+ int64_t bits,
+ int64_t block_size,
+ int64_t gather_axis,
+ int64_t N,
+ bool sign) {
+ int64_t out_idx = blockDim.x * blockIdx.x + threadIdx.x;
+ if (out_idx >= N) return;
+
+ // compute which input element this thread corresponds to:
+ int64_t idx_before = out_idx / (after_gather_dim * ind_dim);
+ int64_t idx_after = out_idx % after_gather_dim;
+ int64_t idx = (out_idx % (after_gather_dim * ind_dim)) / after_gather_dim;
+ int64_t idx_at_g = indices[idx];
+ int64_t in_idx = idx_before * gather_axis_dim * after_gather_dim + idx_at_g * after_gather_dim + idx_after;
+
+ int64_t block_id = in_idx / block_size;
+
+ // unpack zero_point for this block:
+ int64_t offset = 0;
+ if (zero_points) {
+ offset = get_val(zero_points, block_id, bits, sign);
+ }
+
+ // unpack the raw quantized code for this element:
+ int64_t weight = get_val(data, in_idx, bits, sign);
+
+ // apply dequantization:
+ output[out_idx] = static_cast(weight - offset) * scales[block_id];
+}
+
+template
+void LaunchGatherBlockQuantizedKernel(const T1* data,
+ const Tind* indices,
+ const T2* scales,
+ const T1* zero_points,
+ T2* output,
+ GatherBlockQuantizedParam param) {
+ // Require quant_axis is last dim
+ int blocksPerGrid = (int)(ceil(static_cast(param.N) / GridDim::maxThreadsPerBlock));
+ bool sign = std::is_same::value;
+
+ GatherBlockQuantizedKernel<<>>(data, indices, scales, zero_points, output,
+ param.after_gather_dim, param.gather_axis_dim, param.ind_dim, param.bits, param.block_size, param.gather_axis, param.N, sign);
+}
+
+template void LaunchGatherBlockQuantizedKernel(const uint8_t*, const int32_t*, const float*, const uint8_t*, float*, GatherBlockQuantizedParam);
+template void LaunchGatherBlockQuantizedKernel(const uint8_t*, const int64_t*, const float*, const uint8_t*, float*, GatherBlockQuantizedParam);
+template void LaunchGatherBlockQuantizedKernel(const UInt4x2*, const int32_t*, const float*, const UInt4x2*, float*, GatherBlockQuantizedParam);
+template void LaunchGatherBlockQuantizedKernel(const UInt4x2*, const int64_t*, const float*, const UInt4x2*, float*, GatherBlockQuantizedParam);
+template void LaunchGatherBlockQuantizedKernel(const Int4x2*, const int32_t*, const float*, const Int4x2*, float*, GatherBlockQuantizedParam);
+template void LaunchGatherBlockQuantizedKernel(const Int4x2*, const int64_t*, const float*, const Int4x2*, float*, GatherBlockQuantizedParam);
+
+template void LaunchGatherBlockQuantizedKernel(const uint8_t*, const int32_t*, const half*, const uint8_t*, half*, GatherBlockQuantizedParam);
+template void LaunchGatherBlockQuantizedKernel(const uint8_t*, const int64_t*, const half*, const uint8_t*, half*, GatherBlockQuantizedParam);
+template void LaunchGatherBlockQuantizedKernel(const UInt4x2*, const int32_t*, const half*, const UInt4x2*, half*, GatherBlockQuantizedParam);
+template void LaunchGatherBlockQuantizedKernel(const UInt4x2*, const int64_t*, const half*, const UInt4x2*, half*, GatherBlockQuantizedParam);
+template void LaunchGatherBlockQuantizedKernel(const Int4x2*, const int32_t*, const half*, const Int4x2*, half*, GatherBlockQuantizedParam);
+template void LaunchGatherBlockQuantizedKernel(const Int4x2*, const int64_t*, const half*, const Int4x2*, half*, GatherBlockQuantizedParam);
+
+template void LaunchGatherBlockQuantizedKernel(const uint8_t*, const int32_t*, const BFloat16*, const uint8_t*, BFloat16*, GatherBlockQuantizedParam);
+template void LaunchGatherBlockQuantizedKernel(const uint8_t*, const int64_t*, const BFloat16*, const uint8_t*, BFloat16*, GatherBlockQuantizedParam);
+template void LaunchGatherBlockQuantizedKernel(const UInt4x2*, const int32_t*, const BFloat16*, const UInt4x2*, BFloat16*, GatherBlockQuantizedParam);
+template void LaunchGatherBlockQuantizedKernel(const UInt4x2*, const int64_t*, const BFloat16*, const UInt4x2*, BFloat16*, GatherBlockQuantizedParam);
+template void LaunchGatherBlockQuantizedKernel(const Int4x2*, const int32_t*, const BFloat16*, const Int4x2*, BFloat16*, GatherBlockQuantizedParam);
+template void LaunchGatherBlockQuantizedKernel(const Int4x2*, const int64_t*, const BFloat16*, const Int4x2*, BFloat16*, GatherBlockQuantizedParam);
+
+} // namespace cuda
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/quantization/gather_block_quantized.cuh b/onnxruntime/contrib_ops/cuda/quantization/gather_block_quantized.cuh
new file mode 100644
index 0000000000000..f5dea3b1f2d9d
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/quantization/gather_block_quantized.cuh
@@ -0,0 +1,34 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include "core/providers/cuda/cuda_kernel.h"
+#include "core/providers/cuda/shared_inc/cuda_utils.h"
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+
+struct GatherBlockQuantizedParam {
+ cudaStream_t stream;
+ int64_t after_gather_dim;
+ int64_t gather_axis_dim;
+ int64_t ind_dim;
+ int64_t bits;
+ int64_t block_size;
+ int64_t gather_axis;
+ int64_t N;
+};
+
+template
+void LaunchGatherBlockQuantizedKernel(const T1* data,
+ const Tind* indices,
+ const T2* scales,
+ const T1* zero_points,
+ T2* output,
+ GatherBlockQuantizedParam param);
+
+} // namespace cuda
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/quantization/gather_block_quantized.h b/onnxruntime/contrib_ops/cuda/quantization/gather_block_quantized.h
new file mode 100644
index 0000000000000..7718b6dd06765
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/quantization/gather_block_quantized.h
@@ -0,0 +1,33 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include "core/providers/cuda/cuda_kernel.h"
+
+#include
+
+using namespace onnxruntime::cuda;
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+
+using namespace onnxruntime::cuda;
+
+template
+class GatherBlockQuantized final : public CudaKernel {
+ public:
+ GatherBlockQuantized(const OpKernelInfo& info);
+ Status ComputeInternal(OpKernelContext* context) const override;
+
+ private:
+ int64_t bits_;
+ int64_t block_size_;
+ int64_t gather_axis_;
+ int64_t quantize_axis_;
+};
+
+} // namespace cuda
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc b/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc
old mode 100755
new mode 100644
index 334be3e03b483..4b586e24c9bd3
--- a/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc
+++ b/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc
@@ -10,6 +10,7 @@
#include "core/common/common.h"
#include "core/framework/execution_provider.h"
+#include "test/common/cuda_op_test_utils.h"
#include "gtest/gtest.h"
#include "test/providers/provider_test_utils.h"
#include "test/util/include/default_providers.h"
@@ -102,6 +103,7 @@ void RunGatherBlockQuantized(const std::vector& data,
const std::vector& output_shape,
OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess,
bool touch_on_device_data = false) {
+ (void)touch_on_device_data;
CheckDataAndShape(data, data_shape, "data in RunGatherBlockQuantized");
CheckDataAndShape(indices, indices_shape, "indices in RunGatherBlockQuantized");
CheckDataAndShape(scales, scales_shape, "scales in RunGatherBlockQuantized");
@@ -127,12 +129,15 @@ void RunGatherBlockQuantized(const std::vector& data,
test.AddOutput("output", output_shape, output);
- if (touch_on_device_data) {
- // test would need to see data on device
- test.Run(expect_result, "", {kWebGpuExecutionProvider}, nullptr);
+ bool enable_cuda = HasCudaEnvironment(0);
+ std::vector> eps;
+ if (enable_cuda) {
+ eps.push_back(DefaultCudaExecutionProvider());
} else {
- test.Run(expect_result, "");
+ eps.push_back(DefaultCpuExecutionProvider());
}
+
+ test.Run(expect_result, "", {}, nullptr, &eps);
};
run_test(false);
@@ -275,6 +280,7 @@ void Test_Fail_WithZeroPoints(int64_t gather_axis,
gather_axis, quantize_axis, block_size, bits, output, output_shape, false);
}
+#ifndef USE_CUDA
TEST(GatherBlockQuantizedOpTest, UnsupportedTypes) {
Test_Fail_WithZeroPoints(0, 2, 16);
Test_Fail_WithZeroPoints(0, 2, 16);
@@ -289,6 +295,7 @@ TEST(GatherBlockQuantizedOpTest, UnsupportedTypes) {
Test_Fail_WithZeroPoints(0, 2, 16);
Test_Fail_WithZeroPoints(0, 2, 16);
}
+#endif
template
void Test_Fail_WithoutZeroPoints(int64_t gather_axis,
@@ -317,6 +324,7 @@ void Test_Fail_WithoutZeroPoints(int64_t gather_axis,
gather_axis, quantize_axis, block_size, bits, output, output_shape, false);
}
+#ifndef USE_CUDA
TEST(GatherBlockQuantizedOpTest, UnsupportedUInt8DataType) {
// Gather on axis other than 0 is not supported with uint8_t
Test_Fail_WithoutZeroPoints(1, 2, 16);
@@ -349,6 +357,7 @@ TEST(GatherBlockQuantizedOpTest, NotSupportedBits) {
Test_Fail_WithZeroPoints(0, 2, 16, 6);
Test_Fail_WithZeroPoints(0, 2, 16, 7);
}
+#endif
template
void Test_ShapeMismatch_WithZeroPoints() {
@@ -377,11 +386,13 @@ void Test_ShapeMismatch_WithZeroPoints() {
gather_axis, quantize_axis, block_size, bits, output, output_shape, false);
}
+#ifndef USE_CUDA
TEST(GatherBlockQuantizedOpTest, ShapeMismatch) {
Test_ShapeMismatch_WithZeroPoints();
Test_ShapeMismatch_WithZeroPoints();
Test_ShapeMismatch_WithZeroPoints();
}
+#endif
template
void Test_InvalidIndices_WithZeroPoints() {
@@ -410,11 +421,13 @@ void Test_InvalidIndices_WithZeroPoints() {
gather_axis, quantize_axis, block_size, bits, output, output_shape, false, true);
}
+#ifndef USE_CUDA
TEST(GatherBlockQuantizedOpTest, InvalidIndices) {
Test_InvalidIndices_WithZeroPoints();
Test_InvalidIndices_WithZeroPoints();
Test_InvalidIndices_WithZeroPoints();
}
+#endif
template
void Test_GatherAxis0_WithZeroPoints(int bits = 4) {
@@ -447,6 +460,7 @@ void Test_GatherAxis0_WithZeroPoints(int bits = 4) {
-3, -1, block_size, bits, output, output_shape, true);
}
+#ifndef USE_CUDA
TEST(GatherBlockQuantizedOpTest, GatherAxis0WithZeroPoints) {
Test_GatherAxis0_WithZeroPoints();
Test_GatherAxis0_WithZeroPoints();
@@ -457,6 +471,7 @@ TEST(GatherBlockQuantizedOpTest, GatherAxis0WithZeroPoints) {
Test_GatherAxis0_WithZeroPoints();
Test_GatherAxis0_WithZeroPoints();
}
+#endif
template
void Test_GatherAxis0_WithZeroPoints_Uint8(int bits = 4) {
@@ -490,6 +505,7 @@ void Test_GatherAxis0_WithZeroPoints_Uint8(int bits = 4) {
-3, -1, block_size, bits, output, output_shape, true);
}
+#ifndef USE_CUDA
TEST(GatherBlockQuantizedOpTest, GatherAxis0WithZeroPoints_4Bits) {
Test_GatherAxis0_WithZeroPoints_Uint8();
Test_GatherAxis0_WithZeroPoints_Uint8();
@@ -499,6 +515,7 @@ TEST(GatherBlockQuantizedOpTest, GatherAxis0WithZeroPoints_8Bits) {
Test_GatherAxis0_WithZeroPoints_Uint8(8);
Test_GatherAxis0_WithZeroPoints_Uint8(8);
}
+#endif
template
void Test_GatherAxis0_NoZeroPoints(int bits = 4) {
@@ -533,6 +550,7 @@ void Test_GatherAxis0_NoZeroPoints(int bits = 4) {
-3, -1, block_size, bits, output, output_shape, true);
}
+#ifndef USE_CUDA
TEST(GatherBlockQuantizedOpTest, GatherAxis0NoZeroPoints) {
Test_GatherAxis0_NoZeroPoints();
Test_GatherAxis0_NoZeroPoints();
@@ -551,6 +569,7 @@ TEST(GatherBlockQuantizedOpTest, GatherAxis0NoZeroPoints_8Bits) {
Test_GatherAxis0_NoZeroPoints(8);
Test_GatherAxis0_NoZeroPoints(8);
}
+#endif
template
void Test_GatherAxis1_WithZeroPoints() {
@@ -585,6 +604,7 @@ void Test_GatherAxis1_WithZeroPoints() {
-2, -2, block_size, bits, output, output_shape, true);
}
+#ifndef USE_CUDA
TEST(GatherBlockQuantizedOpTest, GatherAxis1) {
Test_GatherAxis1_WithZeroPoints();
Test_GatherAxis1_WithZeroPoints();
@@ -595,6 +615,7 @@ TEST(GatherBlockQuantizedOpTest, GatherAxis1) {
Test_GatherAxis1_WithZeroPoints();
Test_GatherAxis1_WithZeroPoints();
}
+#endif
template
void Test_GatherAxis2_WithZeroPoints() {
@@ -629,6 +650,7 @@ void Test_GatherAxis2_WithZeroPoints() {
-1, -3, block_size, bits, output, output_shape, true);
}
+#ifndef USE_CUDA
TEST(GatherBlockQuantizedOpTest, GatherAxis2) {
Test_GatherAxis2_WithZeroPoints();
Test_GatherAxis2_WithZeroPoints();
@@ -639,6 +661,7 @@ TEST(GatherBlockQuantizedOpTest, GatherAxis2) {
Test_GatherAxis2_WithZeroPoints();
Test_GatherAxis2_WithZeroPoints();
}
+#endif
} // namespace test
} // namespace onnxruntime