Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 74 additions & 66 deletions keras/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import six
import copy
import numpy as np
from six.moves import zip

from . import backend as K
Expand Down Expand Up @@ -170,20 +171,19 @@ class SGD(Optimizer):
learning_rate: float >= 0. Learning rate.
momentum: float >= 0. Parameter that accelerates SGD
in the relevant direction and dampens oscillations.
decay: float >= 0. Learning rate decay over each update.
nesterov: boolean. Whether to apply Nesterov momentum.
"""

def __init__(self, learning_rate=0.01, momentum=0., decay=0.,
def __init__(self, learning_rate=0.01, momentum=0.,
nesterov=False, **kwargs):
learning_rate = kwargs.pop('lr', None) or learning_rate
learning_rate = kwargs.pop('lr', learning_rate)
self.initial_decay = kwargs.pop('decay', 0.0)
super(SGD, self).__init__(**kwargs)
with K.name_scope(self.__class__.__name__):
self.iterations = K.variable(0, dtype='int64', name='iterations')
self.learning_rate = K.variable(learning_rate, name='learning_rate')
self.momentum = K.variable(momentum, name='momentum')
self.decay = K.variable(decay, name='decay')
self.initial_decay = decay
self.decay = K.variable(self.initial_decay, name='decay')
self.nesterov = nesterov

@interfaces.legacy_get_updates_support
Expand Down Expand Up @@ -236,27 +236,22 @@ class RMSprop(Optimizer):
# Arguments
learning_rate: float >= 0. Learning rate.
rho: float >= 0.
epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`.
decay: float >= 0. Learning rate decay over each update.

# References
- [rmsprop: Divide the gradient by a running average of its recent magnitude
](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
"""

def __init__(self, learning_rate=0.001, rho=0.9, epsilon=None, decay=0.,
**kwargs):
learning_rate = kwargs.pop('lr', None) or learning_rate
def __init__(self, learning_rate=0.001, rho=0.9, **kwargs):
self.initial_decay = kwargs.pop('decay', 0.0)
self.epsilon = kwargs.pop('epsilon', K.epsilon())
learning_rate = kwargs.pop('lr', learning_rate)
super(RMSprop, self).__init__(**kwargs)
with K.name_scope(self.__class__.__name__):
self.learning_rate = K.variable(learning_rate, name='learning_rate')
self.rho = K.variable(rho, name='rho')
self.decay = K.variable(decay, name='decay')
self.decay = K.variable(self.initial_decay, name='decay')
self.iterations = K.variable(0, dtype='int64', name='iterations')
if epsilon is None:
epsilon = K.epsilon()
self.epsilon = epsilon
self.initial_decay = decay

@interfaces.legacy_get_updates_support
@K.symbolic
Expand All @@ -266,7 +261,7 @@ def get_updates(self, loss, params):
dtype=K.dtype(p),
name='accumulator_' + str(i))
for (i, p) in enumerate(params)]
self.weights = accumulators
self.weights = [self.iterations] + accumulators
self.updates = [K.update_add(self.iterations, 1)]

lr = self.learning_rate
Expand All @@ -287,6 +282,15 @@ def get_updates(self, loss, params):
self.updates.append(K.update(p, new_p))
return self.updates

def set_weights(self, weights):
params = self.weights
# Override set_weights for backward compatibility of Keras 2.2.4 optimizer
# since it does not include iteration at head of the weight list. Set
# iteration to 0.
if len(params) == len(weights) + 1:
weights = [np.array(0)] + weights
super(RMSprop, self).set_weights(weights)

