1212import rnnrbm
1313import SdA
1414import rnnslu
15+ import lstm
1516
1617
1718def test_rnnslu ():
@@ -61,13 +62,17 @@ def test_rnnrbm():
6162 rnnrbm .test_rnnrbm (num_epochs = 1 )
6263
6364
65+ def test_lstm ():
66+ lstm .train_lstm (max_epochs = 1 , test_size = 1000 , saveto = '' )
67+
68+
6469def speed ():
6570 """
6671 This fonction modify the configuration theano and don't restore it!
6772 """
6873
6974 algo = ['logistic_sgd' , 'logistic_cg' , 'mlp' , 'convolutional_mlp' ,
70- 'dA' , 'SdA' , 'DBN' , 'rbm' , 'rnnrbm' , 'rnnslu' ]
75+ 'dA' , 'SdA' , 'DBN' , 'rbm' , 'rnnrbm' , 'rnnslu' , 'lstm' ]
7176 to_exec = [True ] * len (algo )
7277# to_exec = [False] * len(algo)
7378# to_exec[-1] = True
@@ -82,9 +87,9 @@ def speed():
8287 # 7.1-2 (python 2.7.2, mkl unknow). BLAS with only 1 thread.
8388
8489 expected_times_64 = numpy .asarray ([9.8 , 22.5 , 76.1 , 73.7 , 116.4 ,
85- 346.9 , 381.9 , 558.1 , 186.3 , 50.8 ])
90+ 346.9 , 381.9 , 558.1 , 186.3 , 50.8 , 113.6 ])
8691 expected_times_32 = numpy .asarray ([8.1 , 17.9 , 42.5 , 66.5 , 71 ,
87- 191.2 , 226.8 , 432.8 , 176.2 , 36.9 ])
92+ 191.2 , 226.8 , 432.8 , 176.2 , 36.9 , 78.0 ])
8893
8994 # Number with just 1 decimal are new value that are faster with
9095 # the Theano version 0.5rc2 Other number are older. They are not
@@ -104,8 +109,8 @@ def speed():
104109# 1.35324519 1.7356905 1.12937868]
105110
106111 expected_times_gpu = numpy .asarray ([3.0 , 7.55523491 , 18.99226785 ,
107- 5.8 , 21 .5 ,
108- 11.8 , 47.9 , 290.1 , 255.4 , 72.4 ])
112+ 5.8 , 20 .5 ,
113+ 11.8 , 47.9 , 290.1 , 255.4 , 72.4 , 17.0 ])
109114 expected_times_64 = [s for idx , s in enumerate (expected_times_64 )
110115 if to_exec [idx ]]
111116 expected_times_32 = [s for idx , s in enumerate (expected_times_32 )
@@ -162,6 +167,8 @@ def do_tests():
162167 # 60 is recommended
163168 'savemodel' : False }
164169 time_test (m , l , 9 , rnnslu .main , param = s )
170+ time_test (m , l , 10 , lstm .train_lstm , max_epochs = 1 , test_size = 1000 ,
171+ saveto = '' )
165172 return numpy .asarray (l )
166173
167174 #test in float64 in FAST_RUN mode on the cpu
0 commit comments