Skip to content

Commit becc5f3

Browse files
committed
Add support for dim_ordering in initializations
1 parent 48ce230 commit becc5f3

File tree

3 files changed

+37
-19
lines changed

3 files changed

+37
-19
lines changed

keras/initializations.py

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,26 @@
33
from . import backend as K
44

55

6-
def get_fans(shape):
7-
fan_in = shape[0] if len(shape) == 2 else np.prod(shape[1:])
8-
fan_out = shape[1] if len(shape) == 2 else shape[0]
6+
def get_fans(shape, dim_ordering='th'):
7+
if len(shape) == 2:
8+
fan_in = shape[0]
9+
fan_out = shape[1]
10+
elif len(shape) == 4 or len(shape) == 5:
11+
# assuming convolution kernels (2D or 3D).
12+
# TH kernel shape: (depth, input_depth, ...)
13+
# TF kernel shape: (..., input_depth, depth)
14+
if dim_ordering == 'th':
15+
fan_in = np.prod(shape[1:])
16+
fan_out = shape[0]
17+
elif dim_ordering == 'tf':
18+
fan_in = np.prod(shape[:-1])
19+
fan_out = shape[-1]
20+
else:
21+
raise Exception('Invalid dim_ordering: ' + dim_ordering)
22+
else:
23+
# no specific assumptions
24+
fan_in = np.sqrt(np.prod(shape))
25+
fan_out = np.sqrt(np.prod(shape))
926
return fan_in, fan_out
1027

1128

@@ -19,39 +36,39 @@ def normal(shape, scale=0.05, name=None):
1936
name=name)
2037

2138

22-
def lecun_uniform(shape, name=None):
39+
def lecun_uniform(shape, name=None, dim_ordering='th'):
2340
''' Reference: LeCun 98, Efficient Backprop
2441
http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf
2542
'''
26-
fan_in, fan_out = get_fans(shape)
43+
fan_in, fan_out = get_fans(shape, dim_ordering=dim_ordering)
2744
scale = np.sqrt(3. / fan_in)
2845
return uniform(shape, scale, name=name)
2946

3047

31-
def glorot_normal(shape, name=None):
48+
def glorot_normal(shape, name=None, dim_ordering='th'):
3249
''' Reference: Glorot & Bengio, AISTATS 2010
3350
'''
34-
fan_in, fan_out = get_fans(shape)
51+
fan_in, fan_out = get_fans(shape, dim_ordering=dim_ordering)
3552
s = np.sqrt(2. / (fan_in + fan_out))
3653
return normal(shape, s, name=name)
3754

3855

39-
def glorot_uniform(shape, name=None):
40-
fan_in, fan_out = get_fans(shape)
56+
def glorot_uniform(shape, name=None, dim_ordering='th'):
57+
fan_in, fan_out = get_fans(shape, dim_ordering=dim_ordering)
4158
s = np.sqrt(6. / (fan_in + fan_out))
4259
return uniform(shape, s, name=name)
4360

4461

45-
def he_normal(shape, name=None):
62+
def he_normal(shape, name=None, dim_ordering='th'):
4663
''' Reference: He et al., http://arxiv.org/abs/1502.01852
4764
'''
48-
fan_in, fan_out = get_fans(shape)
65+
fan_in, fan_out = get_fans(shape, dim_ordering=dim_ordering)
4966
s = np.sqrt(2. / fan_in)
5067
return normal(shape, s, name=name)
5168

5269

53-
def he_uniform(shape, name=None):
54-
fan_in, fan_out = get_fans(shape)
70+
def he_uniform(shape, name=None, dim_ordering='th'):
71+
fan_in, fan_out = get_fans(shape, dim_ordering=dim_ordering)
5572
s = np.sqrt(6. / fan_in)
5673
return uniform(shape, s, name=name)
5774

@@ -85,5 +102,6 @@ def one(shape, name=None):
85102

86103

87104
from .utils.generic_utils import get_from_module
88-
def get(identifier):
89-
return get_from_module(identifier, globals(), 'initialization')
105+
def get(identifier, **kwargs):
106+
return get_from_module(identifier, globals(),
107+
'initialization', kwargs=kwargs)

keras/layers/convolutional.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def __init__(self, nb_filter, filter_length,
7979
raise Exception('Invalid border mode for Convolution1D:', border_mode)
8080
self.nb_filter = nb_filter
8181
self.filter_length = filter_length
82-
self.init = initializations.get(init)
82+
self.init = initializations.get(init, dim_ordering='th')
8383
self.activation = activations.get(activation)
8484
assert border_mode in {'valid', 'same'}, 'border_mode must be in {valid, same}'
8585
self.border_mode = border_mode
@@ -233,7 +233,7 @@ def __init__(self, nb_filter, nb_row, nb_col,
233233
self.nb_filter = nb_filter
234234
self.nb_row = nb_row
235235
self.nb_col = nb_col
236-
self.init = initializations.get(init)
236+
self.init = initializations.get(init, dim_ordering=dim_ordering)
237237
self.activation = activations.get(activation)
238238
assert border_mode in {'valid', 'same'}, 'border_mode must be in {valid, same}'
239239
self.border_mode = border_mode
@@ -412,7 +412,7 @@ def __init__(self, nb_filter, kernel_dim1, kernel_dim2, kernel_dim3,
412412
self.kernel_dim1 = kernel_dim1
413413
self.kernel_dim2 = kernel_dim2
414414
self.kernel_dim3 = kernel_dim3
415-
self.init = initializations.get(init)
415+
self.init = initializations.get(init, dim_ordering=dim_ordering)
416416
self.activation = activations.get(activation)
417417
assert border_mode in {'valid', 'same'}, 'border_mode must be in {valid, same}'
418418
self.border_mode = border_mode

tests/integration_tests/test_image_data_tasks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def test_image_classification():
3535
Activation('relu'),
3636
Dense(y_test.shape[-1], activation='softmax')
3737
])
38-
model.compile(loss='categorical_crossentropy', optimizer='sgd')
38+
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
3939
history = model.fit(X_train, y_train, nb_epoch=10, batch_size=16,
4040
validation_data=(X_test, y_test),
4141
show_accuracy=True, verbose=0)

0 commit comments

Comments
 (0)