Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -937,6 +937,7 @@ Do not modify directly.*
|FusedConv|*in* X:**T**<br> *in* W:**T**<br> *in* B:**T**<br> *in* Z:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|FusedMatMul|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|GatedRelativePositionBias|*in* query_layer:**T**<br> *in* query_bias:**T**<br> *in* rel_pos:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* eco_a:**T**<br> *in* token_offset:**M**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|GatherBlockQuantized|*in* data:**T1**<br> *in* indices:**Tind**<br> *in* scales:**T2**<br> *in* zero_points:**T1**<br> *out* output:**T2**|1+|**T1** = tensor(int4), tensor(uint4), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(float), tensor(float16)<br/> **Tind** = tensor(int32), tensor(int64)|
|Gelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|GemmFloat8|*in* A:**TA**<br> *in* B:**TB**<br> *in* C:**TC**<br> *in* scaleA:**TS**<br> *in* scaleB:**TS**<br> *in* scaleY:**TS**<br> *out* Y:**TR**|1+|**TA** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)<br/> **TB** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)<br/> **TR** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)<br/> **TS** = tensor(float)|
|GemmaRotaryEmbedding|*in* emb:**U**<br> *in* q:**T**<br> *in* q_rot:**T**<br> *in* k:**T**<br> *in* k_rot:**T**<br> *out* output1:**T**<br> *out* output2:**T**|1+|**T** = tensor(float16)<br/> **U** = tensor(float)|
Expand Down
18 changes: 18 additions & 0 deletions include/onnxruntime/core/framework/op_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,24 @@ using BuildKernelCreateInfoFn = KernelCreateInfo (*)();
static_cast<KernelCreatePtrFn>([](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status { out = std::make_unique<__VA_ARGS__>(info); return Status::OK(); })); \
}

#define ONNX_OPERATOR_THREE_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type1, type2, type3, name) \
provider##_##name##_##domain##_ver##ver##_##type1##_##type2##_##type3

#define ONNX_OPERATOR_THREE_TYPED_KERNEL_EX(name, domain, ver, type1, type2, type3, provider, builder, ...) \
class ONNX_OPERATOR_THREE_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type1, type2, type3, name); \
template <> \
KernelCreateInfo \
BuildKernelCreateInfo<ONNX_OPERATOR_THREE_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type1, type2, type3, name)>() { \
return KernelCreateInfo( \
builder.SetName(#name) \
.SetDomain(domain) \
.SinceVersion(ver) \
.Provider(provider) \
.Build(), \
static_cast<KernelCreatePtrFn>([](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status { \
out = std::make_unique<__VA_ARGS__>(info); return Status::OK(); })); \
}

#define ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(provider, domain, startver, endver, type, name) \
provider##_##name##_##domain##_ver##startver##_##endver##_##type

Expand Down
41 changes: 41 additions & 0 deletions onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ using namespace onnxruntime::common;
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, start_ver, end_ver, type, name)
#define CUDA_MS_OP_VERSIONED_CLASS_NAME(start_ver, end_ver, name) \
ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, start_ver, end_ver, name)
#define CUDA_MS_OP_TWO_TYPED_CLASS_NAME(ver, type1, type2, name) \
ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, ver, type1, type2, name)
#define CUDA_MS_OP_THREE_TYPED_CLASS_NAME(ver, type1, type2, type3, name) \
ONNX_OPERATOR_THREE_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, ver, type1, type2, type3, name)

#define CUDA_ONNX_OP_TYPED_CLASS_NAME(ver, type, name) \
ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, ver, type, name)
Expand Down Expand Up @@ -186,6 +190,25 @@ class CUDA_MS_OP_CLASS_NAME(1, GemmFloat8);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, SparseAttention);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, SparseAttention);

class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, uint8_t, float, int32_t, GatherBlockQuantized);
class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, uint8_t, MLFloat16, int32_t, GatherBlockQuantized);
class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, uint8_t, BFloat16, int32_t, GatherBlockQuantized);
class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, uint8_t, float, int64_t, GatherBlockQuantized);
class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, uint8_t, MLFloat16, int64_t, GatherBlockQuantized);
class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, uint8_t, BFloat16, int64_t, GatherBlockQuantized);
class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, UInt4x2, float, int32_t, GatherBlockQuantized);
class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, UInt4x2, MLFloat16, int32_t, GatherBlockQuantized);
class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, UInt4x2, BFloat16, int32_t, GatherBlockQuantized);
class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, UInt4x2, float, int64_t, GatherBlockQuantized);
class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, UInt4x2, MLFloat16, int64_t, GatherBlockQuantized);
class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, UInt4x2, BFloat16, int64_t, GatherBlockQuantized);
class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, Int4x2, float, int32_t, GatherBlockQuantized);
class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, Int4x2, MLFloat16, int32_t, GatherBlockQuantized);
class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, Int4x2, BFloat16, int32_t, GatherBlockQuantized);
class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, Int4x2, float, int64_t, GatherBlockQuantized);
class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, Int4x2, MLFloat16, int64_t, GatherBlockQuantized);
class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, Int4x2, BFloat16, int64_t, GatherBlockQuantized);

