Skip to content

Commit ec86b03

Browse files
Extends fold_batch_norms transform to also fold the mul introduced by batch normalization after fully connected layers (MatMul).
Change: 148868461
1 parent 7e48bad commit ec86b03

File tree

3 files changed

+75
-16
lines changed

3 files changed

+75
-16
lines changed

tensorflow/tools/graph_transforms/README.md

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -341,12 +341,13 @@ Args: None \
341341
Prerequisites: [fold_constants](#fold_constants)
342342

343343
This transform tries to optimize away the Mul that's introduced after a Conv2D
344-
when batch normalization has been used during training. It scans the graph for
345-
any channel-wise multiplies immediately after convolutions, and multiplies the
346-
convolution's weights with the Mul instead so this can be omitted at inference
347-
time. You'll need to make sure you run [fold_constants](#fold_constants) first,
348-
since the pattern can only be spotted if the normal complex expression that's
349-
produced by training for the Mul input is collapsed down into a simple constant.
344+
(or a MatMul) when batch normalization has been used during training. It scans
345+
the graph for any channel-wise multiplies immediately after convolutions, and
346+
multiplies the convolution's (or matrix multiplication's) weights with the Mul
347+
instead so this can be omitted at inference time. You'll need to make sure you
348+
run [fold_constants](#fold_constants) first, since the pattern can only be
349+
spotted if the normal complex expression that's produced by training for the Mul
350+
input is collapsed down into a simple constant.
350351

351352
### fold_constants
352353

tensorflow/tools/graph_transforms/fold_batch_norms.cc

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,23 +27,24 @@ limitations under the License.
2727
namespace tensorflow {
2828
namespace graph_transforms {
2929

30-
// Converts Conv2D ops followed by column-wise Muls into equivalent ops with the
31-
// Mul baked into the convolution weights, to save computation during inference.
30+
// Converts Conv2D or MatMul ops followed by column-wise Muls into equivalent
31+
// ops with the Mul baked into the convolution weights, to save computation
32+
// during inference.
3233
Status FoldBatchNorms(const GraphDef& input_graph_def,
3334
const TransformFuncContext& context,
3435
GraphDef* output_graph_def) {
3536
GraphDef replaced_graph_def;
3637
TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
3738
input_graph_def, // clang-format off
38-
{"Mul", // mul_node
39+
{"Mul", // mul_node
3940
{
40-
{"Conv2D", // conv_node
41+
{"Conv2D|MatMul", // conv_node
4142
{
42-
{"*"}, // input_node
43-
{"Const"}, // weights_node
43+
{"*"}, // input_node
44+
{"Const"}, // weights_node
4445
}
4546
},
46-
{"Const"}, // mul_values_node
47+
{"Const"}, // mul_values_node
4748
}
4849
}, // clang-format on
4950
[](const NodeMatch& match, const std::set<string>& input_nodes,
@@ -61,7 +62,8 @@ Status FoldBatchNorms(const GraphDef& input_graph_def,
6162

6263
// Make sure all the inputs really are vectors, with as many entries as
6364
// there are columns in the weights.
64-
const int64 weights_cols = weights.shape().dim_size(3);
65+
const int weights_cols_index = conv_node.op() == "Conv2D" ? 3 : 1;
66+
const int64 weights_cols = weights.shape().dim_size(weights_cols_index);
6567
if ((mul_values.shape().dims() != 1) ||
6668
(mul_values.shape().dim_size(0) != weights_cols)) {
6769
return errors::InvalidArgument(

tensorflow/tools/graph_transforms/fold_batch_norms_test.cc

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License.
1515

1616
#include "tensorflow/cc/ops/const_op.h"
1717
#include "tensorflow/cc/ops/image_ops.h"
18+
#include "tensorflow/cc/ops/math_ops.h"
1819
#include "tensorflow/cc/ops/nn_ops.h"
1920
#include "tensorflow/cc/ops/sendrecv_ops.h"
2021
#include "tensorflow/cc/ops/standard_ops.h"
@@ -35,7 +36,7 @@ Status FoldBatchNorms(const GraphDef& input_graph_def,
3536

3637
class FoldBatchNormsTest : public ::testing::Test {
3738
protected:
38-
void TestFoldBatchNorms() {
39+
void TestFoldBatchNormsConv2D() {
3940
auto root = tensorflow::Scope::NewRootScope();
4041
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
4142

@@ -85,9 +86,64 @@ class FoldBatchNormsTest : public ::testing::Test {
8586
EXPECT_NE("Mul", node.op());
8687
}
8788
}
89+
90+
void TestFoldBatchNormsMatMul() {
91+
auto root = tensorflow::Scope::NewRootScope();
92+
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
93+
94+
Tensor input_data(DT_FLOAT, TensorShape({6, 2}));
95+
test::FillValues<float>(
96+
&input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f,
97+
-5.0f, -3.0f, -6.0f});
98+
Output input_op =
99+
Const(root.WithOpName("input_op"), Input::Initializer(input_data));
100+
101+
Tensor weights_data(DT_FLOAT, TensorShape({2, 2}));
102+
test::FillValues<float>(&weights_data, {1.0f, 2.0f, 0.3f, 0.4f});
103+
Output weights_op =
104+
Const(root.WithOpName("weights_op"), Input::Initializer(weights_data));
105+
106+
Output matmul_op =
107+
MatMul(root.WithOpName("matmul_op"), input_op, weights_op);
108+
109+
Tensor mul_values_data(DT_FLOAT, TensorShape({2}));
110+
test::FillValues<float>(&mul_values_data, {2.0f, 3.0f});
111+
Output mul_values_op = Const(root.WithOpName("mul_values"),
112+
Input::Initializer(mul_values_data));
113+
114+
Output mul_op = Mul(root.WithOpName("output"), matmul_op, mul_values_op);
115+
116+
GraphDef original_graph_def;
117+
TF_ASSERT_OK(root.ToGraphDef(&original_graph_def));
118+
119+
std::unique_ptr<Session> original_session(NewSession(SessionOptions()));
120+
TF_ASSERT_OK(original_session->Create(original_graph_def));
121+
std::vector<Tensor> original_outputs;
122+
TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs));
123+
124+
GraphDef fused_graph_def;
125+
TF_ASSERT_OK(
126+
FoldBatchNorms(original_graph_def, {{}, {"output"}}, &fused_graph_def));
127+
128+
std::unique_ptr<Session> fused_session(NewSession(SessionOptions()));
129+
TF_ASSERT_OK(fused_session->Create(fused_graph_def));
130+
std::vector<Tensor> fused_outputs;
131+
TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs));
132+
133+
test::ExpectTensorNear<float>(original_outputs[0], fused_outputs[0], 1e-5);
134+
135+
for (const NodeDef& node : fused_graph_def.node()) {
136+
EXPECT_NE("Mul", node.op());
137+
}
138+
}
88139
};
89140

90-
TEST_F(FoldBatchNormsTest, TestFoldBatchNorms) { TestFoldBatchNorms(); }
141+
TEST_F(FoldBatchNormsTest, TestFoldBatchNormsConv2D) {
142+
TestFoldBatchNormsConv2D();
143+
}
144+
TEST_F(FoldBatchNormsTest, TestFoldBatchNormsMatMul) {
145+
TestFoldBatchNormsMatMul();
146+
}
91147

92148
} // namespace graph_transforms
93149
} // namespace tensorflow

0 commit comments

Comments
 (0)