33from . import backend as K
44
55
6- def get_fans (shape ):
7- fan_in = shape [0 ] if len (shape ) == 2 else np .prod (shape [1 :])
8- fan_out = shape [1 ] if len (shape ) == 2 else shape [0 ]
6+ def get_fans (shape , dim_ordering = 'th' ):
7+ if len (shape ) == 2 :
8+ fan_in = shape [0 ]
9+ fan_out = shape [1 ]
10+ elif len (shape ) == 4 or len (shape ) == 5 :
11+ # assuming convolution kernels (2D or 3D).
12+ # TH kernel shape: (depth, input_depth, ...)
13+ # TF kernel shape: (..., input_depth, depth)
14+ if dim_ordering == 'th' :
15+ fan_in = np .prod (shape [1 :])
16+ fan_out = shape [0 ]
17+ elif dim_ordering == 'tf' :
18+ fan_in = np .prod (shape [:- 1 ])
19+ fan_out = shape [- 1 ]
20+ else :
21+ raise Exception ('Invalid dim_ordering: ' + dim_ordering )
22+ else :
23+ # no specific assumptions
24+ fan_in = np .sqrt (np .prod (shape ))
25+ fan_out = np .sqrt (np .prod (shape ))
926 return fan_in , fan_out
1027
1128
@@ -19,39 +36,39 @@ def normal(shape, scale=0.05, name=None):
1936 name = name )
2037
2138
22- def lecun_uniform (shape , name = None ):
39+ def lecun_uniform (shape , name = None , dim_ordering = 'th' ):
2340 ''' Reference: LeCun 98, Efficient Backprop
2441 http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf
2542 '''
26- fan_in , fan_out = get_fans (shape )
43+ fan_in , fan_out = get_fans (shape , dim_ordering = dim_ordering )
2744 scale = np .sqrt (3. / fan_in )
2845 return uniform (shape , scale , name = name )
2946
3047
31- def glorot_normal (shape , name = None ):
48+ def glorot_normal (shape , name = None , dim_ordering = 'th' ):
3249 ''' Reference: Glorot & Bengio, AISTATS 2010
3350 '''
34- fan_in , fan_out = get_fans (shape )
51+ fan_in , fan_out = get_fans (shape , dim_ordering = dim_ordering )
3552 s = np .sqrt (2. / (fan_in + fan_out ))
3653 return normal (shape , s , name = name )
3754
3855
39- def glorot_uniform (shape , name = None ):
40- fan_in , fan_out = get_fans (shape )
56+ def glorot_uniform (shape , name = None , dim_ordering = 'th' ):
57+ fan_in , fan_out = get_fans (shape , dim_ordering = dim_ordering )
4158 s = np .sqrt (6. / (fan_in + fan_out ))
4259 return uniform (shape , s , name = name )
4360
4461
45- def he_normal (shape , name = None ):
62+ def he_normal (shape , name = None , dim_ordering = 'th' ):
4663 ''' Reference: He et al., http://arxiv.org/abs/1502.01852
4764 '''
48- fan_in , fan_out = get_fans (shape )
65+ fan_in , fan_out = get_fans (shape , dim_ordering = dim_ordering )
4966 s = np .sqrt (2. / fan_in )
5067 return normal (shape , s , name = name )
5168
5269
53- def he_uniform (shape , name = None ):
54- fan_in , fan_out = get_fans (shape )
70+ def he_uniform (shape , name = None , dim_ordering = 'th' ):
71+ fan_in , fan_out = get_fans (shape , dim_ordering = dim_ordering )
5572 s = np .sqrt (6. / fan_in )
5673 return uniform (shape , s , name = name )
5774
@@ -85,5 +102,6 @@ def one(shape, name=None):
85102
86103
87104from .utils .generic_utils import get_from_module
88- def get (identifier ):
89- return get_from_module (identifier , globals (), 'initialization' )
105+ def get (identifier , ** kwargs ):
106+ return get_from_module (identifier , globals (),
107+ 'initialization' , kwargs = kwargs )
0 commit comments