| 
16 | 16 |   },  | 
17 | 17 |   {  | 
18 | 18 |    "cell_type": "code",  | 
19 |  | -   "execution_count": 2,  | 
 | 19 | +   "execution_count": 1,  | 
20 | 20 |    "metadata": {},  | 
21 | 21 |    "outputs": [],  | 
22 | 22 |    "source": [  | 
23 | 23 |     "from sklearn.datasets import load_diabetes\n",  | 
24 | 24 |     "from sklearn.linear_model import Ridge\n",  | 
25 | 25 |     "from sklearn.metrics import mean_squared_error\n",  | 
26 | 26 |     "from sklearn.model_selection import train_test_split\n",  | 
27 |  | -    "import joblib"  | 
 | 27 | +    "import joblib\n",  | 
 | 28 | +    "import pandas as pd"  | 
28 | 29 |    ]  | 
29 | 30 |   },  | 
30 | 31 |   {  | 
 | 
36 | 37 |   },  | 
37 | 38 |   {  | 
38 | 39 |    "cell_type": "code",  | 
39 |  | -   "execution_count": 3,  | 
 | 40 | +   "execution_count": 6,  | 
40 | 41 |    "metadata": {},  | 
41 | 42 |    "outputs": [],  | 
42 | 43 |    "source": [  | 
43 |  | -    "X, y = load_diabetes(return_X_y=True)"  | 
 | 44 | +    "sample_data = load_diabetes()\n",  | 
 | 45 | +    "\n",  | 
 | 46 | +    "df = pd.DataFrame(\n",  | 
 | 47 | +    "    data=sample_data.data,\n",  | 
 | 48 | +    "    columns=sample_data.feature_names)\n",  | 
 | 49 | +    "df['Y'] = sample_data.target"  | 
44 | 50 |    ]  | 
45 | 51 |   },  | 
46 | 52 |   {  | 
47 | 53 |    "cell_type": "code",  | 
48 |  | -   "execution_count": 4,  | 
 | 54 | +   "execution_count": 7,  | 
49 | 55 |    "metadata": {},  | 
50 | 56 |    "outputs": [  | 
51 | 57 |     {  | 
 | 
57 | 63 |     }  | 
58 | 64 |    ],  | 
59 | 65 |    "source": [  | 
60 |  | -    "print(X.shape)"  | 
61 |  | -   ]  | 
62 |  | -  },  | 
63 |  | -  {  | 
64 |  | -   "cell_type": "code",  | 
65 |  | -   "execution_count": 5,  | 
66 |  | -   "metadata": {},  | 
67 |  | -   "outputs": [  | 
68 |  | -    {  | 
69 |  | -     "name": "stdout",  | 
70 |  | -     "output_type": "stream",  | 
71 |  | -     "text": [  | 
72 |  | -      "(442,)\n"  | 
73 |  | -     ]  | 
74 |  | -    }  | 
75 |  | -   ],  | 
76 |  | -   "source": [  | 
77 |  | -    "print(y.shape)"  | 
 | 66 | +    "print(df.shape)"  | 
78 | 67 |    ]  | 
79 | 68 |   },  | 
80 | 69 |   {  | 
81 | 70 |    "cell_type": "code",  | 
82 |  | -   "execution_count": 8,  | 
 | 71 | +   "execution_count": 11,  | 
83 | 72 |    "metadata": {},  | 
84 | 73 |    "outputs": [  | 
85 | 74 |     {  | 
 | 
103 | 92 |        "  <thead>\n",  | 
104 | 93 |        "    <tr style=\"text-align: right;\">\n",  | 
105 | 94 |        "      <th></th>\n",  | 
106 |  | -       "      <th>0</th>\n",  | 
107 |  | -       "      <th>1</th>\n",  | 
108 |  | -       "      <th>2</th>\n",  | 
109 |  | -       "      <th>3</th>\n",  | 
110 |  | -       "      <th>4</th>\n",  | 
111 |  | -       "      <th>5</th>\n",  | 
112 |  | -       "      <th>6</th>\n",  | 
113 |  | -       "      <th>7</th>\n",  | 
114 |  | -       "      <th>8</th>\n",  | 
115 |  | -       "      <th>9</th>\n",  | 
 | 95 | +       "      <th>age</th>\n",  | 
 | 96 | +       "      <th>sex</th>\n",  | 
 | 97 | +       "      <th>bmi</th>\n",  | 
 | 98 | +       "      <th>bp</th>\n",  | 
 | 99 | +       "      <th>s1</th>\n",  | 
 | 100 | +       "      <th>s2</th>\n",  | 
 | 101 | +       "      <th>s3</th>\n",  | 
 | 102 | +       "      <th>s4</th>\n",  | 
 | 103 | +       "      <th>s5</th>\n",  | 
 | 104 | +       "      <th>s6</th>\n",  | 
 | 105 | +       "      <th>Y</th>\n",  | 
116 | 106 |        "    </tr>\n",  | 
117 | 107 |        "  </thead>\n",  | 
118 | 108 |        "  <tbody>\n",  | 
 | 
128 | 118 |        "      <td>4.420000e+02</td>\n",  | 
129 | 119 |        "      <td>4.420000e+02</td>\n",  | 
130 | 120 |        "      <td>4.420000e+02</td>\n",  | 
 | 121 | +       "      <td>442.000000</td>\n",  | 
131 | 122 |        "    </tr>\n",  | 
132 | 123 |        "    <tr>\n",  | 
133 | 124 |        "      <td>mean</td>\n",  | 
134 |  | -       "      <td>-3.639623e-16</td>\n",  | 
135 |  | -       "      <td>1.309912e-16</td>\n",  | 
136 |  | -       "      <td>-8.013951e-16</td>\n",  | 
137 |  | -       "      <td>1.289818e-16</td>\n",  | 
138 |  | -       "      <td>-9.042540e-17</td>\n",  | 
139 |  | -       "      <td>1.301121e-16</td>\n",  | 
140 |  | -       "      <td>-4.563971e-16</td>\n",  | 
141 |  | -       "      <td>3.863174e-16</td>\n",  | 
142 |  | -       "      <td>-3.848103e-16</td>\n",  | 
143 |  | -       "      <td>-3.398488e-16</td>\n",  | 
 | 125 | +       "      <td>-3.634285e-16</td>\n",  | 
 | 126 | +       "      <td>1.308343e-16</td>\n",  | 
 | 127 | +       "      <td>-8.045349e-16</td>\n",  | 
 | 128 | +       "      <td>1.281655e-16</td>\n",  | 
 | 129 | +       "      <td>-8.835316e-17</td>\n",  | 
 | 130 | +       "      <td>1.327024e-16</td>\n",  | 
 | 131 | +       "      <td>-4.574646e-16</td>\n",  | 
 | 132 | +       "      <td>3.777301e-16</td>\n",  | 
 | 133 | +       "      <td>-3.830854e-16</td>\n",  | 
 | 134 | +       "      <td>-3.412882e-16</td>\n",  | 
 | 135 | +       "      <td>152.133484</td>\n",  | 
144 | 136 |        "    </tr>\n",  | 
145 | 137 |        "    <tr>\n",  | 
146 | 138 |        "      <td>std</td>\n",  | 
 | 
154 | 146 |        "      <td>4.761905e-02</td>\n",  | 
155 | 147 |        "      <td>4.761905e-02</td>\n",  | 
156 | 148 |        "      <td>4.761905e-02</td>\n",  | 
 | 149 | +       "      <td>77.093005</td>\n",  | 
157 | 150 |        "    </tr>\n",  | 
158 | 151 |        "    <tr>\n",  | 
159 | 152 |        "      <td>min</td>\n",  | 
 | 
167 | 160 |        "      <td>-7.639450e-02</td>\n",  | 
168 | 161 |        "      <td>-1.260974e-01</td>\n",  | 
169 | 162 |        "      <td>-1.377672e-01</td>\n",  | 
 | 163 | +       "      <td>25.000000</td>\n",  | 
170 | 164 |        "    </tr>\n",  | 
171 | 165 |        "    <tr>\n",  | 
172 | 166 |        "      <td>25%</td>\n",  | 
 | 
180 | 174 |        "      <td>-3.949338e-02</td>\n",  | 
181 | 175 |        "      <td>-3.324879e-02</td>\n",  | 
182 | 176 |        "      <td>-3.317903e-02</td>\n",  | 
 | 177 | +       "      <td>87.000000</td>\n",  | 
183 | 178 |        "    </tr>\n",  | 
184 | 179 |        "    <tr>\n",  | 
185 | 180 |        "      <td>50%</td>\n",  | 
 | 
193 | 188 |        "      <td>-2.592262e-03</td>\n",  | 
194 | 189 |        "      <td>-1.947634e-03</td>\n",  | 
195 | 190 |        "      <td>-1.077698e-03</td>\n",  | 
 | 191 | +       "      <td>140.500000</td>\n",  | 
196 | 192 |        "    </tr>\n",  | 
197 | 193 |        "    <tr>\n",  | 
198 | 194 |        "      <td>75%</td>\n",  | 
 | 
206 | 202 |        "      <td>3.430886e-02</td>\n",  | 
207 | 203 |        "      <td>3.243323e-02</td>\n",  | 
208 | 204 |        "      <td>2.791705e-02</td>\n",  | 
 | 205 | +       "      <td>211.500000</td>\n",  | 
209 | 206 |        "    </tr>\n",  | 
210 | 207 |        "    <tr>\n",  | 
211 | 208 |        "      <td>max</td>\n",  | 
 | 
219 | 216 |        "      <td>1.852344e-01</td>\n",  | 
220 | 217 |        "      <td>1.335990e-01</td>\n",  | 
221 | 218 |        "      <td>1.356118e-01</td>\n",  | 
 | 219 | +       "      <td>346.000000</td>\n",  | 
222 | 220 |        "    </tr>\n",  | 
223 | 221 |        "  </tbody>\n",  | 
224 | 222 |        "</table>\n",  | 
225 | 223 |        "</div>"  | 
226 | 224 |       ],  | 
227 | 225 |       "text/plain": [  | 
228 |  | -       "                  0             1             2             3             4  \\\n",  | 
 | 226 | +       "                age           sex           bmi            bp            s1  \\\n",  | 
229 | 227 |        "count  4.420000e+02  4.420000e+02  4.420000e+02  4.420000e+02  4.420000e+02   \n",  | 
230 |  | -       "mean  -3.639623e-16  1.309912e-16 -8.013951e-16  1.289818e-16 -9.042540e-17   \n",  | 
 | 228 | +       "mean  -3.634285e-16  1.308343e-16 -8.045349e-16  1.281655e-16 -8.835316e-17   \n",  | 
231 | 229 |        "std    4.761905e-02  4.761905e-02  4.761905e-02  4.761905e-02  4.761905e-02   \n",  | 
232 | 230 |        "min   -1.072256e-01 -4.464164e-02 -9.027530e-02 -1.123996e-01 -1.267807e-01   \n",  | 
233 | 231 |        "25%   -3.729927e-02 -4.464164e-02 -3.422907e-02 -3.665645e-02 -3.424784e-02   \n",  | 
234 | 232 |        "50%    5.383060e-03 -4.464164e-02 -7.283766e-03 -5.670611e-03 -4.320866e-03   \n",  | 
235 | 233 |        "75%    3.807591e-02  5.068012e-02  3.124802e-02  3.564384e-02  2.835801e-02   \n",  | 
236 | 234 |        "max    1.107267e-01  5.068012e-02  1.705552e-01  1.320442e-01  1.539137e-01   \n",  | 
237 | 235 |        "\n",  | 
238 |  | -       "                  5             6             7             8             9  \n",  | 
239 |  | -       "count  4.420000e+02  4.420000e+02  4.420000e+02  4.420000e+02  4.420000e+02  \n",  | 
240 |  | -       "mean   1.301121e-16 -4.563971e-16  3.863174e-16 -3.848103e-16 -3.398488e-16  \n",  | 
241 |  | -       "std    4.761905e-02  4.761905e-02  4.761905e-02  4.761905e-02  4.761905e-02  \n",  | 
242 |  | -       "min   -1.156131e-01 -1.023071e-01 -7.639450e-02 -1.260974e-01 -1.377672e-01  \n",  | 
243 |  | -       "25%   -3.035840e-02 -3.511716e-02 -3.949338e-02 -3.324879e-02 -3.317903e-02  \n",  | 
244 |  | -       "50%   -3.819065e-03 -6.584468e-03 -2.592262e-03 -1.947634e-03 -1.077698e-03  \n",  | 
245 |  | -       "75%    2.984439e-02  2.931150e-02  3.430886e-02  3.243323e-02  2.791705e-02  \n",  | 
246 |  | -       "max    1.987880e-01  1.811791e-01  1.852344e-01  1.335990e-01  1.356118e-01  "  | 
 | 236 | +       "                 s2            s3            s4            s5            s6  \\\n",  | 
 | 237 | +       "count  4.420000e+02  4.420000e+02  4.420000e+02  4.420000e+02  4.420000e+02   \n",  | 
 | 238 | +       "mean   1.327024e-16 -4.574646e-16  3.777301e-16 -3.830854e-16 -3.412882e-16   \n",  | 
 | 239 | +       "std    4.761905e-02  4.761905e-02  4.761905e-02  4.761905e-02  4.761905e-02   \n",  | 
 | 240 | +       "min   -1.156131e-01 -1.023071e-01 -7.639450e-02 -1.260974e-01 -1.377672e-01   \n",  | 
 | 241 | +       "25%   -3.035840e-02 -3.511716e-02 -3.949338e-02 -3.324879e-02 -3.317903e-02   \n",  | 
 | 242 | +       "50%   -3.819065e-03 -6.584468e-03 -2.592262e-03 -1.947634e-03 -1.077698e-03   \n",  | 
 | 243 | +       "75%    2.984439e-02  2.931150e-02  3.430886e-02  3.243323e-02  2.791705e-02   \n",  | 
 | 244 | +       "max    1.987880e-01  1.811791e-01  1.852344e-01  1.335990e-01  1.356118e-01   \n",  | 
 | 245 | +       "\n",  | 
 | 246 | +       "                Y  \n",  | 
 | 247 | +       "count  442.000000  \n",  | 
 | 248 | +       "mean   152.133484  \n",  | 
 | 249 | +       "std     77.093005  \n",  | 
 | 250 | +       "min     25.000000  \n",  | 
 | 251 | +       "25%     87.000000  \n",  | 
 | 252 | +       "50%    140.500000  \n",  | 
 | 253 | +       "75%    211.500000  \n",  | 
 | 254 | +       "max    346.000000  "  | 
247 | 255 |       ]  | 
248 | 256 |      },  | 
249 |  | -     "execution_count": 8,  | 
 | 257 | +     "execution_count": 11,  | 
250 | 258 |      "metadata": {},  | 
251 | 259 |      "output_type": "execute_result"  | 
252 | 260 |     }  | 
253 | 261 |    ],  | 
254 | 262 |    "source": [  | 
255 |  | -    "import pandas as pd\n",  | 
256 |  | -    "features = pd.DataFrame(X)\n",  | 
257 |  | -    "features.describe()"  | 
 | 263 | +    "# All data in a single dataframe\n",  | 
 | 264 | +    "df.describe()"  | 
258 | 265 |    ]  | 
259 | 266 |   },  | 
260 | 267 |   {  | 
 | 
266 | 273 |   },  | 
267 | 274 |   {  | 
268 | 275 |    "cell_type": "code",  | 
269 |  | -   "execution_count": 3,  | 
 | 276 | +   "execution_count": 12,  | 
270 | 277 |    "metadata": {},  | 
271 | 278 |    "outputs": [],  | 
272 | 279 |    "source": [  | 
273 |  | -    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)\n",  | 
 | 280 | +    "X = df.drop('Y', axis=1).values\n",  | 
 | 281 | +    "y = df['Y'].values\n",  | 
 | 282 | +    "\n",  | 
 | 283 | +    "X_train, X_test, y_train, y_test = train_test_split(\n",  | 
 | 284 | +    "    X, y, test_size=0.2, random_state=0)\n",  | 
274 | 285 |     "data = {\"train\": {\"X\": X_train, \"y\": y_train},\n",  | 
275 | 286 |     "        \"test\": {\"X\": X_test, \"y\": y_test}}"  | 
276 | 287 |    ]  | 
 | 
