Skip to content
Merged
Prev Previous commit
Next Next commit
enable custom objects for keras model
  • Loading branch information
somefreestring committed Jan 5, 2020
commit 739dffd8c049d78b2a4724db654aa8c68cce652b
17 changes: 13 additions & 4 deletions pandas_ml_utils/model/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from pandas_ml_utils.summary.summary import Summary
from pandas_ml_utils.model.features_and_labels.features_and_labels import FeaturesAndLabels
from pandas_ml_utils.utils.functions import suitable_kwargs
from pandas_ml_utils.utils.functions import suitable_kwargs, call_with_suitable_kwargs

_log = logging.getLogger(__name__)

Expand Down Expand Up @@ -245,8 +245,17 @@ def __init__(self,
else:
self.is_tensorflow = False

# create keras model
provider_args = suitable_kwargs(keras_compiled_model_provider, **kwargs)
self.keras_model = self._exec_within_session(keras_compiled_model_provider, **provider_args)
keras_model = self._exec_within_session(keras_compiled_model_provider, **provider_args)

# eventually compile keras model
if not keras_model.optimizer:
compile_args = suitable_kwargs(keras_model.compile, **kwargs)
self._exec_within_session(keras_model.compile, **compile_args)

# set all members
self.keras_model = keras_model
self.epochs = epochs
self.callbacks = callbacks
self.history = None
Expand Down Expand Up @@ -333,9 +342,9 @@ def __setstate__(self, state):
with self.graph.as_default():
self.session = tf.Session(graph=self.graph)
K.set_session(self.session)
self.keras_model = load_model(tmp_keras_file)
self.keras_model = load_model(tmp_keras_file, custom_objects=self.kwargs)
else:
self.keras_model = load_model(tmp_keras_file)
self.keras_model = load_model(tmp_keras_file, custom_objects=self.kwargs)

# clean up temp file
with contextlib.suppress(OSError):
Expand Down
4 changes: 2 additions & 2 deletions pandas_ml_utils/utils/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ def suitable_kwargs(func, **kwargs):


def call_with_suitable_kwargs(func, **kwargs):
args = suitable_kwargs(kwargs)
return func(args)
args = suitable_kwargs(func, **kwargs)
return func(**args)


def call_callable_dyamic_args(func, *args):
Expand Down
25 changes: 25 additions & 0 deletions test/unit_tests/model/test__save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,28 @@ def test_model_with_LazyDataFrame_copy(self):
self.assertEqual(model.kwargs["ldf"].kwargs['foo'](None), 'bar')
self.assertEqual(model2.kwargs["ldf"].kwargs['foo'](None), 'bar')

def test_save_load_keras_custom_loss(self):
"""given"""
features_and_labels = pmu.FeaturesAndLabels(["a"], ["b"])
name = '/tmp/pandas-ml-utils-unittest-test_model_keras_custom_loss'

def my_custom_loss(x, y):
import keras.backend as K
return K.sum(x - y)

def keras_model_provider():
model = Sequential()
model.add(Dense(1, input_dim=1, activation='sigmoid'))
return model

"""when"""
# FIXME my_custom_loss should have proper name
fit = df.fit(pmu.KerasModel(keras_model_provider, features_and_labels, optimizer='adam', loss=my_custom_loss, verbose=0, my_custom_loss=my_custom_loss))
fitted_model = fit.model

fit.save_model(name)
restored_model = pmu.Model.load(name)

"""then"""
pd.testing.assert_frame_equal(df.predict(fitted_model), df.predict(restored_model))
pd.testing.assert_frame_equal(df.backtest(fitted_model).df, df.backtest(restored_model).df)