Skip to content

Commit 513b1e4

Browse files
Corey Whartonrmlarsen
authored andcommitted
Allow tensor as iou_threshold parameter to tf.image.non_max_suppression. (tensorflow#9887)
* Implement NonMaxSuppressionV2 op that allows tensor as iou_threshold parameter. * Move local functions into anonymous namespace.
1 parent 7161c82 commit 513b1e4

File tree

5 files changed

+299
-57
lines changed

5 files changed

+299
-57
lines changed

tensorflow/core/kernels/non_max_suppression_op.cc

Lines changed: 93 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ limitations under the License.
3333
#include "tensorflow/core/platform/logging.h"
3434

3535
namespace tensorflow {
36+
namespace {
3637

3738
typedef Eigen::ThreadPoolDevice CPUDevice;
3839

@@ -89,6 +90,63 @@ static inline float ComputeIOU(typename TTypes<float, 2>::ConstTensor boxes,
8990
return intersection_area / (area_i + area_j - intersection_area);
9091
}
9192

93+
void DoNonMaxSuppressionOp(OpKernelContext* context,
94+
const Tensor& boxes,
95+
const Tensor& scores,
96+
const Tensor& max_output_size,
97+
const float iou_threshold) {
98+
OP_REQUIRES(context, iou_threshold >= 0 && iou_threshold <= 1,
99+
errors::InvalidArgument("iou_threshold must be in [0, 1]"));
100+
101+
int num_boxes = 0;
102+
ParseAndCheckBoxSizes(context, boxes, scores, &num_boxes);
103+
if (!context->status().ok()) {
104+
return;
105+
}
106+
107+
const int output_size =
108+
std::min(max_output_size.scalar<int>()(), num_boxes);
109+
typename TTypes<float, 2>::ConstTensor boxes_data =
110+
boxes.tensor<float, 2>();
111+
112+
std::vector<float> scores_data(num_boxes);
113+
std::copy_n(scores.flat<float>().data(), num_boxes, scores_data.begin());
114+
std::vector<int> sorted_indices;
115+
DecreasingArgSort(scores_data, &sorted_indices);
116+
117+
std::vector<bool> active(num_boxes, true);
118+
std::vector<int> selected;
119+
int num_active = active.size();
120+
for (int i = 0; i < num_boxes; ++i) {
121+
if (num_active == 0 || selected.size() >= output_size) break;
122+
if (active[i]) {
123+
selected.push_back(sorted_indices[i]);
124+
} else {
125+
continue;
126+
}
127+
for (int j = i + 1; j < num_boxes; ++j) {
128+
if (active[j]) {
129+
float iou =
130+
ComputeIOU(boxes_data, sorted_indices[i], sorted_indices[j]);
131+
if (iou > iou_threshold) {
132+
active[j] = false;
133+
num_active--;
134+
}
135+
}
136+
}
137+
}
138+
139+
// Allocate output tensor
140+
Tensor* output = nullptr;
141+
TensorShape output_shape({static_cast<int>(selected.size())});
142+
OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
143+
typename TTypes<int, 1>::Tensor selected_indices_data =
144+
output->tensor<int, 1>();
145+
std::copy_n(selected.begin(), selected.size(), selected_indices_data.data());
146+
}
147+
148+
} // namespace
149+
92150
template <typename Device>
93151
class NonMaxSuppressionOp : public OpKernel {
94152
public:
@@ -98,9 +156,6 @@ class NonMaxSuppressionOp : public OpKernel {
98156
}
99157

100158
void Compute(OpKernelContext* context) override {
101-
OP_REQUIRES(context, iou_threshold_ >= 0 && iou_threshold_ <= 1,
102-
errors::InvalidArgument("iou_threshold must be in [0, 1]"));
103-
104159
// boxes: [num_boxes, 4]
105160
const Tensor& boxes = context->input(0);
106161
// scores: [num_boxes]
@@ -112,59 +167,48 @@ class NonMaxSuppressionOp : public OpKernel {
112167
errors::InvalidArgument("max_output_size must be 0-D, got shape ",
113168
max_output_size.shape().DebugString()));
114169

115-
int num_boxes = 0;
116-
ParseAndCheckBoxSizes(context, boxes, scores, &num_boxes);
117-
if (!context->status().ok()) {
118-
return;
119-
}
120-
121-
const int output_size =
122-
std::min(max_output_size.scalar<int>()(), num_boxes);
123-
typename TTypes<float, 2>::ConstTensor boxes_data =
124-
boxes.tensor<float, 2>();
125-
126-
std::vector<float> scores_data(num_boxes);
127-
std::copy_n(scores.flat<float>().data(), num_boxes, scores_data.begin());
128-
std::vector<int> sorted_indices;
129-
DecreasingArgSort(scores_data, &sorted_indices);
130-
131-
std::vector<bool> active(num_boxes, true);
132-
std::vector<int> selected;
133-
int num_active = active.size();
134-
for (int i = 0; i < num_boxes; ++i) {
135-
if (num_active == 0 || selected.size() >= output_size) break;
136-
if (active[i]) {
137-
selected.push_back(sorted_indices[i]);
138-
} else {
139-
continue;
140-
}
141-
for (int j = i + 1; j < num_boxes; ++j) {
142-
if (active[j]) {
143-
float iou =
144-
ComputeIOU(boxes_data, sorted_indices[i], sorted_indices[j]);
145-
if (iou > iou_threshold_) {
146-
active[j] = false;
147-
num_active--;
148-
}
149-
}
150-
}
151-
}
152-
153-
// Allocate output tensor
154-
Tensor* output = nullptr;
155-
TensorShape output_shape({static_cast<int>(selected.size())});
156-
OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
157-
typename TTypes<int, 1>::Tensor selected_indices_data =
158-
output->tensor<int, 1>();
159-
std::copy_n(selected.begin(), selected.size(),
160-
selected_indices_data.data());
170+
DoNonMaxSuppressionOp(context, boxes, scores, max_output_size, iou_threshold_);
161171
}
162172

163173
private:
164174
float iou_threshold_;
165175
};
166176

177+
template <typename Device>
178+
class NonMaxSuppressionV2Op : public OpKernel {
179+
public:
180+
explicit NonMaxSuppressionV2Op(OpKernelConstruction* context)
181+
: OpKernel(context) {
182+
}
183+
184+
void Compute(OpKernelContext* context) override {
185+
// boxes: [num_boxes, 4]
186+
const Tensor& boxes = context->input(0);
187+
// scores: [num_boxes]
188+
const Tensor& scores = context->input(1);
189+
// max_output_size: scalar
190+
const Tensor& max_output_size = context->input(2);
191+
OP_REQUIRES(
192+
context, TensorShapeUtils::IsScalar(max_output_size.shape()),
193+
errors::InvalidArgument("max_output_size must be 0-D, got shape ",
194+
max_output_size.shape().DebugString()));
195+
// iou_threshold: scalar
196+
const Tensor& iou_threshold = context->input(3);
197+
OP_REQUIRES(
198+
context, TensorShapeUtils::IsScalar(iou_threshold.shape()),
199+
errors::InvalidArgument("iou_threshold must be 0-D, got shape ",
200+
iou_threshold.shape().DebugString()));
201+
202+
const float iou_threshold_val = iou_threshold.scalar<float>()();
203+
204+
DoNonMaxSuppressionOp(context, boxes, scores, max_output_size, iou_threshold_val);
205+
}
206+
};
207+
167208
REGISTER_KERNEL_BUILDER(Name("NonMaxSuppression").Device(DEVICE_CPU),
168209
NonMaxSuppressionOp<CPUDevice>);
169210

211+
REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV2").Device(DEVICE_CPU),
212+
NonMaxSuppressionV2Op<CPUDevice>);
213+
170214
} // namespace tensorflow

