99
1010Run on GPU: THEANO_FLAGS=mode=FAST_RUN,device=gpu,floatX=float32 python mnist_siamese_graph.py
1111
12- Get to 99.5% test accuracy after 20 epochs.
12+ Gets to 99.5% test accuracy after 20 epochs.
13133 seconds per epoch on a Titan X GPU
1414'''
1515from __future__ import absolute_import
2020import random
2121from keras .datasets import mnist
2222from keras .models import Sequential , Graph
23- from keras .layers .core import *
23+ from keras .layers .core import Dense , Dropout , Lambda
2424from keras .optimizers import SGD , RMSprop
2525from keras import backend as K
2626
2727
2828def euclidean_distance (inputs ):
29- assert len (inputs ) == 2 , \
30- 'Euclidean distance needs 2 inputs, %d given' % len (inputs )
29+ assert len (inputs ) == 2 , ( 'Euclidean distance needs '
30+ ' 2 inputs, %d given' % len (inputs ) )
3131 u , v = inputs .values ()
32- return K .sqrt ((K .square (u - v )). sum ( axis = 1 , keepdims = True ))
32+ return K .sqrt (K . sum (K .square (u - v ), axis = 1 , keepdims = True ))
3333
3434
3535def contrastive_loss (y , d ):
36- """ Contrastive loss from Hadsell-et-al.'06
37- http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
38- """
36+ ''' Contrastive loss from Hadsell-et-al.'06
37+ http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
38+ '''
3939 margin = 1
4040 return K .mean (y * K .square (d ) + (1 - y ) * K .square (K .maximum (margin - d , 0 )))
4141
4242
4343def create_pairs (x , digit_indices ):
44- """ Positive and negative pair creation.
45- Alternates between positive and negative pairs.
46- """
44+ ''' Positive and negative pair creation.
45+ Alternates between positive and negative pairs.
46+ '''
4747 pairs = []
4848 labels = []
4949 n = min ([len (digit_indices [d ]) for d in range (10 )]) - 1
@@ -59,11 +59,11 @@ def create_pairs(x, digit_indices):
5959 return np .array (pairs ), np .array (labels )
6060
6161
62- def create_base_network (in_dim ):
63- """ Base network to be shared (eq. to feature extraction).
64- """
62+ def create_base_network (input_dim ):
63+ ''' Base network to be shared (eq. to feature extraction).
64+ '''
6565 seq = Sequential ()
66- seq .add (Dense (128 , input_shape = (in_dim ,), activation = 'relu' ))
66+ seq .add (Dense (128 , input_shape = (input_dim ,), activation = 'relu' ))
6767 seq .add (Dropout (0.1 ))
6868 seq .add (Dense (128 , activation = 'relu' ))
6969 seq .add (Dropout (0.1 ))
@@ -72,8 +72,8 @@ def create_base_network(in_dim):
7272
7373
7474def compute_accuracy (predictions , labels ):
75- """ Compute classification accuracy with a fixed threshold on distances.
76- """
75+ ''' Compute classification accuracy with a fixed threshold on distances.
76+ '''
7777 return labels [predictions .ravel () < 0.5 ].mean ()
7878
7979
@@ -85,7 +85,7 @@ def compute_accuracy(predictions, labels):
8585X_test = X_test .astype ('float32' )
8686X_train /= 255
8787X_test /= 255
88- in_dim = 784
88+ input_dim = 784
8989nb_epoch = 20
9090
9191# create training+test positive and negative pairs
@@ -96,11 +96,11 @@ def compute_accuracy(predictions, labels):
9696te_pairs , te_y = create_pairs (X_test , digit_indices )
9797
9898# network definition
99- base_network = create_base_network (in_dim )
99+ base_network = create_base_network (input_dim )
100100
101101g = Graph ()
102- g .add_input (name = 'input_a' , input_shape = (in_dim ,))
103- g .add_input (name = 'input_b' , input_shape = (in_dim ,))
102+ g .add_input (name = 'input_a' , input_shape = (input_dim ,))
103+ g .add_input (name = 'input_b' , input_shape = (input_dim ,))
104104g .add_shared_node (base_network , name = 'shared' , inputs = ['input_a' , 'input_b' ],
105105 merge_mode = 'join' )
106106g .add_node (Lambda (euclidean_distance ), name = 'd' , input = 'shared' )
@@ -111,7 +111,8 @@ def compute_accuracy(predictions, labels):
111111g .compile (loss = {'output' : contrastive_loss }, optimizer = rms )
112112g .fit ({'input_a' : tr_pairs [:, 0 ], 'input_b' : tr_pairs [:, 1 ], 'output' : tr_y },
113113 validation_data = {'input_a' : te_pairs [:, 0 ], 'input_b' : te_pairs [:, 1 ], 'output' : te_y },
114- batch_size = 128 , nb_epoch = nb_epoch )
114+ batch_size = 128 ,
115+ nb_epoch = nb_epoch )
115116
116117# compute final accuracy on training and test sets
117118pred = g .predict ({'input_a' : tr_pairs [:, 0 ], 'input_b' : tr_pairs [:, 1 ]})['output' ]
0 commit comments