Skip to content

Commit cd7ede0

Browse files
update docstrings
1 parent d490e5b commit cd7ede0

File tree

3 files changed

+55
-1
lines changed

3 files changed

+55
-1
lines changed

src/tisthemachinelearner/base.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,13 @@ def __init__(self, base_model, custom=False, **kwargs):
2525
Initialize a scikit-learn model dynamically.
2626
2727
Parameters:
28+
2829
- base_model (str): The class name of the scikit-learn model (e.g., 'LogisticRegression').
30+
2931
- custom (bool): Whether the model is a custom nnetsauce model.
32+
3033
- **kwargs: Additional parameters to pass to the model constructor.
34+
3135
"""
3236
modules = [
3337
"linear_model",
@@ -61,11 +65,15 @@ def _load_model(self, base_model, modules):
6165
Load a model class from scikit-learn modules.
6266
6367
Parameters:
68+
6469
- base_model (str): The class name of the scikit-learn model.
70+
6571
- modules (list): List of scikit-learn submodules to search.
6672
6773
Returns:
74+
6875
- class: The loaded scikit-learn model class.
76+
6977
"""
7078
for module_name in modules:
7179
try:
@@ -84,9 +92,14 @@ def fit(self, X, y, **kwargs):
8492
Fit the model to the training data.
8593
8694
Parameters:
95+
8796
- X (array-like): Training data features.
97+
8898
- y (array-like): Target values.
89-
- **kwargs: Additional parameters to pass to the model fit method.
99+
100+
- **kwargs: Additional parameters to pass to the
101+
model fit method.
102+
90103
"""
91104
self.model.fit(X, y, **kwargs)
92105
return self
@@ -96,10 +109,15 @@ def predict(self, X, **kwargs):
96109
Predict using the trained model.
97110
98111
Parameters:
112+
99113
- X (array-like): Input data.
114+
100115
- **kwargs: Additional parameters to pass to the model predict method.
116+
101117
Returns:
118+
102119
- array-like: Predictions.
120+
103121
"""
104122
return self.model.predict(X, **kwargs)
105123

@@ -108,11 +126,15 @@ def score(self, X, y, **kwargs):
108126
Return the score of the model on the given test data and labels.
109127
110128
Parameters:
129+
111130
- X (array-like): Test data features.
131+
112132
- y (array-like): True labels.
133+
113134
- **kwargs: Additional parameters to pass to the model score method.
114135
115136
Returns:
137+
116138
- float: The score.
117139
"""
118140
return self.model.score(X, y, **kwargs)

src/tisthemachinelearner/finitedifftrainer.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,27 +18,38 @@ class FiniteDiffRegressor(BaseModel, RegressorMixin):
1818
1919
Parameters
2020
----------
21+
2122
base_model : str
2223
The name of the base model (e.g., 'RidgeCV').
24+
2325
lr : float, optional
2426
Learning rate for optimization (default=1e-4).
27+
2528
optimizer : {'gd', 'sgd', 'adam', 'cd'}, optional
2629
Optimization algorithm: gradient descent ('gd'), stochastic gradient descent ('sgd'),
2730
Adam ('adam'), or coordinate descent ('cd'). Default is 'gd'.
31+
2832
eps : float, optional
2933
Scaling factor for adaptive finite difference step size (default=1e-3).
34+
3035
batch_size : int, optional
3136
Batch size for 'sgd' optimizer (default=32).
37+
3238
alpha : float, optional
3339
Elastic net penalty strength (default=0.0).
40+
3441
l1_ratio : float, optional
3542
Elastic net mixing parameter (0 = Ridge, 1 = Lasso, default=0.0).
43+
3644
type_loss : {'mse', 'quantile'}, optional
3745
Type of loss function to use (default='mse').
46+
3847
q : float, optional
3948
Quantile for quantile loss (default=0.5).
49+
4050
**kwargs
4151
Additional parameters to pass to the scikit-learn model.
52+
4253
"""
4354

4455
def __init__(self, base_model,
@@ -70,10 +81,13 @@ def _loss(self, X, y, **kwargs):
7081
7182
Parameters
7283
----------
84+
7385
X : array-like of shape (n_samples, n_features)
7486
Input data.
87+
7588
y : array-like of shape (n_samples,)
7689
Target values.
90+
7791
**kwargs
7892
Additional keyword arguments for loss calculation.
7993
@@ -98,13 +112,16 @@ def _compute_grad(self, X, y):
98112
99113
Parameters
100114
----------
115+
101116
X : array-like of shape (n_samples, n_features)
102117
Input data.
118+
103119
y : array-like of shape (n_samples,)
104120
Target values.
105121
106122
Returns
107123
-------
124+
108125
ndarray
109126
Gradient array with the same shape as W_.
110127
"""
@@ -147,23 +164,31 @@ def fit(self, X, y, epochs=10, verbose=True, show_progress=True, sample_weight=N
147164
148165
Parameters
149166
----------
167+
150168
X : array-like of shape (n_samples, n_features)
151169
Training data.
170+
152171
y : array-like of shape (n_samples,)
153172
Target values.
173+
154174
epochs : int, optional
155175
Number of optimization steps (default=10).
176+
156177
verbose : bool, optional
157178
Whether to print progress messages (default=True).
179+
158180
show_progress : bool, optional
159181
Whether to show tqdm progress bar (default=True).
182+
160183
sample_weight : array-like, optional
161184
Sample weights.
185+
162186
**kwargs
163187
Additional keyword arguments.
164188
165189
Returns
166190
-------
191+
167192
self : object
168193
Returns self.
169194
"""
@@ -255,18 +280,23 @@ def predict(self, X, level=95, method='splitconformal', **kwargs):
255280
256281
Parameters
257282
----------
283+
258284
X : array-like of shape (n_samples, n_features)
259285
Input data.
286+
260287
level : int, optional
261288
Level of confidence for prediction intervals (default=95).
289+
262290
method : {'splitconformal', 'localconformal'}, optional
263291
Method for conformal prediction (default='splitconformal').
292+
264293
**kwargs
265294
Additional keyword arguments. Use `return_pi=True` for prediction intervals,
266295
or `return_std=True` for standard deviation estimates.
267296
268297
Returns
269298
-------
299+
270300
array or tuple
271301
Model predictions, or a tuple with prediction intervals or standard deviations if requested.
272302
"""

src/tisthemachinelearner/regressor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ class Regressor(BaseModel, RegressorMixin):
88
Wrapper for scikit-learn regressor models.
99
1010
Parameters:
11+
1112
- model_name (str): The name of the scikit-learn regressor model.
13+
1214
- **kwargs: Additional parameters to pass to the scikit-learn model.
1315
1416
Examples:

0 commit comments

Comments
 (0)