tensorflow/core/kernels/non_max_suppression_op_test.cc

Lines changed: 161 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,18 +141,174 @@ TEST_F(NonMaxSuppressionOpTest, TestInconsistentBoxAndScoreShapes) {
141141
AddInputFromArray<float>(TensorShape({5}), {.9f, .75f, .6f, .95f, .5f});
142142
AddInputFromArray<int>(TensorShape({}), {30});
143143
Status s = RunOpKernel();
144+
145+
ASSERT_FALSE(s.ok());
146+
EXPECT_TRUE(
147+
StringPiece(s.ToString()).contains("scores has incompatible shape"))
148+
<< s;
149+
}
150+
151+
TEST_F(NonMaxSuppressionOpTest, TestInvalidIOUThreshold) {
152+
MakeOp(1.2);
153+
AddInputFromArray<float>(TensorShape({1, 4}), {0, 0, 1, 1});
154+
AddInputFromArray<float>(TensorShape({1}), {.9f});
155+
AddInputFromArray<int>(TensorShape({}), {3});
156+
Status s = RunOpKernel();
157+
158+
ASSERT_FALSE(s.ok());
159+
EXPECT_TRUE(
160+
StringPiece(s.ToString()).contains("iou_threshold must be in [0, 1]"))
161+
<< s;
162+
}
163+
164+
TEST_F(NonMaxSuppressionOpTest, TestEmptyInput) {
165+
MakeOp(.5);
166+
AddInputFromArray<float>(TensorShape({0, 4}), {});
167+
AddInputFromArray<float>(TensorShape({0}), {});
168+
AddInputFromArray<int>(TensorShape({}), {30});
169+
TF_ASSERT_OK(RunOpKernel());
170+
171+
Tensor expected(allocator(), DT_INT32, TensorShape({0}));
172+
test::FillValues<int>(&expected, {});
173+
test::ExpectTensorEqual<int>(expected, *GetOutput(0));
174+
}
175+
176+
//
177+
// NonMaxSuppressionV2Op Tests
178+
//
179+
180+
class NonMaxSuppressionV2OpTest : public OpsTestBase {
181+
protected:
182+
void MakeOp() {
183+
TF_EXPECT_OK(NodeDefBuilder("non_max_suppression_op", "NonMaxSuppressionV2")
184+
.Input(FakeInput(DT_FLOAT))
185+
.Input(FakeInput(DT_FLOAT))
186+
.Input(FakeInput(DT_INT32))
187+
.Input(FakeInput(DT_FLOAT))
188+
.Finalize(node_def()));
189+
TF_EXPECT_OK(InitOp());
190+
}
191+
};
192+
193+
TEST_F(NonMaxSuppressionV2OpTest, TestSelectFromThreeClusters) {
194+
MakeOp();
195+
AddInputFromArray<float>(TensorShape({6, 4}),
196+
{0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f,
197+
0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101});
198+
AddInputFromArray<float>(TensorShape({6}), {.9f, .75f, .6f, .95f, .5f, .3f});
199+
AddInputFromArray<int>(TensorShape({}), {3});
200+
AddInputFromArray<float>(TensorShape({}), {.5f});
201+
TF_ASSERT_OK(RunOpKernel());
202+
203+
Tensor expected(allocator(), DT_INT32, TensorShape({3}));
204+
test::FillValues<int>(&expected, {3, 0, 5});
205+
test::ExpectTensorEqual<int>(expected, *GetOutput(0));
206+
}
207+
208+
TEST_F(NonMaxSuppressionV2OpTest, TestSelectFromThreeClustersFlippedCoordinates) {
209+
MakeOp();
210+
AddInputFromArray<float>(TensorShape({6, 4}),
211+
{1, 1, 0, 0, 0, 0.1f, 1, 1.1f, 0, .9f, 1, -0.1f,
212+
0, 10, 1, 11, 1, 10.1f, 0, 11.1f, 1, 101, 0, 100});
213+
AddInputFromArray<float>(TensorShape({6}), {.9f, .75f, .6f, .95f, .5f, .3f});
214+
AddInputFromArray<int>(TensorShape({}), {3});
215+
AddInputFromArray<float>(TensorShape({}), {.5f});
216+
TF_ASSERT_OK(RunOpKernel());
217+
218+
Tensor expected(allocator(), DT_INT32, TensorShape({3}));
219+
test::FillValues<int>(&expected, {3, 0, 5});
220+
test::ExpectTensorEqual<int>(expected, *GetOutput(0));
221+
}
222+
223+
TEST_F(NonMaxSuppressionV2OpTest, TestSelectAtMostTwoBoxesFromThreeClusters) {
224+
MakeOp();
225+
AddInputFromArray<float>(TensorShape({6, 4}),
226+
{0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f,
227+
0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101});
228+
AddInputFromArray<float>(TensorShape({6}), {.9f, .75f, .6f, .95f, .5f, .3f});
229+
AddInputFromArray<int>(TensorShape({}), {2});
230+
AddInputFromArray<float>(TensorShape({}), {.5f});
231+
TF_ASSERT_OK(RunOpKernel());
232+
233+
Tensor expected(allocator(), DT_INT32, TensorShape({2}));
234+
test::FillValues<int>(&expected, {3, 0});
235+
test::ExpectTensorEqual<int>(expected, *GetOutput(0));
236+
}
237+
238+
TEST_F(NonMaxSuppressionV2OpTest, TestSelectAtMostThirtyBoxesFromThreeClusters) {
239+
MakeOp();
240+
AddInputFromArray<float>(TensorShape({6, 4}),
241+
{0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f,
242+
0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101});
243+
AddInputFromArray<float>(TensorShape({6}), {.9f, .75f, .6f, .95f, .5f, .3f});
244+
AddInputFromArray<int>(TensorShape({}), {30});
245+
AddInputFromArray<float>(TensorShape({}), {.5f});
246+
TF_ASSERT_OK(RunOpKernel());
247+
248+
Tensor expected(allocator(), DT_INT32, TensorShape({3}));
249+
test::FillValues<int>(&expected, {3, 0, 5});
250+
test::ExpectTensorEqual<int>(expected, *GetOutput(0));
251+
}
252+
253+
TEST_F(NonMaxSuppressionV2OpTest, TestSelectSingleBox) {
254+
MakeOp();
255+
AddInputFromArray<float>(TensorShape({1, 4}), {0, 0, 1, 1});
256+
AddInputFromArray<float>(TensorShape({1}), {.9f});
257+
AddInputFromArray<int>(TensorShape({}), {3});
258+
AddInputFromArray<float>(TensorShape({}), {.5f});
259+
TF_ASSERT_OK(RunOpKernel());
260+
261+
Tensor expected(allocator(), DT_INT32, TensorShape({1}));
262+
test::FillValues<int>(&expected, {0});
263+
test::ExpectTensorEqual<int>(expected, *GetOutput(0));
264+
}
265+
266+
TEST_F(NonMaxSuppressionV2OpTest, TestSelectFromTenIdenticalBoxes) {
267+
MakeOp();
268+
269+
int num_boxes = 10;
270+
std::vector<float> corners(num_boxes * 4);
271+
std::vector<float> scores(num_boxes);
272+
for (int i = 0; i < num_boxes; ++i) {
273+
corners[i * 4 + 0] = 0;
274+
corners[i * 4 + 1] = 0;
275+
corners[i * 4 + 2] = 1;
276+
corners[i * 4 + 3] = 1;
277+
scores[i] = .9;
278+
}
279+
AddInputFromArray<float>(TensorShape({num_boxes, 4}), corners);
280+
AddInputFromArray<float>(TensorShape({num_boxes}), scores);
281+
AddInputFromArray<int>(TensorShape({}), {3});
282+
AddInputFromArray<float>(TensorShape({}), {.5f});
283+
TF_ASSERT_OK(RunOpKernel());
284+
285+
Tensor expected(allocator(), DT_INT32, TensorShape({1}));
286+
test::FillValues<int>(&expected, {0});
287+
test::ExpectTensorEqual<int>(expected, *GetOutput(0));
288+
}
289+
290+
TEST_F(NonMaxSuppressionV2OpTest, TestInconsistentBoxAndScoreShapes) {
291+
MakeOp();
292+
AddInputFromArray<float>(TensorShape({6, 4}),
293+
{0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f,
294+
0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101});
295+
AddInputFromArray<float>(TensorShape({5}), {.9f, .75f, .6f, .95f, .5f});
296+
AddInputFromArray<int>(TensorShape({}), {30});
297+
AddInputFromArray<float>(TensorShape({}), {.5f});
298+
Status s = RunOpKernel();
144299

145300
ASSERT_FALSE(s.ok());
146301
EXPECT_TRUE(
147302
StringPiece(s.ToString()).contains("scores has incompatible shape"))
148303
<< s;
149304
}
150305

