Skip to content

Commit 1cc1b43

Browse files
author
pinard.liu
committed
add a3c code
1 parent 15764a3 commit 1cc1b43

File tree

1 file changed

+217
-0
lines changed

1 file changed

+217
-0
lines changed

reinforcement-learning/a3c.py

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
#######################################################################
2+
# Copyright (C) #
3+
# 2016 - 2019 Pinard Liu([email protected]) #
4+
# https://www.cnblogs.com/pinard #
5+
# Permission given to modify the code as long as you keep this #
6+
# declaration at the top #
7+
#######################################################################
8+
## reference from MorvanZhou's A3C code on Github, minor update:##
9+
##https://github.com/MorvanZhou/Reinforcement-learning-with-tensorflow/blob/master/contents/10_A3C/A3C_discrete_action.py ##
10+
11+
## https://www.cnblogs.com/pinard/p/10334127.html ##
12+
## 强化学习(十五) A3C ##
13+
14+
import threading
15+
import tensorflow as tf
16+
import numpy as np
17+
import gym
18+
import os
19+
import shutil
20+
import matplotlib.pyplot as plt
21+
22+
23+
GAME = 'CartPole-v0'
24+
OUTPUT_GRAPH = True
25+
LOG_DIR = './log'
26+
N_WORKERS = 3
27+
MAX_GLOBAL_EP = 3000
28+
GLOBAL_NET_SCOPE = 'Global_Net'
29+
UPDATE_GLOBAL_ITER = 100
30+
GAMMA = 0.9
31+
ENTROPY_BETA = 0.001
32+
LR_A = 0.001 # learning rate for actor
33+
LR_C = 0.001 # learning rate for critic
34+
GLOBAL_RUNNING_R = []
35+
GLOBAL_EP = 0
36+
STEP = 3000 # Step limitation in an episode
37+
TEST = 10 # The number of experiment test every 100 episode
38+
39+
env = gym.make(GAME)
40+
N_S = env.observation_space.shape[0]
41+
N_A = env.action_space.n
42+
43+
44+
class ACNet(object):
45+
def __init__(self, scope, globalAC=None):
46+
47+
if scope == GLOBAL_NET_SCOPE: # get global network
48+
with tf.variable_scope(scope):
49+
self.s = tf.placeholder(tf.float32, [None, N_S], 'S')
50+
self.a_params, self.c_params = self._build_net(scope)[-2:]
51+
else: # local net, calculate losses
52+
with tf.variable_scope(scope):
53+
self.s = tf.placeholder(tf.float32, [None, N_S], 'S')
54+
self.a_his = tf.placeholder(tf.int32, [None, ], 'A')
55+
self.v_target = tf.placeholder(tf.float32, [None, 1], 'Vtarget')
56+
57+
self.a_prob, self.v, self.a_params, self.c_params = self._build_net(scope)
58+
59+
td = tf.subtract(self.v_target, self.v, name='TD_error')
60+
with tf.name_scope('c_loss'):
61+
self.c_loss = tf.reduce_mean(tf.square(td))
62+
63+
with tf.name_scope('a_loss'):
64+
log_prob = tf.reduce_sum(tf.log(self.a_prob + 1e-5) * tf.one_hot(self.a_his, N_A, dtype=tf.float32), axis=1, keep_dims=True)
65+
exp_v = log_prob * tf.stop_gradient(td)
66+
entropy = -tf.reduce_sum(self.a_prob * tf.log(self.a_prob + 1e-5),
67+
axis=1, keep_dims=True) # encourage exploration
68+
self.exp_v = ENTROPY_BETA * entropy + exp_v
69+
self.a_loss = tf.reduce_mean(-self.exp_v)
70+
71+
with tf.name_scope('local_grad'):
72+
self.a_grads = tf.gradients(self.a_loss, self.a_params)
73+
self.c_grads = tf.gradients(self.c_loss, self.c_params)
74+
75+
with tf.name_scope('sync'):
76+
with tf.name_scope('pull'):
77+
self.pull_a_params_op = [l_p.assign(g_p) for l_p, g_p in zip(self.a_params, globalAC.a_params)]
78+
self.pull_c_params_op = [l_p.assign(g_p) for l_p, g_p in zip(self.c_params, globalAC.c_params)]
79+
with tf.name_scope('push'):
80+
self.update_a_op = OPT_A.apply_gradients(zip(self.a_grads, globalAC.a_params))
81+
self.update_c_op = OPT_C.apply_gradients(zip(self.c_grads, globalAC.c_params))
82+
83+
def _build_net(self, scope):
84+
w_init = tf.random_normal_initializer(0., .1)
85+
with tf.variable_scope('actor'):
86+
l_a = tf.layers.dense(self.s, 200, tf.nn.relu6, kernel_initializer=w_init, name='la')
87+
a_prob = tf.layers.dense(l_a, N_A, tf.nn.softmax, kernel_initializer=w_init, name='ap')
88+
with tf.variable_scope('critic'):
89+
l_c = tf.layers.dense(self.s, 100, tf.nn.relu6, kernel_initializer=w_init, name='lc')
90+
v = tf.layers.dense(l_c, 1, kernel_initializer=w_init, name='v') # state value
91+
a_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=scope + '/actor')
92+
c_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=scope + '/critic')
93+
return a_prob, v, a_params, c_params
94+
95+
def update_global(self, feed_dict): # run by a local
96+
SESS.run([self.update_a_op, self.update_c_op], feed_dict) # local grads applies to global net
97+
98+
def pull_global(self): # run by a local
99+
SESS.run([self.pull_a_params_op, self.pull_c_params_op])
100+
101+
def choose_action(self, s): # run by a local
102+
prob_weights = SESS.run(self.a_prob, feed_dict={self.s: s[np.newaxis, :]})
103+
action = np.random.choice(range(prob_weights.shape[1]),
104+
p=prob_weights.ravel()) # select action w.r.t the actions prob
105+
return action
106+
107+
108+
class Worker(object):
109+
def __init__(self, name, globalAC):
110+
self.env = gym.make(GAME).unwrapped
111+
self.name = name
112+
self.AC = ACNet(name, globalAC)
113+
114+
def work(self):
115+
global GLOBAL_RUNNING_R, GLOBAL_EP
116+
total_step = 1
117+
buffer_s, buffer_a, buffer_r = [], [], []
118+
while not COORD.should_stop() and GLOBAL_EP < MAX_GLOBAL_EP:
119+
s = self.env.reset()
120+
ep_r = 0
121+
while True:
122+
# if self.name == 'W_0':
123+
# self.env.render()
124+
a = self.AC.choose_action(s)
125+
s_, r, done, info = self.env.step(a)
126+
if done: r = -5
127+
ep_r += r
128+
buffer_s.append(s)
129+
buffer_a.append(a)
130+
buffer_r.append(r)
131+
132+
if total_step % UPDATE_GLOBAL_ITER == 0 or done: # update global and assign to local net
133+
if done:
134+
v_s_ = 0 # terminal
135+
else:
136+
v_s_ = SESS.run(self.AC.v, {self.AC.s: s_[np.newaxis, :]})[0, 0]
137+
buffer_v_target = []
138+
for r in buffer_r[::-1]: # reverse buffer r
139+
v_s_ = r + GAMMA * v_s_
140+
buffer_v_target.append(v_s_)
141+
buffer_v_target.reverse()
142+
143+
buffer_s, buffer_a, buffer_v_target = np.vstack(buffer_s), np.array(buffer_a), np.vstack(buffer_v_target)
144+
feed_dict = {
145+
self.AC.s: buffer_s,
146+
self.AC.a_his: buffer_a,
147+
self.AC.v_target: buffer_v_target,
148+
}
149+
self.AC.update_global(feed_dict)
150+
151+
buffer_s, buffer_a, buffer_r = [], [], []
152+
self.AC.pull_global()
153+
154+
s = s_
155+
total_step += 1
156+
if done:
157+
if len(GLOBAL_RUNNING_R) == 0: # record running episode reward
158+
GLOBAL_RUNNING_R.append(ep_r)
159+
else:
160+
GLOBAL_RUNNING_R.append(0.99 * GLOBAL_RUNNING_R[-1] + 0.01 * ep_r)
161+
print(
162+
self.name,
163+
"Ep:", GLOBAL_EP,
164+
"| Ep_r: %i" % GLOBAL_RUNNING_R[-1],
165+
)
166+
GLOBAL_EP += 1
167+
break
168+
169+
if __name__ == "__main__":
170+
SESS = tf.Session()
171+
172+
with tf.device("/cpu:0"):
173+
OPT_A = tf.train.RMSPropOptimizer(LR_A, name='RMSPropA')
174+
OPT_C = tf.train.RMSPropOptimizer(LR_C, name='RMSPropC')
175+
GLOBAL_AC = ACNet(GLOBAL_NET_SCOPE) # we only need its params
176+
workers = []
177+
# Create worker
178+
for i in range(N_WORKERS):
179+
i_name = 'W_%i' % i # worker name
180+
workers.append(Worker(i_name, GLOBAL_AC))
181+
182+
COORD = tf.train.Coordinator()
183+
SESS.run(tf.global_variables_initializer())
184+
185+
if OUTPUT_GRAPH:
186+
if os.path.exists(LOG_DIR):
187+
shutil.rmtree(LOG_DIR)
188+
tf.summary.FileWriter(LOG_DIR, SESS.graph)
189+
190+
worker_threads = []
191+
for worker in workers:
192+
job = lambda: worker.work()
193+
t = threading.Thread(target=job)
194+
t.start()
195+
worker_threads.append(t)
196+
COORD.join(worker_threads)
197+
198+
testWorker = Worker("test", GLOBAL_AC)
199+
testWorker.AC.pull_global()
200+
201+
total_reward = 0
202+
for i in range(TEST):
203+
state = env.reset()
204+
for j in range(STEP):
205+
env.render()
206+
action = testWorker.AC.choose_action(state) # direct action for test
207+
state, reward, done, _ = env.step(action)
208+
total_reward += reward
209+
if done:
210+
break
211+
ave_reward = total_reward / TEST
212+
print('episode: ', GLOBAL_EP, 'Evaluation Average Reward:', ave_reward)
213+
214+
plt.plot(np.arange(len(GLOBAL_RUNNING_R)), GLOBAL_RUNNING_R)
215+
plt.xlabel('step')
216+
plt.ylabel('Total moving reward')
217+
plt.show()

0 commit comments

Comments
 (0)