#ifdef ENABLE_ATEN
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kPytorchAtenDomain, 1, ATen);
#endif
Expand Down Expand Up @@ -408,6 +431,24 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, GemmFloat8)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, SparseAttention)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, SparseAttention)>,
BuildKernelCreateInfo<CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, uint8_t, float, int32_t, GatherBlockQuantized)>,
BuildKernelCreateInfo<CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, uint8_t, float, int64_t, GatherBlockQuantized)>,
BuildKernelCreateInfo<CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, UInt4x2, float, int32_t, GatherBlockQuantized)>,
BuildKernelCreateInfo<CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, UInt4x2, float, int64_t, GatherBlockQuantized)>,
BuildKernelCreateInfo<CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, Int4x2, float, int32_t, GatherBlockQuantized)>,
BuildKernelCreateInfo<CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, Int4x2, float, int64_t, GatherBlockQuantized)>,
BuildKernelCreateInfo<CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, uint8_t, MLFloat16, int32_t, GatherBlockQuantized)>,
BuildKernelCreateInfo<CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, uint8_t, MLFloat16, int64_t, GatherBlockQuantized)>,
BuildKernelCreateInfo<CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, UInt4x2, MLFloat16, int32_t, GatherBlockQuantized)>,
BuildKernelCreateInfo<CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, UInt4x2, MLFloat16, int64_t, GatherBlockQuantized)>,
BuildKernelCreateInfo<CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, Int4x2, MLFloat16, int32_t, GatherBlockQuantized)>,
BuildKernelCreateInfo<CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, Int4x2, MLFloat16, int64_t, GatherBlockQuantized)>,
BuildKernelCreateInfo<CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, uint8_t, BFloat16, int32_t, GatherBlockQuantized)>,
BuildKernelCreateInfo<CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, uint8_t, BFloat16, int64_t, GatherBlockQuantized)>,
BuildKernelCreateInfo<CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, UInt4x2, BFloat16, int32_t, GatherBlockQuantized)>,
BuildKernelCreateInfo<CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, UInt4x2, BFloat16, int64_t, GatherBlockQuantized)>,
BuildKernelCreateInfo<CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, Int4x2, BFloat16, int32_t, GatherBlockQuantized)>,
BuildKernelCreateInfo<CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, Int4x2, BFloat16, int64_t, GatherBlockQuantized)>,

#ifdef ENABLE_ATEN
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kPytorchAtenDomain, 1, ATen)>,
Expand Down
152 changes: 152 additions & 0 deletions onnxruntime/contrib_ops/cuda/quantization/gather_block_quantized.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/cuda/cuda_common.h"
#include "contrib_ops/cuda/quantization/gather_block_quantized.h"
#include "contrib_ops/cuda/quantization/gather_block_quantized.cuh"

