Skip to content

Commit 593b072

Browse files
committed
More tuning
1 parent 25b3a10 commit 593b072

File tree

1 file changed

+53
-40
lines changed

1 file changed

+53
-40
lines changed

2_mnist.ipynb

Lines changed: 53 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
},
5656
"outputs": [],
5757
"source": [
58-
"# 2.3 Get input data: the sets of images and labels for training, validation, and\n",
58+
"# 2.3 Get input data: get the sets of images and labels for training, validation, and\n",
5959
"# test on MNIST.\n",
6060
"data_sets = read_data_sets(TRAIN_DIR, False)"
6161
]
@@ -176,65 +176,73 @@
176176
"cell_type": "code",
177177
"execution_count": null,
178178
"metadata": {
179-
"collapsed": false
179+
"collapsed": true
180180
},
181181
"outputs": [],
182182
"source": [
183-
"# 2.7 Train MNIST for a number of steps.\n",
184-
"# Generate placeholders for the images and labels.\n",
185-
"with tf.Graph().as_default():\n",
186-
" images_placeholder = tf.placeholder(tf.float32, shape=(BATCH_SIZE,\n",
187-
" IMAGE_PIXELS))\n",
183+
"# 2.7 Build the complete graph for feeding inputs, training, and saving checkpoints.\n",
184+
"mnist_graph = tf.Graph()\n",
185+
"with mnist_graph.as_default():\n",
186+
" # Generate placeholders for the images and labels.\n",
187+
" images_placeholder = tf.placeholder(tf.float32,\n",
188+
" shape=(BATCH_SIZE, IMAGE_PIXELS))\n",
188189
" labels_placeholder = tf.placeholder(tf.int32, shape=(BATCH_SIZE))\n",
189-
" tf.add_to_collection(\"images\", images_placeholder)\n",
190-
" tf.add_to_collection(\"labels\", labels_placeholder)\n",
190+
" tf.add_to_collection(\"images\", images_placeholder) # Remember this Op.\n",
191+
" tf.add_to_collection(\"labels\", labels_placeholder) # Remember this Op.\n",
191192
"\n",
192193
" # Build a Graph that computes predictions from the inference model.\n",
193194
" logits = mnist_inference(images_placeholder,\n",
194195
" HIDDEN1_UNITS,\n",
195196
" HIDDEN2_UNITS)\n",
197+
" tf.add_to_collection(\"logits\", logits) # Remember this Op.\n",
196198
"\n",
197199
" # Add to the Graph the Ops that calculate and apply gradients.\n",
198200
" train_op, loss = mnist_training(logits, labels_placeholder, 0.01)\n",
199201
"\n",
200-
" # Add the Op to compare the logits to the labels during evaluation.\n",
201-
" eval_correct = mnist_evaluation(logits, labels_placeholder)\n",
202-
" tf.add_to_collection(\"eval_op\", eval_correct)\n",
203-
"\n",
204202
" # Add the variable initializer Op.\n",
205203
" init = tf.initialize_all_variables()\n",
206204
"\n",
207205
" # Create a saver for writing training checkpoints.\n",
208-
" saver = tf.train.Saver()\n",
209-
"\n",
210-
" with tf.Session() as sess:\n",
211-
" # Run the Op to initialize the variables.\n",
212-
" sess.run(init)\n",
206+
" saver = tf.train.Saver()"
207+
]
208+
},
209+
{
210+
"cell_type": "code",
211+
"execution_count": null,
212+
"metadata": {
213+
"collapsed": false
214+
},
215+
"outputs": [],
216+
"source": [
217+
"# 2.8 Run training for MAX_STEPS and save checkpoint at the end.\n",
218+
"with tf.Session(graph=mnist_graph) as sess:\n",
219+
" # Run the Op to initialize the variables.\n",
220+
" sess.run(init)\n",
213221
"\n",
214-
" # Start the training loop.\n",
215-
" for step in xrange(MAX_STEPS):\n",
216-
" start_time = time.time()\n",
222+
" # Start the training loop.\n",
223+
" for step in xrange(MAX_STEPS):\n",
224+
" start_time = time.time()\n",
217225
"\n",
218-
" # Read a batch of images and labels.\n",
219-
" images_feed, labels_feed = data_sets.train.next_batch(BATCH_SIZE)\n",
226+
" # Read a batch of images and labels.\n",
227+
" images_feed, labels_feed = data_sets.train.next_batch(BATCH_SIZE)\n",
220228
"\n",
221-
" # Run one step of the model. The return values are the activations\n",
222-
" # from the `train_op` (which is discarded) and the `loss` Op. To\n",
223-
" # inspect the values of your Ops or variables, you may include them\n",
224-
" # in the list passed to sess.run() and the value tensors will be\n",
225-
" # returned in the tuple from the call.\n",
226-
" _, loss_value = sess.run([train_op, loss],\n",
227-
" feed_dict={images_placeholder: images_feed,\n",
228-
" labels_placeholder: labels_feed})\n",
229+
" # Run one step of the model. The return values are the activations\n",
230+
" # from the `train_op` (which is discarded) and the `loss` Op. To\n",
231+
" # inspect the values of your Ops or variables, you may include them\n",
232+
" # in the list passed to sess.run() and the value tensors will be\n",
233+
" # returned in the tuple from the call.\n",
234+
" _, loss_value = sess.run([train_op, loss],\n",
235+
" feed_dict={images_placeholder: images_feed,\n",
236+
" labels_placeholder: labels_feed})\n",
229237
"\n",
230-
" # Print out loss value.\n",
231-
" if step % 1000 == 0:\n",
232-
" print('Step %d: loss = %.2f (%.3f sec)' %\n",
233-
" (step, loss_value, time.time() - start_time))\n",
238+
" # Print out loss value.\n",
239+
" if step % 1000 == 0:\n",
240+
" print('Step %d: loss = %.2f (%.3f sec)' %\n",
241+
" (step, loss_value, time.time() - start_time))\n",
234242
"\n",
235-
" # Write a checkpoint.\n",
236-
" checkpoint_file = os.path.join(TRAIN_DIR, 'checkpoint')\n",
237-
" saver.save(sess, checkpoint_file, global_step=step)"
243+
" # Write a checkpoint.\n",
244+
" checkpoint_file = os.path.join(TRAIN_DIR, 'checkpoint')\n",
245+
" saver.save(sess, checkpoint_file, global_step=step)"
238246
]
239247
},
240248
{
@@ -245,7 +253,7 @@
245253
},
246254
"outputs": [],
247255
"source": [
248-
"# 2.8 Run evaluation based on the saved checkpoint.\n",
256+
"# 2.9 Run evaluation based on the saved checkpoint.\n",
249257
"with tf.Session(graph=tf.Graph()) as sess:\n",
250258
" saver = tf.train.import_meta_graph(\n",
251259
" os.path.join(TRAIN_DIR, \"checkpoint-1999.meta\"))\n",
@@ -257,9 +265,14 @@
257265
" true_count = 0 # Counts the number of correct predictions.\n",
258266
" steps_per_epoch = data_sets.validation.num_examples\n",
259267
" num_examples = steps_per_epoch * BATCH_SIZE\n",
260-
" eval_op = tf.get_collection(\"eval_op\")[0]\n",
268+
" # Retrieve the Ops we 'remembered'.\n",
269+
" logits = tf.get_collection(\"logits\")[0]\n",
261270
" images_placeholder = tf.get_collection(\"images\")[0]\n",
262271
" labels_placeholder = tf.get_collection(\"labels\")[0]\n",
272+
" \n",
273+
" # Add the Op to compare the logits to the labels during evaluation.\n",
274+
" eval_op = mnist_evaluation(logits, labels_placeholder)\n",
275+
"\n",
263276
" for step in xrange(steps_per_epoch):\n",
264277
" images_feed, labels_feed = data_sets.validation.next_batch(BATCH_SIZE)\n",
265278
" true_count += sess.run(eval_op,\n",

0 commit comments

Comments
 (0)