1
- from sklearn .base import BaseEstimator , ClassifierMixin , RegressorMixin
2
1
import importlib
2
+ import nnetsauce as ns
3
+ from sklearn .base import BaseEstimator
3
4
4
5
5
6
class BaseModel (BaseEstimator ):
6
7
"""
7
8
Base class for dynamically loading and wrapping scikit-learn models.
8
9
"""
9
- def __init__ (self , base_model , ** kwargs ):
10
+ # Custom parameters that should only be passed to nnetsauce models
11
+ CUSTOM_PARAMS = [
12
+ 'n_hidden_features' ,
13
+ 'activation_name' ,
14
+ 'bias' ,
15
+ 'dropout' ,
16
+ 'direct_link' ,
17
+ 'n_clusters' ,
18
+ 'cluster_encode' ,
19
+ 'type_clust'
20
+ ]
21
+
22
+ def __init__ (self , base_model , custom = False , ** kwargs ):
10
23
"""
11
24
Initialize a scikit-learn model dynamically.
12
25
13
26
Parameters:
14
27
- base_model (str): The class name of the scikit-learn model (e.g., 'LogisticRegression').
15
- - **kwargs: Additional parameters to pass to the scikit-learn model constructor.
28
+ - custom (bool): Whether the model is a custom nnetsauce model.
29
+ - **kwargs: Additional parameters to pass to the model constructor.
16
30
"""
17
31
sklearn_modules = [
18
32
"linear_model" ,
@@ -27,8 +41,16 @@ def __init__(self, base_model, **kwargs):
27
41
"kernel_ridge" ,
28
42
]
29
43
self .base_model = base_model
30
- self .model_params = kwargs
31
- self .model = self ._load_model (base_model , sklearn_modules )(** kwargs )
44
+ self .custom = custom
45
+
46
+ # Split kwargs into base and custom parameters
47
+ self .base_kwargs = {k : v for k , v in kwargs .items () if k not in self .CUSTOM_PARAMS }
48
+ self .custom_kwargs = {k : v for k , v in kwargs .items () if k in self .CUSTOM_PARAMS }
49
+
50
+ # Initialize only the base model here
51
+ self .model = self ._load_model (base_model , sklearn_modules )(** self .base_kwargs )
52
+
53
+ # Custom model wrapping is handled in derived classes
32
54
33
55
def _load_model (self , base_model , modules ):
34
56
"""
@@ -57,7 +79,7 @@ def fit(self, X, y, **kwargs):
57
79
Parameters:
58
80
- X (array-like): Training data features.
59
81
- y (array-like): Target values.
60
- - **kwargs: Additional parameters to pass to the scikit-learn model fit method.
82
+ - **kwargs: Additional parameters to pass to the model fit method.
61
83
"""
62
84
self .model .fit (X , y , ** kwargs )
63
85
return self
@@ -68,7 +90,7 @@ def predict(self, X, **kwargs):
68
90
69
91
Parameters:
70
92
- X (array-like): Input data.
71
- - **kwargs: Additional parameters to pass to the scikit-learn model predict method.
93
+ - **kwargs: Additional parameters to pass to the model predict method.
72
94
Returns:
73
95
- array-like: Predictions.
74
96
"""
@@ -81,7 +103,7 @@ def score(self, X, y, **kwargs):
81
103
Parameters:
82
104
- X (array-like): Test data features.
83
105
- y (array-like): True labels.
84
- - **kwargs: Additional parameters to pass to the scikit-learn model score method.
106
+ - **kwargs: Additional parameters to pass to the model score method.
85
107
86
108
Returns:
87
109
- float: The score.
0 commit comments