|
| 1 | +{ |
| 2 | + "cells": [ |
| 3 | + { |
| 4 | + "cell_type": "markdown", |
| 5 | + "metadata": {}, |
| 6 | + "source": [ |
| 7 | + "A goal of supervised learning is to build a model that performs well on new data. If you have new data, you could see how your model performs on it. The problem is that you may not have new data, but you can simulate this experience with a train test split. In this video, I'll show you how train test split works in Scikit-Learn." |
| 8 | + ] |
| 9 | + }, |
| 10 | + { |
| 11 | + "cell_type": "markdown", |
| 12 | + "metadata": {}, |
| 13 | + "source": [ |
| 14 | + "## What is `train_test_split`" |
| 15 | + ] |
| 16 | + }, |
| 17 | + { |
| 18 | + "cell_type": "markdown", |
| 19 | + "metadata": {}, |
| 20 | + "source": [ |
| 21 | + "1. Split the dataset into two pieces: a **training set** and a **testing set**. Typically, about 75% of the data goes to your training set and 25% goes to your test set. \n", |
| 22 | + "2. Train the model on the **training set**.\n", |
| 23 | + "3. Test the model on the **testing set** and evaluate the performance \n", |
| 24 | + "\n" |
| 25 | + ] |
| 26 | + }, |
| 27 | + { |
| 28 | + "cell_type": "markdown", |
| 29 | + "metadata": {}, |
| 30 | + "source": [ |
| 31 | + "## Import Libraries" |
| 32 | + ] |
| 33 | + }, |
| 34 | + { |
| 35 | + "cell_type": "code", |
| 36 | + "execution_count": 1, |
| 37 | + "metadata": {}, |
| 38 | + "outputs": [], |
| 39 | + "source": [ |
| 40 | + "%matplotlib inline\n", |
| 41 | + "\n", |
| 42 | + "import pandas as pd\n", |
| 43 | + "import matplotlib.pyplot as plt\n", |
| 44 | + "\n", |
| 45 | + "from sklearn.model_selection import train_test_split\n", |
| 46 | + "\n", |
| 47 | + "from sklearn.linear_model import LinearRegression" |
| 48 | + ] |
| 49 | + }, |
| 50 | + { |
| 51 | + "cell_type": "markdown", |
| 52 | + "metadata": {}, |
| 53 | + "source": [ |
| 54 | + "## Load the Dataset\n", |
| 55 | + "The code below loads and displays the Boston dataset." |
| 56 | + ] |
| 57 | + }, |
| 58 | + { |
| 59 | + "cell_type": "code", |
| 60 | + "execution_count": 2, |
| 61 | + "metadata": {}, |
| 62 | + "outputs": [], |
| 63 | + "source": [ |
| 64 | + "df = pd.read_csv(\"https://raw.githubusercontent.com/mGalarnyk/Tutorial_Data/master/Boston_Housing/bostonHousing.csv\")" |
| 65 | + ] |
| 66 | + }, |
| 67 | + { |
| 68 | + "cell_type": "code", |
| 69 | + "execution_count": 3, |
| 70 | + "metadata": {}, |
| 71 | + "outputs": [ |
| 72 | + { |
| 73 | + "data": { |
| 74 | + "text/html": [ |
| 75 | + "<div>\n", |
| 76 | + "<style scoped>\n", |
| 77 | + " .dataframe tbody tr th:only-of-type {\n", |
| 78 | + " vertical-align: middle;\n", |
| 79 | + " }\n", |
| 80 | + "\n", |
| 81 | + " .dataframe tbody tr th {\n", |
| 82 | + " vertical-align: top;\n", |
| 83 | + " }\n", |
| 84 | + "\n", |
| 85 | + " .dataframe thead th {\n", |
| 86 | + " text-align: right;\n", |
| 87 | + " }\n", |
| 88 | + "</style>\n", |
| 89 | + "<table border=\"1\" class=\"dataframe\">\n", |
| 90 | + " <thead>\n", |
| 91 | + " <tr style=\"text-align: right;\">\n", |
| 92 | + " <th></th>\n", |
| 93 | + " <th>CRIM</th>\n", |
| 94 | + " <th>ZN</th>\n", |
| 95 | + " <th>INDUS</th>\n", |
| 96 | + " <th>CHAS</th>\n", |
| 97 | + " <th>NOX</th>\n", |
| 98 | + " <th>RM</th>\n", |
| 99 | + " <th>AGE</th>\n", |
| 100 | + " <th>DIS</th>\n", |
| 101 | + " <th>RAD</th>\n", |
| 102 | + " <th>TAX</th>\n", |
| 103 | + " <th>PTRATIO</th>\n", |
| 104 | + " <th>B</th>\n", |
| 105 | + " <th>LSTAT</th>\n", |
| 106 | + " <th>target</th>\n", |
| 107 | + " </tr>\n", |
| 108 | + " </thead>\n", |
| 109 | + " <tbody>\n", |
| 110 | + " <tr>\n", |
| 111 | + " <th>0</th>\n", |
| 112 | + " <td>0.00632</td>\n", |
| 113 | + " <td>18.0</td>\n", |
| 114 | + " <td>2.31</td>\n", |
| 115 | + " <td>0.0</td>\n", |
| 116 | + " <td>0.538</td>\n", |
| 117 | + " <td>6.575</td>\n", |
| 118 | + " <td>65.2</td>\n", |
| 119 | + " <td>4.0900</td>\n", |
| 120 | + " <td>1.0</td>\n", |
| 121 | + " <td>296.0</td>\n", |
| 122 | + " <td>15.3</td>\n", |
| 123 | + " <td>396.90</td>\n", |
| 124 | + " <td>4.98</td>\n", |
| 125 | + " <td>24.0</td>\n", |
| 126 | + " </tr>\n", |
| 127 | + " <tr>\n", |
| 128 | + " <th>1</th>\n", |
| 129 | + " <td>0.02731</td>\n", |
| 130 | + " <td>0.0</td>\n", |
| 131 | + " <td>7.07</td>\n", |
| 132 | + " <td>0.0</td>\n", |
| 133 | + " <td>0.469</td>\n", |
| 134 | + " <td>6.421</td>\n", |
| 135 | + " <td>78.9</td>\n", |
| 136 | + " <td>4.9671</td>\n", |
| 137 | + " <td>2.0</td>\n", |
| 138 | + " <td>242.0</td>\n", |
| 139 | + " <td>17.8</td>\n", |
| 140 | + " <td>396.90</td>\n", |
| 141 | + " <td>9.14</td>\n", |
| 142 | + " <td>21.6</td>\n", |
| 143 | + " </tr>\n", |
| 144 | + " <tr>\n", |
| 145 | + " <th>2</th>\n", |
| 146 | + " <td>0.02729</td>\n", |
| 147 | + " <td>0.0</td>\n", |
| 148 | + " <td>7.07</td>\n", |
| 149 | + " <td>0.0</td>\n", |
| 150 | + " <td>0.469</td>\n", |
| 151 | + " <td>7.185</td>\n", |
| 152 | + " <td>61.1</td>\n", |
| 153 | + " <td>4.9671</td>\n", |
| 154 | + " <td>2.0</td>\n", |
| 155 | + " <td>242.0</td>\n", |
| 156 | + " <td>17.8</td>\n", |
| 157 | + " <td>392.83</td>\n", |
| 158 | + " <td>4.03</td>\n", |
| 159 | + " <td>34.7</td>\n", |
| 160 | + " </tr>\n", |
| 161 | + " <tr>\n", |
| 162 | + " <th>3</th>\n", |
| 163 | + " <td>0.03237</td>\n", |
| 164 | + " <td>0.0</td>\n", |
| 165 | + " <td>2.18</td>\n", |
| 166 | + " <td>0.0</td>\n", |
| 167 | + " <td>0.458</td>\n", |
| 168 | + " <td>6.998</td>\n", |
| 169 | + " <td>45.8</td>\n", |
| 170 | + " <td>6.0622</td>\n", |
| 171 | + " <td>3.0</td>\n", |
| 172 | + " <td>222.0</td>\n", |
| 173 | + " <td>18.7</td>\n", |
| 174 | + " <td>394.63</td>\n", |
| 175 | + " <td>2.94</td>\n", |
| 176 | + " <td>33.4</td>\n", |
| 177 | + " </tr>\n", |
| 178 | + " <tr>\n", |
| 179 | + " <th>4</th>\n", |
| 180 | + " <td>0.06905</td>\n", |
| 181 | + " <td>0.0</td>\n", |
| 182 | + " <td>2.18</td>\n", |
| 183 | + " <td>0.0</td>\n", |
| 184 | + " <td>0.458</td>\n", |
| 185 | + " <td>7.147</td>\n", |
| 186 | + " <td>54.2</td>\n", |
| 187 | + " <td>6.0622</td>\n", |
| 188 | + " <td>3.0</td>\n", |
| 189 | + " <td>222.0</td>\n", |
| 190 | + " <td>18.7</td>\n", |
| 191 | + " <td>396.90</td>\n", |
| 192 | + " <td>5.33</td>\n", |
| 193 | + " <td>36.2</td>\n", |
| 194 | + " </tr>\n", |
| 195 | + " </tbody>\n", |
| 196 | + "</table>\n", |
| 197 | + "</div>" |
| 198 | + ], |
| 199 | + "text/plain": [ |
| 200 | + " CRIM ZN INDUS CHAS NOX RM AGE DIS RAD TAX \\\n", |
| 201 | + "0 0.00632 18.0 2.31 0.0 0.538 6.575 65.2 4.0900 1.0 296.0 \n", |
| 202 | + "1 0.02731 0.0 7.07 0.0 0.469 6.421 78.9 4.9671 2.0 242.0 \n", |
| 203 | + "2 0.02729 0.0 7.07 0.0 0.469 7.185 61.1 4.9671 2.0 242.0 \n", |
| 204 | + "3 0.03237 0.0 2.18 0.0 0.458 6.998 45.8 6.0622 3.0 222.0 \n", |
| 205 | + "4 0.06905 0.0 2.18 0.0 0.458 7.147 54.2 6.0622 3.0 222.0 \n", |
| 206 | + "\n", |
| 207 | + " PTRATIO B LSTAT target \n", |
| 208 | + "0 15.3 396.90 4.98 24.0 \n", |
| 209 | + "1 17.8 396.90 9.14 21.6 \n", |
| 210 | + "2 17.8 392.83 4.03 34.7 \n", |
| 211 | + "3 18.7 394.63 2.94 33.4 \n", |
| 212 | + "4 18.7 396.90 5.33 36.2 " |
| 213 | + ] |
| 214 | + }, |
| 215 | + "execution_count": 3, |
| 216 | + "metadata": {}, |
| 217 | + "output_type": "execute_result" |
| 218 | + } |
| 219 | + ], |
| 220 | + "source": [ |
| 221 | + "df.head()" |
| 222 | + ] |
| 223 | + }, |
| 224 | + { |
| 225 | + "cell_type": "code", |
| 226 | + "execution_count": 4, |
| 227 | + "metadata": {}, |
| 228 | + "outputs": [], |
| 229 | + "source": [ |
| 230 | + "X = df.loc[:, ['RM', 'LSTAT', 'PTRATIO']].values" |
| 231 | + ] |
| 232 | + }, |
| 233 | + { |
| 234 | + "cell_type": "code", |
| 235 | + "execution_count": 5, |
| 236 | + "metadata": {}, |
| 237 | + "outputs": [], |
| 238 | + "source": [ |
| 239 | + "y = df.loc[:, 'target'].values" |
| 240 | + ] |
| 241 | + }, |
| 242 | + { |
| 243 | + "cell_type": "markdown", |
| 244 | + "metadata": {}, |
| 245 | + "source": [ |
| 246 | + "## Train Test Split " |
| 247 | + ] |
| 248 | + }, |
| 249 | + { |
| 250 | + "cell_type": "markdown", |
| 251 | + "metadata": {}, |
| 252 | + "source": [ |
| 253 | + "\n", |
| 254 | + "The colors in the image indicate which variable (X_train, X_test, y_train, y_test) the data from the dataframe df went to for a particular train test split (not necessarily the exact split of the code below)." |
| 255 | + ] |
| 256 | + }, |
| 257 | + { |
| 258 | + "cell_type": "code", |
| 259 | + "execution_count": 6, |
| 260 | + "metadata": {}, |
| 261 | + "outputs": [], |
| 262 | + "source": [ |
| 263 | + "X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=3)" |
| 264 | + ] |
| 265 | + }, |
| 266 | + { |
| 267 | + "cell_type": "markdown", |
| 268 | + "metadata": {}, |
| 269 | + "source": [ |
| 270 | + "## Linear Regression Model" |
| 271 | + ] |
| 272 | + }, |
| 273 | + { |
| 274 | + "cell_type": "code", |
| 275 | + "execution_count": 7, |
| 276 | + "metadata": {}, |
| 277 | + "outputs": [ |
| 278 | + { |
| 279 | + "data": { |
| 280 | + "text/html": [ |
| 281 | + "<style>#sk-container-id-1 {color: black;background-color: white;}#sk-container-id-1 pre{padding: 0;}#sk-container-id-1 div.sk-toggleable {background-color: white;}#sk-container-id-1 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-1 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-1 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-1 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-1 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-1 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-1 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-1 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-1 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-1 div.sk-item {position: relative;z-index: 1;}#sk-container-id-1 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-1 div.sk-item::before, #sk-container-id-1 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-1 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-1 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-1 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-1 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-1 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-1 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-1 div.sk-label-container {text-align: center;}#sk-container-id-1 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-1 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>LinearRegression()</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" checked><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">LinearRegression</label><div class=\"sk-toggleable__content\"><pre>LinearRegression()</pre></div></div></div></div></div>" |
| 282 | + ], |
| 283 | + "text/plain": [ |
| 284 | + "LinearRegression()" |
| 285 | + ] |
| 286 | + }, |
| 287 | + "execution_count": 7, |
| 288 | + "metadata": {}, |
| 289 | + "output_type": "execute_result" |
| 290 | + } |
| 291 | + ], |
| 292 | + "source": [ |
| 293 | + "# Make a linear regression instance\n", |
| 294 | + "reg = LinearRegression(fit_intercept=True)\n", |
| 295 | + "\n", |
| 296 | + "# Train the model on the training set.\n", |
| 297 | + "reg.fit(X_train, y_train)" |
| 298 | + ] |
| 299 | + }, |
| 300 | + { |
| 301 | + "cell_type": "markdown", |
| 302 | + "metadata": {}, |
| 303 | + "source": [ |
| 304 | + "## Measuring Model Performance\n", |
| 305 | + "By measuring model performance on the test set, you can estimate how well your model is likely to perform on new data (out-of-sample data)" |
| 306 | + ] |
| 307 | + }, |
| 308 | + { |
| 309 | + "cell_type": "code", |
| 310 | + "execution_count": 8, |
| 311 | + "metadata": {}, |
| 312 | + "outputs": [ |
| 313 | + { |
| 314 | + "name": "stdout", |
| 315 | + "output_type": "stream", |
| 316 | + "text": [ |
| 317 | + "0.7155620757319656\n" |
| 318 | + ] |
| 319 | + } |
| 320 | + ], |
| 321 | + "source": [ |
| 322 | + "# Test the model on the testing set and evaluate the performance\n", |
| 323 | + "score = reg.score(X_test, y_test)\n", |
| 324 | + "print(score)" |
| 325 | + ] |
| 326 | + }, |
| 327 | + { |
| 328 | + "cell_type": "markdown", |
| 329 | + "metadata": {}, |
| 330 | + "source": [ |
| 331 | + "So that's it, train_test_split helps you simulate how well a model would perform on new data" |
| 332 | + ] |
| 333 | + } |
| 334 | + ], |
| 335 | + "metadata": { |
| 336 | + "anaconda-cloud": {}, |
| 337 | + "kernelspec": { |
| 338 | + "display_name": "Python 3 (ipykernel)", |
| 339 | + "language": "python", |
| 340 | + "name": "python3" |
| 341 | + }, |
| 342 | + "language_info": { |
| 343 | + "codemirror_mode": { |
| 344 | + "name": "ipython", |
| 345 | + "version": 3 |
| 346 | + }, |
| 347 | + "file_extension": ".py", |
| 348 | + "mimetype": "text/x-python", |
| 349 | + "name": "python", |
| 350 | + "nbconvert_exporter": "python", |
| 351 | + "pygments_lexer": "ipython3", |
| 352 | + "version": "3.9.7" |
| 353 | + } |
| 354 | + }, |
| 355 | + "nbformat": 4, |
| 356 | + "nbformat_minor": 2 |
| 357 | +} |
0 commit comments