namespace onnxruntime {
namespace contrib {
namespace cuda {
using namespace onnxruntime::cuda;

Check warning on line 11 in onnxruntime/contrib_ops/cuda/quantization/gather_block_quantized.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5] Raw Output: onnxruntime/contrib_ops/cuda/quantization/gather_block_quantized.cc:11: Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]

#define REGISTER_GATHERBLOCKQUANTIZED(T1, T2, Tind) \
ONNX_OPERATOR_THREE_TYPED_KERNEL_EX( \
GatherBlockQuantized, \
kMSDomain, 1, \
T1, T2, Tind, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T1", DataTypeImpl::GetTensorType<T1>()) \
.TypeConstraint("T2", DataTypeImpl::GetTensorType<T2>()) \
.TypeConstraint("Tind", DataTypeImpl::GetTensorType<Tind>()), \
GatherBlockQuantized<T1, T2, Tind>);

REGISTER_GATHERBLOCKQUANTIZED(uint8_t, float, int32_t);
REGISTER_GATHERBLOCKQUANTIZED(uint8_t, float, int64_t);
REGISTER_GATHERBLOCKQUANTIZED(UInt4x2, float, int32_t);
REGISTER_GATHERBLOCKQUANTIZED(UInt4x2, float, int64_t);
REGISTER_GATHERBLOCKQUANTIZED(Int4x2, float, int32_t);
REGISTER_GATHERBLOCKQUANTIZED(Int4x2, float, int64_t);

REGISTER_GATHERBLOCKQUANTIZED(uint8_t, MLFloat16, int64_t);
REGISTER_GATHERBLOCKQUANTIZED(uint8_t, MLFloat16, int32_t);
REGISTER_GATHERBLOCKQUANTIZED(UInt4x2, MLFloat16, int32_t);
REGISTER_GATHERBLOCKQUANTIZED(UInt4x2, MLFloat16, int64_t);
REGISTER_GATHERBLOCKQUANTIZED(Int4x2, MLFloat16, int32_t);
REGISTER_GATHERBLOCKQUANTIZED(Int4x2, MLFloat16, int64_t);

REGISTER_GATHERBLOCKQUANTIZED(UInt4x2, BFloat16, int32_t);
REGISTER_GATHERBLOCKQUANTIZED(UInt4x2, BFloat16, int64_t);
REGISTER_GATHERBLOCKQUANTIZED(uint8_t, BFloat16, int32_t);
REGISTER_GATHERBLOCKQUANTIZED(uint8_t, BFloat16, int64_t);
REGISTER_GATHERBLOCKQUANTIZED(Int4x2, BFloat16, int32_t);
REGISTER_GATHERBLOCKQUANTIZED(Int4x2, BFloat16, int64_t);

template <typename T1, typename T2, typename Tind>
GatherBlockQuantized<T1, T2, Tind>::GatherBlockQuantized(const OpKernelInfo& info) : CudaKernel(info) {
ORT_ENFORCE(info.GetAttr("bits", &bits_).IsOK());

block_size_ = info.GetAttrOrDefault<int64_t>("block_size", 0);
gather_axis_ = info.GetAttrOrDefault<int64_t>("gather_axis", 0);
quantize_axis_ = info.GetAttrOrDefault<int64_t>("quantize_axis", 0);

// If block size is set, it has to be no smaller than 16 and must be power of 2
// block_size_ & (block_size_ - 1) == 0 checks if block_size_ only has 1 bit set
ORT_ENFORCE(block_size_ == 0 || (block_size_ >= 16 && ((block_size_ & (block_size_ - 1)) == 0)));
}

template <typename T1, typename T2, typename Tind>
Status GatherBlockQuantized<T1, T2, Tind>::ComputeInternal(OpKernelContext* ctx) const {
const Tensor* data = ctx->Input<Tensor>(0);
const Tensor* indices = ctx->Input<Tensor>(1);
const Tensor* scales = ctx->Input<Tensor>(2);
const Tensor* zero_points = ctx->Input<Tensor>(3);

auto data_shape = data->Shape().GetDims();
auto data_rank = data->Shape().NumDimensions();

auto indices_shape = indices->Shape().GetDims();
auto indices_rank = indices->Shape().NumDimensions();

ORT_ENFORCE(quantize_axis_ == data_rank - 1);

TensorShapeVector output_shape;
output_shape.reserve(data_rank - 1 + indices_rank);

// Dimension after gather axis
int64_t after_gather_dim = 1;

// Dimension of indices
int64_t ind_dim = 1;

// 1) dims before gather_axis
for (int64_t i = 0; i < gather_axis_; ++i) {
output_shape.push_back(data_shape[i]);
}

// 2) all of indices.shape
for (auto dim : indices_shape) {
output_shape.push_back(dim);
ind_dim *= dim;
}

// 3) dims after gather_axis
for (int64_t i = gather_axis_ + 1; i < data_rank; ++i) {
output_shape.push_back(data_shape[i]);
after_gather_dim *= data_shape[i];
}

// Special int4‐in‐uint8 packing tweak: expand the last dim by components
if constexpr (std::is_same_v<T1, uint8_t>) {
uint32_t components = 8 / static_cast<int>(bits_);
if (components > 1) {
output_shape.back() *= components;
}
}

Tensor* output = ctx->Output(0, TensorShape(output_shape));

int64_t N = 1;
for (auto dim : output_shape) {
N *= dim;
}

const auto* data_ptr = data->Data<T1>();
const auto* indices_ptr = indices->Data<Tind>();
const T1* zero_points_ptr = nullptr;
if (zero_points != nullptr) {
zero_points_ptr = zero_points->Data<T1>();
}

GatherBlockQuantizedParam param;
param.stream = Stream(ctx);
param.after_gather_dim = after_gather_dim;
param.gather_axis_dim = data_shape[gather_axis_];
param.ind_dim = ind_dim;
param.bits = bits_;
param.block_size = block_size_;
param.gather_axis = gather_axis_;
param.N = N;

const auto dequantized_type = scales->GetElementType();
if (dequantized_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
const auto* scales_ptr = static_cast<const float*>(scales->DataRaw());
auto* output_ptr = static_cast<float*>(output->MutableDataRaw());
LaunchGatherBlockQuantizedKernel(data_ptr, indices_ptr, scales_ptr, zero_points_ptr, output_ptr, param);
} else if (dequantized_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
const auto* scales_ptr = static_cast<const half*>(scales->DataRaw());
auto* output_ptr = static_cast<half*>(output->MutableDataRaw());
LaunchGatherBlockQuantizedKernel(data_ptr, indices_ptr, scales_ptr, zero_points_ptr, output_ptr, param);
} else if (dequantized_type == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) {
const auto* scales_ptr = static_cast<const BFloat16*>(scales->DataRaw());
auto* output_ptr = static_cast<BFloat16*>(output->MutableDataRaw());
LaunchGatherBlockQuantizedKernel(data_ptr, indices_ptr, scales_ptr, zero_points_ptr, output_ptr, param);
}

return Status::OK();
}

} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
Loading
Loading