|
| 1 | +#!/usr/bin/python |
| 2 | +# -*- coding: utf8 -*- |
| 3 | +""" |
| 4 | +# This script is based on: |
| 5 | +# https://www.tensorflow.org/get_started/mnist/pros |
| 6 | +""" |
| 7 | +from __future__ import print_function |
| 8 | +import argparse |
| 9 | +import sys |
| 10 | +import tensorflow as tf |
| 11 | +from tensorflow.examples.tutorials.mnist import input_data |
| 12 | +from tensorflow.python.framework import graph_util as gu |
| 13 | +from tensorflow.tools.graph_transforms import TransformGraph |
| 14 | + |
| 15 | +FLAGS = None |
| 16 | + |
| 17 | + |
| 18 | +# helper functions |
| 19 | +def weight_variable(shape, name): |
| 20 | + """weight_variable generates a weight variable of a given shape.""" |
| 21 | + initial = tf.truncated_normal(shape, stddev=0.1) |
| 22 | + return tf.Variable(initial, name) |
| 23 | + |
| 24 | + |
| 25 | +def bias_variable(shape, name): |
| 26 | + """bias_variable generates a bias variable of a given shape.""" |
| 27 | + initial = tf.constant(0.1, shape=shape) |
| 28 | + return tf.Variable(initial, name) |
| 29 | + |
| 30 | + |
| 31 | +# Fully connected 2 layer NN |
| 32 | +def deepnn(x): |
| 33 | + W_fc1 = weight_variable([784, 64], name='W_fc1') |
| 34 | + b_fc1 = bias_variable([64], name='b_fc1') |
| 35 | + a_fc1 = tf.add(tf.matmul(x, W_fc1), b_fc1, name="zscore") |
| 36 | + h_fc1 = tf.nn.relu(a_fc1) |
| 37 | + layer1 = tf.nn.dropout(h_fc1, 0.50) |
| 38 | + |
| 39 | + W_fc2 = weight_variable([64, 10], name='W_fc3') |
| 40 | + b_fc2 = bias_variable([10], name='b_fc3') |
| 41 | + logits = tf.add(tf.matmul(layer1, W_fc2), b_fc2, name="logits") |
| 42 | + y_pred = tf.argmax(logits, 1, name='y_pred') |
| 43 | + |
| 44 | + return y_pred, logits |
| 45 | + |
| 46 | + |
| 47 | +def main(_): |
| 48 | + # Import data |
| 49 | + mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True) |
| 50 | + |
| 51 | + # Specify inputs, outputs, and a cost function |
| 52 | + # placeholders |
| 53 | + x = tf.placeholder(tf.float32, [None, 784], name="x") |
| 54 | + y_ = tf.placeholder(tf.float32, [None, 10], name="y") |
| 55 | + |
| 56 | + # Build the graph for the deep net |
| 57 | + y_pred, logits = deepnn(x) |
| 58 | + |
| 59 | + with tf.name_scope("Loss"): |
| 60 | + cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2(labels=y_, |
| 61 | + logits=logits) |
| 62 | + loss = tf.reduce_mean(cross_entropy, name="cross_entropy_loss") |
| 63 | + train_step = tf.train.AdamOptimizer(1e-4).minimize(loss, name="train_step") |
| 64 | + |
| 65 | + with tf.name_scope("Prediction"): |
| 66 | + correct_prediction = tf.equal(y_pred, |
| 67 | + tf.argmax(y_, 1)) |
| 68 | + accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name="accuracy") |
| 69 | + |
| 70 | + # Start training session |
| 71 | + with tf.Session() as sess: |
| 72 | + sess.run(tf.global_variables_initializer()) |
| 73 | + saver = tf.train.Saver() |
| 74 | + |
| 75 | + # SGD |
| 76 | + for i in range(1, FLAGS.num_iter + 1): |
| 77 | + batch_images, batch_labels = mnist.train.next_batch(FLAGS.batch_size) |
| 78 | + feed_dict = {x: batch_images, y_: batch_labels} |
| 79 | + train_step.run(feed_dict=feed_dict) |
| 80 | + if i % FLAGS.log_iter == 0: |
| 81 | + train_accuracy = accuracy.eval(feed_dict=feed_dict) |
| 82 | + print('step %d, training accuracy %g' % (i, train_accuracy)) |
| 83 | + |
| 84 | + print('test accuracy %g' % accuracy.eval(feed_dict={x: mnist.test.images, |
| 85 | + y_: mnist.test.labels})) |
| 86 | + # Saving checkpoint and serialize the graph |
| 87 | + ckpt_path = saver.save(sess, FLAGS.chkp) |
| 88 | + print('saving checkpoint: %s' % ckpt_path) |
| 89 | + out_nodes = [y_pred.op.name] |
| 90 | + # Freeze graph and remove training nodes |
| 91 | + sub_graph_def = gu.remove_training_nodes(sess.graph_def) |
| 92 | + sub_graph_def = gu.convert_variables_to_constants(sess, sub_graph_def, out_nodes) |
| 93 | + if FLAGS.no_quant: |
| 94 | + graph_path = tf.train.write_graph(sub_graph_def, |
| 95 | + FLAGS.output_dir, |
| 96 | + FLAGS.pb_fname, |
| 97 | + as_text=False) |
| 98 | + else: |
| 99 | + # # quantize the graph |
| 100 | + # quant_graph_def = TransformGraph(sub_graph_def, |
| 101 | + # [], |
| 102 | + # out_nodes, |
| 103 | + # ["quantize_weights", "quantize_nodes"]) |
| 104 | + graph_path = tf.train.write_graph(sub_graph_def, |
| 105 | + FLAGS.output_dir, |
| 106 | + FLAGS.pb_fname, |
| 107 | + as_text=False) |
| 108 | + print('written graph to: %s' % graph_path) |
| 109 | + print('the output nodes: {!s}'.format(out_nodes)) |
| 110 | + |
| 111 | + |
| 112 | +if __name__ == '__main__': |
| 113 | + parser = argparse.ArgumentParser() |
| 114 | + parser.add_argument('--data_dir', type=str, |
| 115 | + default='mnist_data', |
| 116 | + help='Directory for storing input data (default: %(default)s)') |
| 117 | + parser.add_argument('--chkp', default='chkps/mnist_model', |
| 118 | + help='session check point (default: %(default)s)') |
| 119 | + parser.add_argument('-n', '--num-iteration', type=int, |
| 120 | + dest='num_iter', |
| 121 | + default=20000, |
| 122 | + help='number of iterations (default: %(default)s)') |
| 123 | + parser.add_argument('--batch-size', dest='batch_size', |
| 124 | + default=50, type=int, |
| 125 | + help='batch size (default: %(default)s)') |
| 126 | + parser.add_argument('--log-every-iters', type=int, |
| 127 | + dest='log_iter', default=1000, |
| 128 | + help='logging the training accuracy per numbers of iteration %(default)s') |
| 129 | + parser.add_argument('--output-dir', default='mnist_model', |
| 130 | + dest='output_dir', |
| 131 | + help='output directory directory (default: %(default)s)') |
| 132 | + parser.add_argument('--no-quantization', action='store_true', |
| 133 | + dest='no_quant', |
| 134 | + help='save the output graph pb file without quantization') |
| 135 | + parser.add_argument('-o', '--output', default='deep_mlp.pb', |
| 136 | + dest='pb_fname', |
| 137 | + help='output pb file (default: %(default)s)') |
| 138 | + FLAGS, unparsed = parser.parse_known_args() |
| 139 | + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) |
0 commit comments