1+ #######################################################################
2+ # Copyright (C) #
3+ # 2016 - 2019 Pinard Liu([email protected] ) # 4+ # https://www.cnblogs.com/pinard #
5+ # Permission given to modify the code as long as you keep this #
6+ # declaration at the top #
7+ #######################################################################
8+ ## reference from MorvanZhou's A3C code on Github, minor update:##
9+ ##https://github.com/MorvanZhou/Reinforcement-learning-with-tensorflow/blob/master/contents/10_A3C/A3C_discrete_action.py ##
10+
11+ ## https://www.cnblogs.com/pinard/p/10334127.html ##
12+ ## 强化学习(十五) A3C ##
13+
14+ import threading
15+ import tensorflow as tf
16+ import numpy as np
17+ import gym
18+ import os
19+ import shutil
20+ import matplotlib .pyplot as plt
21+
22+
23+ GAME = 'CartPole-v0'
24+ OUTPUT_GRAPH = True
25+ LOG_DIR = './log'
26+ N_WORKERS = 3
27+ MAX_GLOBAL_EP = 3000
28+ GLOBAL_NET_SCOPE = 'Global_Net'
29+ UPDATE_GLOBAL_ITER = 100
30+ GAMMA = 0.9
31+ ENTROPY_BETA = 0.001
32+ LR_A = 0.001 # learning rate for actor
33+ LR_C = 0.001 # learning rate for critic
34+ GLOBAL_RUNNING_R = []
35+ GLOBAL_EP = 0
36+ STEP = 3000 # Step limitation in an episode
37+ TEST = 10 # The number of experiment test every 100 episode
38+
39+ env = gym .make (GAME )
40+ N_S = env .observation_space .shape [0 ]
41+ N_A = env .action_space .n
42+
43+
44+ class ACNet (object ):
45+ def __init__ (self , scope , globalAC = None ):
46+
47+ if scope == GLOBAL_NET_SCOPE : # get global network
48+ with tf .variable_scope (scope ):
49+ self .s = tf .placeholder (tf .float32 , [None , N_S ], 'S' )
50+ self .a_params , self .c_params = self ._build_net (scope )[- 2 :]
51+ else : # local net, calculate losses
52+ with tf .variable_scope (scope ):
53+ self .s = tf .placeholder (tf .float32 , [None , N_S ], 'S' )
54+ self .a_his = tf .placeholder (tf .int32 , [None , ], 'A' )
55+ self .v_target = tf .placeholder (tf .float32 , [None , 1 ], 'Vtarget' )
56+
57+ self .a_prob , self .v , self .a_params , self .c_params = self ._build_net (scope )
58+
59+ td = tf .subtract (self .v_target , self .v , name = 'TD_error' )
60+ with tf .name_scope ('c_loss' ):
61+ self .c_loss = tf .reduce_mean (tf .square (td ))
62+
63+ with tf .name_scope ('a_loss' ):
64+ log_prob = tf .reduce_sum (tf .log (self .a_prob + 1e-5 ) * tf .one_hot (self .a_his , N_A , dtype = tf .float32 ), axis = 1 , keep_dims = True )
65+ exp_v = log_prob * tf .stop_gradient (td )
66+ entropy = - tf .reduce_sum (self .a_prob * tf .log (self .a_prob + 1e-5 ),
67+ axis = 1 , keep_dims = True ) # encourage exploration
68+ self .exp_v = ENTROPY_BETA * entropy + exp_v
69+ self .a_loss = tf .reduce_mean (- self .exp_v )
70+
71+ with tf .name_scope ('local_grad' ):
72+ self .a_grads = tf .gradients (self .a_loss , self .a_params )
73+ self .c_grads = tf .gradients (self .c_loss , self .c_params )
74+
75+ with tf .name_scope ('sync' ):
76+ with tf .name_scope ('pull' ):
77+ self .pull_a_params_op = [l_p .assign (g_p ) for l_p , g_p in zip (self .a_params , globalAC .a_params )]
78+ self .pull_c_params_op = [l_p .assign (g_p ) for l_p , g_p in zip (self .c_params , globalAC .c_params )]
79+ with tf .name_scope ('push' ):
80+ self .update_a_op = OPT_A .apply_gradients (zip (self .a_grads , globalAC .a_params ))
81+ self .update_c_op = OPT_C .apply_gradients (zip (self .c_grads , globalAC .c_params ))
82+
83+ def _build_net (self , scope ):
84+ w_init = tf .random_normal_initializer (0. , .1 )
85+ with tf .variable_scope ('actor' ):
86+ l_a = tf .layers .dense (self .s , 200 , tf .nn .relu6 , kernel_initializer = w_init , name = 'la' )
87+ a_prob = tf .layers .dense (l_a , N_A , tf .nn .softmax , kernel_initializer = w_init , name = 'ap' )
88+ with tf .variable_scope ('critic' ):
89+ l_c = tf .layers .dense (self .s , 100 , tf .nn .relu6 , kernel_initializer = w_init , name = 'lc' )
90+ v = tf .layers .dense (l_c , 1 , kernel_initializer = w_init , name = 'v' ) # state value
91+ a_params = tf .get_collection (tf .GraphKeys .TRAINABLE_VARIABLES , scope = scope + '/actor' )
92+ c_params = tf .get_collection (tf .GraphKeys .TRAINABLE_VARIABLES , scope = scope + '/critic' )
93+ return a_prob , v , a_params , c_params
94+
95+ def update_global (self , feed_dict ): # run by a local
96+ SESS .run ([self .update_a_op , self .update_c_op ], feed_dict ) # local grads applies to global net
97+
98+ def pull_global (self ): # run by a local
99+ SESS .run ([self .pull_a_params_op , self .pull_c_params_op ])
100+
101+ def choose_action (self , s ): # run by a local
102+ prob_weights = SESS .run (self .a_prob , feed_dict = {self .s : s [np .newaxis , :]})
103+ action = np .random .choice (range (prob_weights .shape [1 ]),
104+ p = prob_weights .ravel ()) # select action w.r.t the actions prob
105+ return action
106+
107+
108+ class Worker (object ):
109+ def __init__ (self , name , globalAC ):
110+ self .env = gym .make (GAME ).unwrapped
111+ self .name = name
112+ self .AC = ACNet (name , globalAC )
113+
114+ def work (self ):
115+ global GLOBAL_RUNNING_R , GLOBAL_EP
116+ total_step = 1
117+ buffer_s , buffer_a , buffer_r = [], [], []
118+ while not COORD .should_stop () and GLOBAL_EP < MAX_GLOBAL_EP :
119+ s = self .env .reset ()
120+ ep_r = 0
121+ while True :
122+ # if self.name == 'W_0':
123+ # self.env.render()
124+ a = self .AC .choose_action (s )
125+ s_ , r , done , info = self .env .step (a )
126+ if done : r = - 5
127+ ep_r += r
128+ buffer_s .append (s )
129+ buffer_a .append (a )
130+ buffer_r .append (r )
131+
132+ if total_step % UPDATE_GLOBAL_ITER == 0 or done : # update global and assign to local net
133+ if done :
134+ v_s_ = 0 # terminal
135+ else :
136+ v_s_ = SESS .run (self .AC .v , {self .AC .s : s_ [np .newaxis , :]})[0 , 0 ]
137+ buffer_v_target = []
138+ for r in buffer_r [::- 1 ]: # reverse buffer r
139+ v_s_ = r + GAMMA * v_s_
140+ buffer_v_target .append (v_s_ )
141+ buffer_v_target .reverse ()
142+
143+ buffer_s , buffer_a , buffer_v_target = np .vstack (buffer_s ), np .array (buffer_a ), np .vstack (buffer_v_target )
144+ feed_dict = {
145+ self .AC .s : buffer_s ,
146+ self .AC .a_his : buffer_a ,
147+ self .AC .v_target : buffer_v_target ,
148+ }
149+ self .AC .update_global (feed_dict )
150+
151+ buffer_s , buffer_a , buffer_r = [], [], []
152+ self .AC .pull_global ()
153+
154+ s = s_
155+ total_step += 1
156+ if done :
157+ if len (GLOBAL_RUNNING_R ) == 0 : # record running episode reward
158+ GLOBAL_RUNNING_R .append (ep_r )
159+ else :
160+ GLOBAL_RUNNING_R .append (0.99 * GLOBAL_RUNNING_R [- 1 ] + 0.01 * ep_r )
161+ print (
162+ self .name ,
163+ "Ep:" , GLOBAL_EP ,
164+ "| Ep_r: %i" % GLOBAL_RUNNING_R [- 1 ],
165+ )
166+ GLOBAL_EP += 1
167+ break
168+
169+ if __name__ == "__main__" :
170+ SESS = tf .Session ()
171+
172+ with tf .device ("/cpu:0" ):
173+ OPT_A = tf .train .RMSPropOptimizer (LR_A , name = 'RMSPropA' )
174+ OPT_C = tf .train .RMSPropOptimizer (LR_C , name = 'RMSPropC' )
175+ GLOBAL_AC = ACNet (GLOBAL_NET_SCOPE ) # we only need its params
176+ workers = []
177+ # Create worker
178+ for i in range (N_WORKERS ):
179+ i_name = 'W_%i' % i # worker name
180+ workers .append (Worker (i_name , GLOBAL_AC ))
181+
182+ COORD = tf .train .Coordinator ()
183+ SESS .run (tf .global_variables_initializer ())
184+
185+ if OUTPUT_GRAPH :
186+ if os .path .exists (LOG_DIR ):
187+ shutil .rmtree (LOG_DIR )
188+ tf .summary .FileWriter (LOG_DIR , SESS .graph )
189+
190+ worker_threads = []
191+ for worker in workers :
192+ job = lambda : worker .work ()
193+ t = threading .Thread (target = job )
194+ t .start ()
195+ worker_threads .append (t )
196+ COORD .join (worker_threads )
197+
198+ testWorker = Worker ("test" , GLOBAL_AC )
199+ testWorker .AC .pull_global ()
200+
201+ total_reward = 0
202+ for i in range (TEST ):
203+ state = env .reset ()
204+ for j in range (STEP ):
205+ env .render ()
206+ action = testWorker .AC .choose_action (state ) # direct action for test
207+ state , reward , done , _ = env .step (action )
208+ total_reward += reward
209+ if done :
210+ break
211+ ave_reward = total_reward / TEST
212+ print ('episode: ' , GLOBAL_EP , 'Evaluation Average Reward:' , ave_reward )
213+
214+ plt .plot (np .arange (len (GLOBAL_RUNNING_R )), GLOBAL_RUNNING_R )
215+ plt .xlabel ('step' )
216+ plt .ylabel ('Total moving reward' )
217+ plt .show ()
0 commit comments