Skip to content

Commit bcafcee

Browse files
authored
Merge pull request tensorflow#12665 from bpiel/conv2d_maxpool_grads
Conv2DGrad & MaxPoolGradHelper
2 parents b92d538 + 5619b05 commit bcafcee

File tree

3 files changed

+109
-0
lines changed

3 files changed

+109
-0
lines changed

tensorflow/cc/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,7 @@ cc_library(
326326
":cc_ops",
327327
":cc_ops_internal",
328328
":grad_op_registry",
329+
":gradients",
329330
],
330331
alwayslink = 1,
331332
)

tensorflow/cc/gradients/nn_grad.cc

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ limitations under the License.
1818
#include "tensorflow/cc/ops/standard_ops.h"
1919

2020
#include "tensorflow/cc/framework/grad_op_registry.h"
21+
#include "tensorflow/cc/framework/gradients.h"
2122

2223
namespace tensorflow {
2324
namespace ops {
@@ -118,6 +119,86 @@ Status BiasAddGradHelper(const Scope& scope, const Operation& op,
118119
}
119120
REGISTER_GRADIENT_OP("BiasAdd", BiasAddGradHelper);
120121

122+
Status Conv2DGrad(const Scope& scope, const Operation& op,
123+
const std::vector<Output>& grad_inputs,
124+
std::vector<Output>* grad_outputs) {
125+
string data_format;
126+
string padding;
127+
std::vector<int32> strides;
128+
bool use_cudnn_on_gpu;
129+
auto attrs = op.output(0).node()->attrs();
130+
GetNodeAttr(attrs, "data_format", &data_format);
131+
GetNodeAttr(attrs, "padding", &padding);
132+
GetNodeAttr(attrs, "strides", &strides);
133+
GetNodeAttr(attrs, "use_cudnn_on_gpu", &use_cudnn_on_gpu);
134+
Conv2DBackpropInput::Attrs input_attrs;
135+
input_attrs.DataFormat(data_format);
136+
input_attrs.UseCudnnOnGpu(use_cudnn_on_gpu);
137+
auto dx_1 = Conv2DBackpropInput(scope, Shape(scope, op.input(0)),
138+
op.input(1), grad_inputs[0],
139+
strides, padding, input_attrs);
140+
grad_outputs->push_back(dx_1);
141+
Conv2DBackpropFilter::Attrs filter_attrs;
142+
filter_attrs.DataFormat(data_format);
143+
filter_attrs.UseCudnnOnGpu(use_cudnn_on_gpu);
144+
auto dx_2 = Conv2DBackpropFilter(scope, op.input(0),
145+
Shape(scope, op.input(1)), grad_inputs[0],
146+
strides, padding, filter_attrs);
147+
grad_outputs->push_back(dx_2);
148+
return scope.status();
149+
}
150+
REGISTER_GRADIENT_OP("Conv2D", Conv2DGrad);
151+
152+
Status MaxPoolGradHelper(const Scope& scope, const Operation& op,
153+
const std::vector<Output>& grad_inputs,
154+
std::vector<Output>* grad_outputs) {
155+
string data_format;
156+
string padding;
157+
std::vector<int32> strides;
158+
std::vector<int32> ksize;
159+
auto attrs = op.output(0).node()->attrs();
160+
GetNodeAttr(attrs, "data_format", &data_format);
161+
GetNodeAttr(attrs, "ksize", &ksize);
162+
GetNodeAttr(attrs, "padding", &padding);
163+
GetNodeAttr(attrs, "strides", &strides);
164+
internal::MaxPoolGrad::Attrs grad_attrs;
165+
grad_attrs.DataFormat(data_format);
166+
auto dx = internal::MaxPoolGrad(scope, op.input(0),
167+
op.output(0),
168+
grad_inputs[0],
169+
ksize, strides,
170+
padding, grad_attrs);
171+
grad_outputs->push_back(dx);
172+
return scope.status();
173+
}
174+
REGISTER_GRADIENT_OP("MaxPool", MaxPoolGradHelper);
175+
176+
Status MaxPoolGradV2Helper(const Scope& scope, const Operation& op,
177+
const std::vector<Output>& grad_inputs,
178+
std::vector<Output>* grad_outputs) {
179+
string data_format;
180+
string padding;
181+
auto attrs = op.output(0).node()->attrs();
182+
GetNodeAttr(attrs, "data_format", &data_format);
183+
GetNodeAttr(attrs, "padding", &padding);
184+
MaxPoolGradV2::Attrs grad_attrs;
185+
grad_attrs.DataFormat(data_format);
186+
auto dx = MaxPoolGradV2(scope, op.input(0),
187+
op.output(0),
188+
grad_inputs[0],
189+
op.input(1),
190+
op.input(2),
191+
padding,
192+
grad_attrs);
193+
grad_outputs->push_back(dx);
194+
grad_outputs->push_back(NoGradient());
195+
grad_outputs->push_back(NoGradient());
196+
return scope.status();
197+
}
198+
REGISTER_GRADIENT_OP("MaxPoolV2", MaxPoolGradV2Helper);
199+
200+
201+
121202
} // anonymous namespace
122203
} // namespace ops
123204
} // namespace tensorflow

tensorflow/cc/gradients/nn_grad_test.cc

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,5 +139,32 @@ TEST_F(NNGradTest, BiasAddGradHelper) {
139139
RunTest({x,bias}, {shape, bias_shape}, {y}, {shape});
140140
}
141141

142+
TEST_F(NNGradTest, Conv2DGrad) {
143+
TensorShape shape({1, 2, 2, 1});
144+
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
145+
Tensor filter = test::AsTensor<float>({0.5f}, {1, 1, 1, 1});
146+
const std::vector<int> strides{1, 1, 1, 1};
147+
auto y = Conv2D(scope_, x, filter, strides, "SAME");
148+
RunTest(x, shape, y, shape);
149+
}
150+
151+
TEST_F(NNGradTest, MaxPoolGradHelper) {
152+
TensorShape shape({1, 2, 2, 1});
153+
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
154+
const std::vector<int> ksize{1, 2, 2, 1};
155+
const std::vector<int> strides{1, 1, 1, 1};
156+
auto y = MaxPool(scope_, x, ksize, strides, "SAME");
157+
RunTest(x, shape, y, shape);
158+
}
159+
160+
TEST_F(NNGradTest, MaxPoolGradV2Helper) {
161+
TensorShape shape({1, 2, 2, 1});
162+
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
163+
Tensor ksize = test::AsTensor<int>({1, 2, 2, 1}, {4});
164+
Tensor strides = test::AsTensor<int>({1, 1, 1, 1}, {4});
165+
auto y = MaxPoolV2(scope_, x, ksize, strides, "SAME");
166+
RunTest(x, shape, y, shape);
167+
}
168+
142169
} // namespace
143170
} // namespace tensorflow

0 commit comments

Comments
 (0)