|
59 | 59 | "from glob import glob\n", |
60 | 60 | "import boto3\n", |
61 | 61 | "\n", |
62 | | - "# download files from s3 \"directory\"\n", |
| 62 | + "# download train files from s3 \"directory\"\n", |
63 | 63 | "s3_bucket = 'jakechenawspublic'\n", |
64 | 64 | "s3_prefix = 'sample_data/mnist/train'\n", |
65 | 65 | "\n", |
|
70 | 70 | "s3_keys = [r['Key'] for r in response['Contents']]\n", |
71 | 71 | "for key_name in s3_keys:\n", |
72 | 72 | " fname = key_name.replace(s3_prefix, '').lstrip('/') # create local file name by removing prefix from key name\n", |
73 | | - " s3.download_file(bucket_name, key_name, fname) # download file into training dir\n", |
| 73 | + " s3.download_file(bucket_name, key_name, fname) # download train file\n", |
74 | 74 | "\n", |
75 | | - "# load downloaded files into np.array\n", |
76 | | - "fnames = glob('*.csv')\n", |
| 75 | + "# load downloaded files (in this case, file) into np.array\n", |
| 76 | + "fnames = glob('*train.csv')\n", |
77 | 77 | "arrays = np.array([np.loadtxt(f, delimiter=',') for f in fnames])\n", |
78 | 78 | "\n", |
79 | 79 | "# join files into one array with shape [records, 785]\n", |
|
170 | 170 | "cell_type": "markdown", |
171 | 171 | "metadata": {}, |
172 | 172 | "source": [ |
173 | | - "#### Test Network" |
| 173 | + "### Phase 5: Model Evaluation" |
174 | 174 | ] |
175 | 175 | }, |
176 | 176 | { |
|
187 | 187 | } |
188 | 188 | ], |
189 | 189 | "source": [ |
190 | | - "bucket_name = 'jakechenawspublic'\n", |
191 | | - "key_name = 'sample_data/mnist/test/mnist_test.csv'\n", |
| 190 | + "batch_size = 5\n", |
| 191 | + "\n", |
| 192 | + "# download test files from s3 \"directory\"\n", |
| 193 | + "s3_bucket = 'jakechenawspublic'\n", |
| 194 | + "s3_prefix = 'sample_data/mnist/test'\n", |
192 | 195 | "\n", |
193 | 196 | "s3 = boto3.client('s3')\n", |
194 | | - "s3.download_file(bucket_name, key_name, 'mnist_test.csv')\n", |
195 | 197 | "\n", |
196 | | - "mnist_test = np.loadtxt('mnist_test.csv', delimiter=',')\n", |
| 198 | + "response = s3.list_objects_v2(Bucket=s3_bucket, Prefix=s3_prefix)\n", |
| 199 | + "\n", |
| 200 | + "s3_keys = [r['Key'] for r in response['Contents']]\n", |
| 201 | + "for key_name in s3_keys:\n", |
| 202 | + " fname = key_name.replace(s3_prefix, '').lstrip('/') # create local file name by removing prefix from key name\n", |
| 203 | + " s3.download_file(bucket_name, key_name, fname) # download test file\n", |
| 204 | + "\n", |
| 205 | + "fnames = glob('*_test.csv')\n", |
| 206 | + "arrays = np.array([np.loadtxt(f, delimiter=',') for f in fnames])\n", |
| 207 | + "\n", |
| 208 | + "mnist_test = arrays.reshape(-1, 785)\n", |
197 | 209 | "X_test = mnist_test.T[1:].T.reshape(-1,1,28,28)\n", |
198 | 210 | "y_test = mnist_test.T[:1].T.reshape(-1)\n", |
199 | 211 | "\n", |
|
203 | 215 | "acc = mx.metric.Accuracy()\n", |
204 | 216 | "lenet_model.score(test_iter, acc)\n", |
205 | 217 | "print(acc)\n", |
206 | | - "assert acc.get()[1] > 0.98" |
| 218 | + "assert acc.get()[1] > 0.98\n", |
| 219 | + "\n", |
| 220 | + "# predict function for lenet\n", |
| 221 | + "lenet_model.predict(test_iter)" |
207 | 222 | ] |
208 | 223 | }, |
209 | 224 | { |
|
0 commit comments