Skip to content

Commit cf86e2a

Browse files
Update finitedifftrainer.py -- update weights initialization
1 parent 1899d79 commit cf86e2a

File tree

1 file changed

+4
-14
lines changed

1 file changed

+4
-14
lines changed

src/tisthemachinelearner/finitedifftrainer.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -79,23 +79,13 @@ def __init__(self, base_model,
7979
self._cd_index = 0
8080
self._is_initialized = False
8181

82-
def _initialize_weights(self, X):
82+
def _initialize_weights(self, X, y):
8383
"""Initialize weights using proper neural network initialization"""
84-
input_dim = X.shape[1]
85-
86-
# Get model architecture details
87-
n_hidden = getattr(self.model, 'n_hidden_features')
88-
n_clusters = getattr(self.model, 'n_clusters')
89-
90-
# Determine weight shape
91-
if n_clusters >= 0:
92-
shape = (input_dim, n_hidden + n_clusters)
93-
else:
94-
shape = (input_dim, n_hidden)
95-
84+
self.model.fit(X, y)
85+
input_dim = X.shape[1]
9686
# He initialization (good for ReLU-like activations)
9787
scale = np.sqrt(2.0 / input_dim)
98-
self.model.W_ = np.random.normal(0, scale, size=shape)
88+
self.model.W_ = np.random.normal(0, scale, size=self.model.W_.shape)
9989
self._is_initialized = True
10090

10191
def _loss(self, X, y, **kwargs):

0 commit comments

Comments
 (0)