| 
 | 1 | +# Copyright (c) Microsoft. All rights reserved.  | 
 | 2 | +# Licensed under the MIT license.  | 
 | 3 | +# Adapted from:  | 
 | 4 | +# https://github.com/Microsoft/CNTK/blob/master/Examples/Image/Classification/ConvNet/Python/ConvNet_MNIST.py  | 
 | 5 | +# ====================================================================  | 
 | 6 | +"""Train a CNN model on the MNIST dataset via distributed training."""  | 
 | 7 | + | 
 | 8 | +from __future__ import print_function  | 
 | 9 | +import numpy as np  | 
 | 10 | +import os  | 
 | 11 | +import cntk as C  | 
 | 12 | +import argparse  | 
 | 13 | +from cntk.train.training_session import CheckpointConfig, TestConfig  | 
 | 14 | + | 
 | 15 | + | 
 | 16 | +def create_reader(path, is_training, input_dim, label_dim, total_number_of_samples):  | 
 | 17 | +    """Define the reader for both training and evaluation action."""  | 
 | 18 | +    return C.io.MinibatchSource(C.io.CTFDeserializer(path, C.io.StreamDefs(  | 
 | 19 | +        features=C.io.StreamDef(field='features', shape=input_dim),  | 
 | 20 | +        labels=C.io.StreamDef(field='labels', shape=label_dim)  | 
 | 21 | +    )), randomize=is_training, max_samples=total_number_of_samples)  | 
 | 22 | + | 
 | 23 | + | 
 | 24 | +def convnet_mnist(max_epochs, output_dir, data_dir, debug_output=False, epoch_size=60000, minibatch_size=64):  | 
 | 25 | +    """Creates and trains a feedforward classification model for MNIST images."""  | 
 | 26 | +    image_height = 28  | 
 | 27 | +    image_width = 28  | 
 | 28 | +    num_channels = 1  | 
 | 29 | +    input_dim = image_height * image_width * num_channels  | 
 | 30 | +    num_output_classes = 10  | 
 | 31 | + | 
 | 32 | +    # Input variables denoting the features and label data  | 
 | 33 | +    input_var = C.ops.input_variable((num_channels, image_height, image_width), np.float32)  | 
 | 34 | +    label_var = C.ops.input_variable(num_output_classes, np.float32)  | 
 | 35 | + | 
 | 36 | +    # Instantiate the feedforward classification model  | 
 | 37 | +    scaled_input = C.ops.element_times(C.ops.constant(0.00390625), input_var)  | 
 | 38 | + | 
 | 39 | +    with C.layers.default_options(activation=C.ops.relu, pad=False):  | 
 | 40 | +        conv1 = C.layers.Convolution2D((5, 5), 32, pad=True)(scaled_input)  | 
 | 41 | +        pool1 = C.layers.MaxPooling((3, 3), (2, 2))(conv1)  | 
 | 42 | +        conv2 = C.layers.Convolution2D((3, 3), 48)(pool1)  | 
 | 43 | +        pool2 = C.layers.MaxPooling((3, 3), (2, 2))(conv2)  | 
 | 44 | +        conv3 = C.layers.Convolution2D((3, 3), 64)(pool2)  | 
 | 45 | +        f4 = C.layers.Dense(96)(conv3)  | 
 | 46 | +        drop4 = C.layers.Dropout(0.5)(f4)  | 
 | 47 | +        z = C.layers.Dense(num_output_classes, activation=None)(drop4)  | 
 | 48 | + | 
 | 49 | +    ce = C.losses.cross_entropy_with_softmax(z, label_var)  | 
 | 50 | +    pe = C.metrics.classification_error(z, label_var)  | 
 | 51 | + | 
 | 52 | +    # Load train data  | 
 | 53 | +    reader_train = create_reader(os.path.join(data_dir, 'Train-28x28_cntk_text.txt'), True,  | 
 | 54 | +                                 input_dim, num_output_classes, max_epochs * epoch_size)  | 
 | 55 | +    # Load test data  | 
 | 56 | +    reader_test = create_reader(os.path.join(data_dir, 'Test-28x28_cntk_text.txt'), False,  | 
 | 57 | +                                input_dim, num_output_classes, C.io.FULL_DATA_SWEEP)  | 
 | 58 | + | 
 | 59 | +    # Set learning parameters  | 
 | 60 | +    lr_per_sample = [0.001] * 10 + [0.0005] * 10 + [0.0001]  | 
 | 61 | +    lr_schedule = C.learning_parameter_schedule_per_sample(lr_per_sample, epoch_size=epoch_size)  | 
 | 62 | +    mms = [0] * 5 + [0.9990239141819757]  | 
 | 63 | +    mm_schedule = C.learners.momentum_schedule_per_sample(mms, epoch_size=epoch_size)  | 
 | 64 | + | 
 | 65 | +    # Instantiate the trainer object to drive the model training  | 
 | 66 | +    local_learner = C.learners.momentum_sgd(z.parameters, lr_schedule, mm_schedule)  | 
 | 67 | +    progress_printer = C.logging.ProgressPrinter(  | 
 | 68 | +        tag='Training',  | 
 | 69 | +        rank=C.train.distributed.Communicator.rank(),  | 
 | 70 | +        num_epochs=max_epochs,  | 
 | 71 | +    )  | 
 | 72 | + | 
 | 73 | +    learner = C.train.distributed.data_parallel_distributed_learner(local_learner)  | 
 | 74 | +    trainer = C.Trainer(z, (ce, pe), learner, progress_printer)  | 
 | 75 | + | 
 | 76 | +    # define mapping from reader streams to network inputs  | 
 | 77 | +    input_map_train = {  | 
 | 78 | +        input_var: reader_train.streams.features,  | 
 | 79 | +        label_var: reader_train.streams.labels  | 
 | 80 | +    }  | 
 | 81 | + | 
 | 82 | +    input_map_test = {  | 
 | 83 | +        input_var: reader_test.streams.features,  | 
 | 84 | +        label_var: reader_test.streams.labels  | 
 | 85 | +    }  | 
 | 86 | + | 
 | 87 | +    C.logging.log_number_of_parameters(z)  | 
 | 88 | +    print()  | 
 | 89 | + | 
 | 90 | +    C.train.training_session(  | 
 | 91 | +        trainer=trainer,  | 
 | 92 | +        mb_source=reader_train,  | 
 | 93 | +        model_inputs_to_streams=input_map_train,  | 
 | 94 | +        mb_size=minibatch_size,  | 
 | 95 | +        progress_frequency=epoch_size,  | 
 | 96 | +        checkpoint_config=CheckpointConfig(frequency=epoch_size,  | 
 | 97 | +                                           filename=os.path.join(output_dir, "ConvNet_MNIST")),  | 
 | 98 | +        test_config=TestConfig(reader_test, minibatch_size=minibatch_size,  | 
 | 99 | +                               model_inputs_to_streams=input_map_test)  | 
 | 100 | +    ).train()  | 
 | 101 | + | 
 | 102 | +    return  | 
 | 103 | + | 
 | 104 | + | 
 | 105 | +if __name__ == '__main__':  | 
 | 106 | +    parser = argparse.ArgumentParser()  | 
 | 107 | +    parser.add_argument('--num_epochs', help='Total number of epochs to train', type=int, default='40')  | 
 | 108 | +    parser.add_argument('--output_dir', help='Output directory', required=False, default='outputs')  | 
 | 109 | +    parser.add_argument('--data_dir', help='Directory with training data')  | 
 | 110 | +    args = parser.parse_args()  | 
 | 111 | + | 
 | 112 | +    os.makedirs(args.output_dir, exist_ok=True)  | 
 | 113 | + | 
 | 114 | +    convnet_mnist(args.num_epochs, args.output_dir, args.data_dir)  | 
 | 115 | + | 
 | 116 | +    # Must call MPI finalize when process exit without exceptions  | 
 | 117 | +    C.train.distributed.Communicator.finalize()  | 
0 commit comments