Skip to content

Commit dc091b0

Browse files
sosoxucfchollet
authored andcommitted
fix image_ocr example steps_per_epoch and validation_steps params (keras-team#7696)
* fix image_ocr example steps_per_epoch and validation_steps params * Update image_ocr.py * Update image_ocr.py * Update image_ocr.py
1 parent f8b0dc2 commit dc091b0

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

examples/image_ocr.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,7 @@ def train(run_name, start_epoch, stop_epoch, img_w):
409409
pool_size = 2
410410
time_dense_size = 32
411411
rnn_size = 512
412+
minibatch_size = 32
412413

413414
if K.image_data_format() == 'channels_first':
414415
input_shape = (1, img_w, img_h)
@@ -420,7 +421,7 @@ def train(run_name, start_epoch, stop_epoch, img_w):
420421

421422
img_gen = TextImageGenerator(monogram_file=os.path.join(fdir, 'wordlist_mono_clean.txt'),
422423
bigram_file=os.path.join(fdir, 'wordlist_bi_clean.txt'),
423-
minibatch_size=32,
424+
minibatch_size=minibatch_size,
424425
img_w=img_w,
425426
img_h=img_h,
426427
downsample_factor=(pool_size ** 2),
@@ -479,9 +480,13 @@ def train(run_name, start_epoch, stop_epoch, img_w):
479480

480481
viz_cb = VizCallback(run_name, test_func, img_gen.next_val())
481482

482-
model.fit_generator(generator=img_gen.next_train(), steps_per_epoch=(words_per_epoch - val_words),
483-
epochs=stop_epoch, validation_data=img_gen.next_val(), validation_steps=val_words,
484-
callbacks=[viz_cb, img_gen], initial_epoch=start_epoch)
483+
model.fit_generator(generator=img_gen.next_train(),
484+
steps_per_epoch=(words_per_epoch - val_words) // minibatch_size,
485+
epochs=stop_epoch,
486+
validation_data=img_gen.next_val(),
487+
validation_steps=val_words // minibatch_size,
488+
callbacks=[viz_cb, img_gen],
489+
initial_epoch=start_epoch)
485490

486491

487492
if __name__ == '__main__':

0 commit comments

Comments
 (0)