1818datasets = {'imdb' : (imdb .load_data , imdb .prepare_data )}
1919
2020
21- def get_minibatches_idx (n , nb_batches , shuffle = False ):
21+ def get_minibatches_idx (n , minibatch_size , shuffle = False ):
2222 """
2323 Used to shuffle the dataset at each iteration.
2424 """
@@ -30,17 +30,16 @@ def get_minibatches_idx(n, nb_batches, shuffle=False):
3030
3131 minibatches = []
3232 minibatch_start = 0
33- for i in range (nb_batches ):
34- if i < n % nb_batches :
35- minibatch_size = n // nb_batches + 1
36- else :
37- minibatch_size = n // nb_batches
38-
33+ for i in range (n // minibatch_size ):
3934 minibatches .append (idx_list [minibatch_start :
4035 minibatch_start + minibatch_size ])
4136 minibatch_start += minibatch_size
4237
43- return zip (range (nb_batches ), minibatches )
38+ if (minibatch_start != n ):
39+ # Make a minibatch out of what is left
40+ minibatches .append (idx_list [minibatch_start :])
41+
42+ return zip (range (len (minibatches )), minibatches )
4443
4544
4645def get_dataset (name ):
@@ -446,11 +445,9 @@ def test_lstm(
446445
447446 print 'Optimization'
448447
449- kf_valid = get_minibatches_idx (len (valid [0 ]),
450- len (valid [0 ]) / valid_batch_size ,
448+ kf_valid = get_minibatches_idx (len (valid [0 ]), valid_batch_size ,
451449 shuffle = True )
452- kf_test = get_minibatches_idx (len (test [0 ]),
453- len (test [0 ]) / valid_batch_size ,
450+ kf_test = get_minibatches_idx (len (test [0 ]), valid_batch_size ,
454451 shuffle = True )
455452
456453 history_errs = []
@@ -469,8 +466,7 @@ def test_lstm(
469466 n_samples = 0
470467
471468 # Get new shuffled index for the training set.
472- kf = get_minibatches_idx (len (train [0 ]), len (train [0 ])/ batch_size ,
473- shuffle = True )
469+ kf = get_minibatches_idx (len (train [0 ]), batch_size , shuffle = True )
474470
475471 for _ , train_index in kf :
476472 uidx += 1
0 commit comments