Skip to content

Commit 7c96643

Browse files
committed
Updated wrapper test script.
1 parent c9461d7 commit 7c96643

File tree

1 file changed

+60
-11
lines changed

1 file changed

+60
-11
lines changed

tests/manual/check_wrappers.py

Lines changed: 60 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,28 @@
44
from keras.models import Sequential
55
from keras.layers.core import Dense, Activation
66
from keras.utils import np_utils
7-
from keras.wrappers.scikit_learn import KerasClassifier
7+
from keras.wrappers.scikit_learn import *
88
import numpy as np
99

10-
nb_classes = 10
1110
batch_size = 128
1211
nb_epoch = 1
1312

13+
nb_classes = 10
1414
max_train_samples = 5000
1515
max_test_samples = 1000
1616

1717
np.random.seed(1337) # for reproducibility
1818

19-
# the data, shuffled and split between tran and test sets
19+
############################################
20+
# scikit-learn classification wrapper test #
21+
############################################
22+
print('Beginning scikit-learn classification wrapper test')
23+
24+
print('Loading data')
2025
(X_train, y_train), (X_test, y_test) = mnist.load_data()
2126

22-
X_train = X_train.reshape(60000,784)[:max_train_samples]
23-
X_test = X_test.reshape(10000,784)[:max_test_samples]
27+
X_train = X_train.reshape(60000, 784)[:max_train_samples]
28+
X_test = X_test.reshape(10000, 784)[:max_test_samples]
2429
X_train = X_train.astype('float32')
2530
X_test = X_test.astype('float32')
2631
X_train /= 255
@@ -29,11 +34,6 @@
2934
Y_train = np_utils.to_categorical(y_train, nb_classes)[:max_train_samples]
3035
Y_test = np_utils.to_categorical(y_test, nb_classes)[:max_test_samples]
3136

32-
#############################
33-
# scikit-learn wrapper test #
34-
#############################
35-
print('Beginning scikit-learn wrapper test')
36-
3737
print('Defining model')
3838
model = Sequential()
3939
model.add(Dense(784, 50))
@@ -63,7 +63,7 @@
6363
print(classifier.get_params())
6464

6565
print('Testing set params')
66-
classifier.set_params(optimizer='sgd', loss='mse')
66+
classifier.set_params(optimizer='sgd', loss='binary_crossentropy')
6767
print(classifier.get_params())
6868

6969
print('Testing attributes')
@@ -73,5 +73,54 @@
7373
print(classifier.config_)
7474
print('Weights')
7575
print(classifier.weights_)
76+
print('Compiled model')
77+
print(classifier.compiled_model_)
78+
79+
########################################
80+
# scikit-learn regression wrapper test #
81+
########################################
82+
print('Beginning scikit-learn regression wrapper test')
83+
84+
print('Generating data')
85+
X_train = np.random((5000, 100))
86+
X_test = np.random((1000, 100))
87+
y_train = np.random(5000)
88+
y_test = np.random(1000)
89+
90+
print('Defining model')
91+
model = Sequential()
92+
model.add(Dense(100, 50))
93+
model.add(Activation('relu'))
94+
model.add(Dense(50, 10))
95+
model.add(Activation('linear'))
96+
97+
print('Creating wrapper')
98+
regressor = KerasRegressor(model)
99+
100+
print('Fitting model')
101+
regressor.fit(X_train, y_train, batch_size=batch_size, nb_epoch=nb_epoch)
102+
103+
print('Testing score function')
104+
score = regressor.score(X_train, y_train)
105+
print('Score: ', score)
106+
107+
print('Testing predict function')
108+
preds = regressor.predict(X_test)
109+
print('Preds.shape: ', preds.shape)
110+
111+
print('Testing get params')
112+
print(regressor.get_params())
113+
114+
print('Testing set params')
115+
regressor.set_params(optimizer='sgd', loss='mean_absolute_error')
116+
print(regressor.get_params())
117+
118+
print('Testing attributes')
119+
print('Config')
120+
print(regressor.config_)
121+
print('Weights')
122+
print(regressor.weights_)
123+
print('Compiled model')
124+
print(regressor.compiled_model_)
76125

77126
print('Test script complete.')

0 commit comments

Comments
 (0)