@@ -12,10 +12,10 @@ def test_TimeDistributed():
1212 model = Sequential ()
1313 model .add (wrappers .TimeDistributed (core .Dense (2 ), input_shape = (3 , 4 )))
1414 model .add (core .Activation ('relu' ))
15-
1615 model .compile (optimizer = 'rmsprop' , loss = 'mse' )
1716 model .fit (np .random .random ((10 , 3 , 4 )), np .random .random ((10 , 3 , 2 )), nb_epoch = 1 , batch_size = 10 )
1817
18+ # test config
1919 model .get_config ()
2020
2121 # compare to TimeDistributedDense
@@ -26,7 +26,15 @@ def test_TimeDistributed():
2626 reference = Sequential ()
2727 reference .add (core .TimeDistributedDense (2 , input_shape = (3 , 4 ), weights = weights ))
2828 reference .add (core .Activation ('relu' ))
29+ reference .compile (optimizer = 'rmsprop' , loss = 'mse' )
30+
31+ reference_output = reference .predict (test_input )
32+ assert_allclose (test_output , reference_output , atol = 1e-05 )
2933
34+ # test when specifying a batch_input_shape
35+ reference = Sequential ()
36+ reference .add (core .TimeDistributedDense (2 , batch_input_shape = (1 , 3 , 4 ), weights = weights ))
37+ reference .add (core .Activation ('relu' ))
3038 reference .compile (optimizer = 'rmsprop' , loss = 'mse' )
3139
3240 reference_output = reference .predict (test_input )
@@ -36,13 +44,22 @@ def test_TimeDistributed():
3644 model = Sequential ()
3745 model .add (wrappers .TimeDistributed (convolutional .Convolution2D (5 , 2 , 2 , border_mode = 'same' ), input_shape = (2 , 3 , 4 , 4 )))
3846 model .add (core .Activation ('relu' ))
39-
4047 model .compile (optimizer = 'rmsprop' , loss = 'mse' )
4148 model .train_on_batch (np .random .random ((1 , 2 , 3 , 4 , 4 )), np .random .random ((1 , 2 , 5 , 4 , 4 )))
4249
4350 model = model_from_json (model .to_json ())
4451 model .summary ()
4552
53+ # test stacked layers
54+ model = Sequential ()
55+ model .add (wrappers .TimeDistributed (core .Dense (2 ), input_shape = (3 , 4 )))
56+ model .add (wrappers .TimeDistributed (core .Dense (3 )))
57+ model .add (core .Activation ('relu' ))
58+ model .compile (optimizer = 'rmsprop' , loss = 'mse' )
59+
60+ model .fit (np .random .random ((10 , 3 , 4 )), np .random .random ((10 , 3 , 3 )), nb_epoch = 1 , batch_size = 10 )
61+
62+
4663
4764if __name__ == '__main__' :
4865 pytest .main ([__file__ ])
0 commit comments