From 8947e8af4b634972fe8cb895ea1270eac1b77643 Mon Sep 17 00:00:00 2001 From: Dref360 Date: Thu, 30 Aug 2018 13:14:57 -0400 Subject: [PATCH] Better UX --- keras/engine/training_generator.py | 9 +++++---- keras/engine/training_utils.py | 14 ++++++++++++++ keras/utils/data_utils.py | 7 +++---- 3 files changed, 22 insertions(+), 8 deletions(-) diff --git a/keras/engine/training_generator.py b/keras/engine/training_generator.py index 733e5b2fe672..9d3aec0c8cff 100644 --- a/keras/engine/training_generator.py +++ b/keras/engine/training_generator.py @@ -7,6 +7,7 @@ import warnings import numpy as np +from .training_utils import iter_sequence_infinite from .. import backend as K from ..utils.data_utils import Sequence from ..utils.data_utils import GeneratorEnqueuer @@ -121,7 +122,7 @@ def fit_generator(model, elif val_gen: val_data = validation_data if isinstance(val_data, Sequence): - val_enqueuer_gen = iter(val_data) + val_enqueuer_gen = iter_sequence_infinite(generator) else: val_enqueuer_gen = val_data else: @@ -160,7 +161,7 @@ def fit_generator(model, output_generator = enqueuer.get() else: if is_sequence: - output_generator = iter(generator) + output_generator = iter_sequence_infinite(generator) else: output_generator = generator @@ -315,7 +316,7 @@ def evaluate_generator(model, generator, output_generator = enqueuer.get() else: if is_sequence: - output_generator = iter(generator) + output_generator = iter_sequence_infinite(generator) else: output_generator = generator @@ -420,7 +421,7 @@ def predict_generator(model, generator, output_generator = enqueuer.get() else: if is_sequence: - output_generator = iter(generator) + output_generator = iter_sequence_infinite(generator) else: output_generator = generator diff --git a/keras/engine/training_utils.py b/keras/engine/training_utils.py index 9b280175e4fc..2f13685b05bc 100644 --- a/keras/engine/training_utils.py +++ b/keras/engine/training_utils.py @@ -575,3 +575,17 @@ def check_num_samples(ins, if hasattr(ins[0], 'shape'): return int(ins[0].shape[0]) return None # Edge case where ins == [static_learning_phase] + + +def iter_sequence_infinite(seq): + """Iterate indefinitely over a Sequence. + + # Arguments + seq: Sequence object + + # Returns + Generator yielding batches. + """ + while True: + for item in seq: + yield item diff --git a/keras/utils/data_utils.py b/keras/utils/data_utils.py index 59553e097cac..776dcfc86d9c 100644 --- a/keras/utils/data_utils.py +++ b/keras/utils/data_utils.py @@ -368,10 +368,9 @@ def on_epoch_end(self): pass def __iter__(self): - """Create an infinite generator that iterate over the Sequence.""" - while True: - for item in (self[i] for i in range(len(self))): - yield item + """Create a generator that iterate over the Sequence.""" + for item in (self[i] for i in range(len(self))): + yield item # Global variables to be shared across processes