55#include " core/providers/cpu/tensor/gather.h"
66#include " core/common/common.h"
77#include " core/platform/threadpool.h"
8+ #include " core/providers/op_kernel_type_control.h"
9+ #include " core/providers/op_kernel_type_control_utils.h"
810
911namespace onnxruntime {
1012
13+ namespace op_kernel_type_control {
14+ ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES_ALL_OPSETS (
15+ kCpuExecutionProvider , kOnnxDomain , Gather, Input, 1 , int32_t , int64_t );
16+ }
17+
18+ namespace {
19+ using EnabledIndexTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(
20+ kCpuExecutionProvider , kOnnxDomain , Gather, Input, 1 );
21+
22+ const auto index_type_constraints =
23+ BuildKernelDefConstraintsFunctorFromTypeList<EnabledIndexTypes>{}();
24+ } // namespace
25+
1126ONNX_CPU_OPERATOR_VERSIONED_KERNEL (
1227 Gather,
1328 1 ,
1429 10 ,
1530 KernelDefBuilder ()
1631 .TypeConstraint(" T" , DataTypeImpl::AllTensorTypes())
17- .TypeConstraint(" Tind" , std::vector<MLDataType>{DataTypeImpl::GetTensorType<int32_t >(),
18- DataTypeImpl::GetTensorType<int64_t >()}),
32+ .TypeConstraint(" Tind" , index_type_constraints),
1933 Gather);
2034
2135ONNX_CPU_OPERATOR_VERSIONED_KERNEL (
@@ -24,17 +38,15 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
2438 12 ,
2539 KernelDefBuilder ()
2640 .TypeConstraint(" T" , DataTypeImpl::AllTensorTypes())
27- .TypeConstraint(" Tind" , std::vector<MLDataType>{DataTypeImpl::GetTensorType<int32_t >(),
28- DataTypeImpl::GetTensorType<int64_t >()}),
41+ .TypeConstraint(" Tind" , index_type_constraints),
2942 Gather);
3043
3144ONNX_CPU_OPERATOR_KERNEL (
3245 Gather,
3346 13 ,
3447 KernelDefBuilder ()
3548 .TypeConstraint(" T" , DataTypeImpl::AllTensorTypes())
36- .TypeConstraint(" Tind" , std::vector<MLDataType>{DataTypeImpl::GetTensorType<int32_t >(),
37- DataTypeImpl::GetTensorType<int64_t >()}),
49+ .TypeConstraint(" Tind" , index_type_constraints),
3850 Gather);
3951
4052Status GatherBase::PrepareForCompute (OpKernelContext* context, Prepare& p) const {
@@ -132,16 +144,18 @@ Status Gather::Compute(OpKernelContext* context) const {
132144
133145 concurrency::ThreadPool* tp = context->GetOperatorThreadPool ();
134146
135- if (p.indices_tensor ->IsDataType <int32_t >()) {
147+ if (utils::HasTypeWithSameSize<EnabledIndexTypes, int32_t >() &&
148+ p.indices_tensor ->IsDataType <int32_t >()) {
136149 return GatherCopyData<int32_t >(p.indices_tensor , src_base, dst_base, is_string_type, element_bytes,
137150 block_size, M, N, data_batch_bytes, gathered_batch_bytes, input_data_shape, p.axis , tp);
138151 }
139- if (p.indices_tensor ->IsDataType <int64_t >()) {
152+ if (utils::HasTypeWithSameSize<EnabledIndexTypes, int64_t >() &&
153+ p.indices_tensor ->IsDataType <int64_t >()) {
140154 return GatherCopyData<int64_t >(p.indices_tensor , src_base, dst_base, is_string_type, element_bytes,
141155 block_size, M, N, data_batch_bytes, gathered_batch_bytes, input_data_shape, p.axis , tp);
142156 }
143157
144- return ORT_MAKE_STATUS (ONNXRUNTIME, NOT_IMPLEMENTED, " Type for Tind not supported yet in Gather ." );
158+ return ORT_MAKE_STATUS (ONNXRUNTIME, NOT_IMPLEMENTED, " Gather Tind type not supported in this build ." );
145159}
146160
147161} // namespace onnxruntime
0 commit comments