Skip to content

Commit c89f56d

Browse files
author
EC2 Default User
committed
wip
1 parent 9d86f32 commit c89f56d

File tree

3 files changed

+140
-38
lines changed

3 files changed

+140
-38
lines changed

sagemaker/part1_model_research.ipynb

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959
"from glob import glob\n",
6060
"import boto3\n",
6161
"\n",
62-
"# download files from s3 \"directory\"\n",
62+
"# download train files from s3 \"directory\"\n",
6363
"s3_bucket = 'jakechenawspublic'\n",
6464
"s3_prefix = 'sample_data/mnist/train'\n",
6565
"\n",
@@ -70,10 +70,10 @@
7070
"s3_keys = [r['Key'] for r in response['Contents']]\n",
7171
"for key_name in s3_keys:\n",
7272
" fname = key_name.replace(s3_prefix, '').lstrip('/') # create local file name by removing prefix from key name\n",
73-
" s3.download_file(bucket_name, key_name, fname) # download file into training dir\n",
73+
" s3.download_file(bucket_name, key_name, fname) # download train file\n",
7474
"\n",
75-
"# load downloaded files into np.array\n",
76-
"fnames = glob('*.csv')\n",
75+
"# load downloaded files (in this case, file) into np.array\n",
76+
"fnames = glob('*train.csv')\n",
7777
"arrays = np.array([np.loadtxt(f, delimiter=',') for f in fnames])\n",
7878
"\n",
7979
"# join files into one array with shape [records, 785]\n",
@@ -170,7 +170,7 @@
170170
"cell_type": "markdown",
171171
"metadata": {},
172172
"source": [
173-
"#### Test Network"
173+
"### Phase 5: Model Evaluation"
174174
]
175175
},
176176
{
@@ -187,13 +187,25 @@
187187
}
188188
],
189189
"source": [
190-
"bucket_name = 'jakechenawspublic'\n",
191-
"key_name = 'sample_data/mnist/test/mnist_test.csv'\n",
190+
"batch_size = 5\n",
191+
"\n",
192+
"# download test files from s3 \"directory\"\n",
193+
"s3_bucket = 'jakechenawspublic'\n",
194+
"s3_prefix = 'sample_data/mnist/test'\n",
192195
"\n",
193196
"s3 = boto3.client('s3')\n",
194-
"s3.download_file(bucket_name, key_name, 'mnist_test.csv')\n",
195197
"\n",
196-
"mnist_test = np.loadtxt('mnist_test.csv', delimiter=',')\n",
198+
"response = s3.list_objects_v2(Bucket=s3_bucket, Prefix=s3_prefix)\n",
199+
"\n",
200+
"s3_keys = [r['Key'] for r in response['Contents']]\n",
201+
"for key_name in s3_keys:\n",
202+
" fname = key_name.replace(s3_prefix, '').lstrip('/') # create local file name by removing prefix from key name\n",
203+
" s3.download_file(bucket_name, key_name, fname) # download test file\n",
204+
"\n",
205+
"fnames = glob('*_test.csv')\n",
206+
"arrays = np.array([np.loadtxt(f, delimiter=',') for f in fnames])\n",
207+
"\n",
208+
"mnist_test = arrays.reshape(-1, 785)\n",
197209
"X_test = mnist_test.T[1:].T.reshape(-1,1,28,28)\n",
198210
"y_test = mnist_test.T[:1].T.reshape(-1)\n",
199211
"\n",
@@ -203,7 +215,10 @@
203215
"acc = mx.metric.Accuracy()\n",
204216
"lenet_model.score(test_iter, acc)\n",
205217
"print(acc)\n",
206-
"assert acc.get()[1] > 0.98"
218+
"assert acc.get()[1] > 0.98\n",
219+
"\n",
220+
"# predict function for lenet\n",
221+
"lenet_model.predict(test_iter)"
207222
]
208223
},
209224
{

sagemaker/part2_sm_mnist.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -122,13 +122,7 @@ def train(
122122
fc2 = mx.sym.FullyConnected(data=tanh3, num_hidden=10)
123123
# softmax loss
124124
lenet = mx.sym.SoftmaxOutput(data=fc2, name='softmax')
125-
# define training batch size
126-
batch_size = 100
127125

128-
# create iterator around training and validation data
129-
train_iter = mx.io.NDArrayIter(mnist['train_data'][:ntrain], mnist['train_label'][:ntrain], batch_size, shuffle=True)
130-
val_iter = mx.io.NDArrayIter(mnist['train_data'][ntrain:], mnist['train_label'][ntrain:], batch_size)
131-
132126
"""
133127
End copy/paste from tutorial part 1
134128
"""
@@ -153,7 +147,7 @@ def train(
153147
kvstore=kvstore # added kvstore argument
154148
)
155149

156-
return lenet
150+
return lenet_model
157151

158152

159153
# ---------------------------------------------------------------------------- #

0 commit comments

Comments
 (0)