Skip to content

Commit 6cde6df

Browse files
committed
Add regularization and constraints to layers
1 parent 22012fd commit 6cde6df

File tree

3 files changed

+74
-29
lines changed

3 files changed

+74
-29
lines changed

keras/layers/core.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,34 @@
1212
from six.moves import zip
1313
srng = RandomStreams()
1414

15+
def l1(lam):
16+
def l1wrap(g,p):
17+
g += T.sgn(p) * lam
18+
return g
19+
return l1wrap
20+
21+
def l2(lam):
22+
def l2wrap(g,p):
23+
g += p * lam
24+
return g
25+
return l2wrap
26+
27+
def maxnorm(m):
28+
def maxnorm_wrap(p):
29+
norms = T.sqrt(T.sum(T.sqr(p), axis=0))
30+
desired = T.clip(norms, 0, m)
31+
p = p * (desired / (1e-7 + norms))
32+
return p
33+
return maxnorm_wrap
34+
35+
def nonneg(p):
36+
p *= T.ge(p,0)
37+
return p
38+
39+
def ident(g,*l):
40+
return g
41+
42+
1543
class Layer(object):
1644
def connect(self, previous_layer):
1745
self.previous_layer = previous_layer
@@ -46,6 +74,8 @@ class Dropout(Layer):
4674
def __init__(self, p):
4775
self.p = p
4876
self.params = []
77+
self.regularizer = []
78+
self.constraint = []
4979

5080
def output(self, train):
5181
X = self.get_input(train)
@@ -69,6 +99,8 @@ class Activation(Layer):
6999
def __init__(self, activation):
70100
self.activation = activations.get(activation)
71101
self.params = []
102+
self.regularizer = []
103+
self.constraint = []
72104

73105
def output(self, train):
74106
X = self.get_input(train)
@@ -88,6 +120,8 @@ class Reshape(Layer):
88120
def __init__(self, *dims):
89121
self.dims = dims
90122
self.params = []
123+
self.regularizer = []
124+
self.constraint = []
91125

