|
1 | 1 | # |
2 | 2 | # |
3 | | -# Templates for this file can be found here: |
| 3 | +# Templates for required and optional functions for this file can be found here: |
4 | 4 | # http://docs.aws.amazon.com/sagemaker/latest/dg/mxnet-training-inference-code-template.html |
5 | 5 | # |
6 | 6 | # More information can be found here: |
|
17 | 17 | # ---------------------------------------------------------------------------- # |
18 | 18 | # Training functions # |
19 | 19 | # ---------------------------------------------------------------------------- # |
20 | | -def train(channel_input_dirs, **kwargs): |
| 20 | +def train_working(channel_input_dirs, **kwargs): |
21 | 21 | import mxnet as mx |
22 | 22 | mnist = mx.test_utils.get_mnist() |
23 | 23 |
|
@@ -57,16 +57,16 @@ def train(channel_input_dirs, **kwargs): |
57 | 57 | return lenet_model |
58 | 58 |
|
59 | 59 |
|
60 | | -def train_old( |
61 | | -# hyperparameters, |
62 | | -# input_data_config, |
| 60 | +def train( |
| 61 | +# hyperparameters, # not used in tutorial |
| 62 | +# input_data_config, # not used in tutorial |
63 | 63 | channel_input_dirs, |
64 | | -# output_data_dir, |
65 | | -# model_dir, |
66 | | -# num_gpus, |
67 | | -# num_cpus, |
68 | | -# hosts, |
69 | | -# current_host, |
| 64 | +# output_data_dir, # not used in tutorial |
| 65 | +# model_dir, # not used in tutorial |
| 66 | +# num_gpus, # not used in tutorial |
| 67 | +# num_cpus, # not used in tutorial |
| 68 | +# hosts, # not used in tutorial |
| 69 | +# current_host, # not used in tutorial |
70 | 70 | **kwargs): |
71 | 71 |
|
72 | 72 | """ |
@@ -133,8 +133,8 @@ def train_old( |
133 | 133 | y_arrays = np.loadtxt(f_name, delimiter=',') |
134 | 134 |
|
135 | 135 | # reshape into requisite shape for NN |
136 | | - X_train = mx.nd.array(X_arrays.reshape(-1, 1, 28, 28)) |
137 | | - y_train = mx.nd.array(y_arrays.reshape(-1)).one_hot(10) |
| 136 | + X_train = X_arrays.reshape(-1, 1, 28, 28) |
| 137 | + y_train = y_arrays.reshape(-1) |
138 | 138 | logging.info('X_train.shape: {}'.format(X_train.shape)) |
139 | 139 | logging.info('y_train.shape: {}'.format(y_train.shape)) |
140 | 140 |
|
|
0 commit comments