Skip to content

Commit 0943136

Browse files
wangpengmittensorflower-gardener
authored andcommitted
Adds back a DoCopy in StridedSliceAssignOp that was accidentally deleted.
PiperOrigin-RevId: 323074864 Change-Id: Ie46f51353a85aa423e562da7e2f3009238cca07e
1 parent 2a5e737 commit 0943136

File tree

3 files changed

+35
-6
lines changed

3 files changed

+35
-6
lines changed

tensorflow/core/framework/op_kernel.h

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -885,10 +885,14 @@ class OpKernelContext {
885885

886886
// Tries to forward one of the inputs given in input_indices to
887887
// output[output_index]. If none of the given inputs can be forwarded, calls
888-
// allocate_output() to allocate a new output buffer.
888+
// allocate_output() to allocate a new output buffer. The index of the
889+
// forwarded input will be assign to output argument forwarded_input (if it's
890+
// not nullptr). If no inputs are forwarded, forwarded_input will be assigned
891+
// -1.
889892
Status forward_input_or_allocate_output(
890893
gtl::ArraySlice<int> candidate_input_indices, int output_index,
891-
const TensorShape& output_shape, Tensor** output) TF_MUST_USE_RESULT;
894+
const TensorShape& output_shape, Tensor** output,
895+
int* forwarded_input = nullptr) TF_MUST_USE_RESULT;
892896
Status forward_input_or_allocate_output(
893897
gtl::ArraySlice<StringPiece> candidate_input_names,
894898
StringPiece output_name, const TensorShape& output_shape,
@@ -1636,13 +1640,19 @@ inline TensorValue OpKernelContext::release_output(int index) {
16361640

16371641
inline Status OpKernelContext::forward_input_or_allocate_output(
16381642
gtl::ArraySlice<int> candidate_input_indices, int output_index,
1639-
const TensorShape& output_shape, Tensor** output) {
1643+
const TensorShape& output_shape, Tensor** output, int* forwarded_input) {
16401644
for (int input_index : candidate_input_indices) {
16411645
if (forward_input_to_output_with_shape(input_index, output_index,
16421646
output_shape, output)) {
1647+
if (forwarded_input != nullptr) {
1648+
*forwarded_input = input_index;
1649+
}
16431650
return Status::OK();
16441651
}
16451652
}
1653+
if (forwarded_input != nullptr) {
1654+
*forwarded_input = -1;
1655+
}
16461656
return allocate_output(output_index, output_shape, output);
16471657
}
16481658

tensorflow/core/kernels/strided_slice_op.cc

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -306,8 +306,15 @@ class StridedSliceAssignOp : public OpKernel {
306306
if (isTensor) {
307307
const Tensor& input = context->input(0);
308308

309-
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
310-
{0}, 0, input.shape(), &old_lhs));
309+
int forwarded_input;
310+
OP_REQUIRES_OK(context,
311+
context->forward_input_or_allocate_output(
312+
{0}, 0, input.shape(), &old_lhs, &forwarded_input));
313+
if (forwarded_input < 0) {
314+
OP_REQUIRES_OK(context,
315+
tensorflow::functor::DoCopy(
316+
context->eigen_device<Device>(), input, old_lhs));
317+
}
311318
} else {
312319
if (context->input_dtype(0) == DT_RESOURCE) {
313320
core::RefCountPtr<Var> v;

tensorflow/python/kernel_tests/array_ops_test.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1228,13 +1228,25 @@ def testTypeErrorResource(self):
12281228
sess.run(v[:].assign(too_small_val))
12291229

12301230
@test_util.run_in_graph_and_eager_modes
1231-
def testTensorStridedSliceAssign(self):
1231+
def testTensorStridedSliceAssignWithInputForward(self):
1232+
"""Tests tensor_strided_slice_update with input-forwarding taking effect."""
12321233
@def_function.function
12331234
def assign(x):
12341235
y = x + 1
12351236
return gen_array_ops.tensor_strided_slice_update(y, [0], [1], [1], [0])
12361237
self.assertAllEqual([0, 1], self.evaluate(assign(array_ops.zeros([2]))))
12371238

1239+
@test_util.run_in_graph_and_eager_modes
1240+
def testTensorStridedSliceAssignNoInputForward(self):
1241+
"""Tests tensor_strided_slice_update with no input-forwarding."""
1242+
x = constant_op.constant([0.2, 0.3])
1243+
y = x + 1
1244+
# y's buffer won't be forwarded to z because y and z will be alive at the
1245+
# same time later.
1246+
z = gen_array_ops.tensor_strided_slice_update(y, [0], [1], [1], [0.4])
1247+
ans = y + z
1248+
self.assertAllClose([1.6, 2.6], self.evaluate(ans))
1249+
12381250

12391251
class ShapeSizeRankTest(test_util.TensorFlowTestCase):
12401252

0 commit comments

Comments
 (0)