Skip to content

Commit 9b4cf97

Browse files
committed
RL Step by step Q-Learning sample
1 parent a14202f commit 9b4cf97

File tree

1 file changed

+281
-0
lines changed

1 file changed

+281
-0
lines changed
Lines changed: 281 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,281 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"## Q-Learning Step By Step Example\n",
8+
"\n",
9+
"A simple example of Q learning in a step by step fashion using a simple 2x2 gridworld type problem\n",
10+
"\n",
11+
"State 0 | State 1\n",
12+
"--------|--------\n",
13+
"State 2 | State 3\n",
14+
"\n",
15+
"State 0 = Start<br />\n",
16+
"State 1 = Safe<br />\n",
17+
"State 2 = Hole<br />\n",
18+
"State 3 = Goal<br />\n",
19+
"\n",
20+
"For each state we can move up, down, left, right or stay put - not excluding invalid moves at edges.\n",
21+
"\n",
22+
"Each hole gives a reward of -10, reaching the goal gives +10, all other states give a reward of -1.\n",
23+
"\n",
24+
"So the optimal path is 0-1-3."
25+
]
26+
},
27+
{
28+
"cell_type": "code",
29+
"execution_count": 12,
30+
"metadata": {},
31+
"outputs": [
32+
{
33+
"name": "stdout",
34+
"output_type": "stream",
35+
"text": [
36+
"[[ 0. -9.2 0. 0. 0. ]\n",
37+
" [ 0. 0. 0. 0. 0. ]\n",
38+
" [ 0. 0. 0. 0. 0. ]\n",
39+
" [ 0. 0. 0. 0. 0. ]]\n",
40+
"[[ 0. -9.2 0. 0. 0. ]\n",
41+
" [ 0. 0. 0. 0. 0. ]\n",
42+
" [ 0. 0. 0. 0. -0.2]\n",
43+
" [ 0. 0. 0. 0. 0. ]]\n",
44+
"goal state reached\n",
45+
"[[ 0. -9.2 0. -0.2 0. ]\n",
46+
" [ 0. 0. 0. 0. 0. ]\n",
47+
" [ 0. 0. 0. 0. -0.2]\n",
48+
" [ 0. 0. 0. 0. 0. ]]\n",
49+
"[[ 0. -9.2 0. -0.2 0. ]\n",
50+
" [ 0. 0. 0. 0. -0.2]\n",
51+
" [ 0. 0. 0. 0. -0.2]\n",
52+
" [ 0. 0. 0. 0. 0. ]]\n",
53+
"[[ 0. -9.2 0. -0.2 0. ]\n",
54+
" [ 0. 0. 0. 0. -0.2]\n",
55+
" [ 0. 0. 0. 10.8 -0.2]\n",
56+
" [ 0. 0. 0. 0. 0. ]]\n",
57+
"goal state reached\n",
58+
"[[ 0. -9.2 0. -0.2 0. ]\n",
59+
" [ 0. 0. 0. 0. -0.2]\n",
60+
" [ 0. 0. 0. 10.8 -0.2]\n",
61+
" [ 0. 0. 0. 0. 0. ]]\n",
62+
"[[ 0. -9.2 0. -0.2 0. ]\n",
63+
" [ 0. 10.8 0. 0. -0.2]\n",
64+
" [ 0. 0. 0. 10.8 -0.2]\n",
65+
" [ 0. 0. 0. 0. 0. ]]\n",
66+
"goal state reached\n",
67+
"[[ 0. -9.2 0. 10.6 0. ]\n",
68+
" [ 0. 10.8 0. 0. -0.2]\n",
69+
" [ 0. 0. 0. 10.8 -0.2]\n",
70+
" [ 0. 0. 0. 0. 0. ]]\n",
71+
"[[ 0. -9.2 0. 10.6 0. ]\n",
72+
" [ 0. 10.8 0. 0. -0.2]\n",
73+
" [ 0. 0. 0. 10.8 -0.2]\n",
74+
" [ 0. 0. 0. 0. 0. ]]\n",
75+
"goal state reached\n",
76+
"[[ 0. -9.2 0. 10.6 10.6]\n",
77+
" [ 0. 10.8 0. 0. -0.2]\n",
78+
" [ 0. 0. 0. 10.8 -0.2]\n",
79+
" [ 0. 0. 0. 0. 0. ]]\n",
80+
"[[ 0. -9.2 0. 10.6 10.6]\n",
81+
" [ 0. 10.8 0. 0. -0.2]\n",
82+
" [ 0. 0. 0. 10.8 -0.2]\n",
83+
" [ 0. 0. 0. 0. 0. ]]\n",
84+
"goal state reached\n",
85+
"[[ 0. -9.2 0. 10.6 10.6]\n",
86+
" [ 0. 10.8 0. 0. -0.2]\n",
87+
" [ 0. 0. 0. 10.8 -0.2]\n",
88+
" [ 0. 0. 0. 0. 0. ]]\n",
89+
"[[ 0. -9.2 0. 10.6 10.6]\n",
90+
" [ 0. 10.8 0. 0. 10.6]\n",
91+
" [ 0. 0. 0. 10.8 -0.2]\n",
92+
" [ 0. 0. 0. 0. 0. ]]\n",
93+
"[[ 0. -9.2 0. 10.6 10.6]\n",
94+
" [ 0. 10.8 0. 0. 10.6]\n",
95+
" [ 0. 0. 0. 10.8 -0.2]\n",
96+
" [ 0. 0. 0. 0. 0. ]]\n",
97+
"goal state reached\n",
98+
"[[ 0. 1.6 0. 10.6 10.6]\n",
99+
" [ 0. 10.8 0. 0. 10.6]\n",
100+
" [ 0. 0. 0. 10.8 -0.2]\n",
101+
" [ 0. 0. 0. 0. 0. ]]\n",
102+
"[[ 0. 1.6 0. 10.6 10.6]\n",
103+
" [ 0. 10.8 0. 0. 10.6]\n",
104+
" [ 0. 0. 0. 10.8 -0.2]\n",
105+
" [ 0. 0. 0. 0. 0. ]]\n",
106+
"goal state reached\n",
107+
"[[ 0. 1.6 0. 10.6 10.6]\n",
108+
" [ 0. 10.8 0. 0. 10.6]\n",
109+
" [ 0. 0. 0. 10.8 -0.2]\n",
110+
" [ 0. 0. 0. 0. 0. ]]\n",
111+
"[[ 0. 1.6 0. 10.6 10.6]\n",
112+
" [ 0. 10.8 0. 0. 10.6]\n",
113+
" [10.4 0. 0. 10.8 -0.2]\n",
114+
" [ 0. 0. 0. 0. 0. ]]\n",
115+
"[[ 0. 1.6 0. 10.6 10.6]\n",
116+
" [ 0. 10.8 0. 0. 10.6]\n",
117+
" [10.4 0. 0. 10.8 -0.2]\n",
118+
" [ 0. 0. 0. 0. 0. ]]\n",
119+
"[[ 0. 1.6 0. 10.6 10.6]\n",
120+
" [ 0. 10.8 0. 0. 10.6]\n",
121+
" [10.4 0. 0. 10.8 -0.2]\n",
122+
" [ 0. 0. 0. 0. 0. ]]\n",
123+
"[[ 0. 1.6 0. 10.6 10.6]\n",
124+
" [ 0. 10.8 0. 0. 10.6]\n",
125+
" [10.4 0. 0. 10.8 -0.2]\n",
126+
" [ 0. 0. 0. 0. 0. ]]\n",
127+
"[[ 0. 1.6 0. 10.6 10.6]\n",
128+
" [ 0. 10.8 0. 0. 10.6]\n",
129+
" [10.4 0. 0. 10.8 -0.2]\n",
130+
" [ 0. 0. 0. 0. 0. ]]\n",
131+
"[[ 0. 1.6 0. 10.6 10.6]\n",
132+
" [ 0. 10.8 0. 0. 10.6]\n",
133+
" [10.4 0. 0. 10.8 -0.2]\n",
134+
" [ 0. 0. 0. 0. 0. ]]\n",
135+
"goal state reached\n",
136+
"[[ 0. 1.6 0. 10.6 10.6]\n",
137+
" [ 0. 10.8 0. 0. 10.6]\n",
138+
" [10.4 0. 0. 10.8 -0.2]\n",
139+
" [ 0. 0. 0. 0. 0. ]]\n",
140+
"[[ 0. 1.6 0. 10.6 10.6]\n",
141+
" [ 0. 10.8 10.4 0. 10.6]\n",
142+
" [10.4 0. 0. 10.8 -0.2]\n",
143+
" [ 0. 0. 0. 0. 0. ]]\n",
144+
"[[ 0. 1.6 0. 10.6 10.6]\n",
145+
" [ 0. 10.8 10.4 0. 10.6]\n",
146+
" [10.4 0. 0. 10.8 -0.2]\n",
147+
" [ 0. 0. 0. 0. 0. ]]\n",
148+
"[[ 0. 1.6 0. 10.6 10.6]\n",
149+
" [ 0. 10.8 10.4 0. 10.6]\n",
150+
" [10.4 0. 0. 10.8 -0.2]\n",
151+
" [ 0. 0. 0. 0. 0. ]]\n",
152+
"goal state reached\n",
153+
"[[ 0. 1.6 0. 10.6 10.6]\n",
154+
" [ 0. 10.8 10.4 0. 10.6]\n",
155+
" [10.4 0. 0. 10.8 -0.2]\n",
156+
" [ 0. 0. 0. 0. 0. ]]\n",
157+
"[[ 0. 1.6 0. 10.6 10.6]\n",
158+
" [ 0. 10.8 10.4 0. 10.6]\n",
159+
" [10.4 0. 0. 10.8 -0.2]\n",
160+
" [ 0. 0. 0. 0. 0. ]]\n",
161+
"goal state reached\n"
162+
]
163+
}
164+
],
165+
"source": [
166+
"import numpy as np\n",
167+
"import random\n",
168+
"import matplotlib.pyplot as plt\n",
169+
"\n",
170+
"gamma = 0.8\n",
171+
"\n",
172+
"# each matrix below has states as rows, columns in order (U, D, L, R, N) unless otherwise stated\n",
173+
"\n",
174+
"# rewards for each state / action. 0 represents no such transition possible\n",
175+
"rewards = np.array([[0, -10, 0, -1, -1],\n",
176+
" [0, 10, -1, 0, -1],\n",
177+
" [-1, 0, 0, 10, -1],\n",
178+
" [-1, 0, -10, 0, 0]])\n",
179+
"\n",
180+
"q_matrix = np.zeros((4,5))\n",
181+
"\n",
182+
"# valid actions for each state encoded as 0=up,1=down, 2=left, 3?right, 4=no action\n",
183+
"valid_actions = np.array([[1, 3, 4],\n",
184+
" [1, 2, 4],\n",
185+
" [0, 3, 4],\n",
186+
" [0, 2, 4]])\n",
187+
"\n",
188+
"# what states we move to for each state / action. -1 represents invalid transaction\n",
189+
"transition_matrix = np.array([[-1, 2, -1, 1, 1 ],\n",
190+
" [-1, 3, 0, -1, 2 ],\n",
191+
" [0, -1, -1, 3, 3 ],\n",
192+
" [1, -1, 2, -1, -1]])\n",
193+
"\n",
194+
"\n",
195+
"for i in range(100): # 10 episodes\n",
196+
" current_state = 0\n",
197+
" while current_state != 3:\n",
198+
" # chose a random action - could use epsilon-greedy here\n",
199+
" action = random.choice(valid_actions[current_state])\n",
200+
"\n",
201+
" # record next state and reward (r, s')\n",
202+
" next_state = transition_matrix[current_state][action]\n",
203+
" reward = rewards[current_state][action]\n",
204+
"\n",
205+
" # get possible rewards for all valid actions\n",
206+
" future_rewards = []\n",
207+
" for action_next in valid_actions[next_state]:\n",
208+
" future_rewards.append(q_matrix[next_state][action_next])\n",
209+
"\n",
210+
" # q update\n",
211+
" q_state = reward + gamma + max(future_rewards)\n",
212+
" q_matrix[current_state][action] = q_state\n",
213+
" print(q_matrix)\n",
214+
"\n",
215+
" current_state = next_state\n",
216+
" if current_state == 3:\n",
217+
" print('goal state reached')"
218+
]
219+
},
220+
{
221+
"cell_type": "markdown",
222+
"metadata": {},
223+
"source": [
224+
"If this works then we would expect to:\n",
225+
"\n",
226+
"1. go right (q value for row 1, column 4 to be highest)\n",
227+
"2. go down (q value for row 2, column 2 to be highest) "
228+
]
229+
},
230+
{
231+
"cell_type": "code",
232+
"execution_count": 13,
233+
"metadata": {},
234+
"outputs": [
235+
{
236+
"name": "stdout",
237+
"output_type": "stream",
238+
"text": [
239+
"Final q-matrix\n",
240+
"[[ 0. 1.6 0. 10.6 10.6]\n",
241+
" [ 0. 10.8 10.4 0. 10.6]\n",
242+
" [10.4 0. 0. 10.8 -0.2]\n",
243+
" [ 0. 0. 0. 0. 0. ]]\n"
244+
]
245+
}
246+
],
247+
"source": [
248+
"print(\"Final q-matrix\")\n",
249+
"print(q_matrix)"
250+
]
251+
},
252+
{
253+
"cell_type": "code",
254+
"execution_count": null,
255+
"metadata": {},
256+
"outputs": [],
257+
"source": []
258+
}
259+
],
260+
"metadata": {
261+
"kernelspec": {
262+
"display_name": "Python 3",
263+
"language": "python",
264+
"name": "python3"
265+
},
266+
"language_info": {
267+
"codemirror_mode": {
268+
"name": "ipython",
269+
"version": 3
270+
},
271+
"file_extension": ".py",
272+
"mimetype": "text/x-python",
273+
"name": "python",
274+
"nbconvert_exporter": "python",
275+
"pygments_lexer": "ipython3",
276+
"version": "3.6.6"
277+
}
278+
},
279+
"nbformat": 4,
280+
"nbformat_minor": 2
281+
}

0 commit comments

Comments
 (0)