Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add kwargs to multi model
  • Loading branch information
somefreestring committed Feb 9, 2020
commit 6e8dcfbf9e9d91f6fab01bea1c0a87c3a6433754
10 changes: 5 additions & 5 deletions pandas_ml_utils/model/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,15 +388,15 @@ class MultiModel(Model):
def __init__(self,
model_provider: Model,
summary_provider: Callable[[pd.DataFrame], Summary] = Summary,
alpha: float = 0.5):
super().__init__(model_provider.features_and_labels, summary_provider)
loss_alpha: float = 0.5):
super().__init__(model_provider.features_and_labels, summary_provider, **model_provider.kwargs)

if isinstance(model_provider, MultiModel):
raise ValueError("Nesting Multi Models is not supported, you might use a flat structure of all your models")

self.model_provider = model_provider
self.models = {target: model_provider() for target in self.features_and_labels.labels.keys()}
self.alpha = alpha
self.loss_alpha = loss_alpha

def fit(self, x, y, x_val, y_val, df_index_train, df_index_test) -> float:
losses = []
Expand All @@ -410,7 +410,7 @@ def fit(self, x, y, x_val, y_val, df_index_train, df_index_test) -> float:
pos += len(labels)

losses = np.array(losses)
a = self.alpha
a = self.loss_alpha

# return weighted loss between mean and max loss
return (losses.mean() * (1 - a) + a * losses.max()) if len(losses) > 0 else None
Expand All @@ -436,7 +436,7 @@ def predict_as_column_matrix(target):
axis=1)

def __call__(self, *args, **kwargs):
new_multi_model = MultiModel(self.model_provider, self.summary_provider, self.alpha)
new_multi_model = MultiModel(self.model_provider, self.summary_provider, self.loss_alpha, **self.kwargs)

if kwargs:
new_multi_model.models = {target: self.model_provider(**kwargs) for target in self.features_and_labels.get_goals().keys()}
Expand Down