@@ -373,22 +373,22 @@ def pred_error(f_pred, prepare_data, data, iterator, verbose=False):
373373 return valid_err
374374
375375
376- def test_lstm (
376+ def train_lstm (
377377 dim_proj = 128 , # word embeding dimension and LSTM number of hidden units.
378- patience = 10 , # number of epoch to wait before early stop if no progress
378+ patience = 10 , # Number of epoch to wait before early stop if no progress
379379 max_epochs = 5000 , # The maximum number of epoch to run
380- dispFreq = 10 , # display to stdout the training progress every N updates
381- decay_c = 0. , # weight decay for the classifier applied to the U weights.
382- lrate = 0.0001 , # learning rate for sgd (not used for adadelta and rmsprop)
383- n_words = 10000 , # vocabulary size
384- optimizer = adadelta , # sgd, adadelta and rmsprop available, sgd very hard to use, not recommanded (probably need momentum and decay learning rate).
380+ dispFreq = 10 , # Display to stdout the training progress every N updates
381+ decay_c = 0. , # Weight decay for the classifier applied to the U weights.
382+ lrate = 0.0001 , # Learning rate for sgd (not used for adadelta and rmsprop)
383+ n_words = 10000 , # Vocabulary size
384+ optimizer = adadelta , # sgd, adadelta and rmsprop available, sgd very hard to use, not recommanded (probably need momentum and decaying learning rate).
385385 encoder = 'lstm' , # TODO: can be removed must be lstm.
386386 saveto = 'lstm_model.npz' , # The best model will be saved there
387- validFreq = 10000 , # after 1000
388- saveFreq = 100000 , # save the parameters after every saveFreq updates
389- maxlen = 100 , # longer sequence get ignored
390- batch_size = 64 , # the batch size during training.
391- valid_batch_size = 64 , # The batch size during validation
387+ validFreq = 390 , # Compute the validation error after this number of update.
388+ saveFreq = 1040 , # Save the parameters after every saveFreq updates
389+ maxlen = 100 , # Sequence longer then this get ignored
390+ batch_size = 16 , # The batch size during training.
391+ valid_batch_size = 64 , # The batch size used for validation/test set.
392392 dataset = 'imdb' ,
393393
394394 # Parameter for extra option
@@ -400,11 +400,13 @@ def test_lstm(
400400
401401 # Model options
402402 model_options = locals ().copy ()
403+ print "model options" , model_options
403404
404405 load_data , prepare_data = get_dataset (dataset )
405406
406407 print 'Loading data'
407- train , valid , test = load_data (n_words = n_words , valid_portion = 0.01 )
408+ train , valid , test = load_data (n_words = n_words , valid_portion = 0.01 ,
409+ maxlen = maxlen )
408410
409411 ydim = numpy .max (train [1 ])+ 1
410412
@@ -569,7 +571,7 @@ def test_lstm(
569571 theano .config .scan .allow_gc = False
570572
571573 # See function train for all possible parameter and there definition.
572- test_lstm (
574+ train_lstm (
573575 #reload_model="lstm_model.npz",
574- max_epochs = 10 ,
576+ max_epochs = 100 ,
575577 )
0 commit comments