@@ -384,8 +384,8 @@ def train_lstm(
384384 optimizer = adadelta , # sgd, adadelta and rmsprop available, sgd very hard to use, not recommanded (probably need momentum and decaying learning rate).
385385 encoder = 'lstm' , # TODO: can be removed must be lstm.
386386 saveto = 'lstm_model.npz' , # The best model will be saved there
387- validFreq = 390 , # Compute the validation error after this number of update.
388- saveFreq = 1040 , # Save the parameters after every saveFreq updates
387+ validFreq = 370 , # Compute the validation error after this number of update.
388+ saveFreq = 1110 , # Save the parameters after every saveFreq updates
389389 maxlen = 100 , # Sequence longer then this get ignored
390390 batch_size = 16 , # The batch size during training.
391391 valid_batch_size = 64 , # The batch size used for validation/test set.
@@ -467,80 +467,85 @@ def train_lstm(
467467 uidx = 0 # the number of update done
468468 estop = False # early stop
469469 start_time = time .clock ()
470- for eidx in xrange (max_epochs ):
471- n_samples = 0
472-
473- # Get new shuffled index for the training set.
474- kf = get_minibatches_idx (len (train [0 ]), batch_size , shuffle = True )
475-
476- for _ , train_index in kf :
477- uidx += 1
478- use_noise .set_value (1. )
479-
480- # Select the random examples for this minibatch
481- y = [train [1 ][t ] for t in train_index ]
482- x = [train [0 ][t ]for t in train_index ]
483-
484- # Get the data in numpy.ndarray formet.
485- # It return something of the shape (minibatch maxlen, n samples)
486- x , mask , y = prepare_data (x , y , maxlen = maxlen )
487- if x is None :
488- print 'Minibatch with zero sample under length ' , maxlen
489- continue
490- n_samples += x .shape [1 ]
491-
492- cost = f_grad_shared (x , mask , y )
493- f_update (lrate )
494-
495- if numpy .isnan (cost ) or numpy .isinf (cost ):
496- print 'NaN detected'
497- return 1. , 1. , 1.
498-
499- if numpy .mod (uidx , dispFreq ) == 0 :
500- print 'Epoch ' , eidx , 'Update ' , uidx , 'Cost ' , cost
501-
502- if numpy .mod (uidx , saveFreq ) == 0 :
503- print 'Saving...' ,
504-
505- if best_p is not None :
506- params = best_p
507- else :
508- params = unzip (tparams )
509- numpy .savez (saveto , history_errs = history_errs , ** params )
510- pkl .dump (model_options , open ('%s.pkl' % saveto , 'wb' ), - 1 )
511- print 'Done'
512-
513- if numpy .mod (uidx , validFreq ) == 0 :
514- use_noise .set_value (0. )
515- train_err = pred_error (f_pred , prepare_data , train , kf )
516- valid_err = pred_error (f_pred , prepare_data , valid , kf_valid )
517- test_err = pred_error (f_pred , prepare_data , test , kf_test )
518-
519- history_errs .append ([valid_err , test_err ])
520-
521- if (uidx == 0 or
522- valid_err <= numpy .array (history_errs )[:,
523- 0 ].min ()):
524-
525- best_p = unzip (tparams )
526- bad_counter = 0
527-
528- print ('Train ' , train_err , 'Valid ' , valid_err ,
529- 'Test ' , test_err )
530-
531- if (len (history_errs ) > patience and
532- valid_err >= numpy .array (history_errs )[:- patience ,
533- 0 ].min ()):
534- bad_counter += 1
535- if bad_counter > patience :
536- print 'Early Stop!'
537- estop = True
538- break
539-
540- print 'Seen %d samples' % n_samples
541-
542- if estop :
543- break
470+ try :
471+ for eidx in xrange (max_epochs ):
472+ n_samples = 0
473+
474+ # Get new shuffled index for the training set.
475+ kf = get_minibatches_idx (len (train [0 ]), batch_size , shuffle = True )
476+
477+ for _ , train_index in kf :
478+ uidx += 1
479+ use_noise .set_value (1. )
480+
481+ # Select the random examples for this minibatch
482+ y = [train [1 ][t ] for t in train_index ]
483+ x = [train [0 ][t ]for t in train_index ]
484+
485+ # Get the data in numpy.ndarray formet.
486+ # It return something of the shape (minibatch maxlen, n samples)
487+ x , mask , y = prepare_data (x , y , maxlen = maxlen )
488+ if x is None :
489+ print 'Minibatch with zero sample under length ' , maxlen
490+ continue
491+ n_samples += x .shape [1 ]
492+
493+ cost = f_grad_shared (x , mask , y )
494+ f_update (lrate )
495+
496+ if numpy .isnan (cost ) or numpy .isinf (cost ):
497+ print 'NaN detected'
498+ return 1. , 1. , 1.
499+
500+ if numpy .mod (uidx , dispFreq ) == 0 :
501+ print 'Epoch ' , eidx , 'Update ' , uidx , 'Cost ' , cost
502+
503+ if numpy .mod (uidx , saveFreq ) == 0 :
504+ print 'Saving...' ,
505+
506+ if best_p is not None :
507+ params = best_p
508+ else :
509+ params = unzip (tparams )
510+ numpy .savez (saveto , history_errs = history_errs , ** params )
511+ pkl .dump (model_options , open ('%s.pkl' % saveto , 'wb' ), - 1 )
512+ print 'Done'
513+
514+ if numpy .mod (uidx , validFreq ) == 0 :
515+ use_noise .set_value (0. )
516+ train_err = pred_error (f_pred , prepare_data , train , kf )
517+ valid_err = pred_error (f_pred , prepare_data , valid , kf_valid )
518+ test_err = pred_error (f_pred , prepare_data , test , kf_test )
519+
520+ history_errs .append ([valid_err , test_err ])
521+
522+ if (uidx == 0 or
523+ valid_err <= numpy .array (history_errs )[:,
524+ 0 ].min ()):
525+
526+ best_p = unzip (tparams )
527+ bad_counter = 0
528+
529+ print ('Train ' , train_err , 'Valid ' , valid_err ,
530+ 'Test ' , test_err )
531+
532+ if (len (history_errs ) > patience and
533+ valid_err >= numpy .array (history_errs )[:- patience ,
534+ 0 ].min ()):
535+ bad_counter += 1
536+ if bad_counter > patience :
537+ print 'Early Stop!'
538+ estop = True
539+ break
540+
541+ print 'Seen %d samples' % n_samples
542+
543+ if estop :
544+ break
545+
546+ except KeyboardInterrupt :
547+ print "Training interupted"
548+
544549 end_time = time .clock ()
545550 if best_p is not None :
546551 zipp (best_p , tparams )
0 commit comments