def get_config(self):
config = {'learning_rate': float(K.get_value(self.learning_rate)),
'rho': float(K.get_value(self.rho)),
Expand All @@ -309,25 +313,21 @@ class Adagrad(Optimizer):

# Arguments
learning_rate: float >= 0. Initial learning rate.
epsilon: float >= 0. If `None`, defaults to `K.epsilon()`.
decay: float >= 0. Learning rate decay over each update.

# References
- [Adaptive Subgradient Methods for Online Learning and Stochastic
Optimization](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf)
"""

def __init__(self, learning_rate=0.01, epsilon=None, decay=0., **kwargs):
learning_rate = kwargs.pop('lr', None) or learning_rate
def __init__(self, learning_rate=0.01, **kwargs):
self.initial_decay = kwargs.pop('decay', 0.0)
self.epsilon = kwargs.pop('epsilon', K.epsilon())
learning_rate = kwargs.pop('lr', learning_rate)
super(Adagrad, self).__init__(**kwargs)
with K.name_scope(self.__class__.__name__):
self.learning_rate = K.variable(learning_rate, name='learning_rate')
self.decay = K.variable(decay, name='decay')
self.decay = K.variable(self.initial_decay, name='decay')
self.iterations = K.variable(0, dtype='int64', name='iterations')
if epsilon is None:
epsilon = K.epsilon()
self.epsilon = epsilon
self.initial_decay = decay

@interfaces.legacy_get_updates_support
@K.symbolic
Expand All @@ -336,7 +336,7 @@ def get_updates(self, loss, params):
shapes = [K.int_shape(p) for p in params]
accumulators = [K.zeros(shape, name='accumulator_' + str(i))
for (i, shape) in enumerate(shapes)]
self.weights = accumulators
self.weights = [self.iterations] + accumulators
self.updates = [K.update_add(self.iterations, 1)]

lr = self.learning_rate
Expand All @@ -356,6 +356,15 @@ def get_updates(self, loss, params):
self.updates.append(K.update(p, new_p))
return self.updates

def set_weights(self, weights):
params = self.weights
# Override set_weights for backward compatibility of Keras 2.2.4 optimizer
# since it does not include iteration at head of the weight list. Set
# iteration to 0.
if len(params) == len(weights) + 1:
weights = [np.array(0)] + weights
super(Adagrad, self).set_weights(weights)

def get_config(self):
config = {'learning_rate': float(K.get_value(self.learning_rate)),
'decay': float(K.get_value(self.decay)),
Expand Down Expand Up @@ -383,27 +392,22 @@ class Adadelta(Optimizer):
It is recommended to leave it at the default value.
rho: float >= 0. Adadelta decay factor, corresponding to fraction of
gradient to keep at each time step.
epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`.
decay: float >= 0. Initial learning rate decay.

# References
- [Adadelta - an adaptive learning rate method](
https://arxiv.org/abs/1212.5701)
"""

def __init__(self, learning_rate=1.0, rho=0.95, epsilon=None, decay=0.,
**kwargs):
learning_rate = kwargs.pop('lr', None) or learning_rate
def __init__(self, learning_rate=1.0, rho=0.95, **kwargs):
self.initial_decay = kwargs.pop('decay', 0.0)
self.epsilon = kwargs.pop('epsilon', K.epsilon())
learning_rate = kwargs.pop('lr', learning_rate)
super(Adadelta, self).__init__(**kwargs)
with K.name_scope(self.__class__.__name__):
self.learning_rate = K.variable(learning_rate, name='learning_rate')
self.decay = K.variable(decay, name='decay')
self.decay = K.variable(self.initial_decay, name='decay')
self.iterations = K.variable(0, dtype='int64', name='iterations')
if epsilon is None:
epsilon = K.epsilon()
self.rho = rho
self.epsilon = epsilon
self.initial_decay = decay

@interfaces.legacy_get_updates_support
@K.symbolic
Expand All @@ -414,7 +418,7 @@ def get_updates(self, loss, params):
for (i, shape) in enumerate(shapes)]
delta_accumulators = [K.zeros(shape, name='delta_accumulator_' + str(i))
for (i, shape) in enumerate(shapes)]
self.weights = accumulators + delta_accumulators
self.weights = [self.iterations] + accumulators + delta_accumulators
self.updates = [K.update_add(self.iterations, 1)]

lr = self.learning_rate
Expand Down Expand Up @@ -442,6 +446,15 @@ def get_updates(self, loss, params):
self.updates.append(K.update(d_a, new_d_a))
return self.updates

def set_weights(self, weights):
params = self.weights
# Override set_weights for backward compatibility of Keras 2.2.4 optimizer
# since it does not include iteration at head of the weight list. Set
# iteration to 0.
if len(params) == len(weights) + 1:
weights = [np.array(0)] + weights
super(Adadelta, self).set_weights(weights)

