Skip to content

Commit 57af5c4

Browse files
author
EC2 Default User
committed
Custom train() works, trains, and deploys.
1 parent 9950bc8 commit 57af5c4

File tree

2 files changed

+135
-64
lines changed

2 files changed

+135
-64
lines changed

sagemaker/part2_sm_mnist.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#
22
#
3-
# Templates for this file can be found here:
3+
# Templates for required and optional functions for this file can be found here:
44
# http://docs.aws.amazon.com/sagemaker/latest/dg/mxnet-training-inference-code-template.html
55
#
66
# More information can be found here:
@@ -17,7 +17,7 @@
1717
# ---------------------------------------------------------------------------- #
1818
# Training functions #
1919
# ---------------------------------------------------------------------------- #
20-
def train(channel_input_dirs, **kwargs):
20+
def train_working(channel_input_dirs, **kwargs):
2121
import mxnet as mx
2222
mnist = mx.test_utils.get_mnist()
2323

@@ -57,16 +57,16 @@ def train(channel_input_dirs, **kwargs):
5757
return lenet_model
5858

5959

60-
def train_old(
61-
# hyperparameters,
62-
# input_data_config,
60+
def train(
61+
# hyperparameters, # not used in tutorial
62+
# input_data_config, # not used in tutorial
6363
channel_input_dirs,
64-
# output_data_dir,
65-
# model_dir,
66-
# num_gpus,
67-
# num_cpus,
68-
# hosts,
69-
# current_host,
64+
# output_data_dir, # not used in tutorial
65+
# model_dir, # not used in tutorial
66+
# num_gpus, # not used in tutorial
67+
# num_cpus, # not used in tutorial
68+
# hosts, # not used in tutorial
69+
# current_host, # not used in tutorial
7070
**kwargs):
7171

7272
"""
@@ -133,8 +133,8 @@ def train_old(
133133
y_arrays = np.loadtxt(f_name, delimiter=',')
134134

135135
# reshape into requisite shape for NN
136-
X_train = mx.nd.array(X_arrays.reshape(-1, 1, 28, 28))
137-
y_train = mx.nd.array(y_arrays.reshape(-1)).one_hot(10)
136+
X_train = X_arrays.reshape(-1, 1, 28, 28)
137+
y_train = y_arrays.reshape(-1)
138138
logging.info('X_train.shape: {}'.format(X_train.shape))
139139
logging.info('y_train.shape: {}'.format(y_train.shape))
140140

0 commit comments

Comments
 (0)