284 | 295 |   },  | 
285 | 296 |   {  | 
286 | 297 |    "cell_type": "code",  | 
287 |  | -   "execution_count": 4,  | 
 | 298 | +   "execution_count": 16,  | 
288 | 299 |    "metadata": {},  | 
289 | 300 |    "outputs": [  | 
290 | 301 |     {  | 
 | 
294 | 305 |        "      normalize=False, random_state=None, solver='auto', tol=0.001)"  | 
295 | 306 |       ]  | 
296 | 307 |      },  | 
297 |  | -     "execution_count": 4,  | 
 | 308 | +     "execution_count": 16,  | 
298 | 309 |      "metadata": {},  | 
299 | 310 |      "output_type": "execute_result"  | 
300 | 311 |     }  | 
301 | 312 |    ],  | 
302 | 313 |    "source": [  | 
303 |  | -    "alpha = 0.5\n",  | 
 | 314 | +    "# experiment parameters\n",  | 
 | 315 | +    "args = {\n",  | 
 | 316 | +    "    \"alpha\": 0.5\n",  | 
 | 317 | +    "}\n",  | 
304 | 318 |     "\n",  | 
305 |  | -    "reg = Ridge(alpha=alpha)\n",  | 
306 |  | -    "reg.fit(data[\"train\"][\"X\"], data[\"train\"][\"y\"])"  | 
 | 319 | +    "reg_model = Ridge(**args)\n",  | 
 | 320 | +    "reg_model.fit(data[\"train\"][\"X\"], data[\"train\"][\"y\"])"  | 
307 | 321 |    ]  | 
308 | 322 |   },  | 
309 | 323 |   {  | 
 | 
315 | 329 |   },  | 
316 | 330 |   {  | 
317 | 331 |    "cell_type": "code",  | 
318 |  | -   "execution_count": 6,  | 
 | 332 | +   "execution_count": 18,  | 
319 | 333 |    "metadata": {},  | 
320 | 334 |    "outputs": [  | 
321 | 335 |     {  | 
322 | 336 |      "name": "stdout",  | 
323 | 337 |      "output_type": "stream",  | 
324 | 338 |      "text": [  | 
325 |  | -      "mse:  3298.9096058070622\n"  | 
 | 339 | +      "{'mse': 3298.9096058070622}\n"  | 
326 | 340 |      ]  | 
327 | 341 |     }  | 
328 | 342 |    ],  | 
329 | 343 |    "source": [  | 
330 |  | -    "preds = reg.predict(data[\"test\"][\"X\"])\n",  | 
331 |  | -    "print(\"mse: \", mean_squared_error(preds, y_test))"  | 
 | 344 | +    "preds = reg_model.predict(data[\"test\"][\"X\"])\n",  | 
 | 345 | +    "mse = mean_squared_error(preds, y_test)\n",  | 
 | 346 | +    "metrics = {\"mse\": mse}\n",  | 
 | 347 | +    "print(metrics)"  | 
332 | 348 |    ]  | 
333 | 349 |   },  | 
334 | 350 |   {  | 
 | 
363 | 379 |  ],  | 
364 | 380 |  "metadata": {  | 
365 | 381 |   "kernelspec": {  | 
366 |  | -   "display_name": "Python (storedna)",  | 
 | 382 | +   "display_name": "Python 3",  | 
367 | 383 |    "language": "python",  | 
368 |  | -   "name": "storedna"  | 
 | 384 | +   "name": "python3"  | 
369 | 385 |   },  | 
370 | 386 |   "language_info": {  | 
371 | 387 |    "codemirror_mode": {  | 
 | 
377 | 393 |    "name": "python",  | 
378 | 394 |    "nbconvert_exporter": "python",  | 
379 | 395 |    "pygments_lexer": "ipython3",  | 
380 |  | -   "version": "3.6.9"  | 
 | 396 | +   "version": "3.7.4"  | 
381 | 397 |   }  | 
382 | 398 |  },  | 
383 | 399 |  "nbformat": 4,  | 
 | 
0 commit comments