Skip to content

Commit 3b376da

Browse files
authored
Enable type reduction for Gather CPU kernel. (microsoft#6579)
* Enable type reduction in Gather.
1 parent c5d2538 commit 3b376da

File tree

3 files changed

+29
-15
lines changed

3 files changed

+29
-15
lines changed

onnxruntime/core/providers/cpu/tensor/gather.cc

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,31 @@
55
#include "core/providers/cpu/tensor/gather.h"
66
#include "core/common/common.h"
77
#include "core/platform/threadpool.h"
8+
#include "core/providers/op_kernel_type_control.h"
9+
#include "core/providers/op_kernel_type_control_utils.h"
810

911
namespace onnxruntime {
1012

13+
namespace op_kernel_type_control {
14+
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES_ALL_OPSETS(
15+
kCpuExecutionProvider, kOnnxDomain, Gather, Input, 1, int32_t, int64_t);
16+
}
17+
18+
namespace {
19+
using EnabledIndexTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(
20+
kCpuExecutionProvider, kOnnxDomain, Gather, Input, 1);
21+
22+
const auto index_type_constraints =
23+
BuildKernelDefConstraintsFunctorFromTypeList<EnabledIndexTypes>{}();
24+
} // namespace
25+
1126
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
1227
Gather,
1328
1,
1429
10,
1530
KernelDefBuilder()
1631
.TypeConstraint("T", DataTypeImpl::AllTensorTypes())
17-
.TypeConstraint("Tind", std::vector<MLDataType>{DataTypeImpl::GetTensorType<int32_t>(),
18-
DataTypeImpl::GetTensorType<int64_t>()}),
32+
.TypeConstraint("Tind", index_type_constraints),
1933
Gather);
2034

2135
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
@@ -24,17 +38,15 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
2438
12,
2539
KernelDefBuilder()
2640
.TypeConstraint("T", DataTypeImpl::AllTensorTypes())
27-
.TypeConstraint("Tind", std::vector<MLDataType>{DataTypeImpl::GetTensorType<int32_t>(),
28-
DataTypeImpl::GetTensorType<int64_t>()}),
41+
.TypeConstraint("Tind", index_type_constraints),
2942
Gather);
3043

3144
ONNX_CPU_OPERATOR_KERNEL(
3245
Gather,
3346
13,
3447
KernelDefBuilder()
3548
.TypeConstraint("T", DataTypeImpl::AllTensorTypes())
36-
.TypeConstraint("Tind", std::vector<MLDataType>{DataTypeImpl::GetTensorType<int32_t>(),
37-
DataTypeImpl::GetTensorType<int64_t>()}),
49+
.TypeConstraint("Tind", index_type_constraints),
3850
Gather);
3951

4052
Status GatherBase::PrepareForCompute(OpKernelContext* context, Prepare& p) const {
@@ -132,16 +144,18 @@ Status Gather::Compute(OpKernelContext* context) const {
132144

133145
concurrency::ThreadPool* tp = context->GetOperatorThreadPool();
134146

135-
if (p.indices_tensor->IsDataType<int32_t>()) {
147+
if (utils::HasTypeWithSameSize<EnabledIndexTypes, int32_t>() &&
148+
p.indices_tensor->IsDataType<int32_t>()) {
136149
return GatherCopyData<int32_t>(p.indices_tensor, src_base, dst_base, is_string_type, element_bytes,
137150
block_size, M, N, data_batch_bytes, gathered_batch_bytes, input_data_shape, p.axis, tp);
138151
}
139-
if (p.indices_tensor->IsDataType<int64_t>()) {
152+
if (utils::HasTypeWithSameSize<EnabledIndexTypes, int64_t>() &&
153+
p.indices_tensor->IsDataType<int64_t>()) {
140154
return GatherCopyData<int64_t>(p.indices_tensor, src_base, dst_base, is_string_type, element_bytes,
141155
block_size, M, N, data_batch_bytes, gathered_batch_bytes, input_data_shape, p.axis, tp);
142156
}
143157

144-
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Type for Tind not supported yet in Gather.");
158+
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Gather Tind type not supported in this build.");
145159
}
146160

147161
} // namespace onnxruntime

onnxruntime/core/providers/op_kernel_type_control.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
* - Enabled types are the types that are supported in the actual, compiled implementation. They are obtained from the
2121
* intersection of supported and allowed types.
2222
*
23-
* The types are associated with an Op kernel argument. It is also possible to specify a global list of allowed types.
23+
* The types are associated with an Op argument. It is also possible to specify a global list of allowed types.
2424
*
2525
* Use of these utilities is optional. They are useful for cases where one registered Op kernel handles multiple types.
2626
*
@@ -239,8 +239,8 @@ struct EnabledTypes {
239239
* namespace onnxruntime {
240240
* namespace op_kernel_type_control {
241241
* // specify supported types, i.e., the full set of types that can be enabled
242-
* ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES(
243-
* MyProvider, DomainContainingMyOp, MyOp, OpSet, Input, 0,
242+
* ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPES_ALL_OPSETS(
243+
* MyProvider, DomainContainingMyOp, MyOp, Input, 0,
244244
* int, float, double);
245245
* } // namespace op_kernel_type_control
246246
* } // namespace onnxruntime
@@ -249,7 +249,7 @@ struct EnabledTypes {
249249
*
250250
* // get enabled types
251251
* using MyOpFirstInputEnabledTypes =
252-
* ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(MyProvider, DomainContainingMyOp, MyOp, Input, 0);
252+
* ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(MyProvider, DomainContainingMyOp, MyOp, Input, 0);
253253
*
254254
* // ...
255255
*

onnxruntime/core/providers/op_kernel_type_control_utils.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ using SizeOfT = boost::mp11::mp_size_t<sizeof(T)>;
2424

2525
/**
2626
* Check if the set of types contains a type with the same size as T.
27-
*
28-
* @remarks e.g. will return true if T is int32_t and the list contains any 4 byte type (i.e. sizeof(int32_t))
27+
*
28+
* @remarks e.g. will return true if T is int32_t and the list contains any 4 byte type (i.e. sizeof(int32_t))
2929
* such as int32_t, uint32_t or float.
3030
*/
3131
template <typename TypeSet, typename T>

0 commit comments

Comments
 (0)