@@ -15,6 +15,7 @@ limitations under the License.
1515
1616#include " tensorflow/cc/ops/const_op.h"
1717#include " tensorflow/cc/ops/image_ops.h"
18+ #include " tensorflow/cc/ops/math_ops.h"
1819#include " tensorflow/cc/ops/nn_ops.h"
1920#include " tensorflow/cc/ops/sendrecv_ops.h"
2021#include " tensorflow/cc/ops/standard_ops.h"
@@ -35,7 +36,7 @@ Status FoldBatchNorms(const GraphDef& input_graph_def,
3536
3637class FoldBatchNormsTest : public ::testing::Test {
3738 protected:
38- void TestFoldBatchNorms () {
39+ void TestFoldBatchNormsConv2D () {
3940 auto root = tensorflow::Scope::NewRootScope ();
4041 using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
4142
@@ -85,9 +86,64 @@ class FoldBatchNormsTest : public ::testing::Test {
8586 EXPECT_NE (" Mul" , node.op ());
8687 }
8788 }
89+
90+ void TestFoldBatchNormsMatMul () {
91+ auto root = tensorflow::Scope::NewRootScope ();
92+ using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
93+
94+ Tensor input_data (DT_FLOAT, TensorShape ({6 , 2 }));
95+ test::FillValues<float >(
96+ &input_data, {1 .0f , 4 .0f , 2 .0f , 5 .0f , 3 .0f , 6 .0f , -1 .0f , -4 .0f , -2 .0f ,
97+ -5 .0f , -3 .0f , -6 .0f });
98+ Output input_op =
99+ Const (root.WithOpName (" input_op" ), Input::Initializer (input_data));
100+
101+ Tensor weights_data (DT_FLOAT, TensorShape ({2 , 2 }));
102+ test::FillValues<float >(&weights_data, {1 .0f , 2 .0f , 0 .3f , 0 .4f });
103+ Output weights_op =
104+ Const (root.WithOpName (" weights_op" ), Input::Initializer (weights_data));
105+
106+ Output matmul_op =
107+ MatMul (root.WithOpName (" matmul_op" ), input_op, weights_op);
108+
109+ Tensor mul_values_data (DT_FLOAT, TensorShape ({2 }));
110+ test::FillValues<float >(&mul_values_data, {2 .0f , 3 .0f });
111+ Output mul_values_op = Const (root.WithOpName (" mul_values" ),
112+ Input::Initializer (mul_values_data));
113+
114+ Output mul_op = Mul (root.WithOpName (" output" ), matmul_op, mul_values_op);
115+
116+ GraphDef original_graph_def;
117+ TF_ASSERT_OK (root.ToGraphDef (&original_graph_def));
118+
119+ std::unique_ptr<Session> original_session (NewSession (SessionOptions ()));
120+ TF_ASSERT_OK (original_session->Create (original_graph_def));
121+ std::vector<Tensor> original_outputs;
122+ TF_ASSERT_OK (original_session->Run ({}, {" output" }, {}, &original_outputs));
123+
124+ GraphDef fused_graph_def;
125+ TF_ASSERT_OK (
126+ FoldBatchNorms (original_graph_def, {{}, {" output" }}, &fused_graph_def));
127+
128+ std::unique_ptr<Session> fused_session (NewSession (SessionOptions ()));
129+ TF_ASSERT_OK (fused_session->Create (fused_graph_def));
130+ std::vector<Tensor> fused_outputs;
131+ TF_ASSERT_OK (fused_session->Run ({}, {" output" }, {}, &fused_outputs));
132+
133+ test::ExpectTensorNear<float >(original_outputs[0 ], fused_outputs[0 ], 1e-5 );
134+
135+ for (const NodeDef& node : fused_graph_def.node ()) {
136+ EXPECT_NE (" Mul" , node.op ());
137+ }
138+ }
88139};
89140
90- TEST_F (FoldBatchNormsTest, TestFoldBatchNorms) { TestFoldBatchNorms (); }
141+ TEST_F (FoldBatchNormsTest, TestFoldBatchNormsConv2D) {
142+ TestFoldBatchNormsConv2D ();
143+ }
144+ TEST_F (FoldBatchNormsTest, TestFoldBatchNormsMatMul) {
145+ TestFoldBatchNormsMatMul ();
146+ }
91147
92148} // namespace graph_transforms
93149} // namespace tensorflow
0 commit comments