Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add target kwargs for multi model
  • Loading branch information
somefreestring committed Feb 11, 2020
commit 541cc2f699ed7bc3cbd4d6f8bad9d586f333bd47
23 changes: 7 additions & 16 deletions pandas_ml_utils/model/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,30 +397,21 @@ def __init__(self,
loss_alpha: float = 0.5,
target_kwargs: Dict[str, Dict[str, Any]] = None):
assert isinstance(model_provider.features_and_labels.labels, (TargetLabelEncoder, Dict))
super().__init__(model_provider.features_and_labels, # FIXME {target: self.features_and_labels.labels(kwargs) for target, kwargs in target_kwargs.items()} if target_kwargs else model_provider.features_and_labels
summary_provider,
**model_provider.kwargs)
super().__init__(
# if we have a target args and a target encoder then we need generate multiple targets with different kwargs
{self.features_and_labels.with_labels(self.features_and_labels.labels.with_kwargs(kwargs)) for target, kwargs in target_kwargs.items()} \
if target_kwargs and isinstance(model_provider.features_and_labels.labels, TargetLabelEncoder) else model_provider.features_and_labels,
summary_provider,
**model_provider.kwargs)

if isinstance(model_provider, MultiModel):
raise ValueError("Nesting Multi Models is not supported, you might use a flat structure of all your models")

self.models = {target: model_provider() for target in self.features_and_labels.labels.keys()}
self.model_provider = model_provider
self.target_kwargs = target_kwargs
self.loss_alpha = loss_alpha

if target_kwargs:
self.models = {target: model_provider(**kwargs) for target, kwargs in target_kwargs.items()}

# FIXME we need to fix current models labels -> wrap all encoders/labels into a dict
# this is a bit ugly as this will only work with label encoders what if we dont use encoders?
kwargs = {}
target = self.features_and_labels.with_labels(self.features_and_labels.labels.with_kwargs(kwargs))

{target: self.features_and_labels.labels.with_kwargs(kwargs) for target, kwargs in target_kwargs.items()}

else:
self.models = {target: model_provider() for target in self.features_and_labels.labels.keys()}

def fit(self, x, y, x_val, y_val, df_index_train, df_index_test) -> float:
losses = []
pos = 0
Expand Down