Skip to content

Commit 9999177

Browse files
committed
TF fixes and style fixes
1 parent 652f2eb commit 9999177

File tree

1 file changed

+23
-22
lines changed

1 file changed

+23
-22
lines changed

examples/mnist_siamese_graph.py

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
1010
Run on GPU: THEANO_FLAGS=mode=FAST_RUN,device=gpu,floatX=float32 python mnist_siamese_graph.py
1111
12-
Get to 99.5% test accuracy after 20 epochs.
12+
Gets to 99.5% test accuracy after 20 epochs.
1313
3 seconds per epoch on a Titan X GPU
1414
'''
1515
from __future__ import absolute_import
@@ -20,30 +20,30 @@
2020
import random
2121
from keras.datasets import mnist
2222
from keras.models import Sequential, Graph
23-
from keras.layers.core import *
23+
from keras.layers.core import Dense, Dropout, Lambda
2424
from keras.optimizers import SGD, RMSprop
2525
from keras import backend as K
2626

2727

2828
def euclidean_distance(inputs):
29-
assert len(inputs) == 2, \
30-
'Euclidean distance needs 2 inputs, %d given' % len(inputs)
29+
assert len(inputs) == 2, ('Euclidean distance needs '
30+
'2 inputs, %d given' % len(inputs))
3131
u, v = inputs.values()
32-
return K.sqrt((K.square(u - v)).sum(axis=1, keepdims=True))
32+
return K.sqrt(K.sum(K.square(u - v), axis=1, keepdims=True))
3333

3434

3535
def contrastive_loss(y, d):
36-
""" Contrastive loss from Hadsell-et-al.'06
37-
http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
38-
"""
36+
'''Contrastive loss from Hadsell-et-al.'06
37+
http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
38+
'''
3939
margin = 1
4040
return K.mean(y * K.square(d) + (1 - y) * K.square(K.maximum(margin - d, 0)))
4141

4242

4343
def create_pairs(x, digit_indices):
44-
""" Positive and negative pair creation.
45-
Alternates between positive and negative pairs.
46-
"""
44+
'''Positive and negative pair creation.
45+
Alternates between positive and negative pairs.
46+
'''
4747
pairs = []
4848
labels = []
4949
n = min([len(digit_indices[d]) for d in range(10)]) - 1
@@ -59,11 +59,11 @@ def create_pairs(x, digit_indices):
5959
return np.array(pairs), np.array(labels)
6060

6161

62-
def create_base_network(in_dim):
63-
""" Base network to be shared (eq. to feature extraction).
64-
"""
62+
def create_base_network(input_dim):
63+
'''Base network to be shared (eq. to feature extraction).
64+
'''
6565
seq = Sequential()
66-
seq.add(Dense(128, input_shape=(in_dim,), activation='relu'))
66+
seq.add(Dense(128, input_shape=(input_dim,), activation='relu'))
6767
seq.add(Dropout(0.1))
6868
seq.add(Dense(128, activation='relu'))
6969
seq.add(Dropout(0.1))
@@ -72,8 +72,8 @@ def create_base_network(in_dim):
7272

7373

7474
def compute_accuracy(predictions, labels):
75-
""" Compute classification accuracy with a fixed threshold on distances.
76-
"""
75+
'''Compute classification accuracy with a fixed threshold on distances.
76+
'''
7777
return labels[predictions.ravel() < 0.5].mean()
7878

7979

@@ -85,7 +85,7 @@ def compute_accuracy(predictions, labels):
8585
X_test = X_test.astype('float32')
8686
X_train /= 255
8787
X_test /= 255
88-
in_dim = 784
88+
input_dim = 784
8989
nb_epoch = 20
9090

9191
# create training+test positive and negative pairs
@@ -96,11 +96,11 @@ def compute_accuracy(predictions, labels):
9696
te_pairs, te_y = create_pairs(X_test, digit_indices)
9797

9898
# network definition
99-
base_network = create_base_network(in_dim)
99+
base_network = create_base_network(input_dim)
100100

101101
g = Graph()
102-
g.add_input(name='input_a', input_shape=(in_dim,))
103-
g.add_input(name='input_b', input_shape=(in_dim,))
102+
g.add_input(name='input_a', input_shape=(input_dim,))
103+
g.add_input(name='input_b', input_shape=(input_dim,))
104104
g.add_shared_node(base_network, name='shared', inputs=['input_a', 'input_b'],
105105
merge_mode='join')
106106
g.add_node(Lambda(euclidean_distance), name='d', input='shared')
@@ -111,7 +111,8 @@ def compute_accuracy(predictions, labels):
111111
g.compile(loss={'output': contrastive_loss}, optimizer=rms)
112112
g.fit({'input_a': tr_pairs[:, 0], 'input_b': tr_pairs[:, 1], 'output': tr_y},
113113
validation_data={'input_a': te_pairs[:, 0], 'input_b': te_pairs[:, 1], 'output': te_y},
114-
batch_size=128, nb_epoch=nb_epoch)
114+
batch_size=128,
115+
nb_epoch=nb_epoch)
115116

116117
# compute final accuracy on training and test sets
117118
pred = g.predict({'input_a': tr_pairs[:, 0], 'input_b': tr_pairs[:, 1]})['output']

0 commit comments

Comments
 (0)