Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
allow multi model with different kwargs per target
  • Loading branch information
somefreestring committed Feb 10, 2020
commit 670d81ec76a5958a04471c4a27125f69de13e93f
22 changes: 16 additions & 6 deletions pandas_ml_utils/model/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import tempfile
import uuid
from copy import deepcopy
from typing import List, Callable, TYPE_CHECKING, Tuple
from typing import List, Callable, TYPE_CHECKING, Tuple, Dict, Any

import dill as pickle
import numpy as np
Expand Down Expand Up @@ -257,8 +257,11 @@ def __init__(self,
for i in range(1, len(keras_model)):
if hasattr(keras_model[i], "__name__"):
self.custom_objects[keras_model[i].__name__] = keras_model[i]
else:
raise ValueError("keras custom object must have a __name__")

keras_model = keras_model[0]
_log.warning(f"keras model using custom objects {self.custom_objects}")

# eventually compile keras model
if not keras_model.optimizer:
Expand Down Expand Up @@ -378,7 +381,7 @@ def __call__(self, *args, **kwargs):
self.summary_provider,
self.epochs,
deepcopy(self.callbacks),
**deepcopy(self.kwargs), **kwargs)
**{**deepcopy(self.kwargs), **kwargs})

# copy weights before return
new_model.set_weights(self.get_weights())
Expand All @@ -390,19 +393,26 @@ class MultiModel(Model):
def __init__(self,
model_provider: Model,
summary_provider: Callable[[pd.DataFrame], Summary] = Summary,
loss_alpha: float = 0.5):
loss_alpha: float = 0.5,
target_kwargs: Dict[str, Any] = None):
super().__init__(model_provider.features_and_labels, summary_provider, **model_provider.kwargs)

if isinstance(model_provider, MultiModel):
raise ValueError("Nesting Multi Models is not supported, you might use a flat structure of all your models")

self.model_provider = model_provider
self.models = {target: model_provider() for target in self.features_and_labels.labels.keys()}
self.target_kwargs = target_kwargs
self.loss_alpha = loss_alpha

if target_kwargs:
self.models = {target: model_provider(**kwargs) for target, kwargs in target_kwargs.items()}
else:
self.models = {target: model_provider() for target in self.features_and_labels.labels.keys()}

def fit(self, x, y, x_val, y_val, df_index_train, df_index_test) -> float:
losses = []
pos = 0
# FIXME use the features and labels of each individual model
for target, labels in self.features_and_labels.labels.items():
index = range(pos, pos + len(labels))
target_y = y[:,index]
Expand Down Expand Up @@ -438,10 +448,10 @@ def predict_as_column_matrix(target):
axis=1)

def __call__(self, *args, **kwargs):
new_multi_model = MultiModel(self.model_provider, self.summary_provider, self.loss_alpha, **self.kwargs)
new_multi_model = MultiModel(self.model_provider, self.summary_provider, self.loss_alpha, self.target_kwargs)

if kwargs:
new_multi_model.models = {target: self.model_provider(**kwargs) for target in self.features_and_labels.get_goals().keys()}
raise ValueError("kwargs on cloning multi model ist currently not supported!")

return new_multi_model

Expand Down