|
| 1 | +'''Train a Siamese MLP on pairs of digits from the MNIST dataset. |
| 2 | +
|
| 3 | +It follows Hadsell-et-al.'06 [1] by computing the Euclidean distance on the |
| 4 | +output of the shared network and by optimizing the contrastive loss (see paper |
| 5 | +for mode details). |
| 6 | +
|
| 7 | +[1] "Dimensionality Reduction by Learning an Invariant Mapping" |
| 8 | + http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf |
| 9 | +
|
| 10 | +Run on GPU: THEANO_FLAGS=mode=FAST_RUN,device=gpu,floatX=float32 python mnist_siamese_graph.py |
| 11 | +
|
| 12 | +Get to 99.5% test accuracy after 20 epochs. |
| 13 | +3 seconds per epoch on a Titan X GPU |
| 14 | +''' |
| 15 | +from __future__ import absolute_import |
| 16 | +from __future__ import print_function |
| 17 | +import numpy as np |
| 18 | +np.random.seed(1337) # for reproducibility |
| 19 | + |
| 20 | +import random |
| 21 | +from keras.datasets import mnist |
| 22 | +from keras.models import Sequential, Graph |
| 23 | +from keras.layers.core import * |
| 24 | +from keras.optimizers import SGD, RMSprop |
| 25 | +from keras import backend as K |
| 26 | + |
| 27 | + |
| 28 | +def euclidean_distance(inputs): |
| 29 | + assert len(inputs) == 2, \ |
| 30 | + 'Euclidean distance needs 2 inputs, %d given' % len(inputs) |
| 31 | + u, v = inputs.values() |
| 32 | + return K.sqrt((K.square(u - v)).sum(axis=1, keepdims=True)) |
| 33 | + |
| 34 | + |
| 35 | +def contrastive_loss(y, d): |
| 36 | + """ Contrastive loss from Hadsell-et-al.'06 |
| 37 | + http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf |
| 38 | + """ |
| 39 | + margin = 1 |
| 40 | + return K.mean(y * K.square(d) + (1 - y) * K.square(K.maximum(margin - d, 0))) |
| 41 | + |
| 42 | + |
| 43 | +def create_pairs(x, digit_indices): |
| 44 | + """ Positive and negative pair creation. |
| 45 | + Alternates between positive and negative pairs. |
| 46 | + """ |
| 47 | + pairs = [] |
| 48 | + labels = [] |
| 49 | + n = min([len(digit_indices[d]) for d in range(10)]) - 1 |
| 50 | + for d in range(10): |
| 51 | + for i in range(n): |
| 52 | + z1, z2 = digit_indices[d][i], digit_indices[d][i+1] |
| 53 | + pairs += [[x[z1], x[z2]]] |
| 54 | + inc = random.randrange(1, 10) |
| 55 | + dn = (d + inc) % 10 |
| 56 | + z1, z2 = digit_indices[d][i], digit_indices[dn][i] |
| 57 | + pairs += [[x[z1], x[z2]]] |
| 58 | + labels += [1, 0] |
| 59 | + return np.array(pairs), np.array(labels) |
| 60 | + |
| 61 | + |
| 62 | +def create_base_network(in_dim): |
| 63 | + """ Base network to be shared (eq. to feature extraction). |
| 64 | + """ |
| 65 | + seq = Sequential() |
| 66 | + seq.add(Dense(128, input_shape=(in_dim,), activation='relu')) |
| 67 | + seq.add(Dropout(0.1)) |
| 68 | + seq.add(Dense(128, activation='relu')) |
| 69 | + seq.add(Dropout(0.1)) |
| 70 | + seq.add(Dense(128, activation='relu')) |
| 71 | + return seq |
| 72 | + |
| 73 | + |
| 74 | +def compute_accuracy(predictions, labels): |
| 75 | + """ Compute classification accuracy with a fixed threshold on distances. |
| 76 | + """ |
| 77 | + return labels[predictions.ravel() < 0.5].mean() |
| 78 | + |
| 79 | + |
| 80 | +# the data, shuffled and split between tran and test sets |
| 81 | +(X_train, y_train), (X_test, y_test) = mnist.load_data() |
| 82 | +X_train = X_train.reshape(60000, 784) |
| 83 | +X_test = X_test.reshape(10000, 784) |
| 84 | +X_train = X_train.astype('float32') |
| 85 | +X_test = X_test.astype('float32') |
| 86 | +X_train /= 255 |
| 87 | +X_test /= 255 |
| 88 | +in_dim = 784 |
| 89 | +nb_epoch = 20 |
| 90 | + |
| 91 | +# create training+test positive and negative pairs |
| 92 | +digit_indices = [np.where(y_train == i)[0] for i in range(10)] |
| 93 | +tr_pairs, tr_y = create_pairs(X_train, digit_indices) |
| 94 | + |
| 95 | +digit_indices = [np.where(y_test == i)[0] for i in range(10)] |
| 96 | +te_pairs, te_y = create_pairs(X_test, digit_indices) |
| 97 | + |
| 98 | +# network definition |
| 99 | +base_network = create_base_network(in_dim) |
| 100 | + |
| 101 | +g = Graph() |
| 102 | +g.add_input(name='input_a', input_shape=(in_dim,)) |
| 103 | +g.add_input(name='input_b', input_shape=(in_dim,)) |
| 104 | +g.add_shared_node(base_network, name='shared', inputs=['input_a', 'input_b'], |
| 105 | + merge_mode='join') |
| 106 | +g.add_node(Lambda(euclidean_distance), name='d', input='shared') |
| 107 | +g.add_output(name='output', input='d') |
| 108 | + |
| 109 | +# train |
| 110 | +rms = RMSprop() |
| 111 | +g.compile(loss={'output': contrastive_loss}, optimizer=rms) |
| 112 | +g.fit({'input_a': tr_pairs[:, 0], 'input_b': tr_pairs[:, 1], 'output': tr_y}, |
| 113 | + validation_data={'input_a': te_pairs[:, 0], 'input_b': te_pairs[:, 1], 'output': te_y}, |
| 114 | + batch_size=128, nb_epoch=nb_epoch) |
| 115 | + |
| 116 | +# compute final accuracy on training and test sets |
| 117 | +pred = g.predict({'input_a': tr_pairs[:, 0], 'input_b': tr_pairs[:, 1]})['output'] |
| 118 | +tr_acc = compute_accuracy(pred, tr_y) |
| 119 | +pred = g.predict({'input_a': te_pairs[:, 0], 'input_b': te_pairs[:, 1]})['output'] |
| 120 | +te_acc = compute_accuracy(pred, te_y) |
| 121 | + |
| 122 | +print('* Accuracy on training set: %0.2f%%' % (100 * tr_acc)) |
| 123 | +print('* Accuracy on test set: %0.2f%%' % (100 * te_acc)) |
0 commit comments