We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 408ea68 commit ff00edbCopy full SHA for ff00edb
code/lstm.py
@@ -511,7 +511,7 @@ def train_lstm(
511
if numpy.mod(uidx, dispFreq) == 0:
512
print 'Epoch ', eidx, 'Update ', uidx, 'Cost ', cost
513
514
- if numpy.mod(uidx, saveFreq) == 0:
+ if saveto and numpy.mod(uidx, saveFreq) == 0:
515
print 'Saving...',
516
517
if best_p is not None:
@@ -571,10 +571,10 @@ def train_lstm(
571
test_err = pred_error(f_pred, prepare_data, test, kf_test)
572
573
print 'Train ', train_err, 'Valid ', valid_err, 'Test ', test_err
574
-
575
- numpy.savez(saveto, train_err=train_err,
576
- valid_err=valid_err, test_err=test_err,
577
- history_errs=history_errs, **best_p)
+ if saveto:
+ numpy.savez(saveto, train_err=train_err,
+ valid_err=valid_err, test_err=test_err,
+ history_errs=history_errs, **best_p)
578
print 'The code run for %d epochs, with %f sec/epochs' % (
579
(eidx + 1), (end_time - start_time) / (1. * (eidx + 1)))
580
print >> sys.stderr, ('Training took %.1fs' %
code/test.py
@@ -12,6 +12,7 @@
12
import rnnrbm
13
import SdA
14
import rnnslu
15
+import lstm
16
17
18
def test_rnnslu():
@@ -61,6 +62,10 @@ def test_rnnrbm():
61
62
rnnrbm.test_rnnrbm(num_epochs=1)
63
64
65
+def test_lstm():
66
+ lstm.train_lstm(max_epochs=1, test_size=1000, saveto='')
67
+
68
69
def speed():
70
"""
71
This fonction modify the configuration theano and don't restore it!
0 commit comments