Skip to content

Commit df860fd

Browse files
committed
Update IRNN example
1 parent 361a7cf commit df860fd

File tree

1 file changed

+4
-6
lines changed

1 file changed

+4
-6
lines changed

examples/mnist_irnn.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
Optimizer is replaced with RMSprop which yields more stable and steady
2424
improvement.
2525
26-
Reaches 0.93 train/test accuracy after 900 epochs (which roughly corresponds
27-
to 1687500 steps in the original paper.)
26+
Reaches 0.93 train/test accuracy after 900 epochs
27+
(which roughly corresponds to 1687500 steps in the original paper.)
2828
'''
2929

3030
batch_size = 32
@@ -34,7 +34,6 @@
3434

3535
learning_rate = 1e-6
3636
clip_norm = 1.0
37-
BPTT_truncate = 28*28
3837

3938
# the data, shuffled and split between train and test sets
4039
(X_train, y_train), (X_test, y_test) = mnist.load_data()
@@ -58,8 +57,7 @@
5857
model.add(SimpleRNN(output_dim=hidden_units,
5958
init=lambda shape: normal(shape, scale=0.001),
6059
inner_init=lambda shape: identity(shape, scale=1.0),
61-
activation='relu', truncate_gradient=BPTT_truncate,
62-
input_shape=(None, 1)))
60+
activation='relu', input_shape=X_train.shape[1:]))
6361
model.add(Dense(nb_classes))
6462
model.add(Activation('softmax'))
6563
rmsprop = RMSprop(lr=learning_rate)
@@ -74,7 +72,7 @@
7472

7573
print('Compare to LSTM...')
7674
model = Sequential()
77-
model.add(LSTM(hidden_units, input_shape=(None, 1)))
75+
model.add(LSTM(hidden_units, input_shape=X_train.shape[1:]))
7876
model.add(Dense(nb_classes))
7977
model.add(Activation('softmax'))
8078
rmsprop = RMSprop(lr=learning_rate)

0 commit comments

Comments
 (0)