Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Export trained model
Added option to export model
  • Loading branch information
mathetes87 committed Feb 15, 2018
commit c57f941bcdfb4b249707ac9ab86230d3dfdc3099
13 changes: 13 additions & 0 deletions DeepFM.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from time import time
from tensorflow.contrib.layers.python.layers import batch_norm as batch_norm
from yellowfin import YFOptimizer
import os, sys


class DeepFM(BaseEstimator, TransformerMixin):
Expand Down Expand Up @@ -44,6 +45,7 @@ def __init__(self, feature_size, field_size,
self.l2_reg = l2_reg

self.epoch = epoch
self.current_epoch = 0
self.batch_size = batch_size
self.learning_rate = learning_rate
self.optimizer_type = optimizer_type
Expand Down Expand Up @@ -285,6 +287,7 @@ def fit(self, Xi_train, Xv_train, y_train,
# evaluate training and validation datasets
train_result = self.evaluate(Xi_train, Xv_train, y_train)
self.train_result.append(train_result)
self.current_epoch += 1
if has_valid:
valid_result = self.evaluate(Xi_valid, Xv_valid, y_valid)
self.valid_result.append(valid_result)
Expand Down Expand Up @@ -383,3 +386,13 @@ def evaluate(self, Xi, Xv, y):
y_pred = self.predict(Xi, Xv)
return self.eval_metric(y, y_pred)

def export_model(self, filename='DeepFM_model.epoch'):
"""
:param filename: name of the model to be saved
"""
filepath = os.path.join('models', filename)
with self.sess.graph.as_default():
self.sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
saver.save(self.sess, filepath, global_step=self.current_epoch)

39 changes: 39 additions & 0 deletions train_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import tensorflow as tf
from sklearn.metrics import roc_auc_score
from DeepFM import DeepFM

# params
dfm_params = {
"use_fm": True,
"use_deep": True,
"embedding_size": 8,
"dropout_fm": [1.0, 1.0],
"deep_layers": [32, 32],
"dropout_deep": [0.5, 0.5, 0.5],
"deep_layers_activation": tf.nn.relu,
"epoch": 30,
"batch_size": 512*4,
"learning_rate": 0.001,
"optimizer_type": "adam",
"batch_norm": 1,
"batch_norm_decay": 0.995,
"l2_reg": 0.01,
"verbose": True,
"eval_metric": roc_auc_score,
"random_seed": 2018
}
# prepare training and validation data in the required format
Xi_train, Xv_train, y_train = prepare(...)
Xi_valid, Xv_valid, y_valid = prepare(...)
# init a DeepFM model
dfm = DeepFM(feature_size, field_size, **dfm_params)
# fit a DeepFM model
dfm.fit(Xi_train, Xv_train, y_train)
# export model
dfm.export_model()
# make prediction
print dfm.predict(Xi_valid, Xv_valid)
# evaluate a trained model
print dfm.evaluate(Xi_valid, Xv_valid, y_valid)