151-
TEST_F(NonMaxSuppressionOpTest, TestInvalidIOUThreshold) {
152-
MakeOp(1.2);
306+
TEST_F(NonMaxSuppressionV2OpTest, TestInvalidIOUThreshold) {
307+
MakeOp();
153308
AddInputFromArray<float>(TensorShape({1, 4}), {0, 0, 1, 1});
154309
AddInputFromArray<float>(TensorShape({1}), {.9f});
155310
AddInputFromArray<int>(TensorShape({}), {3});
311+
AddInputFromArray<float>(TensorShape({}), {1.2f});
156312
Status s = RunOpKernel();
157313

158314
ASSERT_FALSE(s.ok());
@@ -161,11 +317,12 @@ TEST_F(NonMaxSuppressionOpTest, TestInvalidIOUThreshold) {
161317
<< s;
162318
}
163319

164-
TEST_F(NonMaxSuppressionOpTest, TestEmptyInput) {
165-
MakeOp(.5);
320+
TEST_F(NonMaxSuppressionV2OpTest, TestEmptyInput) {
321+
MakeOp();
166322
AddInputFromArray<float>(TensorShape({0, 4}), {});
167323
AddInputFromArray<float>(TensorShape({0}), {});
168324
AddInputFromArray<int>(TensorShape({}), {30});
325+
AddInputFromArray<float>(TensorShape({}), {.5f});
169326
TF_ASSERT_OK(RunOpKernel());
170327

171328
Tensor expected(allocator(), DT_INT32, TensorShape({0}));

0 commit comments

Comments
 (0)