@@ -409,6 +409,7 @@ def train(run_name, start_epoch, stop_epoch, img_w):
409
409
pool_size = 2
410
410
time_dense_size = 32
411
411
rnn_size = 512
412
+ minibatch_size = 32
412
413
413
414
if K .image_data_format () == 'channels_first' :
414
415
input_shape = (1 , img_w , img_h )
@@ -420,7 +421,7 @@ def train(run_name, start_epoch, stop_epoch, img_w):
420
421
421
422
img_gen = TextImageGenerator (monogram_file = os .path .join (fdir , 'wordlist_mono_clean.txt' ),
422
423
bigram_file = os .path .join (fdir , 'wordlist_bi_clean.txt' ),
423
- minibatch_size = 32 ,
424
+ minibatch_size = minibatch_size ,
424
425
img_w = img_w ,
425
426
img_h = img_h ,
426
427
downsample_factor = (pool_size ** 2 ),
@@ -479,9 +480,13 @@ def train(run_name, start_epoch, stop_epoch, img_w):
479
480
480
481
viz_cb = VizCallback (run_name , test_func , img_gen .next_val ())
481
482
482
- model .fit_generator (generator = img_gen .next_train (), steps_per_epoch = (words_per_epoch - val_words ),
483
- epochs = stop_epoch , validation_data = img_gen .next_val (), validation_steps = val_words ,
484
- callbacks = [viz_cb , img_gen ], initial_epoch = start_epoch )
483
+ model .fit_generator (generator = img_gen .next_train (),
484
+ steps_per_epoch = (words_per_epoch - val_words ) // minibatch_size ,
485
+ epochs = stop_epoch ,
486
+ validation_data = img_gen .next_val (),
487
+ validation_steps = val_words // minibatch_size ,
488
+ callbacks = [viz_cb , img_gen ],
489
+ initial_epoch = start_epoch )
485
490
486
491
487
492
if __name__ == '__main__' :
0 commit comments