Skip to content

Commit 0ed00e3

Browse files
committed
Add inception v3 example
1 parent 36eef0d commit 0ed00e3

File tree

1 file changed

+290
-0
lines changed

1 file changed

+290
-0
lines changed

examples/inception_v3.py

Lines changed: 290 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,290 @@
1+
'''This script demonstrates how to build the Inception v3 architecture
2+
using the Keras functional API.
3+
We are not actually training it here, for lack of appropriate data.
4+
5+
For more information about this architecture, see:
6+
7+
"Rethinking the Inception Architecture for Computer Vision"
8+
Christian Szegedy, Vincent Vanhoucke, Sergey Ioffe, Jonathon Shlens, Zbigniew Wojna
9+
http://arxiv.org/abs/1512.00567
10+
'''
11+
from keras.layers import Convolution2D, MaxPooling2D, AveragePooling2D
12+
from keras.layers import BatchNormalization, Flatten, Dense, Dropout
13+
from keras.layers import Input, merge
14+
from keras.models import Model
15+
from keras import regularizers
16+
17+
18+
# global constants
19+
NB_CLASS = 1000 # number of classes
20+
DIM_ORDERING = 'th' # 'th' (channels, width, height) or 'tf' (width, height, channels)
21+
WEIGHT_DECAY = 0. # L2 regularization factor
22+
USE_BN = False # whether to use batch normalization
23+
24+
25+
def conv2D_bn(x, nb_filter, nb_row, nb_col,
26+
border_mode='same', subsample=(1, 1),
27+
activation='relu', batch_norm=USE_BN,
28+
weight_decay=WEIGHT_DECAY, dim_ordering=DIM_ORDERING):
29+
'''Utility function to apply to a tensor a module conv + BN
30+
with optional weight decay (L2 weight regularization).
31+
'''
32+
if weight_decay:
33+
W_regularizer = regularizers.l2(weight_decay)
34+
b_regularizer = regularizers.l2(weight_decay)
35+
else:
36+
W_regularizer = None
37+
b_regularizer = None
38+
x = Convolution2D(nb_filter, nb_row, nb_col,
39+
subsample=subsample,
40+
activation=activation,
41+
border_mode=border_mode,
42+
W_regularizer=W_regularizer,
43+
b_regularizer=b_regularizer,
44+
dim_ordering=dim_ordering)(x)
45+
if batch_norm:
46+
x = BatchNormalization()(x)
47+
return x
48+
49+
# Define image input layer
50+
51+
if DIM_ORDERING == 'th':
52+
img_input = Input(shape=(3, 299, 299))
53+
CONCAT_AXIS = 1
54+
elif DIM_ORDERING == 'tf':
55+
img_input = Input(shape=(299, 299, 3))
56+
CONCAT_AXIS = 3
57+
else:
58+
raise Exception('Invalid dim ordering: ' + str(DIM_ORDERING))
59+
60+
# Entry module
61+
62+
x = conv2D_bn(img_input, 32, 3, 3, subsample=(2, 2), border_mode='valid')
63+
x = conv2D_bn(x, 32, 3, 3, border_mode='valid')
64+
x = conv2D_bn(x, 64, 3, 3)
65+
x = MaxPooling2D((3, 3), strides=(2, 2), dim_ordering=DIM_ORDERING)(x)
66+
67+
x = conv2D_bn(x, 80, 1, 1, border_mode='valid')
68+
x = conv2D_bn(x, 192, 3, 3, border_mode='valid')
69+
x = MaxPooling2D((3, 3), strides=(2, 2), dim_ordering=DIM_ORDERING)(x)
70+
71+
# mixed: 35 x 35 x 256
72+
73+
branch1x1 = conv2D_bn(x, 64, 1, 1)
74+
75+
branch5x5 = conv2D_bn(x, 48, 1, 1)
76+
branch5x5 = conv2D_bn(branch5x5, 64, 5, 5)
77+
78+
branch3x3dbl = conv2D_bn(x, 64, 1, 1)
79+
branch3x3dbl = conv2D_bn(branch3x3dbl, 96, 3, 3)
80+
branch3x3dbl = conv2D_bn(branch3x3dbl, 96, 3, 3)
81+
82+
branch_pool = AveragePooling2D((3, 3), strides=(1, 1), border_mode='same', dim_ordering=DIM_ORDERING)(x)
83+
branch_pool = conv2D_bn(branch_pool, 32, 1, 1)
84+
x = merge([branch1x1, branch5x5, branch3x3dbl, branch_pool], mode='concat', concat_axis=CONCAT_AXIS)
85+
86+
# mixed_1: 35 x 35 x 288
87+
88+
branch1x1 = conv2D_bn(x, 64, 1, 1)
89+
90+
branch5x5 = conv2D_bn(x, 48, 1, 1)
91+
branch5x5 = conv2D_bn(branch5x5, 64, 5, 5)
92+
93+
branch3x3dbl = conv2D_bn(x, 64, 1, 1)
94+
branch3x3dbl = conv2D_bn(branch3x3dbl, 96, 3, 3)
95+
branch3x3dbl = conv2D_bn(branch3x3dbl, 96, 3, 3)
96+
97+
branch_pool = AveragePooling2D((3, 3), strides=(1, 1), border_mode='same', dim_ordering=DIM_ORDERING)(x)
98+
branch_pool = conv2D_bn(branch_pool, 32, 1, 1)
99+
x = merge([branch1x1, branch5x5, branch3x3dbl, branch_pool], mode='concat', concat_axis=CONCAT_AXIS)
100+
101+
# mixed2: 35 x 35 x 288
102+
103+
branch1x1 = conv2D_bn(x, 64, 1, 1)
104+
105+
branch5x5 = conv2D_bn(x, 48, 1, 1)
106+
branch5x5 = conv2D_bn(branch5x5, 64, 5, 5)
107+
108+
branch3x3dbl = conv2D_bn(x, 64, 1, 1)
109+
branch3x3dbl = conv2D_bn(branch3x3dbl, 96, 3, 3)
110+
branch3x3dbl = conv2D_bn(branch3x3dbl, 96, 3, 3)
111+
112+
branch_pool = AveragePooling2D((3, 3), strides=(1, 1), border_mode='same', dim_ordering=DIM_ORDERING)(x)
113+
branch_pool = conv2D_bn(branch_pool, 64, 1, 1)
114+
x = merge([branch1x1, branch5x5, branch3x3dbl, branch_pool], mode='concat', concat_axis=CONCAT_AXIS)
115+
116+
# mixed3: 17 x 17 x 768
117+
118+
branch3x3 = conv2D_bn(x, 384, 3, 3, subsample=(2, 2), border_mode='valid')
119+
120+
branch3x3dbl = conv2D_bn(x, 64, 1, 1)
121+
branch3x3dbl = conv2D_bn(branch3x3dbl, 96, 3, 3)
122+
branch3x3dbl = conv2D_bn(branch3x3dbl, 96, 3, 3, subsample=(2, 2), border_mode='valid')
123+
124+
branch_pool = MaxPooling2D((3, 3), strides=(2, 2), dim_ordering=DIM_ORDERING)(x)
125+
x = merge([branch3x3, branch3x3dbl, branch_pool], mode='concat', concat_axis=CONCAT_AXIS)
126+
127+
# mixed4: 17 x 17 x 768
128+
129+
branch1x1 = conv2D_bn(x, 192, 1, 1)
130+
131+
branch7x7 = conv2D_bn(x, 128, 1, 1)
132+
branch7x7 = conv2D_bn(branch7x7, 128, 1, 7)
133+
branch7x7 = conv2D_bn(branch7x7, 192, 7, 1)
134+
135+
branch7x7dbl = conv2D_bn(x, 128, 1, 1)
136+
branch7x7dbl = conv2D_bn(branch7x7dbl, 128, 7, 1)
137+
branch7x7dbl = conv2D_bn(branch7x7dbl, 128, 1, 7)
138+
branch7x7dbl = conv2D_bn(branch7x7dbl, 128, 7, 1)
139+
branch7x7dbl = conv2D_bn(branch7x7dbl, 128, 1, 7)
140+
141+
branch_pool = AveragePooling2D((3, 3), strides=(1, 1), border_mode='same', dim_ordering=DIM_ORDERING)(x)
142+
branch_pool = conv2D_bn(branch_pool, 192, 1, 1)
143+
x = merge([branch1x1, branch7x7, branch7x7dbl, branch_pool], mode='concat', concat_axis=CONCAT_AXIS)
144+
145+
# mixed5: 17 x 17 x 768
146+
147+
branch1x1 = conv2D_bn(x, 192, 1, 1)
148+
149+
branch7x7 = conv2D_bn(x, 160, 1, 1)
150+
branch7x7 = conv2D_bn(branch7x7, 160, 1, 7)
151+
branch7x7 = conv2D_bn(branch7x7, 192, 7, 1)
152+
153+
branch7x7dbl = conv2D_bn(x, 160, 1, 1)
154+
branch7x7dbl = conv2D_bn(branch7x7dbl, 160, 7, 1)
155+
branch7x7dbl = conv2D_bn(branch7x7dbl, 192, 1, 7)
156+
branch7x7dbl = conv2D_bn(branch7x7dbl, 160, 7, 1)
157+
branch7x7dbl = conv2D_bn(branch7x7dbl, 192, 1, 7)
158+
159+
branch_pool = AveragePooling2D((3, 3), strides=(1, 1), border_mode='same', dim_ordering=DIM_ORDERING)(x)
160+
branch_pool = conv2D_bn(branch_pool, 192, 1, 1)
161+
x = merge([branch1x1, branch7x7, branch7x7dbl, branch_pool], mode='concat', concat_axis=CONCAT_AXIS)
162+
163+
# mixed5: 17 x 17 x 768
164+
165+
branch1x1 = conv2D_bn(x, 192, 1, 1)
166+
167+
branch7x7 = conv2D_bn(x, 160, 1, 1)
168+
branch7x7 = conv2D_bn(branch7x7, 160, 1, 7)
169+
branch7x7 = conv2D_bn(branch7x7, 192, 7, 1)
170+
171+
branch7x7dbl = conv2D_bn(x, 160, 1, 1)
172+
branch7x7dbl = conv2D_bn(branch7x7dbl, 160, 7, 1)
173+
branch7x7dbl = conv2D_bn(branch7x7dbl, 192, 1, 7)
174+
branch7x7dbl = conv2D_bn(branch7x7dbl, 160, 7, 1)
175+
branch7x7dbl = conv2D_bn(branch7x7dbl, 192, 1, 7)
176+
177+
branch_pool = AveragePooling2D((3, 3), strides=(1, 1), border_mode='same', dim_ordering=DIM_ORDERING)(x)
178+
branch_pool = conv2D_bn(branch_pool, 192, 1, 1)
179+
x = merge([branch1x1, branch7x7, branch7x7dbl, branch_pool], mode='concat', concat_axis=CONCAT_AXIS)
180+
181+
# mixed6: 17 x 17 x 768
182+
183+
branch1x1 = conv2D_bn(x, 192, 1, 1)
184+
185+
branch7x7 = conv2D_bn(x, 160, 1, 1)
186+
branch7x7 = conv2D_bn(branch7x7, 160, 1, 7)
187+
branch7x7 = conv2D_bn(branch7x7, 192, 7, 1)
188+
189+
branch7x7dbl = conv2D_bn(x, 160, 1, 1)
190+
branch7x7dbl = conv2D_bn(branch7x7dbl, 160, 7, 1)
191+
branch7x7dbl = conv2D_bn(branch7x7dbl, 192, 1, 7)
192+
branch7x7dbl = conv2D_bn(branch7x7dbl, 160, 7, 1)
193+
branch7x7dbl = conv2D_bn(branch7x7dbl, 192, 1, 7)
194+
195+
branch_pool = AveragePooling2D((3, 3), strides=(1, 1), border_mode='same', dim_ordering=DIM_ORDERING)(x)
196+
branch_pool = conv2D_bn(branch_pool, 192, 1, 1)
197+
x = merge([branch1x1, branch7x7, branch7x7dbl, branch_pool], mode='concat', concat_axis=CONCAT_AXIS)
198+
199+
# mixed7: 17 x 17 x 768
200+
201+
branch1x1 = conv2D_bn(x, 192, 1, 1)
202+
203+
branch7x7 = conv2D_bn(x, 192, 1, 1)
204+
branch7x7 = conv2D_bn(branch7x7, 192, 1, 7)
205+
branch7x7 = conv2D_bn(branch7x7, 192, 7, 1)
206+
207+
branch7x7dbl = conv2D_bn(x, 160, 1, 1)
208+
branch7x7dbl = conv2D_bn(branch7x7dbl, 192, 7, 1)
209+
branch7x7dbl = conv2D_bn(branch7x7dbl, 192, 1, 7)
210+
branch7x7dbl = conv2D_bn(branch7x7dbl, 192, 7, 1)
211+
branch7x7dbl = conv2D_bn(branch7x7dbl, 192, 1, 7)
212+
213+
branch_pool = AveragePooling2D((3, 3), strides=(1, 1), border_mode='same', dim_ordering=DIM_ORDERING)(x)
214+
branch_pool = conv2D_bn(branch_pool, 192, 1, 1)
215+
x = merge([branch1x1, branch7x7, branch7x7dbl, branch_pool], mode='concat', concat_axis=CONCAT_AXIS)
216+
217+
# Auxiliary head
218+
219+
aux_logits = AveragePooling2D((5, 5), strides=(3, 3), dim_ordering=DIM_ORDERING)(x)
220+
aux_logits = conv2D_bn(aux_logits, 128, 1, 1)
221+
aux_logits = conv2D_bn(aux_logits, 728, 5, 5, border_mode='valid')
222+
aux_logits = Flatten()(aux_logits)
223+
aux_preds = Dense(NB_CLASS, activation='softmax')(aux_logits)
224+
225+
# mixed8: 8 x 8 x 1280
226+
227+
branch3x3 = conv2D_bn(x, 192, 1, 1)
228+
branch3x3 = conv2D_bn(branch3x3, 192, 3, 3, subsample=(2, 2), border_mode='valid')
229+
230+
branch7x7x3 = conv2D_bn(x, 192, 1, 1)
231+
branch7x7x3 = conv2D_bn(branch7x7x3, 192, 1, 7)
232+
branch7x7x3 = conv2D_bn(branch7x7x3, 192, 7, 1)
233+
branch7x7x3 = conv2D_bn(branch7x7x3, 192, 3, 3, subsample=(2, 2), border_mode='valid')
234+
235+
branch_pool = AveragePooling2D((3, 3), strides=(2, 2), dim_ordering=DIM_ORDERING)(x)
236+
x = merge([branch3x3, branch7x7x3, branch_pool], mode='concat', concat_axis=CONCAT_AXIS)
237+
238+
# mixed9: 8 x 8 x 2048
239+
240+
branch1x1 = conv2D_bn(x, 320, 1, 1)
241+
242+
branch3x3 = conv2D_bn(x, 384, 1, 1)
243+
branch3x3_1 = conv2D_bn(branch3x3, 384, 1, 3)
244+
branch3x3_2 = conv2D_bn(branch3x3, 384, 3, 1)
245+
branch3x3 = merge([branch3x3_1, branch3x3_2], mode='concat', concat_axis=CONCAT_AXIS)
246+
247+
branch3x3dbl = conv2D_bn(x, 448, 1, 1)
248+
branch3x3dbl = conv2D_bn(branch3x3dbl, 384, 3, 3)
249+
branch3x3dbl_1 = conv2D_bn(branch3x3dbl, 384, 1, 3)
250+
branch3x3dbl_2 = conv2D_bn(branch3x3dbl, 384, 3, 1)
251+
branch3x3dbl = merge([branch3x3dbl_1, branch3x3dbl_2], mode='concat', concat_axis=CONCAT_AXIS)
252+
253+
branch_pool = AveragePooling2D((3, 3), strides=(1, 1), border_mode='same', dim_ordering=DIM_ORDERING)(x)
254+
branch_pool = conv2D_bn(branch_pool, 192, 1, 1)
255+
x = merge([branch1x1, branch3x3, branch3x3dbl, branch_pool], mode='concat', concat_axis=CONCAT_AXIS)
256+
257+
# mixed10: 8 x 8 x 2048
258+
259+
branch1x1 = conv2D_bn(x, 320, 1, 1)
260+
261+
branch3x3 = conv2D_bn(x, 384, 1, 1)
262+
branch3x3_1 = conv2D_bn(branch3x3, 384, 1, 3)
263+
branch3x3_2 = conv2D_bn(branch3x3, 384, 3, 1)
264+
branch3x3 = merge([branch3x3_1, branch3x3_2], mode='concat', concat_axis=CONCAT_AXIS)
265+
266+
branch3x3dbl = conv2D_bn(x, 448, 1, 1)
267+
branch3x3dbl = conv2D_bn(branch3x3dbl, 384, 3, 3)
268+
branch3x3dbl_1 = conv2D_bn(branch3x3dbl, 384, 1, 3)
269+
branch3x3dbl_2 = conv2D_bn(branch3x3dbl, 384, 3, 1)
270+
branch3x3dbl = merge([branch3x3dbl_1, branch3x3dbl_2], mode='concat', concat_axis=CONCAT_AXIS)
271+
272+
branch_pool = AveragePooling2D((3, 3), strides=(1, 1), border_mode='same', dim_ordering=DIM_ORDERING)(x)
273+
branch_pool = conv2D_bn(branch_pool, 192, 1, 1)
274+
x = merge([branch1x1, branch3x3, branch3x3dbl, branch_pool], mode='concat', concat_axis=CONCAT_AXIS)
275+
276+
# Final pooling and prediction
277+
278+
x = AveragePooling2D((8, 8), strides=(1, 1), dim_ordering=DIM_ORDERING)(x)
279+
x = Dropout(0.5)(x)
280+
x = Flatten()(x)
281+
preds = Dense(NB_CLASS, activation='softmax')(x)
282+
283+
# Define model
284+
285+
model = Model(input=img_input, output=[preds, aux_preds])
286+
model.compile('rmsprop', 'categorical_crossentropy')
287+
288+
# train via e.g. `model.fit(x_train, [y_train] * 2, batch_size=32, nb_epoch=100)`
289+
# Note that for a large dataset it would be preferable
290+
# to train using `fit_generator` (see Keras docs).

0 commit comments

Comments
 (0)