Skip to content

Commit 25f2774

Browse files
committed
- reduced the network size
- updated the generated code
1 parent 82b7495 commit 25f2774

File tree

3 files changed

+373
-315
lines changed

3 files changed

+373
-315
lines changed

deep_mlp.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
#!/usr/bin/python
2+
# -*- coding: utf8 -*-
3+
"""
4+
# This script is based on:
5+
# https://www.tensorflow.org/get_started/mnist/pros
6+
"""
7+
from __future__ import print_function
8+
import argparse
9+
import sys
10+
import tensorflow as tf
11+
from tensorflow.examples.tutorials.mnist import input_data
12+
from tensorflow.python.framework import graph_util as gu
13+
from tensorflow.tools.graph_transforms import TransformGraph
14+
15+
FLAGS = None
16+
17+
18+
# helper functions
19+
def weight_variable(shape, name):
20+
"""weight_variable generates a weight variable of a given shape."""
21+
initial = tf.truncated_normal(shape, stddev=0.1)
22+
return tf.Variable(initial, name)
23+
24+
25+
def bias_variable(shape, name):
26+
"""bias_variable generates a bias variable of a given shape."""
27+
initial = tf.constant(0.1, shape=shape)
28+
return tf.Variable(initial, name)
29+
30+
31+
# Fully connected 2 layer NN
32+
def deepnn(x):
33+
W_fc1 = weight_variable([784, 64], name='W_fc1')
34+
b_fc1 = bias_variable([64], name='b_fc1')
35+
a_fc1 = tf.add(tf.matmul(x, W_fc1), b_fc1, name="zscore")
36+
h_fc1 = tf.nn.relu(a_fc1)
37+
layer1 = tf.nn.dropout(h_fc1, 0.50)
38+
39+
W_fc2 = weight_variable([64, 10], name='W_fc3')
40+
b_fc2 = bias_variable([10], name='b_fc3')
41+
logits = tf.add(tf.matmul(layer1, W_fc2), b_fc2, name="logits")
42+
y_pred = tf.argmax(logits, 1, name='y_pred')
43+
44+
return y_pred, logits
45+
46+
47+
def main(_):
48+
# Import data
49+
mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
50+
51+
# Specify inputs, outputs, and a cost function
52+
# placeholders
53+
x = tf.placeholder(tf.float32, [None, 784], name="x")
54+
y_ = tf.placeholder(tf.float32, [None, 10], name="y")
55+
56+
# Build the graph for the deep net
57+
y_pred, logits = deepnn(x)
58+
59+
with tf.name_scope("Loss"):
60+
cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2(labels=y_,
61+
logits=logits)
62+
loss = tf.reduce_mean(cross_entropy, name="cross_entropy_loss")
63+
train_step = tf.train.AdamOptimizer(1e-4).minimize(loss, name="train_step")
64+
65+
with tf.name_scope("Prediction"):
66+
correct_prediction = tf.equal(y_pred,
67+
tf.argmax(y_, 1))
68+
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name="accuracy")
69+
70+
# Start training session
71+
with tf.Session() as sess:
72+
sess.run(tf.global_variables_initializer())
73+
saver = tf.train.Saver()
74+
75+
# SGD
76+
for i in range(1, FLAGS.num_iter + 1):
77+
batch_images, batch_labels = mnist.train.next_batch(FLAGS.batch_size)
78+
feed_dict = {x: batch_images, y_: batch_labels}
79+
train_step.run(feed_dict=feed_dict)
80+
if i % FLAGS.log_iter == 0:
81+
train_accuracy = accuracy.eval(feed_dict=feed_dict)
82+
print('step %d, training accuracy %g' % (i, train_accuracy))
83+
84+
print('test accuracy %g' % accuracy.eval(feed_dict={x: mnist.test.images,
85+
y_: mnist.test.labels}))
86+
# Saving checkpoint and serialize the graph
87+
ckpt_path = saver.save(sess, FLAGS.chkp)
88+
print('saving checkpoint: %s' % ckpt_path)
89+
out_nodes = [y_pred.op.name]
90+
# Freeze graph and remove training nodes
91+
sub_graph_def = gu.remove_training_nodes(sess.graph_def)
92+
sub_graph_def = gu.convert_variables_to_constants(sess, sub_graph_def, out_nodes)
93+
if FLAGS.no_quant:
94+
graph_path = tf.train.write_graph(sub_graph_def,
95+
FLAGS.output_dir,
96+
FLAGS.pb_fname,
97+
as_text=False)
98+
else:
99+
# # quantize the graph
100+
# quant_graph_def = TransformGraph(sub_graph_def,
101+
# [],
102+
# out_nodes,
103+
# ["quantize_weights", "quantize_nodes"])
104+
graph_path = tf.train.write_graph(sub_graph_def,
105+
FLAGS.output_dir,
106+
FLAGS.pb_fname,
107+
as_text=False)
108+
print('written graph to: %s' % graph_path)
109+
print('the output nodes: {!s}'.format(out_nodes))
110+
111+
112+
if __name__ == '__main__':
113+
parser = argparse.ArgumentParser()
114+
parser.add_argument('--data_dir', type=str,
115+
default='mnist_data',
116+
help='Directory for storing input data (default: %(default)s)')
117+
parser.add_argument('--chkp', default='chkps/mnist_model',
118+
help='session check point (default: %(default)s)')
119+
parser.add_argument('-n', '--num-iteration', type=int,
120+
dest='num_iter',
121+
default=20000,
122+
help='number of iterations (default: %(default)s)')
123+
parser.add_argument('--batch-size', dest='batch_size',
124+
default=50, type=int,
125+
help='batch size (default: %(default)s)')
126+
parser.add_argument('--log-every-iters', type=int,
127+
dest='log_iter', default=1000,
128+
help='logging the training accuracy per numbers of iteration %(default)s')
129+
parser.add_argument('--output-dir', default='mnist_model',
130+
dest='output_dir',
131+
help='output directory directory (default: %(default)s)')
132+
parser.add_argument('--no-quantization', action='store_true',
133+
dest='no_quant',
134+
help='save the output graph pb file without quantization')
135+
parser.add_argument('-o', '--output', default='deep_mlp.pb',
136+
dest='pb_fname',
137+
help='output pb file (default: %(default)s)')
138+
FLAGS, unparsed = parser.parse_known_args()
139+
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

0 commit comments

Comments
 (0)