Skip to content

Commit 396ca02

Browse files
Update finitedifftrainer.py -- use seed
1 parent cf86e2a commit 396ca02

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

src/tisthemachinelearner/finitedifftrainer.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ class FiniteDiffRegressor(BaseModel, RegressorMixin):
4747
q : float, optional
4848
Quantile for quantile loss (default=0.5).
4949
50+
seed : int
51+
Random seed.
52+
5053
**kwargs
5154
Additional parameters to pass to the scikit-learn model.
5255
@@ -72,6 +75,7 @@ def __init__(self, base_model,
7275
self.l1_ratio = l1_ratio
7376
self.type_loss = type_loss
7477
self.q = q
78+
self.seed = seed
7579

7680
# Training state
7781
self.loss_history_ = []
@@ -85,6 +89,7 @@ def _initialize_weights(self, X, y):
8589
input_dim = X.shape[1]
8690
# He initialization (good for ReLU-like activations)
8791
scale = np.sqrt(2.0 / input_dim)
92+
np.random.seed(self.seed)
8893
self.model.W_ = np.random.normal(0, scale, size=self.model.W_.shape)
8994
self._is_initialized = True
9095

@@ -108,7 +113,7 @@ def _loss(self, X, y, **kwargs):
108113
def _compute_grad(self, X, y):
109114
"""Compute gradient using finite differences"""
110115
if not self._is_initialized:
111-
self._initialize_weights(X)
116+
self._initialize_weights(X, y)
112117

113118
W = self.model.W_.copy() # Use current weights
114119
shape = W.shape
@@ -152,7 +157,7 @@ def fit(self, X, y, epochs=10, verbose=True, show_progress=True, sample_weight=N
152157
"""Fit model using finite difference optimization"""
153158
# Initialize weights if not already done
154159
if not self._is_initialized:
155-
self._initialize_weights(X)
160+
self._initialize_weights(X, y)
156161

157162
# Training loop
158163
iterator = tqdm(range(epochs)) if show_progress else range(epochs)

0 commit comments

Comments
 (0)