@@ -471,7 +471,8 @@ class Sequential(Model, containers.Sequential):
471471 '''
472472 def compile (self , optimizer , loss ,
473473 class_mode = "categorical" ,
474- sample_weight_mode = None ):
474+ sample_weight_mode = None ,
475+ ** kwargs ):
475476 '''Configure the learning process.
476477
477478 # Arguments
@@ -485,6 +486,8 @@ def compile(self, optimizer, loss,
485486 sample_weight_mode: if you need to do timestep-wise
486487 sample weighting (2D weights), set this to "temporal".
487488 "None" defaults to sample-wise weights (1D).
489+ kwargs: for Theano backend, these are passed into K.function.
490+ Ignored for Tensorflow backend.
488491 '''
489492 self .optimizer = optimizers .get (optimizer )
490493 self .sample_weight_mode = sample_weight_mode
@@ -544,11 +547,18 @@ def compile(self, optimizer, loss,
544547 test_ins = [self .X_test , self .y , self .weights ]
545548 predict_ins = [self .X_test ]
546549
547- self ._train = K .function (train_ins , [train_loss ], updates = updates )
548- self ._train_with_acc = K .function (train_ins , [train_loss , train_accuracy ], updates = updates )
549- self ._predict = K .function (predict_ins , [self .y_test ], updates = self .state_updates )
550- self ._test = K .function (test_ins , [test_loss ], updates = self .state_updates )
551- self ._test_with_acc = K .function (test_ins , [test_loss , test_accuracy ], updates = self .state_updates )
550+ self ._train = K .function (train_ins , [train_loss ],
551+ updates = updates , ** kwargs )
552+ self ._train_with_acc = K .function (train_ins ,
553+ [train_loss , train_accuracy ],
554+ updates = updates , ** kwargs )
555+ self ._predict = K .function (predict_ins , [self .y_test ],
556+ updates = self .state_updates , ** kwargs )
557+ self ._test = K .function (test_ins , [test_loss ],
558+ updates = self .state_updates , ** kwargs )
559+ self ._test_with_acc = K .function (test_ins ,
560+ [test_loss , test_accuracy ],
561+ updates = self .state_updates , ** kwargs )
552562
553563 def fit (self , X , y , batch_size = 128 , nb_epoch = 100 , verbose = 1 , callbacks = [],
554564 validation_split = 0. , validation_data = None , shuffle = True ,
@@ -1173,7 +1183,7 @@ class Graph(Model, containers.Graph):
11731183
11741184 Inherits from `containers.Graph`.
11751185 '''
1176- def compile (self , optimizer , loss , sample_weight_modes = {}):
1186+ def compile (self , optimizer , loss , sample_weight_modes = {}, ** kwargs ):
11771187 '''Configure the learning process.
11781188
11791189 # Arguments
@@ -1188,6 +1198,8 @@ def compile(self, optimizer, loss, sample_weight_modes={}):
11881198 timestep-wise loss weighting on one of your graph outputs,
11891199 you will need to set the sample weight mode for this output
11901200 to "temporal".
1201+ kwargs: for Theano backend, these are passed into K.function.
1202+ Ignored for Tensorflow backend.
11911203 '''
11921204 assert type (loss ) is dict , 'The "loss" argument should be a dictionary.'
11931205 assert type (sample_weight_modes ) is dict , 'The "sample_weight_modes" argument should be a dictionary.'
@@ -1236,10 +1248,12 @@ def compile(self, optimizer, loss, sample_weight_modes={}):
12361248 updates += self .updates
12371249 self .loss = loss
12381250
1239- self ._train = K .function (train_ins , [train_loss ], updates = updates )
1240- self ._test = K .function (test_ins , [test_loss ], updates = self .state_updates )
1251+ self ._train = K .function (train_ins , [train_loss ],
1252+ updates = updates , ** kwargs )
1253+ self ._test = K .function (test_ins , [test_loss ],
1254+ updates = self .state_updates , ** kwargs )
12411255 self ._predict = K .function (inputs = ins , outputs = ys_test ,
1242- updates = self .state_updates )
1256+ updates = self .state_updates , ** kwargs )
12431257
12441258 def fit (self , data , batch_size = 128 , nb_epoch = 100 , verbose = 1 , callbacks = [],
12451259 validation_split = 0. , validation_data = None , shuffle = True ,
0 commit comments