@@ -74,8 +74,9 @@ def get_dataset_file(dataset, default_dataset, origin):
7474 return dataset
7575
7676
77- def load_data (path = "imdb.pkl" , n_words = 100000 , valid_portion = 0.1 , maxlen = None ):
78- ''' Loads the dataset
77+ def load_data (path = "imdb.pkl" , n_words = 100000 , valid_portion = 0.1 , maxlen = None ,
78+ sort_by_len = True ):
79+ '''Loads the dataset
7980
8081 :type path: String
8182 :param path: The path to the dataset (here IMDB)
@@ -87,6 +88,12 @@ def load_data(path="imdb.pkl", n_words=100000, valid_portion=0.1, maxlen=None):
8788 the validation set.
8889 :type maxlen: None or positive int
8990 :param maxlen: the max sequence length we use in the train/valid set.
91+ :type sort_by_len: bool
92+ :name sort_by_len: Sort by the sequence lenght for the train,
93+ valid and test set. This allow faster execution as it cause
94+ less padding per minibatch. Another mechanism must be used to
95+ shuffle the train set at each epoch.
96+
9097 '''
9198
9299 #############
@@ -140,6 +147,22 @@ def remove_unk(x):
140147 valid_set_x = remove_unk (valid_set_x )
141148 test_set_x = remove_unk (test_set_x )
142149
150+ def len_argsort (seq ):
151+ return sorted (range (len (seq )), key = lambda x : len (seq [x ]))
152+
153+ if sort_by_len :
154+ sorted_index = len_argsort (test_set_x )
155+ test_set_x = [test_set_x [i ] for i in sorted_index ]
156+ test_set_y = [test_set_y [i ] for i in sorted_index ]
157+
158+ sorted_index = len_argsort (valid_set_x )
159+ valid_set_x = [valid_set_x [i ] for i in sorted_index ]
160+ valid_set_y = [valid_set_y [i ] for i in sorted_index ]
161+
162+ sorted_index = len_argsort (train_set_x )
163+ train_set_x = [train_set_x [i ] for i in sorted_index ]
164+ train_set_y = [train_set_y [i ] for i in sorted_index ]
165+
143166 train = (train_set_x , train_set_y )
144167 valid = (valid_set_x , valid_set_y )
145168 test = (test_set_x , test_set_y )
0 commit comments