92126
def output(self, train):
93127
X = self.get_input(train)
@@ -106,6 +140,8 @@ class Flatten(Layer):
106140
'''
107141
def __init__(self):
108142
self.params = []
143+
self.regularizer = []
144+
self.constraint = []
109145

110146
def output(self, train):
111147
X = self.get_input(train)
@@ -124,6 +160,8 @@ class RepeatVector(Layer):
124160
def __init__(self, n):
125161
self.n = n
126162
self.params = []
163+
self.regularizer = []
164+
self.constraint = []
127165

128166
def output(self, train):
129167
X = self.get_input(train)
@@ -140,7 +178,7 @@ class Dense(Layer):
140178
'''
141179
Just your regular fully connected NN layer.
142180
'''
143-
def __init__(self, input_dim, output_dim, init='uniform', activation='linear', weights=None):
181+
def __init__(self, input_dim, output_dim, init='uniform', activation='linear', weights=None, regularizer=[ident, ident], constraint=[ident, ident]):
144182
self.init = initializations.get(init)
145183
self.activation = activations.get(activation)
146184
self.input_dim = input_dim
@@ -152,6 +190,9 @@ def __init__(self, input_dim, output_dim, init='uniform', activation='linear', w
152190

153191
self.params = [self.W, self.b]
154192

193+
self.regularizer = regularizer
194+
self.constraint = constraint
195+
155196
if weights is not None:
156197
self.set_weights(weights)
157198

@@ -176,7 +217,7 @@ class TimeDistributedDense(Layer):
176217
Tensor output dimensions: (nb_sample, shared_dimension, output_dim)
177218
178219
'''
179-
def __init__(self, input_dim, output_dim, init='uniform', activation='linear', weights=None):
220+
def __init__(self, input_dim, output_dim, init='uniform', activation='linear', weights=None, regularizer=[ident, ident], constraint=[ident, ident]):
180221
self.init = initializations.get(init)
181222
self.activation = activations.get(activation)
182223
self.input_dim = input_dim
@@ -188,6 +229,9 @@ def __init__(self, input_dim, output_dim, init='uniform', activation='linear', w
188229

189230
self.params = [self.W, self.b]
190231

232+
self.regularizer = regularizer
233+
self.constraint = constraint
234+
191235
if weights is not None:
192236
self.set_weights(weights)
193237

keras/models.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,16 @@ class Sequential(object):
2525
def __init__(self):
2626
self.layers = []
2727
self.params = []
28+
self.regularizer = []
29+
self.constraint = []
2830

2931
def add(self, layer):
3032
self.layers.append(layer)
3133
if len(self.layers) > 1:
3234
self.layers[-1].connect(self.layers[-2])
3335
self.params += [p for p in layer.params]
36+
self.regularizer += [r for r in layer.regularizer]
37+
self.constraint += [c for c in layer.constraint]
3438

3539
def compile(self, optimizer, loss, class_mode="categorical"):
3640
self.optimizer = optimizers.get(optimizer)
@@ -58,7 +62,7 @@ def compile(self, optimizer, loss, class_mode="categorical"):
5862
raise Exception("Invalid class mode:" + str(class_mode))
5963
self.class_mode = class_mode
6064

61-
updates = self.optimizer.get_updates(self.params, train_loss)
65+
updates = self.optimizer.get_updates(self.params, self.regularizer, self.constraint, train_loss)
6266

6367
self._train = theano.function([self.X, self.y], train_loss,
6468
updates=updates, allow_input_downcast=True)

keras/optimizers.py

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -22,21 +22,7 @@ def get_gradients(self, cost, params):
2222
norm = T.sqrt(sum([T.sum(g**2) for g in grads]))
2323
grads = [clip_norm(g, c, norm) for g in grads]
2424

25-
new_grads = []
26-
for p, g in zip(params, grads):
27-
if hasattr(self, 'l1') and self.l1 > 0:
28-
g += T.sgn(p) * self.l1
29-
30-
if hasattr(self, 'l2') and self.l2 > 0:
31-
g += p * self.l2
32-
33-
if hasattr(self, 'maxnorm') and self.maxnorm > 0:
34-
norms = T.sqrt(T.sum(T.sqr(p), axis=0))
35-
desired = T.clip(norms, 0, self.maxnorm)
36-
p = p * (desired / (1e-7 + norms))
37-
38-
new_grads.append(g)
39-
return new_grads
25+
return grads
4026

4127

4228
class SGD(Optimizer):
@@ -45,12 +31,13 @@ def __init__(self, lr=0.01, momentum=0., decay=0., nesterov=False, *args, **kwar
4531
self.__dict__.update(locals())
4632
self.iterations = shared_scalar(0)
4733

48-
def get_updates(self, params, cost):
34+
def get_updates(self, params, regularizers, constraints, cost):
4935
grads = self.get_gradients(cost, params)
5036
lr = self.lr * (1.0 / (1.0 + self.decay * self.iterations))
5137
updates = [(self.iterations, self.iterations+1.)]
5238

53-
for p, g in zip(params, grads):
39+
for p, g, r, c in zip(params, grads, regularizers, constraints):
40+
g = r(g,p)
5441
m = shared_zeros(p.get_value().shape) # momentum
5542
v = self.momentum * m - lr * g # velocity
5643
updates.append((m, v))
@@ -59,6 +46,8 @@ def get_updates(self, params, cost):
5946
new_p = p + self.momentum * v - lr * g
6047
else:
6148
new_p = p + v
49+
50+
new_p = c(new_p)
6251
updates.append((p, new_p))
6352
return updates
6453

@@ -68,16 +57,18 @@ class RMSprop(Optimizer):
6857
def __init__(self, lr=0.001, rho=0.9, epsilon=1e-6, *args, **kwargs):
6958
self.__dict__.update(locals())
7059

71-
def get_updates(self, params, cost):
60+
def get_updates(self, params, regularizers, constraints, cost):
7261
grads = self.get_gradients(cost, params)
7362
accumulators = [shared_zeros(p.get_value().shape) for p in params]
7463
updates = []
7564

76-
for p, g, a in zip(params, grads, accumulators):
65+
for p, g, a, r, c in zip(params, grads, accumulators, regularizers, constraints):
66+
g = r(g,p)
7767
new_a = self.rho * a + (1 - self.rho) * g ** 2 # update accumulator
7868
updates.append((a, new_a))
7969

8070
new_p = p - self.lr * g / T.sqrt(new_a + self.epsilon)
71+
new_p = c(new_p)
8172
updates.append((p, new_p))
8273
return updates
8374

@@ -87,16 +78,18 @@ class Adagrad(Optimizer):
8778
def __init__(self, lr=0.01, epsilon=1e-6, *args, **kwargs):
8879
self.__dict__.update(locals())
8980

90-
def get_updates(self, params, cost):
81+
def get_updates(self, params, regularizers, constraints, cost):
9182
grads = self.get_gradients(cost, params)
9283
accumulators = [shared_zeros(p.get_value().shape) for p in params]
9384
updates = []
9485

95-
for p, g, a in zip(params, grads, accumulators):
86+
for p, g, a, r, c in zip(params, grads, accumulators, regularizers, constraints):
87+
g = r(g,p)
9688
new_a = a + g ** 2 # update accumulator
9789
updates.append((a, new_a))
9890

9991
new_p = p - self.lr * g / T.sqrt(new_a + self.epsilon)
92+
new_p = c(new_p)
10093
updates.append((p, new_p))
10194
return updates
10295

@@ -108,20 +101,22 @@ class Adadelta(Optimizer):
108101
def __init__(self, lr=1.0, rho=0.95, epsilon=1e-6, *args, **kwargs):
109102
self.__dict__.update(locals())
110103

111-
def get_updates(self, params, cost):
104+
def get_updates(self, params, regularizers, constraints, cost):
112105
grads = self.get_gradients(cost, params)
113106
accumulators = [shared_zeros(p.get_value().shape) for p in params]
114107
delta_accumulators = [shared_zeros(p.get_value().shape) for p in params]
115108
updates = []
116109

117-
for p, g, a, d_a in zip(params, grads, accumulators, delta_accumulators):
110+
for p, g, a, d_a, r, c in zip(params, grads, accumulators, delta_accumulators, regularizers, constraints):
111+
g = r(g,p)
118112
new_a = self.rho * a + (1 - self.rho) * g ** 2 # update accumulator
119113
updates.append((a, new_a))
120114

121115
# use the new accumulator and the *old* delta_accumulator
122116
update = g * T.sqrt(d_a + self.epsilon) / T.sqrt(new_a + self.epsilon)
123117

124118
new_p = p - self.lr * update
119+
new_p = c(new_p)
125120
updates.append((p, new_p))
126121

127122
# update delta_accumulator
@@ -142,7 +137,7 @@ def __init__(self, lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-8, kappa=1-1e-
142137
self.__dict__.update(locals())
143138
self.iterations = shared_scalar(0)
144139

145-
def get_updates(self, params, cost):
140+
def get_updates(self, params, regularizers, constraints, cost):
146141
grads = self.get_gradients(cost, params)
147142
updates = [(self.iterations, self.iterations+1.)]
148143

@@ -152,7 +147,8 @@ def get_updates(self, params, cost):
152147
# the update below seems missing from the paper, but is obviously required
153148
beta_2_t = self.beta_2 * (self.kappa**i)
154149

155-
for p, g in zip(params, grads):
150+
for p, g, r, c in zip(params, grads, regularizers, constraints):
151+
g = r(g,p)
156152
m = theano.shared(p.get_value() * 0.) # zero init of moment
157153
v = theano.shared(p.get_value() * 0.) # zero init of velocity
158154

@@ -163,7 +159,8 @@ def get_updates(self, params, cost):
163159
v_b_t = v_t / (1 - beta_2_t)
164160

165161
p_t = p - self.lr * m_b_t / (T.sqrt(v_b_t) + self.epsilon)
166-
162+
163+
p_t = c(p_t)
167164
updates.append((m, m_t))
168165
updates.append((v, v_t))
169166
updates.append((p, p_t))

0 commit comments

Comments
 (0)