Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix CUDA_MS_OP_THREE_TYPED_CLASS_NAME
  • Loading branch information
Xiaoyan Hu committed Aug 1, 2025
commit 941dd734cfdeb2239be54cb690a114661603a02a
26 changes: 20 additions & 6 deletions onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ using namespace onnxruntime::common;
ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, start_ver, end_ver, name)
#define CUDA_MS_OP_TWO_TYPED_CLASS_NAME(ver, type1, type2, name) \
ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, ver, type1, type2, name)
#define CUDA_MS_OP_THREE_TYPED_CLASS_NAME(ver, type1, type2, type3, name) \
ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, ver, type1, type2, type3, name)

#define CUDA_ONNX_OP_TYPED_CLASS_NAME(ver, type, name) \
ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, ver, type, name)
Expand Down Expand Up @@ -188,12 +190,24 @@ class CUDA_MS_OP_CLASS_NAME(1, GemmFloat8);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, SparseAttention);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, SparseAttention);

class CUDA_MS_OP_TWO_TYPED_CLASS_NAME(1, uint8_t, int32_t, GatherBlockQuantized);
class CUDA_MS_OP_TWO_TYPED_CLASS_NAME(1, uint8_t, int64_t, GatherBlockQuantized);
class CUDA_MS_OP_TWO_TYPED_CLASS_NAME(1, UInt4x2, int32_t, GatherBlockQuantized);
class CUDA_MS_OP_TWO_TYPED_CLASS_NAME(1, UInt4x2, int64_t, GatherBlockQuantized);
class CUDA_MS_OP_TWO_TYPED_CLASS_NAME(1, Int4x2, int32_t, GatherBlockQuantized);
class CUDA_MS_OP_TWO_TYPED_CLASS_NAME(1, Int4x2, int64_t, GatherBlockQuantized);
class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, uint8_t, float, int32_t, GatherBlockQuantized);
class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, uint8_t, MLFloat16, int32_t, GatherBlockQuantized);
class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, uint8_t, BFloat16, int32_t, GatherBlockQuantized);
class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, uint8_t, float, int64_t, GatherBlockQuantized);
class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, uint8_t, MLFloat16, int64_t, GatherBlockQuantized);
class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, uint8_t, BFloat16, int64_t, GatherBlockQuantized);
class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, UInt4x2, float, int32_t, GatherBlockQuantized);
class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, UInt4x2, MLFloat16, int32_t, GatherBlockQuantized);
class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, UInt4x2, BFloat16, int32_t, GatherBlockQuantized);
class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, UInt4x2, float, int64_t, GatherBlockQuantized);
class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, UInt4x2, MLFloat16, int64_t, GatherBlockQuantized);
class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, UInt4x2, BFloat16, int64_t, GatherBlockQuantized);
class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, Int4x2, float, int32_t, GatherBlockQuantized);
class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, Int4x2, MLFloat16, int32_t, GatherBlockQuantized);
class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, Int4x2, BFloat16, int32_t, GatherBlockQuantized);
class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, Int4x2, float, int64_t, GatherBlockQuantized);
class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, Int4x2, MLFloat16, int64_t, GatherBlockQuantized);
class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, Int4x2, BFloat16, int64_t, GatherBlockQuantized);

#ifdef ENABLE_ATEN
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kPytorchAtenDomain, 1, ATen);
Expand Down