Skip to content

Commit 943d2d4

Browse files
committed
Add native names to weight tensors
1 parent c4361d2 commit 943d2d4

File tree

6 files changed

+116
-69
lines changed

6 files changed

+116
-69
lines changed

keras/layers/advanced_activations.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ def __init__(self, init='zero', weights=None, **kwargs):
5959

6060
def build(self):
6161
input_shape = self.input_shape[1:]
62-
self.alphas = self.init(input_shape)
62+
self.alphas = self.init(input_shape,
63+
name='{}_alphas'.format(self.name))
6364
self.trainable_weights = [self.alphas]
6465

6566
if self.initial_weights is not None:
@@ -140,8 +141,10 @@ def __init__(self, alpha_init=0.2, beta_init=5.0,
140141

141142
def build(self):
142143
input_shape = self.input_shape[1:]
143-
self.alphas = K.variable(self.alpha_init * np.ones(input_shape))
144-
self.betas = K.variable(self.beta_init * np.ones(input_shape))
144+
self.alphas = K.variable(self.alpha_init * np.ones(input_shape),
145+
name='{}_alphas'.format(self.name))
146+
self.betas = K.variable(self.beta_init * np.ones(input_shape),
147+
name='{}_betas'.format(self.name))
145148
self.trainable_weights = [self.alphas, self.betas]
146149

147150
if self.initial_weights is not None:
@@ -246,25 +249,32 @@ class SReLU(MaskedLayer):
246249
'''
247250
def __init__(self, t_left_init='zero', a_left_init='glorot_uniform',
248251
t_right_init='glorot_uniform', a_right_init='one', **kwargs):
249-
self.t_left_init = t_left_init
250-
self.a_left_init = a_left_init
251-
self.t_right_init = t_right_init
252-
self.a_right_init = a_right_init
252+
self.t_left_init = initializations.get(t_left_init)
253+
self.a_left_init = initializations.get(a_left_init)
254+
self.t_right_init = initializations.get(t_right_init)
255+
self.a_right_init = initializations.get(a_right_init)
253256
super(SReLU, self).__init__(**kwargs)
254257

255258
def build(self):
256-
params_shape = self.input_shape[1:]
257-
self.t_left = initializations.get(self.t_left_init)(params_shape)
258-
self.a_left = initializations.get(self.a_left_init)(params_shape)
259-
self.t_right = initializations.get(self.t_right_init)(params_shape)
260-
self.a_right = initializations.get(self.a_right_init)(params_shape)
259+
input_shape = self.input_shape[1:]
260+
self.t_left = self.t_left_init(input_shape,
261+
name='{}_t_left'.format(self.name))
262+
self.a_left = self.a_left_init(input_shape,
263+
name='{}_a_left'.format(self.name))
264+
self.t_right = self.t_right_init(input_shape,
265+
name='{}_t_right'.format(self.name))
266+
self.a_right = self.a_right_init(input_shape,
267+
name='{}_a_right'.format(self.name))
261268
# ensure the the right part is always to the right of the left
262269
self.t_right_actual = self.t_left + abs(self.t_right)
263-
self.trainable_weights = [self.t_left, self.a_left, self.t_right, self.a_right]
270+
self.trainable_weights = [self.t_left, self.a_left,
271+
self.t_right, self.a_right]
264272

265273
def get_output(self, train=False):
266274
X = self.get_input(train)
267-
Y_left_and_center = self.t_left + K.relu(X - self.t_left, self.a_left, self.t_right_actual - self.t_left)
275+
Y_left_and_center = self.t_left + K.relu(X - self.t_left,
276+
self.a_left,
277+
self.t_right_actual - self.t_left)
268278
Y_right = K.relu(X - self.t_right_actual) * self.a_right
269279
return Y_left_and_center + Y_right
270280

keras/layers/convolutional.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,8 @@ def __init__(self, nb_filter, filter_length,
106106
def build(self):
107107
input_dim = self.input_shape[2]
108108
self.W_shape = (self.nb_filter, input_dim, self.filter_length, 1)
109-
self.W = self.init(self.W_shape)
110-
self.b = K.zeros((self.nb_filter,))
109+
self.W = self.init(self.W_shape, name='{}_W'.format(self.name))
110+
self.b = K.zeros((self.nb_filter,), name='{}_b'.format(self.name))
111111
self.trainable_weights = [self.W, self.b]
112112
self.regularizers = []
113113

@@ -261,8 +261,8 @@ def build(self):
261261
self.W_shape = (self.nb_row, self.nb_col, stack_size, self.nb_filter)
262262
else:
263263
raise Exception('Invalid dim_ordering: ' + self.dim_ordering)
264-
self.W = self.init(self.W_shape)
265-
self.b = K.zeros((self.nb_filter,))
264+
self.W = self.init(self.W_shape, name='{}_W'.format(self.name))
265+
self.b = K.zeros((self.nb_filter,), name='{}_b'.format(self.name))
266266
self.trainable_weights = [self.W, self.b]
267267
self.regularizers = []
268268

@@ -444,8 +444,8 @@ def build(self):
444444
else:
445445
raise Exception('Invalid dim_ordering: ' + self.dim_ordering)
446446

447-
self.W = self.init(self.W_shape)
448-
self.b = K.zeros((self.nb_filter,))
447+
self.W = self.init(self.W_shape, name='{}_W'.format(self.name))
448+
self.b = K.zeros((self.nb_filter,), name='{}_b'.format(self.name))
449449
self.trainable_weights = [self.W, self.b]
450450
self.regularizers = []
451451

keras/layers/core.py

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -45,19 +45,26 @@ def __init__(self, **kwargs):
4545
'name'}
4646
for kwarg in kwargs:
4747
assert kwarg in allowed_kwargs, 'Keyword argument not understood: ' + kwarg
48+
49+
if 'name' in kwargs:
50+
self.name = kwargs['name']
51+
else:
52+
self.name = self.__class__.__name__.lower()
53+
54+
if 'cache_enabled' in kwargs:
55+
self.cache_enabled = kwargs['cache_enabled']
56+
else:
57+
self.cache_enabled = True
58+
4859
if 'batch_input_shape' in kwargs:
4960
self.set_input_shape(tuple(kwargs['batch_input_shape']))
5061
elif 'input_shape' in kwargs:
5162
self.set_input_shape((None,) + tuple(kwargs['input_shape']))
52-
self.trainable = True
63+
5364
if 'trainable' in kwargs:
5465
self.trainable = kwargs['trainable']
55-
self.name = self.__class__.__name__.lower()
56-
if 'name' in kwargs:
57-
self.name = kwargs['name']
58-
self.cache_enabled = True
59-
if 'cache_enabled' in kwargs:
60-
self.cache_enabled = kwargs['cache_enabled']
66+
else:
67+
self.trainable = True
6168

6269
@property
6370
def name(self):
@@ -1003,8 +1010,10 @@ def __init__(self, output_dim, init='glorot_uniform', activation='linear', weigh
10031010
def build(self):
10041011
input_dim = self.input_shape[1]
10051012

1006-
self.W = self.init((input_dim, self.output_dim))
1007-
self.b = K.zeros((self.output_dim,))
1013+
self.W = self.init((input_dim, self.output_dim),
1014+
name='{}_W'.format(self.name))
1015+
self.b = K.zeros((self.output_dim,),
1016+
name='{}_b'.format(self.name))
10081017

10091018
self.trainable_weights = [self.W, self.b]
10101019

@@ -1117,8 +1126,10 @@ def __init__(self, output_dim,
11171126
def build(self):
11181127
input_dim = self.input_shape[2]
11191128

1120-
self.W = self.init((input_dim, self.output_dim))
1121-
self.b = K.zeros((self.output_dim,))
1129+
self.W = self.init((input_dim, self.output_dim),
1130+
name='{}_W'.format(self.name))
1131+
self.b = K.zeros((self.output_dim,),
1132+
name='{}_b'.format(self.name))
11221133

11231134
self.trainable_weights = [self.W, self.b]
11241135
self.regularizers = []
@@ -1422,8 +1433,10 @@ def __init__(self, output_dim, nb_feature=4,
14221433
def build(self):
14231434
input_dim = self.input_shape[1]
14241435

1425-
self.W = self.init((self.nb_feature, input_dim, self.output_dim))
1426-
self.b = K.zeros((self.nb_feature, self.output_dim))
1436+
self.W = self.init((self.nb_feature, input_dim, self.output_dim),
1437+
name='{}_W'.format(self.name))
1438+
self.b = K.zeros((self.nb_feature, self.output_dim),
1439+
name='{}_b'.format(self.name))
14271440

14281441
self.trainable_weights = [self.W, self.b]
14291442
self.regularizers = []
@@ -1995,12 +2008,15 @@ def __init__(self, init='glorot_uniform', transform_bias=-2,
19952008
def build(self):
19962009
input_dim = self.input_shape[1]
19972010

1998-
self.W = self.init((input_dim, input_dim))
1999-
self.W_carry = self.init((input_dim, input_dim))
2011+
self.W = self.init((input_dim, input_dim),
2012+
name='{}_W'.format(self.name))
2013+
self.W_carry = self.init((input_dim, input_dim),
2014+
name='{}_W_carry'.format(self.name))
20002015

2001-
self.b = K.zeros((input_dim,))
2016+
self.b = K.zeros((input_dim,), name='{}_b'.format(self.name))
20022017
# initialize with a vector of values `transform_bias`
2003-
self.b_carry = K.variable(np.ones((input_dim,)) * self.transform_bias)
2018+
self.b_carry = K.variable(np.ones((input_dim,)) * self.transform_bias,
2019+
name='{}_b_carry'.format(self.name))
20042020

20052021
self.trainable_weights = [self.W, self.b, self.W_carry, self.b_carry]
20062022

keras/layers/embeddings.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,8 @@ def __init__(self, input_dim, output_dim,
7373
def build(self):
7474
self.input = K.placeholder(shape=(self.input_shape[0], self.input_length),
7575
dtype='int32')
76-
self.W = self.init((self.input_dim, self.output_dim))
76+
self.W = self.init((self.input_dim, self.output_dim),
77+
name='{}_W'.format(self.name))
7778
self.trainable_weights = [self.W]
7879
self.regularizers = []
7980
if self.W_regularizer:

keras/layers/normalization.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,15 @@ def build(self):
5555
input_shape = self.input_shape # starts with samples axis
5656
shape = (input_shape[self.axis],)
5757

58-
self.gamma = self.init(shape)
59-
self.beta = K.zeros(shape)
58+
self.gamma = self.init(shape,
59+
name='{}_gamma'.format(self.name))
60+
self.beta = K.zeros(shape, name='{}_beta'.format(self.name))
6061
self.trainable_weights = [self.gamma, self.beta]
6162

62-
self.running_mean = K.zeros(shape)
63-
self.running_std = K.ones(shape)
63+
self.running_mean = K.zeros(shape,
64+
name='{}_running_mean'.format(self.name))
65+
self.running_std = K.ones(shape,
66+
name='{}_running_std'.format(self.name))
6467
self.non_trainable_weights = [self.running_mean, self.running_std]
6568

6669
if self.initial_weights is not None:

keras/layers/recurrent.py

Lines changed: 44 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -261,9 +261,11 @@ def build(self):
261261
input_dim = input_shape[2]
262262
self.input_dim = input_dim
263263

264-
self.W = self.init((input_dim, self.output_dim))
265-
self.U = self.inner_init((self.output_dim, self.output_dim))
266-
self.b = K.zeros((self.output_dim,))
264+
self.W = self.init((input_dim, self.output_dim),
265+
name='{}_W'.format(self.name))
266+
self.U = self.inner_init((self.output_dim, self.output_dim),
267+
name='{}_U'.format(self.name))
268+
self.b = K.zeros((self.output_dim,), name='{}_b'.format(self.name))
267269

268270
def append_regulariser(input_regulariser, param, regularizers_list):
269271
regulariser = regularizers.get(input_regulariser)
@@ -384,17 +386,23 @@ def build(self):
384386
input_dim = input_shape[2]
385387
self.input_dim = input_dim
386388

387-
self.W_z = self.init((input_dim, self.output_dim))
388-
self.U_z = self.inner_init((self.output_dim, self.output_dim))
389-
self.b_z = K.zeros((self.output_dim,))
389+
self.W_z = self.init((input_dim, self.output_dim),
390+
name='{}_W_z'.format(self.name))
391+
self.U_z = self.inner_init((self.output_dim, self.output_dim),
392+
name='{}_U_z'.format(self.name))
393+
self.b_z = K.zeros((self.output_dim,), name='{}_b_z'.format(self.name))
390394

391-
self.W_r = self.init((input_dim, self.output_dim))
392-
self.U_r = self.inner_init((self.output_dim, self.output_dim))
393-
self.b_r = K.zeros((self.output_dim,))
395+
self.W_r = self.init((input_dim, self.output_dim),
396+
name='{}_W_r'.format(self.name))
397+
self.U_r = self.inner_init((self.output_dim, self.output_dim),
398+
name='{}_U_r'.format(self.name))
399+
self.b_r = K.zeros((self.output_dim,), name='{}_b_r'.format(self.name))
394400

395-
self.W_h = self.init((input_dim, self.output_dim))
396-
self.U_h = self.inner_init((self.output_dim, self.output_dim))
397-
self.b_h = K.zeros((self.output_dim,))
401+
self.W_h = self.init((input_dim, self.output_dim),
402+
name='{}_W_h'.format(self.name))
403+
self.U_h = self.inner_init((self.output_dim, self.output_dim),
404+
name='{}_U_h'.format(self.name))
405+
self.b_h = K.zeros((self.output_dim,), name='{}_b_h'.format(self.name))
398406

399407
def append_regulariser(input_regulariser, param, regularizers_list):
400408
regulariser = regularizers.get(input_regulariser)
@@ -556,21 +564,30 @@ def build(self):
556564
# initial states: 2 all-zero tensors of shape (output_dim)
557565
self.states = [None, None]
558566

559-
self.W_i = self.init((input_dim, self.output_dim))
560-
self.U_i = self.inner_init((self.output_dim, self.output_dim))
561-
self.b_i = K.zeros((self.output_dim,))
562-
563-
self.W_f = self.init((input_dim, self.output_dim))
564-
self.U_f = self.inner_init((self.output_dim, self.output_dim))
565-
self.b_f = self.forget_bias_init((self.output_dim,))
566-
567-
self.W_c = self.init((input_dim, self.output_dim))
568-
self.U_c = self.inner_init((self.output_dim, self.output_dim))
569-
self.b_c = K.zeros((self.output_dim,))
570-
571-
self.W_o = self.init((input_dim, self.output_dim))
572-
self.U_o = self.inner_init((self.output_dim, self.output_dim))
573-
self.b_o = K.zeros((self.output_dim,))
567+
self.W_i = self.init((input_dim, self.output_dim),
568+
name='{}_W_i'.format(self.name))
569+
self.U_i = self.inner_init((self.output_dim, self.output_dim),
570+
name='{}_U_i'.format(self.name))
571+
self.b_i = K.zeros((self.output_dim,), name='{}_b_i'.format(self.name))
572+
573+
self.W_f = self.init((input_dim, self.output_dim),
574+
name='{}_W_f'.format(self.name))
575+
self.U_f = self.inner_init((self.output_dim, self.output_dim),
576+
name='{}_U_f'.format(self.name))
577+
self.b_f = self.forget_bias_init((self.output_dim,),
578+
name='{}_b_f'.format(self.name))
579+
580+
self.W_c = self.init((input_dim, self.output_dim),
581+
name='{}_W_c'.format(self.name))
582+
self.U_c = self.inner_init((self.output_dim, self.output_dim),
583+
name='{}_U_c'.format(self.name))
584+
self.b_c = K.zeros((self.output_dim,), name='{}_b_c'.format(self.name))
585+
586+
self.W_o = self.init((input_dim, self.output_dim),
587+
name='{}_W_o'.format(self.name))
588+
self.U_o = self.inner_init((self.output_dim, self.output_dim),
589+
name='{}_U_o'.format(self.name))
590+
self.b_o = K.zeros((self.output_dim,), name='{}_b_o'.format(self.name))
574591

575592
def append_regulariser(input_regulariser, param, regularizers_list):
576593
regulariser = regularizers.get(input_regulariser)

0 commit comments

Comments
 (0)