def get_config(self):
config = {'learning_rate': float(K.get_value(self.learning_rate)),
'rho': self.rho,
Expand All @@ -460,8 +473,6 @@ class Adam(Optimizer):
learning_rate: float >= 0. Learning rate.
beta_1: float, 0 < beta < 1. Generally close to 1.
beta_2: float, 0 < beta < 1. Generally close to 1.
epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`.
decay: float >= 0. Learning rate decay over each update.
amsgrad: boolean. Whether to apply the AMSGrad variant of this
algorithm from the paper "On the Convergence of Adam and
Beyond".
Expand All @@ -474,19 +485,17 @@ class Adam(Optimizer):
"""

def __init__(self, learning_rate=0.001, beta_1=0.9, beta_2=0.999,
epsilon=None, decay=0., amsgrad=False, **kwargs):
learning_rate = kwargs.pop('lr', None) or learning_rate
amsgrad=False, **kwargs):
self.initial_decay = kwargs.pop('decay', 0.0)
self.epsilon = kwargs.pop('epsilon', K.epsilon())
learning_rate = kwargs.pop('lr', learning_rate)
super(Adam, self).__init__(**kwargs)
with K.name_scope(self.__class__.__name__):
self.iterations = K.variable(0, dtype='int64', name='iterations')
self.learning_rate = K.variable(learning_rate, name='learning_rate')
self.beta_1 = K.variable(beta_1, name='beta_1')
self.beta_2 = K.variable(beta_2, name='beta_2')
self.decay = K.variable(decay, name='decay')
if epsilon is None:
epsilon = K.epsilon()
self.epsilon = epsilon
self.initial_decay = decay
self.decay = K.variable(self.initial_decay, name='decay')
self.amsgrad = amsgrad

@interfaces.legacy_get_updates_support
Expand Down Expand Up @@ -565,28 +574,23 @@ class Adamax(Optimizer):
learning_rate: float >= 0. Learning rate.
beta_1: float, 0 < beta < 1. Generally close to 1.
beta_2: float, 0 < beta < 1. Generally close to 1.
epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`.
decay: float >= 0. Learning rate decay over each update.

# References
- [Adam - A Method for Stochastic Optimization](
https://arxiv.org/abs/1412.6980v8)
"""

def __init__(self, learning_rate=0.002, beta_1=0.9, beta_2=0.999,
epsilon=None, decay=0., **kwargs):
learning_rate = kwargs.pop('lr', None) or learning_rate
def __init__(self, learning_rate=0.002, beta_1=0.9, beta_2=0.999, **kwargs):
self.initial_decay = kwargs.pop('decay', 0.0)
self.epsilon = kwargs.pop('epsilon', K.epsilon())
learning_rate = kwargs.pop('lr', learning_rate)
super(Adamax, self).__init__(**kwargs)
with K.name_scope(self.__class__.__name__):
self.iterations = K.variable(0, dtype='int64', name='iterations')
self.learning_rate = K.variable(learning_rate, name='learning_rate')
self.beta_1 = K.variable(beta_1, name='beta_1')
self.beta_2 = K.variable(beta_2, name='beta_2')
self.decay = K.variable(decay, name='decay')
if epsilon is None:
epsilon = K.epsilon()
self.epsilon = epsilon
self.initial_decay = decay
self.decay = K.variable(self.initial_decay, name='decay')

@interfaces.legacy_get_updates_support
@K.symbolic
Expand Down Expand Up @@ -652,29 +656,24 @@ class Nadam(Optimizer):
learning_rate: float >= 0. Learning rate.
beta_1: float, 0 < beta < 1. Generally close to 1.
beta_2: float, 0 < beta < 1. Generally close to 1.
epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`.
schedule_decay: float, 0 < schedule_decay < 1.

# References
- [Nadam report](http://cs229.stanford.edu/proj2015/054_report.pdf)
- [On the importance of initialization and momentum in deep learning](
http://www.cs.toronto.edu/~fritz/absps/momentum.pdf)
"""

def __init__(self, learning_rate=0.002, beta_1=0.9, beta_2=0.999,
epsilon=None, schedule_decay=0.004, **kwargs):
learning_rate = kwargs.pop('lr', None) or learning_rate
def __init__(self, learning_rate=0.002, beta_1=0.9, beta_2=0.999, **kwargs):
self.schedule_decay = kwargs.pop('schedule_decay', 0.004)
self.epsilon = kwargs.pop('epsilon', K.epsilon())
learning_rate = kwargs.pop('lr', learning_rate)
super(Nadam, self).__init__(**kwargs)
with K.name_scope(self.__class__.__name__):
self.iterations = K.variable(0, dtype='int64', name='iterations')
self.m_schedule = K.variable(1., name='m_schedule')
self.learning_rate = K.variable(learning_rate, name='learning_rate')
self.beta_1 = K.variable(beta_1, name='beta_1')
self.beta_2 = K.variable(beta_2, name='beta_2')
if epsilon is None:
epsilon = K.epsilon()
self.epsilon = epsilon
self.schedule_decay = schedule_decay

@interfaces.legacy_get_updates_support
@K.symbolic
Expand All @@ -699,7 +698,7 @@ def get_updates(self, loss, params):
vs = [K.zeros(shape, name='v_' + str(i))
for (i, shape) in enumerate(shapes)]

self.weights = [self.iterations] + ms + vs
self.weights = [self.iterations, self.m_schedule] + ms + vs

for p, g, m, v in zip(params, grads, ms, vs):
# the following equations given in [1]
Expand All @@ -725,6 +724,15 @@ def get_updates(self, loss, params):
self.updates.append(K.update(p, new_p))
return self.updates

def set_weights(self, weights):
params = self.weights
# Override set_weights for backward compatibility of Keras 2.2.4 optimizer
# since it does not include m_schedule at head of the weight list. Set
# m_schedule to 1.
if len(params) == len(weights) + 1:
weights = [weights[0]] + [np.array(1.)] + weights[1:]
super(Nadam, self).set_weights(weights)

def get_config(self):
config = {'learning_rate': float(K.get_value(self.learning_rate)),
'beta_1': float(K.get_value(self.beta_1)),
Expand Down