Skip to content

Commit ff00edb

Browse files
committed
Add test for lstm
1 parent 408ea68 commit ff00edb

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

code/lstm.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -511,7 +511,7 @@ def train_lstm(
511511
if numpy.mod(uidx, dispFreq) == 0:
512512
print 'Epoch ', eidx, 'Update ', uidx, 'Cost ', cost
513513

514-
if numpy.mod(uidx, saveFreq) == 0:
514+
if saveto and numpy.mod(uidx, saveFreq) == 0:
515515
print 'Saving...',
516516

517517
if best_p is not None:
@@ -571,10 +571,10 @@ def train_lstm(
571571
test_err = pred_error(f_pred, prepare_data, test, kf_test)
572572

573573
print 'Train ', train_err, 'Valid ', valid_err, 'Test ', test_err
574-
575-
numpy.savez(saveto, train_err=train_err,
576-
valid_err=valid_err, test_err=test_err,
577-
history_errs=history_errs, **best_p)
574+
if saveto:
575+
numpy.savez(saveto, train_err=train_err,
576+
valid_err=valid_err, test_err=test_err,
577+
history_errs=history_errs, **best_p)
578578
print 'The code run for %d epochs, with %f sec/epochs' % (
579579
(eidx + 1), (end_time - start_time) / (1. * (eidx + 1)))
580580
print >> sys.stderr, ('Training took %.1fs' %

code/test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import rnnrbm
1313
import SdA
1414
import rnnslu
15+
import lstm
1516

1617

1718
def test_rnnslu():
@@ -61,6 +62,10 @@ def test_rnnrbm():
6162
rnnrbm.test_rnnrbm(num_epochs=1)
6263

6364

65+
def test_lstm():
66+
lstm.train_lstm(max_epochs=1, test_size=1000, saveto='')
67+
68+
6469
def speed():
6570
"""
6671
This fonction modify the configuration theano and don't restore it!

0 commit comments

Comments
 (0)