Skip to content

Commit 8606edf

Browse files
author
Matthew Peters
committed
Allow passing kwargs down to K.function
1 parent 71d46b7 commit 8606edf

File tree

3 files changed

+27
-13
lines changed

3 files changed

+27
-13
lines changed

keras/backend/tensorflow_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ def __call__(self, inputs):
402402
return updated[:len(self.outputs)]
403403

404404

405-
def function(inputs, outputs, updates=[]):
405+
def function(inputs, outputs, updates=[], **kwargs):
406406
return Function(inputs, outputs, updates=updates)
407407

408408

keras/backend/theano_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -449,8 +449,8 @@ def __call__(self, inputs):
449449
return self.function(*inputs)
450450

451451

452-
def function(inputs, outputs, updates=[]):
453-
return Function(inputs, outputs, updates=updates)
452+
def function(inputs, outputs, updates=[], **kwargs):
453+
return Function(inputs, outputs, updates=updates, **kwargs)
454454

455455

456456
def gradients(loss, variables):

keras/models.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -471,7 +471,8 @@ class Sequential(Model, containers.Sequential):
471471
'''
472472
def compile(self, optimizer, loss,
473473
class_mode="categorical",
474-
sample_weight_mode=None):
474+
sample_weight_mode=None,
475+
**kwargs):
475476
'''Configure the learning process.
476477
477478
# Arguments
@@ -485,6 +486,8 @@ def compile(self, optimizer, loss,
485486
sample_weight_mode: if you need to do timestep-wise
486487
sample weighting (2D weights), set this to "temporal".
487488
"None" defaults to sample-wise weights (1D).
489+
kwargs: for Theano backend, these are passed into K.function.
490+
Ignored for Tensorflow backend.
488491
'''
489492
self.optimizer = optimizers.get(optimizer)
490493
self.sample_weight_mode = sample_weight_mode
@@ -544,11 +547,18 @@ def compile(self, optimizer, loss,
544547
test_ins = [self.X_test, self.y, self.weights]
545548
predict_ins = [self.X_test]
546549

547-
self._train = K.function(train_ins, [train_loss], updates=updates)
548-
self._train_with_acc = K.function(train_ins, [train_loss, train_accuracy], updates=updates)
549-
self._predict = K.function(predict_ins, [self.y_test], updates=self.state_updates)
550-
self._test = K.function(test_ins, [test_loss], updates=self.state_updates)
551-
self._test_with_acc = K.function(test_ins, [test_loss, test_accuracy], updates=self.state_updates)
550+
self._train = K.function(train_ins, [train_loss],
551+
updates=updates, **kwargs)
552+
self._train_with_acc = K.function(train_ins,
553+
[train_loss, train_accuracy],
554+
updates=updates, **kwargs)
555+
self._predict = K.function(predict_ins, [self.y_test],
556+
updates=self.state_updates, **kwargs)
557+
self._test = K.function(test_ins, [test_loss],
558+
updates=self.state_updates, **kwargs)
559+
self._test_with_acc = K.function(test_ins,
560+
[test_loss, test_accuracy],
561+
updates=self.state_updates, **kwargs)
552562

553563
def fit(self, X, y, batch_size=128, nb_epoch=100, verbose=1, callbacks=[],
554564
validation_split=0., validation_data=None, shuffle=True,
@@ -1173,7 +1183,7 @@ class Graph(Model, containers.Graph):
11731183
11741184
Inherits from `containers.Graph`.
11751185
'''
1176-
def compile(self, optimizer, loss, sample_weight_modes={}):
1186+
def compile(self, optimizer, loss, sample_weight_modes={}, **kwargs):
11771187
'''Configure the learning process.
11781188
11791189
# Arguments
@@ -1188,6 +1198,8 @@ def compile(self, optimizer, loss, sample_weight_modes={}):
11881198
timestep-wise loss weighting on one of your graph outputs,
11891199
you will need to set the sample weight mode for this output
11901200
to "temporal".
1201+
kwargs: for Theano backend, these are passed into K.function.
1202+
Ignored for Tensorflow backend.
11911203
'''
11921204
assert type(loss) is dict, 'The "loss" argument should be a dictionary.'
11931205
assert type(sample_weight_modes) is dict, 'The "sample_weight_modes" argument should be a dictionary.'
@@ -1236,10 +1248,12 @@ def compile(self, optimizer, loss, sample_weight_modes={}):
12361248
updates += self.updates
12371249
self.loss = loss
12381250

1239-
self._train = K.function(train_ins, [train_loss], updates=updates)
1240-
self._test = K.function(test_ins, [test_loss], updates=self.state_updates)
1251+
self._train = K.function(train_ins, [train_loss],
1252+
updates=updates, **kwargs)
1253+
self._test = K.function(test_ins, [test_loss],
1254+
updates=self.state_updates, **kwargs)
12411255
self._predict = K.function(inputs=ins, outputs=ys_test,
1242-
updates=self.state_updates)
1256+
updates=self.state_updates, **kwargs)
12431257

12441258
def fit(self, data, batch_size=128, nb_epoch=100, verbose=1, callbacks=[],
12451259
validation_split=0., validation_data=None, shuffle=True,

0 commit comments

Comments
 (0)