Skip to content
Merged
Prev Previous commit
Next Next commit
fix custom objects
  • Loading branch information
somefreestring committed Jan 6, 2020
commit d1b6ce0a761c7090e198f9dc602248fb92364c67
5 changes: 3 additions & 2 deletions pandas_ml_utils/model/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,16 +335,17 @@ def __setstate__(self, state):
self.__dict__.update(state)

# restore keras model and tensorflow session if needed
custom_objects = {v.__name__: v for v in self.kwargs.values() if hasattr(v, "__name__")}
if self.is_tensorflow:
from keras import backend as K
import tensorflow as tf
self.graph = tf.Graph()
with self.graph.as_default():
self.session = tf.Session(graph=self.graph)
K.set_session(self.session)
self.keras_model = load_model(tmp_keras_file, custom_objects=self.kwargs)
self.keras_model = load_model(tmp_keras_file, custom_objects=custom_objects)
else:
self.keras_model = load_model(tmp_keras_file, custom_objects=self.kwargs)
self.keras_model = load_model(tmp_keras_file, custom_objects=custom_objects)

# clean up temp file
with contextlib.suppress(OSError):
Expand Down
3 changes: 1 addition & 2 deletions test/unit_tests/model/test__save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,7 @@ def keras_model_provider():
return model

"""when"""
# FIXME my_custom_loss should have proper name
fit = df.fit(pmu.KerasModel(keras_model_provider, features_and_labels, optimizer='adam', loss=my_custom_loss, verbose=0, my_custom_loss=my_custom_loss))
fit = df.fit(pmu.KerasModel(keras_model_provider, features_and_labels, optimizer='adam', loss=my_custom_loss, verbose=0))
fitted_model = fit.model

fit.save_model(name)
Expand Down