Skip to content

Commit 652f2eb

Browse files
author
Mikael Rousson
committed
add siamese example
use graph model take pairs of digits as input
1 parent 359f91f commit 652f2eb

File tree

1 file changed

+123
-0
lines changed

1 file changed

+123
-0
lines changed

examples/mnist_siamese_graph.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
'''Train a Siamese MLP on pairs of digits from the MNIST dataset.
2+
3+
It follows Hadsell-et-al.'06 [1] by computing the Euclidean distance on the
4+
output of the shared network and by optimizing the contrastive loss (see paper
5+
for mode details).
6+
7+
[1] "Dimensionality Reduction by Learning an Invariant Mapping"
8+
http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
9+
10+
Run on GPU: THEANO_FLAGS=mode=FAST_RUN,device=gpu,floatX=float32 python mnist_siamese_graph.py
11+
12+
Get to 99.5% test accuracy after 20 epochs.
13+
3 seconds per epoch on a Titan X GPU
14+
'''
15+
from __future__ import absolute_import
16+
from __future__ import print_function
17+
import numpy as np
18+
np.random.seed(1337) # for reproducibility
19+
20+
import random
21+
from keras.datasets import mnist
22+
from keras.models import Sequential, Graph
23+
from keras.layers.core import *
24+
from keras.optimizers import SGD, RMSprop
25+
from keras import backend as K
26+
27+
28+
def euclidean_distance(inputs):
29+
assert len(inputs) == 2, \
30+
'Euclidean distance needs 2 inputs, %d given' % len(inputs)
31+
u, v = inputs.values()
32+
return K.sqrt((K.square(u - v)).sum(axis=1, keepdims=True))
33+
34+
35+
def contrastive_loss(y, d):
36+
""" Contrastive loss from Hadsell-et-al.'06
37+
http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
38+
"""
39+
margin = 1
40+
return K.mean(y * K.square(d) + (1 - y) * K.square(K.maximum(margin - d, 0)))
41+
42+
43+
def create_pairs(x, digit_indices):
44+
""" Positive and negative pair creation.
45+
Alternates between positive and negative pairs.
46+
"""
47+
pairs = []
48+
labels = []
49+
n = min([len(digit_indices[d]) for d in range(10)]) - 1
50+
for d in range(10):
51+
for i in range(n):
52+
z1, z2 = digit_indices[d][i], digit_indices[d][i+1]
53+
pairs += [[x[z1], x[z2]]]
54+
inc = random.randrange(1, 10)
55+
dn = (d + inc) % 10
56+
z1, z2 = digit_indices[d][i], digit_indices[dn][i]
57+
pairs += [[x[z1], x[z2]]]
58+
labels += [1, 0]
59+
return np.array(pairs), np.array(labels)
60+
61+
62+
def create_base_network(in_dim):
63+
""" Base network to be shared (eq. to feature extraction).
64+
"""
65+
seq = Sequential()
66+
seq.add(Dense(128, input_shape=(in_dim,), activation='relu'))
67+
seq.add(Dropout(0.1))
68+
seq.add(Dense(128, activation='relu'))
69+
seq.add(Dropout(0.1))
70+
seq.add(Dense(128, activation='relu'))
71+
return seq
72+
73+
74+
def compute_accuracy(predictions, labels):
75+
""" Compute classification accuracy with a fixed threshold on distances.
76+
"""
77+
return labels[predictions.ravel() < 0.5].mean()
78+
79+
80+
# the data, shuffled and split between tran and test sets
81+
(X_train, y_train), (X_test, y_test) = mnist.load_data()
82+
X_train = X_train.reshape(60000, 784)
83+
X_test = X_test.reshape(10000, 784)
84+
X_train = X_train.astype('float32')
85+
X_test = X_test.astype('float32')
86+
X_train /= 255
87+
X_test /= 255
88+
in_dim = 784
89+
nb_epoch = 20
90+
91+
# create training+test positive and negative pairs
92+
digit_indices = [np.where(y_train == i)[0] for i in range(10)]
93+
tr_pairs, tr_y = create_pairs(X_train, digit_indices)
94+
95+
digit_indices = [np.where(y_test == i)[0] for i in range(10)]
96+
te_pairs, te_y = create_pairs(X_test, digit_indices)
97+
98+
# network definition
99+
base_network = create_base_network(in_dim)
100+
101+
g = Graph()
102+
g.add_input(name='input_a', input_shape=(in_dim,))
103+
g.add_input(name='input_b', input_shape=(in_dim,))
104+
g.add_shared_node(base_network, name='shared', inputs=['input_a', 'input_b'],
105+
merge_mode='join')
106+
g.add_node(Lambda(euclidean_distance), name='d', input='shared')
107+
g.add_output(name='output', input='d')
108+
109+
# train
110+
rms = RMSprop()
111+
g.compile(loss={'output': contrastive_loss}, optimizer=rms)
112+
g.fit({'input_a': tr_pairs[:, 0], 'input_b': tr_pairs[:, 1], 'output': tr_y},
113+
validation_data={'input_a': te_pairs[:, 0], 'input_b': te_pairs[:, 1], 'output': te_y},
114+
batch_size=128, nb_epoch=nb_epoch)
115+
116+
# compute final accuracy on training and test sets
117+
pred = g.predict({'input_a': tr_pairs[:, 0], 'input_b': tr_pairs[:, 1]})['output']
118+
tr_acc = compute_accuracy(pred, tr_y)
119+
pred = g.predict({'input_a': te_pairs[:, 0], 'input_b': te_pairs[:, 1]})['output']
120+
te_acc = compute_accuracy(pred, te_y)
121+
122+
print('* Accuracy on training set: %0.2f%%' % (100 * tr_acc))
123+
print('* Accuracy on test set: %0.2f%%' % (100 * te_acc))

0 commit comments

Comments
 (0)