Skip to content

Commit b8c59ac

Browse files
committed
Improve tests for TimeDistributed
1 parent 3cc2426 commit b8c59ac

File tree

1 file changed

+19
-2
lines changed

1 file changed

+19
-2
lines changed

tests/keras/layers/test_wrappers.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@ def test_TimeDistributed():
1212
model = Sequential()
1313
model.add(wrappers.TimeDistributed(core.Dense(2), input_shape=(3, 4)))
1414
model.add(core.Activation('relu'))
15-
1615
model.compile(optimizer='rmsprop', loss='mse')
1716
model.fit(np.random.random((10, 3, 4)), np.random.random((10, 3, 2)), nb_epoch=1, batch_size=10)
1817

18+
# test config
1919
model.get_config()
2020

2121
# compare to TimeDistributedDense
@@ -26,7 +26,15 @@ def test_TimeDistributed():
2626
reference = Sequential()
2727
reference.add(core.TimeDistributedDense(2, input_shape=(3, 4), weights=weights))
2828
reference.add(core.Activation('relu'))
29+
reference.compile(optimizer='rmsprop', loss='mse')
30+
31+
reference_output = reference.predict(test_input)
32+
assert_allclose(test_output, reference_output, atol=1e-05)
2933

34+
# test when specifying a batch_input_shape
35+
reference = Sequential()
36+
reference.add(core.TimeDistributedDense(2, batch_input_shape=(1, 3, 4), weights=weights))
37+
reference.add(core.Activation('relu'))
3038
reference.compile(optimizer='rmsprop', loss='mse')
3139

3240
reference_output = reference.predict(test_input)
@@ -36,13 +44,22 @@ def test_TimeDistributed():
3644
model = Sequential()
3745
model.add(wrappers.TimeDistributed(convolutional.Convolution2D(5, 2, 2, border_mode='same'), input_shape=(2, 3, 4, 4)))
3846
model.add(core.Activation('relu'))
39-
4047
model.compile(optimizer='rmsprop', loss='mse')
4148
model.train_on_batch(np.random.random((1, 2, 3, 4, 4)), np.random.random((1, 2, 5, 4, 4)))
4249

4350
model = model_from_json(model.to_json())
4451
model.summary()
4552

53+
# test stacked layers
54+
model = Sequential()
55+
model.add(wrappers.TimeDistributed(core.Dense(2), input_shape=(3, 4)))
56+
model.add(wrappers.TimeDistributed(core.Dense(3)))
57+
model.add(core.Activation('relu'))
58+
model.compile(optimizer='rmsprop', loss='mse')
59+
60+
model.fit(np.random.random((10, 3, 4)), np.random.random((10, 3, 3)), nb_epoch=1, batch_size=10)
61+
62+
4663

4764
if __name__ == '__main__':
4865
pytest.main([__file__])

0 commit comments

Comments
 (0)