Skip to content

Commit 74b2e0c

Browse files
committed
Filter for the max seq len when we load the dataset
1 parent 04c02d4 commit 74b2e0c

File tree

2 files changed

+37
-18
lines changed

2 files changed

+37
-18
lines changed

code/imdb.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,19 @@ def get_dataset_file(dataset, default_dataset, origin):
7474
return dataset
7575

7676

77-
def load_data(path="imdb.pkl", n_words=100000, valid_portion=0.1):
77+
def load_data(path="imdb.pkl", n_words=100000, valid_portion=0.1, maxlen=None):
7878
''' Loads the dataset
7979
80-
:type dataset: string
81-
:param dataset: the path to the dataset (here IMDB)
80+
:type path: String
81+
:param path: The path to the dataset (here IMDB)
82+
:type n_words: int
83+
:param n_words: The number of word to keep in the vocabulary.
84+
All extra words are set to unknow (1).
85+
:type valid_portion: float
86+
:param valid_portion: The proportion of the full train set used for
87+
the validation set.
88+
:type maxlen: None or positive int
89+
:param maxlen: the max sequence length we use in the train/valid set.
8290
'''
8391

8492
#############
@@ -98,6 +106,15 @@ def load_data(path="imdb.pkl", n_words=100000, valid_portion=0.1):
98106
train_set = cPickle.load(f)
99107
test_set = cPickle.load(f)
100108
f.close()
109+
if maxlen:
110+
new_train_set_x = []
111+
new_train_set_y = []
112+
for x, y in zip(train_set[0], train_set[1]):
113+
if len(x) < maxlen:
114+
new_train_set_x.append(x)
115+
new_train_set_y.append(y)
116+
train_set = (new_train_set_x, new_train_set_y)
117+
del new_train_set_x, new_train_set_y
101118

102119
# split training set into validation set
103120
train_set_x, train_set_y = train_set

code/lstm.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -373,22 +373,22 @@ def pred_error(f_pred, prepare_data, data, iterator, verbose=False):
373373
return valid_err
374374

375375

376-
def test_lstm(
376+
def train_lstm(
377377
dim_proj=128, # word embeding dimension and LSTM number of hidden units.
378-
patience=10, # number of epoch to wait before early stop if no progress
378+
patience=10, # Number of epoch to wait before early stop if no progress
379379
max_epochs=5000, # The maximum number of epoch to run
380-
dispFreq=10, # display to stdout the training progress every N updates
381-
decay_c=0., # weight decay for the classifier applied to the U weights.
382-
lrate=0.0001, # learning rate for sgd (not used for adadelta and rmsprop)
383-
n_words=10000, # vocabulary size
384-
optimizer=adadelta, # sgd, adadelta and rmsprop available, sgd very hard to use, not recommanded (probably need momentum and decay learning rate).
380+
dispFreq=10, # Display to stdout the training progress every N updates
381+
decay_c=0., # Weight decay for the classifier applied to the U weights.
382+
lrate=0.0001, # Learning rate for sgd (not used for adadelta and rmsprop)
383+
n_words=10000, # Vocabulary size
384+
optimizer=adadelta, # sgd, adadelta and rmsprop available, sgd very hard to use, not recommanded (probably need momentum and decaying learning rate).
385385
encoder='lstm', # TODO: can be removed must be lstm.
386386
saveto='lstm_model.npz', # The best model will be saved there
387-
validFreq=10000, # after 1000
388-
saveFreq=100000, # save the parameters after every saveFreq updates
389-
maxlen=100, # longer sequence get ignored
390-
batch_size=64, # the batch size during training.
391-
valid_batch_size=64, # The batch size during validation
387+
validFreq=390, # Compute the validation error after this number of update.
388+
saveFreq=1040, # Save the parameters after every saveFreq updates
389+
maxlen=100, # Sequence longer then this get ignored
390+
batch_size=16, # The batch size during training.
391+
valid_batch_size=64, # The batch size used for validation/test set.
392392
dataset='imdb',
393393

394394
# Parameter for extra option
@@ -400,11 +400,13 @@ def test_lstm(
400400

401401
# Model options
402402
model_options = locals().copy()
403+
print "model options", model_options
403404

404405
load_data, prepare_data = get_dataset(dataset)
405406

406407
print 'Loading data'
407-
train, valid, test = load_data(n_words=n_words, valid_portion=0.01)
408+
train, valid, test = load_data(n_words=n_words, valid_portion=0.01,
409+
maxlen=maxlen)
408410

409411
ydim = numpy.max(train[1])+1
410412

@@ -569,7 +571,7 @@ def test_lstm(
569571
theano.config.scan.allow_gc = False
570572

571573
# See function train for all possible parameter and there definition.
572-
test_lstm(
574+
train_lstm(
573575
#reload_model="lstm_model.npz",
574-
max_epochs=10,
576+
max_epochs=100,
575577
)

0 commit comments

Comments
 (0)