|
4 | 4 | from keras.models import Sequential |
5 | 5 | from keras.layers.core import Dense, Activation |
6 | 6 | from keras.utils import np_utils |
7 | | -from keras.wrappers.scikit_learn import KerasClassifier |
| 7 | +from keras.wrappers.scikit_learn import * |
8 | 8 | import numpy as np |
9 | 9 |
|
10 | | -nb_classes = 10 |
11 | 10 | batch_size = 128 |
12 | 11 | nb_epoch = 1 |
13 | 12 |
|
| 13 | +nb_classes = 10 |
14 | 14 | max_train_samples = 5000 |
15 | 15 | max_test_samples = 1000 |
16 | 16 |
|
17 | 17 | np.random.seed(1337) # for reproducibility |
18 | 18 |
|
19 | | -# the data, shuffled and split between tran and test sets |
| 19 | +############################################ |
| 20 | +# scikit-learn classification wrapper test # |
| 21 | +############################################ |
| 22 | +print('Beginning scikit-learn classification wrapper test') |
| 23 | + |
| 24 | +print('Loading data') |
20 | 25 | (X_train, y_train), (X_test, y_test) = mnist.load_data() |
21 | 26 |
|
22 | | -X_train = X_train.reshape(60000,784)[:max_train_samples] |
23 | | -X_test = X_test.reshape(10000,784)[:max_test_samples] |
| 27 | +X_train = X_train.reshape(60000, 784)[:max_train_samples] |
| 28 | +X_test = X_test.reshape(10000, 784)[:max_test_samples] |
24 | 29 | X_train = X_train.astype('float32') |
25 | 30 | X_test = X_test.astype('float32') |
26 | 31 | X_train /= 255 |
|
29 | 34 | Y_train = np_utils.to_categorical(y_train, nb_classes)[:max_train_samples] |
30 | 35 | Y_test = np_utils.to_categorical(y_test, nb_classes)[:max_test_samples] |
31 | 36 |
|
32 | | -############################# |
33 | | -# scikit-learn wrapper test # |
34 | | -############################# |
35 | | -print('Beginning scikit-learn wrapper test') |
36 | | - |
37 | 37 | print('Defining model') |
38 | 38 | model = Sequential() |
39 | 39 | model.add(Dense(784, 50)) |
|
63 | 63 | print(classifier.get_params()) |
64 | 64 |
|
65 | 65 | print('Testing set params') |
66 | | -classifier.set_params(optimizer='sgd', loss='mse') |
| 66 | +classifier.set_params(optimizer='sgd', loss='binary_crossentropy') |
67 | 67 | print(classifier.get_params()) |
68 | 68 |
|
69 | 69 | print('Testing attributes') |
|
73 | 73 | print(classifier.config_) |
74 | 74 | print('Weights') |
75 | 75 | print(classifier.weights_) |
| 76 | +print('Compiled model') |
| 77 | +print(classifier.compiled_model_) |
| 78 | + |
| 79 | +######################################## |
| 80 | +# scikit-learn regression wrapper test # |
| 81 | +######################################## |
| 82 | +print('Beginning scikit-learn regression wrapper test') |
| 83 | + |
| 84 | +print('Generating data') |
| 85 | +X_train = np.random((5000, 100)) |
| 86 | +X_test = np.random((1000, 100)) |
| 87 | +y_train = np.random(5000) |
| 88 | +y_test = np.random(1000) |
| 89 | + |
| 90 | +print('Defining model') |
| 91 | +model = Sequential() |
| 92 | +model.add(Dense(100, 50)) |
| 93 | +model.add(Activation('relu')) |
| 94 | +model.add(Dense(50, 10)) |
| 95 | +model.add(Activation('linear')) |
| 96 | + |
| 97 | +print('Creating wrapper') |
| 98 | +regressor = KerasRegressor(model) |
| 99 | + |
| 100 | +print('Fitting model') |
| 101 | +regressor.fit(X_train, y_train, batch_size=batch_size, nb_epoch=nb_epoch) |
| 102 | + |
| 103 | +print('Testing score function') |
| 104 | +score = regressor.score(X_train, y_train) |
| 105 | +print('Score: ', score) |
| 106 | + |
| 107 | +print('Testing predict function') |
| 108 | +preds = regressor.predict(X_test) |
| 109 | +print('Preds.shape: ', preds.shape) |
| 110 | + |
| 111 | +print('Testing get params') |
| 112 | +print(regressor.get_params()) |
| 113 | + |
| 114 | +print('Testing set params') |
| 115 | +regressor.set_params(optimizer='sgd', loss='mean_absolute_error') |
| 116 | +print(regressor.get_params()) |
| 117 | + |
| 118 | +print('Testing attributes') |
| 119 | +print('Config') |
| 120 | +print(regressor.config_) |
| 121 | +print('Weights') |
| 122 | +print(regressor.weights_) |
| 123 | +print('Compiled model') |
| 124 | +print(regressor.compiled_model_) |
76 | 125 |
|
77 | 126 | print('Test script complete.') |
0 commit comments