| 
55 | 55 |    },  | 
56 | 56 |    "outputs": [],  | 
57 | 57 |    "source": [  | 
58 |  | -    "# 2.3 Get input data: the sets of images and labels for training, validation, and\n",  | 
 | 58 | +    "# 2.3 Get input data: get the sets of images and labels for training, validation, and\n",  | 
59 | 59 |     "# test on MNIST.\n",  | 
60 | 60 |     "data_sets = read_data_sets(TRAIN_DIR, False)"  | 
61 | 61 |    ]  | 
 | 
176 | 176 |    "cell_type": "code",  | 
177 | 177 |    "execution_count": null,  | 
178 | 178 |    "metadata": {  | 
179 |  | -    "collapsed": false  | 
 | 179 | +    "collapsed": true  | 
180 | 180 |    },  | 
181 | 181 |    "outputs": [],  | 
182 | 182 |    "source": [  | 
183 |  | -    "# 2.7 Train MNIST for a number of steps.\n",  | 
184 |  | -    "# Generate placeholders for the images and labels.\n",  | 
185 |  | -    "with tf.Graph().as_default():\n",  | 
186 |  | -    "    images_placeholder = tf.placeholder(tf.float32, shape=(BATCH_SIZE,\n",  | 
187 |  | -    "                                                           IMAGE_PIXELS))\n",  | 
 | 183 | +    "# 2.7 Build the complete graph for feeding inputs, training, and saving checkpoints.\n",  | 
 | 184 | +    "mnist_graph = tf.Graph()\n",  | 
 | 185 | +    "with mnist_graph.as_default():\n",  | 
 | 186 | +    "    # Generate placeholders for the images and labels.\n",  | 
 | 187 | +    "    images_placeholder = tf.placeholder(tf.float32,\n",  | 
 | 188 | +    "                                        shape=(BATCH_SIZE, IMAGE_PIXELS))\n",  | 
188 | 189 |     "    labels_placeholder = tf.placeholder(tf.int32, shape=(BATCH_SIZE))\n",  | 
189 |  | -    "    tf.add_to_collection(\"images\", images_placeholder)\n",  | 
190 |  | -    "    tf.add_to_collection(\"labels\", labels_placeholder)\n",  | 
 | 190 | +    "    tf.add_to_collection(\"images\", images_placeholder)  # Remember this Op.\n",  | 
 | 191 | +    "    tf.add_to_collection(\"labels\", labels_placeholder)  # Remember this Op.\n",  | 
191 | 192 |     "\n",  | 
192 | 193 |     "    # Build a Graph that computes predictions from the inference model.\n",  | 
193 | 194 |     "    logits = mnist_inference(images_placeholder,\n",  | 
194 | 195 |     "                             HIDDEN1_UNITS,\n",  | 
195 | 196 |     "                             HIDDEN2_UNITS)\n",  | 
 | 197 | +    "    tf.add_to_collection(\"logits\", logits)  # Remember this Op.\n",  | 
196 | 198 |     "\n",  | 
197 | 199 |     "    # Add to the Graph the Ops that calculate and apply gradients.\n",  | 
198 | 200 |     "    train_op, loss = mnist_training(logits, labels_placeholder, 0.01)\n",  | 
199 | 201 |     "\n",  | 
200 |  | -    "    # Add the Op to compare the logits to the labels during evaluation.\n",  | 
201 |  | -    "    eval_correct = mnist_evaluation(logits, labels_placeholder)\n",  | 
202 |  | -    "    tf.add_to_collection(\"eval_op\", eval_correct)\n",  | 
203 |  | -    "\n",  | 
204 | 202 |     "    # Add the variable initializer Op.\n",  | 
205 | 203 |     "    init = tf.initialize_all_variables()\n",  | 
206 | 204 |     "\n",  | 
207 | 205 |     "    # Create a saver for writing training checkpoints.\n",  | 
208 |  | -    "    saver = tf.train.Saver()\n",  | 
209 |  | -    "\n",  | 
210 |  | -    "    with tf.Session() as sess:\n",  | 
211 |  | -    "        # Run the Op to initialize the variables.\n",  | 
212 |  | -    "        sess.run(init)\n",  | 
 | 206 | +    "    saver = tf.train.Saver()"  | 
 | 207 | +   ]  | 
 | 208 | +  },  | 
 | 209 | +  {  | 
 | 210 | +   "cell_type": "code",  | 
 | 211 | +   "execution_count": null,  | 
 | 212 | +   "metadata": {  | 
 | 213 | +    "collapsed": false  | 
 | 214 | +   },  | 
 | 215 | +   "outputs": [],  | 
 | 216 | +   "source": [  | 
 | 217 | +    "# 2.8 Run training for MAX_STEPS and save checkpoint at the end.\n",  | 
 | 218 | +    "with tf.Session(graph=mnist_graph) as sess:\n",  | 
 | 219 | +    "    # Run the Op to initialize the variables.\n",  | 
 | 220 | +    "    sess.run(init)\n",  | 
213 | 221 |     "\n",  | 
214 |  | -    "        # Start the training loop.\n",  | 
215 |  | -    "        for step in xrange(MAX_STEPS):\n",  | 
216 |  | -    "          start_time = time.time()\n",  | 
 | 222 | +    "    # Start the training loop.\n",  | 
 | 223 | +    "    for step in xrange(MAX_STEPS):\n",  | 
 | 224 | +    "      start_time = time.time()\n",  | 
217 | 225 |     "\n",  | 
218 |  | -    "          # Read a batch of images and labels.\n",  | 
219 |  | -    "          images_feed, labels_feed = data_sets.train.next_batch(BATCH_SIZE)\n",  | 
 | 226 | +    "      # Read a batch of images and labels.\n",  | 
 | 227 | +    "      images_feed, labels_feed = data_sets.train.next_batch(BATCH_SIZE)\n",  | 
220 | 228 |     "\n",  | 
221 |  | -    "          # Run one step of the model.  The return values are the activations\n",  | 
222 |  | -    "          # from the `train_op` (which is discarded) and the `loss` Op.  To\n",  | 
223 |  | -    "          # inspect the values of your Ops or variables, you may include them\n",  | 
224 |  | -    "          # in the list passed to sess.run() and the value tensors will be\n",  | 
225 |  | -    "          # returned in the tuple from the call.\n",  | 
226 |  | -    "          _, loss_value = sess.run([train_op, loss],\n",  | 
227 |  | -    "                                   feed_dict={images_placeholder: images_feed,\n",  | 
228 |  | -    "                                              labels_placeholder: labels_feed})\n",  | 
 | 229 | +    "      # Run one step of the model.  The return values are the activations\n",  | 
 | 230 | +    "      # from the `train_op` (which is discarded) and the `loss` Op.  To\n",  | 
 | 231 | +    "      # inspect the values of your Ops or variables, you may include them\n",  | 
 | 232 | +    "      # in the list passed to sess.run() and the value tensors will be\n",  | 
 | 233 | +    "      # returned in the tuple from the call.\n",  | 
 | 234 | +    "      _, loss_value = sess.run([train_op, loss],\n",  | 
 | 235 | +    "                               feed_dict={images_placeholder: images_feed,\n",  | 
 | 236 | +    "                                          labels_placeholder: labels_feed})\n",  | 
229 | 237 |     "\n",  | 
230 |  | -    "          # Print out loss value.\n",  | 
231 |  | -    "          if step % 1000 == 0:\n",  | 
232 |  | -    "            print('Step %d: loss = %.2f (%.3f sec)' %\n",  | 
233 |  | -    "                  (step, loss_value, time.time() - start_time))\n",  | 
 | 238 | +    "      # Print out loss value.\n",  | 
 | 239 | +    "      if step % 1000 == 0:\n",  | 
 | 240 | +    "        print('Step %d: loss = %.2f (%.3f sec)' %\n",  | 
 | 241 | +    "              (step, loss_value, time.time() - start_time))\n",  | 
234 | 242 |     "\n",  | 
235 |  | -    "        # Write a checkpoint.\n",  | 
236 |  | -    "        checkpoint_file = os.path.join(TRAIN_DIR, 'checkpoint')\n",  | 
237 |  | -    "        saver.save(sess, checkpoint_file, global_step=step)"  | 
 | 243 | +    "    # Write a checkpoint.\n",  | 
 | 244 | +    "    checkpoint_file = os.path.join(TRAIN_DIR, 'checkpoint')\n",  | 
 | 245 | +    "    saver.save(sess, checkpoint_file, global_step=step)"  | 
238 | 246 |    ]  | 
239 | 247 |   },  | 
240 | 248 |   {  | 
 | 
245 | 253 |    },  | 
246 | 254 |    "outputs": [],  | 
247 | 255 |    "source": [  | 
248 |  | -    "# 2.8 Run evaluation based on the saved checkpoint.\n",  | 
 | 256 | +    "# 2.9 Run evaluation based on the saved checkpoint.\n",  | 
249 | 257 |     "with tf.Session(graph=tf.Graph()) as sess:\n",  | 
250 | 258 |     "    saver = tf.train.import_meta_graph(\n",  | 
251 | 259 |     "        os.path.join(TRAIN_DIR, \"checkpoint-1999.meta\"))\n",  | 
 | 
257 | 265 |     "    true_count = 0  # Counts the number of correct predictions.\n",  | 
258 | 266 |     "    steps_per_epoch = data_sets.validation.num_examples\n",  | 
259 | 267 |     "    num_examples = steps_per_epoch * BATCH_SIZE\n",  | 
260 |  | -    "    eval_op = tf.get_collection(\"eval_op\")[0]\n",  | 
 | 268 | +    "    # Retrieve the Ops we 'remembered'.\n",  | 
 | 269 | +    "    logits = tf.get_collection(\"logits\")[0]\n",  | 
261 | 270 |     "    images_placeholder = tf.get_collection(\"images\")[0]\n",  | 
262 | 271 |     "    labels_placeholder = tf.get_collection(\"labels\")[0]\n",  | 
 | 272 | +    "  \n",  | 
 | 273 | +    "    # Add the Op to compare the logits to the labels during evaluation.\n",  | 
 | 274 | +    "    eval_op = mnist_evaluation(logits, labels_placeholder)\n",  | 
 | 275 | +    "\n",  | 
263 | 276 |     "    for step in xrange(steps_per_epoch):\n",  | 
264 | 277 |     "        images_feed, labels_feed = data_sets.validation.next_batch(BATCH_SIZE)\n",  | 
265 | 278 |     "        true_count += sess.run(eval_op,\n",  | 
 | 
0 commit comments