Skip to content

Commit 78fba3c

Browse files
authored
Merge pull request Azure#113 from rastala/master
notebook patches
2 parents ce7ca94 + 01dc3d0 commit 78fba3c

File tree

4 files changed

+838
-331
lines changed

4 files changed

+838
-331
lines changed
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# Copyright (c) Microsoft. All rights reserved.
2+
# Licensed under the MIT license.
3+
# Adapted from:
4+
# https://github.com/Microsoft/CNTK/blob/master/Examples/Image/Classification/ConvNet/Python/ConvNet_MNIST.py
5+
# ====================================================================
6+
"""Train a CNN model on the MNIST dataset via distributed training."""
7+
8+
from __future__ import print_function
9+
import numpy as np
10+
import os
11+
import cntk as C
12+
import argparse
13+
from cntk.train.training_session import CheckpointConfig, TestConfig
14+
15+
16+
def create_reader(path, is_training, input_dim, label_dim, total_number_of_samples):
17+
"""Define the reader for both training and evaluation action."""
18+
return C.io.MinibatchSource(C.io.CTFDeserializer(path, C.io.StreamDefs(
19+
features=C.io.StreamDef(field='features', shape=input_dim),
20+
labels=C.io.StreamDef(field='labels', shape=label_dim)
21+
)), randomize=is_training, max_samples=total_number_of_samples)
22+
23+
24+
def convnet_mnist(max_epochs, output_dir, data_dir, debug_output=False, epoch_size=60000, minibatch_size=64):
25+
"""Creates and trains a feedforward classification model for MNIST images."""
26+
image_height = 28
27+
image_width = 28
28+
num_channels = 1
29+
input_dim = image_height * image_width * num_channels
30+
num_output_classes = 10
31+
32+
# Input variables denoting the features and label data
33+
input_var = C.ops.input_variable((num_channels, image_height, image_width), np.float32)
34+
label_var = C.ops.input_variable(num_output_classes, np.float32)
35+
36+
# Instantiate the feedforward classification model
37+
scaled_input = C.ops.element_times(C.ops.constant(0.00390625), input_var)
38+
39+
with C.layers.default_options(activation=C.ops.relu, pad=False):
40+
conv1 = C.layers.Convolution2D((5, 5), 32, pad=True)(scaled_input)
41+
pool1 = C.layers.MaxPooling((3, 3), (2, 2))(conv1)
42+
conv2 = C.layers.Convolution2D((3, 3), 48)(pool1)
43+
pool2 = C.layers.MaxPooling((3, 3), (2, 2))(conv2)
44+
conv3 = C.layers.Convolution2D((3, 3), 64)(pool2)
45+
f4 = C.layers.Dense(96)(conv3)
46+
drop4 = C.layers.Dropout(0.5)(f4)
47+
z = C.layers.Dense(num_output_classes, activation=None)(drop4)
48+
49+
ce = C.losses.cross_entropy_with_softmax(z, label_var)
50+
pe = C.metrics.classification_error(z, label_var)
51+
52+
# Load train data
53+
reader_train = create_reader(os.path.join(data_dir, 'Train-28x28_cntk_text.txt'), True,
54+
input_dim, num_output_classes, max_epochs * epoch_size)
55+
# Load test data
56+
reader_test = create_reader(os.path.join(data_dir, 'Test-28x28_cntk_text.txt'), False,
57+
input_dim, num_output_classes, C.io.FULL_DATA_SWEEP)
58+
59+
# Set learning parameters
60+
lr_per_sample = [0.001] * 10 + [0.0005] * 10 + [0.0001]
61+
lr_schedule = C.learning_parameter_schedule_per_sample(lr_per_sample, epoch_size=epoch_size)
62+
mms = [0] * 5 + [0.9990239141819757]
63+
mm_schedule = C.learners.momentum_schedule_per_sample(mms, epoch_size=epoch_size)
64+
65+
# Instantiate the trainer object to drive the model training
66+
local_learner = C.learners.momentum_sgd(z.parameters, lr_schedule, mm_schedule)
67+
progress_printer = C.logging.ProgressPrinter(
68+
tag='Training',
69+
rank=C.train.distributed.Communicator.rank(),
70+
num_epochs=max_epochs,
71+
)
72+
73+
learner = C.train.distributed.data_parallel_distributed_learner(local_learner)
74+
trainer = C.Trainer(z, (ce, pe), learner, progress_printer)
75+
76+
# define mapping from reader streams to network inputs
77+
input_map_train = {
78+
input_var: reader_train.streams.features,
79+
label_var: reader_train.streams.labels
80+
}
81+
82+
input_map_test = {
83+
input_var: reader_test.streams.features,
84+
label_var: reader_test.streams.labels
85+
}
86+
87+
C.logging.log_number_of_parameters(z)
88+
print()
89+
90+
C.train.training_session(
91+
trainer=trainer,
92+
mb_source=reader_train,
93+
model_inputs_to_streams=input_map_train,
94+
mb_size=minibatch_size,
95+
progress_frequency=epoch_size,
96+
checkpoint_config=CheckpointConfig(frequency=epoch_size,
97+
filename=os.path.join(output_dir, "ConvNet_MNIST")),
98+
test_config=TestConfig(reader_test, minibatch_size=minibatch_size,
99+
model_inputs_to_streams=input_map_test)
100+
).train()
101+
102+
return
103+
104+
105+
if __name__ == '__main__':
106+
parser = argparse.ArgumentParser()
107+
parser.add_argument('--num_epochs', help='Total number of epochs to train', type=int, default='40')
108+
parser.add_argument('--output_dir', help='Output directory', required=False, default='outputs')
109+
parser.add_argument('--data_dir', help='Directory with training data')
110+
args = parser.parse_args()
111+
112+
os.makedirs(args.output_dir, exist_ok=True)
113+
114+
convnet_mnist(args.num_epochs, args.output_dir, args.data_dir)
115+
116+
# Must call MPI finalize when process exit without exceptions
117+
C.train.distributed.Communicator.finalize()

0 commit comments

Comments
 (0)