Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions keras/engine/training_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import warnings
import numpy as np

from .training_utils import iter_sequence_infinite
from .. import backend as K
from ..utils.data_utils import Sequence
from ..utils.data_utils import GeneratorEnqueuer
Expand Down Expand Up @@ -121,7 +122,7 @@ def fit_generator(model,
elif val_gen:
val_data = validation_data
if isinstance(val_data, Sequence):
val_enqueuer_gen = iter(val_data)
val_enqueuer_gen = iter_sequence_infinite(generator)
else:
val_enqueuer_gen = val_data
else:
Expand Down Expand Up @@ -160,7 +161,7 @@ def fit_generator(model,
output_generator = enqueuer.get()
else:
if is_sequence:
output_generator = iter(generator)
output_generator = iter_sequence_infinite(generator)
else:
output_generator = generator

Expand Down Expand Up @@ -315,7 +316,7 @@ def evaluate_generator(model, generator,
output_generator = enqueuer.get()
else:
if is_sequence:
output_generator = iter(generator)
output_generator = iter_sequence_infinite(generator)
else:
output_generator = generator

Expand Down Expand Up @@ -420,7 +421,7 @@ def predict_generator(model, generator,
output_generator = enqueuer.get()
else:
if is_sequence:
output_generator = iter(generator)
output_generator = iter_sequence_infinite(generator)
else:
output_generator = generator

Expand Down
14 changes: 14 additions & 0 deletions keras/engine/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,3 +575,17 @@ def check_num_samples(ins,
if hasattr(ins[0], 'shape'):
return int(ins[0].shape[0])
return None # Edge case where ins == [static_learning_phase]


def iter_sequence_infinite(seq):
"""Iterate indefinitely over a Sequence.

# Arguments
seq: Sequence object

# Returns
Generator yielding batches.
"""
while True:
for item in seq:
yield item
7 changes: 3 additions & 4 deletions keras/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,10 +368,9 @@ def on_epoch_end(self):
pass

def __iter__(self):
"""Create an infinite generator that iterate over the Sequence."""
while True:
for item in (self[i] for i in range(len(self))):
yield item
"""Create a generator that iterate over the Sequence."""
for item in (self[i] for i in range(len(self))):
yield item


# Global variables to be shared across processes
Expand Down