|
11 | 11 | "%matplotlib inline\n", |
12 | 12 | "\n", |
13 | 13 | "import gym\n", |
| 14 | + "from gym.wrappers import Monitor\n", |
14 | 15 | "import itertools\n", |
15 | 16 | "import numpy as np\n", |
16 | 17 | "import os\n", |
|
67 | 68 | " self.output = tf.image.rgb_to_grayscale(self.input_state)\n", |
68 | 69 | " self.output = tf.image.crop_to_bounding_box(self.output, 34, 0, 160, 160)\n", |
69 | 70 | " self.output = tf.image.resize_images(\n", |
70 | | - " self.output, 84, 84, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)\n", |
| 71 | + " self.output, [84, 84], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)\n", |
71 | 72 | " self.output = tf.squeeze(self.output)\n", |
72 | 73 | "\n", |
73 | 74 | " def process(self, sess, state):\n", |
|
107 | 108 | " summary_dir = os.path.join(summaries_dir, \"summaries_{}\".format(scope))\n", |
108 | 109 | " if not os.path.exists(summary_dir):\n", |
109 | 110 | " os.makedirs(summary_dir)\n", |
110 | | - " self.summary_writer = tf.train.SummaryWriter(summary_dir)\n", |
| 111 | + " self.summary_writer = tf.summary.FileWriter(summary_dir)\n", |
111 | 112 | "\n", |
112 | 113 | " def _build_model(self):\n", |
113 | 114 | " \"\"\"\n", |
|
151 | 152 | " self.train_op = self.optimizer.minimize(self.loss, global_step=tf.contrib.framework.get_global_step())\n", |
152 | 153 | "\n", |
153 | 154 | " # Summaries for Tensorboard\n", |
154 | | - " self.summaries = tf.merge_summary([\n", |
155 | | - " tf.scalar_summary(\"loss\", self.loss),\n", |
156 | | - " tf.histogram_summary(\"loss_hist\", self.losses),\n", |
157 | | - " tf.histogram_summary(\"q_values_hist\", self.predictions),\n", |
158 | | - " tf.scalar_summary(\"max_q_value\", tf.reduce_max(self.predictions))\n", |
| 155 | + " self.summaries = tf.summary.merge([\n", |
| 156 | + " tf.summary.scalar(\"loss\", self.loss),\n", |
| 157 | + " tf.summary.histogram(\"loss_hist\", self.losses),\n", |
| 158 | + " tf.summary.histogram(\"q_values_hist\", self.predictions),\n", |
| 159 | + " tf.summary.scalar(\"max_q_value\", tf.reduce_max(self.predictions))\n", |
159 | 160 | " ])\n", |
160 | 161 | "\n", |
161 | | - "\n", |
162 | 162 | " def predict(self, sess, s):\n", |
163 | 163 | " \"\"\"\n", |
164 | 164 | " Predicts action values.\n", |
|
212 | 212 | "sp = StateProcessor()\n", |
213 | 213 | "\n", |
214 | 214 | "with tf.Session() as sess:\n", |
215 | | - " sess.run(tf.initialize_all_variables())\n", |
| 215 | + " sess.run(tf.global_variables_initializer())\n", |
216 | 216 | " \n", |
217 | 217 | " # Example observation batch\n", |
218 | 218 | " observation = env.reset()\n", |
|
357 | 357 | " checkpoint_dir = os.path.join(experiment_dir, \"checkpoints\")\n", |
358 | 358 | " checkpoint_path = os.path.join(checkpoint_dir, \"model\")\n", |
359 | 359 | " monitor_path = os.path.join(experiment_dir, \"monitor\")\n", |
360 | | - "\n", |
| 360 | + " \n", |
361 | 361 | " if not os.path.exists(checkpoint_dir):\n", |
362 | 362 | " os.makedirs(checkpoint_dir)\n", |
363 | 363 | " if not os.path.exists(monitor_path):\n", |
|
400 | 400 | " else:\n", |
401 | 401 | " state = next_state\n", |
402 | 402 | "\n", |
| 403 | + "\n", |
403 | 404 | " # Record videos\n", |
404 | | - " env.monitor.start(monitor_path,\n", |
405 | | - " resume=True,\n", |
406 | | - " video_callable=lambda count: count % record_video_every == 0)\n", |
| 405 | + " # Add env Monitor wrapper\n", |
| 406 | + " env = Monitor(env, directory=monitor_path, video_callable=lambda count: count % record_video_every == 0, resume=True)\n", |
407 | 407 | "\n", |
408 | 408 | " for i_episode in range(num_episodes):\n", |
409 | 409 | "\n", |
|
484 | 484 | " episode_lengths=stats.episode_lengths[:i_episode+1],\n", |
485 | 485 | " episode_rewards=stats.episode_rewards[:i_episode+1])\n", |
486 | 486 | "\n", |
487 | | - " env.monitor.close()\n", |
488 | 487 | " return stats" |
489 | 488 | ] |
490 | 489 | }, |
|
513 | 512 | "\n", |
514 | 513 | "# Run it!\n", |
515 | 514 | "with tf.Session() as sess:\n", |
516 | | - " sess.run(tf.initialize_all_variables())\n", |
| 515 | + " sess.run(tf.global_variables_initializer())\n", |
517 | 516 | " for t, stats in deep_q_learning(sess,\n", |
518 | 517 | " env,\n", |
519 | 518 | " q_estimator=q_estimator,\n", |
|
550 | 549 | "name": "python", |
551 | 550 | "nbconvert_exporter": "python", |
552 | 551 | "pygments_lexer": "ipython3", |
553 | | - "version": "3.5.1" |
| 552 | + "version": "3.4.3" |
554 | 553 | } |
555 | 554 | }, |
556 | 555 | "nbformat": 4, |
|
0 commit comments