6868import torch
6969import torch .nn as nn
7070import torch .optim as optim
71- import torch .autograd as autograd
7271import torch .nn .functional as F
72+ from torch .autograd import Variable
7373import torchvision .transforms as T
7474
75+
7576env = gym .make ('CartPole-v0' ).unwrapped
7677
78+ # set up matplotlib
7779is_ipython = 'inline' in matplotlib .get_backend ()
7880if is_ipython :
7981 from IPython import display
8082
8183plt .ion ()
84+
85+ # if gpu is to be used
86+ use_cuda = torch .cuda .is_available ()
87+ FloatTensor = torch .cuda .FloatTensor if use_cuda else torch .FloatTensor
88+ LongTensor = torch .cuda .LongTensor if use_cuda else torch .LongTensor
89+ ByteTensor = torch .cuda .ByteTensor if use_cuda else torch .ByteTensor
90+ Tensor = FloatTensor
91+
92+
8293######################################################################
8394# Replay Memory
8495# -------------
@@ -260,12 +271,12 @@ def get_screen():
260271 screen = np .ascontiguousarray (screen , dtype = np .float32 ) / 255
261272 screen = torch .from_numpy (screen )
262273 # Resize, and add a batch dimension (BCHW)
263- return resize (screen ).unsqueeze (0 )
274+ return resize (screen ).unsqueeze (0 ). type ( Tensor )
264275
265276env .reset ()
266277plt .figure ()
267- plt .imshow (get_screen ().squeeze (0 ).permute (
268- 1 , 2 , 0 ). numpy (), interpolation = 'none' )
278+ plt .imshow (get_screen ().cpu (). squeeze (0 ).permute (1 , 2 , 0 ). numpy (),
279+ interpolation = 'none' )
269280plt .title ('Example extracted screen' )
270281plt .show ()
271282
@@ -300,22 +311,14 @@ def get_screen():
300311EPS_START = 0.9
301312EPS_END = 0.05
302313EPS_DECAY = 200
303- USE_CUDA = torch .cuda .is_available ()
304314
305315model = DQN ()
306- memory = ReplayMemory (10000 )
307- optimizer = optim .RMSprop (model .parameters ())
308316
309- if USE_CUDA :
317+ if use_cuda :
310318 model .cuda ()
311319
312-
313- class Variable (autograd .Variable ):
314-
315- def __init__ (self , data , * args , ** kwargs ):
316- if USE_CUDA :
317- data = data .cuda ()
318- super (Variable , self ).__init__ (data , * args , ** kwargs )
320+ optimizer = optim .RMSprop (model .parameters ())
321+ memory = ReplayMemory (10000 )
319322
320323
321324steps_done = 0
@@ -328,9 +331,10 @@ def select_action(state):
328331 math .exp (- 1. * steps_done / EPS_DECAY )
329332 steps_done += 1
330333 if sample > eps_threshold :
331- return model (Variable (state , volatile = True )).data .max (1 )[1 ].cpu ()
334+ return model (
335+ Variable (state , volatile = True ).type (FloatTensor )).data .max (1 )[1 ]
332336 else :
333- return torch . LongTensor ([[random .randrange (2 )]])
337+ return LongTensor ([[random .randrange (2 )]])
334338
335339
336340episode_durations = []
@@ -339,7 +343,7 @@ def select_action(state):
339343def plot_durations ():
340344 plt .figure (2 )
341345 plt .clf ()
342- durations_t = torch .Tensor (episode_durations )
346+ durations_t = torch .FloatTensor (episode_durations )
343347 plt .title ('Training...' )
344348 plt .xlabel ('Episode' )
345349 plt .ylabel ('Duration' )
@@ -349,6 +353,8 @@ def plot_durations():
349353 means = durations_t .unfold (0 , 100 , 1 ).mean (1 ).view (- 1 )
350354 means = torch .cat ((torch .zeros (99 ), means ))
351355 plt .plot (means .numpy ())
356+
357+ plt .pause (0.001 ) # pause a bit so that plots are updated
352358 if is_ipython :
353359 display .clear_output (wait = True )
354360 display .display (plt .gcf ())
@@ -370,6 +376,7 @@ def plot_durations():
370376
371377last_sync = 0
372378
379+
373380def optimize_model ():
374381 global last_sync
375382 if len (memory ) < BATCH_SIZE :
@@ -380,10 +387,9 @@ def optimize_model():
380387 batch = Transition (* zip (* transitions ))
381388
382389 # Compute a mask of non-final states and concatenate the batch elements
383- non_final_mask = torch .ByteTensor (
384- tuple (map (lambda s : s is not None , batch .next_state )))
385- if USE_CUDA :
386- non_final_mask = non_final_mask .cuda ()
390+ non_final_mask = ByteTensor (tuple (map (lambda s : s is not None ,
391+ batch .next_state )))
392+
387393 # We don't want to backprop through the expected action values and volatile
388394 # will save us on temporarily changing the model parameters'
389395 # requires_grad to False!
@@ -399,7 +405,7 @@ def optimize_model():
399405 state_action_values = model (state_batch ).gather (1 , action_batch )
400406
401407 # Compute V(s_{t+1}) for all next states.
402- next_state_values = Variable (torch .zeros (BATCH_SIZE ))
408+ next_state_values = Variable (torch .zeros (BATCH_SIZE ). type ( Tensor ) )
403409 next_state_values [non_final_mask ] = model (non_final_next_states ).max (1 )[0 ]
404410 # Now, we don't want to mess up the loss with a volatile flag, so let's
405411 # clear it. After this, we'll just end up with a Variable that has
@@ -440,7 +446,7 @@ def optimize_model():
440446 # Select and perform an action
441447 action = select_action (state )
442448 _ , reward , done , _ = env .step (action [0 , 0 ])
443- reward = torch . Tensor ([reward ])
449+ reward = Tensor ([reward ])
444450
445451 # Observe new state
446452 last_screen = current_screen
@@ -463,6 +469,8 @@ def optimize_model():
463469 plot_durations ()
464470 break
465471
472+ print ('Complete' )
473+ env .render (close = True )
466474env .close ()
467475plt .ioff ()
468476plt .show ()
0 commit comments