Skip to content

Commit 00ab9b3

Browse files
committed
Logistic regression model
1 parent e20e518 commit 00ab9b3

File tree

3 files changed

+293
-0
lines changed

3 files changed

+293
-0
lines changed

.DS_Store

6 KB
Binary file not shown.
Lines changed: 288 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,288 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {
6+
"slideshow": {
7+
"slide_type": "slide"
8+
}
9+
},
10+
"source": [
11+
"# Logistic regression with Tensorflow\n",
12+
"\n"
13+
]
14+
},
15+
{
16+
"cell_type": "code",
17+
"execution_count": 4,
18+
"metadata": {
19+
"collapsed": true
20+
},
21+
"outputs": [],
22+
"source": [
23+
"import numpy as np\n",
24+
"import tensorflow as tf\n",
25+
"s = tf.InteractiveSession()"
26+
]
27+
},
28+
{
29+
"cell_type": "markdown",
30+
"metadata": {},
31+
"source": [
32+
"# Logistic regression\n",
33+
"\n",
34+
"Plan:\n",
35+
"* Use a shared variable for weights\n",
36+
"* Use a matrix placeholder for `X`\n",
37+
" \n",
38+
"train on a two-class MNIST dataset\n",
39+
]
40+
},
41+
{
42+
"cell_type": "code",
43+
"execution_count": 31,
44+
"metadata": {},
45+
"outputs": [
46+
{
47+
"name": "stdout",
48+
"output_type": "stream",
49+
"text": [
50+
"y [shape - (360,)]: [0 1 0 1 0 1 0 0 1 1]\n",
51+
"X [shape - (360, 64)]:\n"
52+
]
53+
}
54+
],
55+
"source": [
56+
"from sklearn.datasets import load_digits\n",
57+
"mnist = load_digits(2)\n",
58+
"\n",
59+
"X, y = mnist.data, mnist.target\n",
60+
"\n",
61+
"print(\"y [shape - %s]:\" % (str(y.shape)), y[:10])\n",
62+
"print(\"X [shape - %s]:\" % (str(X.shape)))"
63+
]
64+
},
65+
{
66+
"cell_type": "code",
67+
"execution_count": 32,
68+
"metadata": {},
69+
"outputs": [
70+
{
71+
"name": "stdout",
72+
"output_type": "stream",
73+
"text": [
74+
"X:\n",
75+
" [[ 0. 0. 5. 13. 9. 1. 0. 0. 0. 0.]\n",
76+
" [ 0. 0. 0. 12. 13. 5. 0. 0. 0. 0.]\n",
77+
" [ 0. 0. 1. 9. 15. 11. 0. 0. 0. 0.]]\n",
78+
"y:\n",
79+
" [0 1 0 1 0 1 0 0 1 1]\n"
80+
]
81+
},
82+
{
83+
"data": {
84+
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPgAAAD8CAYAAABaQGkdAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAACxNJREFUeJzt3fuLXPUZx/HPp5vErRqTYqxKNjShaEAqNZqmhIjQBEus\nokJL3YCWSmGhoCiGihZL239A0h+KIFErmBpsVBDrBVsVK6QxF1M1txKDJRvURLwHTLLm6Q87gShp\n92zmnO+ZeXy/YHEvw36fQd45Z2ZnztcRIQA5fa3tAQA0h8CBxAgcSIzAgcQIHEiMwIHECBxIjMCB\nxAgcSGxKE790mk+JQZ3WxK9u1dissvfpnHPeL7bWvoMzi601OHqk2FpxZKzYWiV9poM6HIc80e0a\nCXxQp+n7XtbEr27Vez9eXHS9X61cW2yt32y+ptha59/2drG1xt55t9haJW2Iv1e6HafoQGIEDiRG\n4EBiBA4kRuBAYgQOJEbgQGIEDiRWKXDby23vsr3b9h1NDwWgHhMGbntA0h8lXSHpAkkrbF/Q9GAA\nulflCL5I0u6I2BMRhyWtlVTudY0ATlqVwGdL2nvc16Od7wHocbW92cT2iKQRSRrUqXX9WgBdqHIE\n3ydpznFfD3W+9wURcW9ELIyIhVN1Sl3zAehClcA3SjrP9jzb0yQNS3qi2bEA1GHCU/SIGLN9k6Rn\nJQ1Iuj8itjU+GYCuVXoMHhFPSXqq4VkA1IxXsgGJETiQGIEDiRE4kBiBA4kROJAYgQOJETiQWCM7\nm2RVcqcRSRqe/kGxtVbN/LTYWn/d8myxtS753S+LrSVJs+5dX3S9iXAEBxIjcCAxAgcSI3AgMQIH\nEiNwIDECBxIjcCAxAgcSq7Kzyf2299t+o8RAAOpT5Qj+J0nLG54DQAMmDDwiXpL0foFZANSMx+BA\nYmxdBCRW2xGcrYuA3sMpOpBYlT+TPSxpvaT5tkdt/6L5sQDUocreZCtKDAKgfpyiA4kROJAYgQOJ\nETiQGIEDiRE4kBiBA4kROJBY329dNLb0kmJrDU/fWmwtSbpi+XCxtWa8trPYWj99eVmxtd5f8Hmx\ntSRpVtHVJsYRHEiMwIHECBxIjMCBxAgcSIzAgcQIHEiMwIHECBxIjMCBxKpcdHGO7Rdsb7e9zfYt\nJQYD0L0qr0Ufk7QyIrbYni5ps+3nImJ7w7MB6FKVvcnejogtnc8/kbRD0uymBwPQvUm9m8z2XEkL\nJG04wc/YugjoMZWfZLN9uqRHJd0aER9/+edsXQT0nkqB256q8bjXRMRjzY4EoC5VnkW3pPsk7YiI\nu5sfCUBdqhzBl0i6QdJS21s7Hz9qeC4ANaiyN9nLklxgFgA145VsQGIEDiRG4EBiBA4kRuBAYgQO\nJEbgQGIEDiTW93uTfXZmubtw1/4Li60lSUcL7hdW0sbXv932CF8ZHMGBxAgcSIzAgcQIHEiMwIHE\nCBxIjMCBxAgcSIzAgcSqXHRx0PYrtv/V2bro9yUGA9C9Kq/zPCRpaUR82rl88su2n46IfzY8G4Au\nVbnoYkj6tPPl1M5HNDkUgHpU3fhgwPZWSfslPRcRJ9y6yPYm25uO6FDdcwI4CZUCj4jPI+IiSUOS\nFtn+zgluw9ZFQI+Z1LPoEfGhpBckLW9mHAB1qvIs+lm2Z3Y+/7qkyyXlfKMykEyVZ9HPlfSg7QGN\n/4PwSEQ82exYAOpQ5Vn01zS+JziAPsMr2YDECBxIjMCBxAgcSIzAgcQIHEiMwIHECBxIrP+3LvpG\nuX+j1qxfXGwtSTpfrxRdr5QpMw4XW2vso2nF1upFHMGBxAgcSIzAgcQIHEiMwIHECBxIjMCBxAgc\nSIzAgcQqB965NvqrtrkeG9AnJnMEv0XSjqYGAVC/qjubDEm6UtLqZscBUKeqR/BVkm6XdLTBWQDU\nrMrGB1dJ2h8Rmye4HXuTAT2myhF8iaSrbb8laa2kpbYf+vKN2JsM6D0TBh4Rd0bEUETMlTQs6fmI\nuL7xyQB0jb+DA4lN6oouEfGipBcbmQRA7TiCA4kROJAYgQOJETiQGIEDiRE4kBiBA4kROJBY329d\nNPhBuTe4fe/CN4utJUkfFVxryjlnF1vrugv+7/uWavXI05cWW6sXcQQHEiNwIDECBxIjcCAxAgcS\nI3AgMQIHEiNwIDECBxKr9Eq2zhVVP5H0uaSxiFjY5FAA6jGZl6r+ICLea2wSALXjFB1IrGrgIelv\ntjfbHmlyIAD1qXqKfmlE7LP9TUnP2d4ZES8df4NO+COSNKhTax4TwMmodASPiH2d/+6X9LikRSe4\nDVsXAT2myuaDp9mefuxzST+U9EbTgwHoXpVT9LMlPW772O3/HBHPNDoVgFpMGHhE7JH03QKzAKgZ\nfyYDEiNwIDECBxIjcCAxAgcSI3AgMQIHEiNwILG+37rojF3lNvj57dCTxdaSpJ+N3FZsranXHii2\nVknz7lzf9git4ggOJEbgQGIEDiRG4EBiBA4kRuBAYgQOJEbgQGIEDiRWKXDbM22vs73T9g7bi5se\nDED3qr5U9Q+SnomIn9ieJnHhc6AfTBi47RmSLpP0c0mKiMOSDjc7FoA6VDlFnyfpgKQHbL9qe3Xn\n+ugAelyVwKdIuljSPRGxQNJBSXd8+Ua2R2xvsr3piA7VPCaAk1El8FFJoxGxofP1Oo0H/wVsXQT0\nngkDj4h3JO21Pb/zrWWStjc6FYBaVH0W/WZJazrPoO+RdGNzIwGoS6XAI2KrpIUNzwKgZrySDUiM\nwIHECBxIjMCBxAgcSIzAgcQIHEiMwIHECBxIrO/3Jjv62s5ia113z8pia0nSXSsfLrbWqjeXFVtr\n40UDxdb6quMIDiRG4EBiBA4kRuBAYgQOJEbgQGIEDiRG4EBiBA4kNmHgtufb3nrcx8e2by0xHIDu\nTPhS1YjYJekiSbI9IGmfpMcbngtADSZ7ir5M0psR8Z8mhgFQr8m+2WRY0gnfAWF7RNKIJA2y+SjQ\nEyofwTubHlwt6S8n+jlbFwG9ZzKn6FdI2hIR7zY1DIB6TSbwFfofp+cAelOlwDv7gV8u6bFmxwFQ\np6p7kx2UdGbDswCoGa9kAxIjcCAxAgcSI3AgMQIHEiNwIDECBxIjcCAxR0T9v9Q+IGmybymdJem9\n2ofpDVnvG/erPd+KiLMmulEjgZ8M25siYmHbczQh633jfvU+TtGBxAgcSKyXAr+37QEalPW+cb96\nXM88BgdQv146ggOoWU8Ebnu57V22d9u+o+156mB7ju0XbG+3vc32LW3PVCfbA7Zftf1k27PUyfZM\n2+ts77S9w/bitmfqRuun6J1rrf9b41eMGZW0UdKKiNje6mBdsn2upHMjYovt6ZI2S7q23+/XMbZv\nk7RQ0hkRcVXb89TF9oOS/hERqzsXGj01Ij5se66T1QtH8EWSdkfEnog4LGmtpGtanqlrEfF2RGzp\nfP6JpB2SZrc7VT1sD0m6UtLqtmepk+0Zki6TdJ8kRcThfo5b6o3AZ0vae9zXo0oSwjG250paIGlD\nu5PUZpWk2yUdbXuQms2TdEDSA52HH6s71yPsW70QeGq2T5f0qKRbI+Ljtufplu2rJO2PiM1tz9KA\nKZIulnRPRCyQdFBSXz8n1AuB75M057ivhzrf63u2p2o87jURkeWKtEskXW37LY0/nFpq+6F2R6rN\nqKTRiDh2prVO48H3rV4IfKOk82zP6zypMSzpiZZn6ppta/yx3I6IuLvteeoSEXdGxFBEzNX4/6vn\nI+L6lseqRUS8I2mv7fmdby2T1NdPik52b7LaRcSY7ZskPStpQNL9EbGt5bHqsETSDZJet721871f\nR8RTLc6Eid0saU3nYLNH0o0tz9OV1v9MBqA5vXCKDqAhBA4kRuBAYgQOJEbgQGIEDiRG4EBiBA4k\n9l+8Q5/pEyhkXAAAAABJRU5ErkJggg==\n",
85+
"text/plain": [
86+
"<matplotlib.figure.Figure at 0x7f8ab6f9f7f0>"
87+
]
88+
},
89+
"metadata": {},
90+
"output_type": "display_data"
91+
}
92+
],
93+
"source": [
94+
"print('X:\\n',X[:3,:10])\n",
95+
"print('y:\\n',y[:10])\n",
96+
"plt.imshow(X[0].reshape([8,8]));"
97+
]
98+
},
99+
{
100+
"cell_type": "markdown",
101+
"metadata": {},
102+
"source": [
103+
"It's your turn now!\n",
104+
"Just a small reminder of the relevant math:\n",
105+
"\n",
106+
"$$\n",
107+
"P(y=1|X) = \\sigma(X \\cdot W + b)\n",
108+
"$$\n",
109+
"$$\n",
110+
"\\text{loss} = -\\log\\left(P\\left(y_\\text{predicted} = 1\\right)\\right)\\cdot y_\\text{true} - \\log\\left(1 - P\\left(y_\\text{predicted} = 1\\right)\\right)\\cdot\\left(1 - y_\\text{true}\\right)\n",
111+
"$$\n",
112+
"\n",
113+
"$\\sigma(x)$ is available via `tf.nn.sigmoid` and matrix multiplication via `tf.matmul`"
114+
]
115+
},
116+
{
117+
"cell_type": "code",
118+
"execution_count": 33,
119+
"metadata": {
120+
"collapsed": true
121+
},
122+
"outputs": [],
123+
"source": [
124+
"from sklearn.model_selection import train_test_split\n",
125+
"X_train, X_test, y_train, y_test = train_test_split(\n",
126+
" X, y, random_state=42)"
127+
]
128+
},
129+
{
130+
"cell_type": "markdown",
131+
"metadata": {},
132+
"source": [
133+
"__Your code goes here.__ For the training and testing scaffolding to work, please stick to the names in comments."
134+
]
135+
},
136+
{
137+
"cell_type": "code",
138+
"execution_count": 62,
139+
"metadata": {
140+
"collapsed": true
141+
},
142+
"outputs": [],
143+
"source": [
144+
"# Model parameters - weights and bias\n",
145+
"weights = tf.get_variable(shape=(X.shape[1], 1), dtype=tf.float64,name=\"w\")\n",
146+
"b=tf.Variable(0,dtype=tf.float64,name='bias')"
147+
]
148+
},
149+
{
150+
"cell_type": "code",
151+
"execution_count": 63,
152+
"metadata": {
153+
"collapsed": true
154+
},
155+
"outputs": [],
156+
"source": [
157+
"# Placeholders for the input data\n",
158+
"input_X = tf.placeholder('float64', shape=(None, X.shape[1]))\n",
159+
"input_y = tf.placeholder('float64')"
160+
]
161+
},
162+
{
163+
"cell_type": "code",
164+
"execution_count": 66,
165+
"metadata": {
166+
"collapsed": true
167+
},
168+
"outputs": [],
169+
"source": [
170+
"# The model code\n",
171+
"\n",
172+
"# Compute a vector of predictions, resulting shape should be [input_X.shape[0],]\n",
173+
"# This is 1D, if you have extra dimensions, you can get rid of them with tf.squeeze .\n",
174+
"# Don't forget the sigmoid.\n",
175+
"predicted_y = tf.squeeze(tf.nn.sigmoid(tf.matmul(input_X,weights)+b))\n",
176+
"\n",
177+
"# Loss. Should be a scalar number - average loss over all the objects\n",
178+
"# tf.reduce_mean is your friend here\n",
179+
"loss = tf.reduce_mean(-input_y * tf.log(predicted_y)-(1-input_y) * tf.log(1-predicted_y))\n",
180+
" #<logistic loss (scalar, mean over sample)>\n",
181+
"\n",
182+
"# See above for an example. tf.train.*Optimizer\n",
183+
"optimizer = tf.train.MomentumOptimizer(0.01, 0.5).minimize(loss)"
184+
]
185+
},
186+
{
187+
"cell_type": "markdown",
188+
"metadata": {},
189+
"source": [
190+
"A test to help with the debugging"
191+
]
192+
},
193+
{
194+
"cell_type": "code",
195+
"execution_count": 67,
196+
"metadata": {
197+
"collapsed": true
198+
},
199+
"outputs": [],
200+
"source": [
201+
"validation_weights = 1e-3 * np.fromiter(map(lambda x:\n",
202+
" s.run(weird_psychotic_function, {my_scalar:x, my_vector:[1, 0.1, 2]}),\n",
203+
" 0.15 * np.arange(1, X.shape[1] + 1)),\n",
204+
" count=X.shape[1], dtype=np.float32)[:, np.newaxis]\n",
205+
"# Compute predictions for given weights and bias\n",
206+
"prediction_validation = s.run(\n",
207+
" predicted_y, {\n",
208+
" input_X: X,\n",
209+
" weights: validation_weights,\n",
210+
" b: 1e-1})\n",
211+
"\n",
212+
"# Load the reference values for the predictions\n",
213+
"validation_true_values = np.loadtxt(\"validation_predictons.txt\")\n",
214+
"\n",
215+
"assert prediction_validation.shape == (X.shape[0],),\\\n",
216+
" \"Predictions must be a 1D array with length equal to the number \" \\\n",
217+
" \"of examples in input_X\"\n",
218+
"assert np.allclose(validation_true_values, prediction_validation)\n",
219+
"loss_validation = s.run(\n",
220+
" loss, {\n",
221+
" input_X: X[:100],\n",
222+
" input_y: y[-100:],\n",
223+
" weights: validation_weights+1.21e-3,\n",
224+
" b: -1e-1})\n",
225+
"assert np.allclose(loss_validation, 0.728689)"
226+
]
227+
},
228+
{
229+
"cell_type": "code",
230+
"execution_count": 68,
231+
"metadata": {},
232+
"outputs": [
233+
{
234+
"name": "stdout",
235+
"output_type": "stream",
236+
"text": [
237+
"loss at iter 0:0.4043\n",
238+
"train auc: 0.948232323232\n",
239+
"test auc: 0.980731225296\n",
240+
"loss at iter 1:1.2870\n",
241+
"train auc: 0.973429951691\n",
242+
"test auc: 0.991600790514\n",
243+
"loss at iter 2:0.1875\n",
244+
"train auc: 0.993302591129\n",
245+
"test auc: 1.0\n",
246+
"loss at iter 3:0.0827\n",
247+
"train auc: 0.997419850681\n",
248+
"test auc: 1.0\n",
249+
"loss at iter 4:0.0921\n",
250+
"train auc: 0.998407992973\n",
251+
"test auc: 1.0\n"
252+
]
253+
}
254+
],
255+
"source": [
256+
"from sklearn.metrics import roc_auc_score\n",
257+
"s.run(tf.global_variables_initializer())\n",
258+
"for i in range(5):\n",
259+
" s.run(optimizer, {input_X: X_train, input_y: y_train})\n",
260+
" loss_i = s.run(loss, {input_X: X_train, input_y: y_train})\n",
261+
" print(\"loss at iter %i:%.4f\" % (i, loss_i))\n",
262+
" print(\"train auc:\", roc_auc_score(y_train, s.run(predicted_y, {input_X:X_train})))\n",
263+
" print(\"test auc:\", roc_auc_score(y_test, s.run(predicted_y, {input_X:X_test})))"
264+
]
265+
}
266+
],
267+
"metadata": {
268+
"kernelspec": {
269+
"display_name": "Python 3",
270+
"language": "python",
271+
"name": "python3"
272+
},
273+
"language_info": {
274+
"codemirror_mode": {
275+
"name": "ipython",
276+
"version": 3
277+
},
278+
"file_extension": ".py",
279+
"mimetype": "text/x-python",
280+
"name": "python",
281+
"nbconvert_exporter": "python",
282+
"pygments_lexer": "ipython3",
283+
"version": "3.6.4"
284+
}
285+
},
286+
"nbformat": 4,
287+
"nbformat_minor": 2
288+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
numpy==1.13.3
2+
pandas==0.21.0
3+
matplotlib==2.1.0
4+
scikit-learn==0.19.1
5+
tensorflow >= 1.0

0 commit comments

Comments
 (0)