Skip to content
Prev Previous commit
Next Next commit
fix custom objects
  • Loading branch information
somefreestring committed Jan 7, 2020
commit 7ec0a75afb841687402db691bcc857016e223a04
12 changes: 8 additions & 4 deletions test/unit_tests/model/test__save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,21 @@ def test_save_load_keras_custom_loss(self):
features_and_labels = pmu.FeaturesAndLabels(["a"], ["b"])
name = '/tmp/pandas-ml-utils-unittest-test_model_keras_custom_loss'

def my_custom_loss(x, y):
import keras.backend as K
return K.sum(x - y)
def loss_provider(foo):
def my_custom_loss(x, y):
print(foo)
import keras.backend as K
return K.sum(x - y)

return my_custom_loss

def keras_model_provider():
model = Sequential()
model.add(Dense(1, input_dim=1, activation='sigmoid'))
return model

"""when"""
fit = df.fit(pmu.KerasModel(keras_model_provider, features_and_labels, optimizer='adam', loss=my_custom_loss, verbose=0))
fit = df.fit(pmu.KerasModel(keras_model_provider, features_and_labels, optimizer='adam', loss=loss_provider("bar"), verbose=0))
fitted_model = fit.model

fit.save_model(name)
Expand Down