Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Improve type-check for Sequence
  • Loading branch information
Dref360 committed Oct 23, 2018
commit 452d92b8ec106630affa564b9add4d807a3e4089
40 changes: 21 additions & 19 deletions keras/engine/training_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import warnings
import numpy as np

from .training_utils import is_sequence
from .training_utils import iter_sequence_infinite
from .. import backend as K
from ..utils.data_utils import Sequence
Expand Down Expand Up @@ -40,15 +41,15 @@ def fit_generator(model,
if do_validation:
model._make_test_function()

is_sequence = isinstance(generator, Sequence)
if not is_sequence and use_multiprocessing and workers > 1:
use_sequence_api = is_sequence(generator)
if not use_sequence_api and use_multiprocessing and workers > 1:
warnings.warn(
UserWarning('Using a generator with `use_multiprocessing=True`'
' and multiple workers may duplicate your data.'
' Please consider using the`keras.utils.Sequence'
' class.'))
if steps_per_epoch is None:
if is_sequence:
if use_sequence_api:
steps_per_epoch = len(generator)
else:
raise ValueError('`steps_per_epoch=None` is only valid for a'
Expand All @@ -59,10 +60,11 @@ def fit_generator(model,

# python 2 has 'next', 3 has '__next__'
# avoid any explicit version checks
val_use_sequence_api = is_sequence(validation_data)
val_gen = (hasattr(validation_data, 'next') or
hasattr(validation_data, '__next__') or
isinstance(validation_data, Sequence))
if (val_gen and not isinstance(validation_data, Sequence) and
val_use_sequence_api)
if (val_gen and not val_use_sequence_api and
not validation_steps):
raise ValueError('`validation_steps=None` is only valid for a'
' generator based on the `keras.utils.Sequence`'
Expand Down Expand Up @@ -108,7 +110,7 @@ def fit_generator(model,
if val_gen and workers > 0:
# Create an Enqueuer that can be reused
val_data = validation_data
if isinstance(val_data, Sequence):
if is_sequence(val_data):
val_enqueuer = OrderedEnqueuer(
val_data,
use_multiprocessing=use_multiprocessing)
Expand All @@ -122,7 +124,7 @@ def fit_generator(model,
val_enqueuer_gen = val_enqueuer.get()
elif val_gen:
val_data = validation_data
if isinstance(val_data, Sequence):
if is_sequence(val_data):
val_enqueuer_gen = iter_sequence_infinite(val_data)
validation_steps = validation_steps or len(val_data)
else:
Expand All @@ -149,7 +151,7 @@ def fit_generator(model,
cbk.validation_data = val_data

if workers > 0:
if is_sequence:
if use_sequence_api:
enqueuer = OrderedEnqueuer(
generator,
use_multiprocessing=use_multiprocessing,
Expand All @@ -161,7 +163,7 @@ def fit_generator(model,
enqueuer.start(workers=workers, max_queue_size=max_queue_size)
output_generator = enqueuer.get()
else:
if is_sequence:
if use_sequence_api:
output_generator = iter_sequence_infinite(generator)
else:
output_generator = generator
Expand Down Expand Up @@ -284,15 +286,15 @@ def evaluate_generator(model, generator,
steps_done = 0
outs_per_batch = []
batch_sizes = []
is_sequence = isinstance(generator, Sequence)
if not is_sequence and use_multiprocessing and workers > 1:
use_sequence_api = is_sequence(generator)
if not use_sequence_api and use_multiprocessing and workers > 1:
warnings.warn(
UserWarning('Using a generator with `use_multiprocessing=True`'
' and multiple workers may duplicate your data.'
' Please consider using the`keras.utils.Sequence'
' class.'))
if steps is None:
if is_sequence:
if use_sequence_api:
steps = len(generator)
else:
raise ValueError('`steps=None` is only valid for a generator'
Expand All @@ -303,7 +305,7 @@ def evaluate_generator(model, generator,

try:
if workers > 0:
if is_sequence:
if use_sequence_api:
enqueuer = OrderedEnqueuer(
generator,
use_multiprocessing=use_multiprocessing)
Expand All @@ -314,7 +316,7 @@ def evaluate_generator(model, generator,
enqueuer.start(workers=workers, max_queue_size=max_queue_size)
output_generator = enqueuer.get()
else:
if is_sequence:
if use_sequence_api:
output_generator = iter_sequence_infinite(generator)
else:
output_generator = generator
Expand Down Expand Up @@ -387,15 +389,15 @@ def predict_generator(model, generator,

steps_done = 0
all_outs = []
is_sequence = isinstance(generator, Sequence)
if not is_sequence and use_multiprocessing and workers > 1:
use_sequence_api = is_sequence(generator)
if not use_sequence_api and use_multiprocessing and workers > 1:
warnings.warn(
UserWarning('Using a generator with `use_multiprocessing=True`'
' and multiple workers may duplicate your data.'
' Please consider using the`keras.utils.Sequence'
' class.'))
if steps is None:
if is_sequence:
if use_sequence_api:
steps = len(generator)
else:
raise ValueError('`steps=None` is only valid for a generator'
Expand All @@ -406,7 +408,7 @@ def predict_generator(model, generator,

try:
if workers > 0:
if is_sequence:
if use_sequence_api:
enqueuer = OrderedEnqueuer(
generator,
use_multiprocessing=use_multiprocessing)
Expand All @@ -417,7 +419,7 @@ def predict_generator(model, generator,
enqueuer.start(workers=workers, max_queue_size=max_queue_size)
output_generator = enqueuer.get()
else:
if is_sequence:
if use_sequence_api:
output_generator = iter_sequence_infinite(generator)
else:
output_generator = generator
Expand Down
15 changes: 15 additions & 0 deletions keras/engine/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from .. import backend as K
from .. import losses
from ..utils import Sequence
from ..utils.generic_utils import to_list


Expand Down Expand Up @@ -589,3 +590,17 @@ def iter_sequence_infinite(seq):
while True:
for item in seq:
yield item


def is_sequence(seq):
"""Determine if an object follows the Sequence API.

# Arguments
seq: a possible Sequence object

# Returns
boolean, whether the object follows the Sequence API.
"""
# TODO Dref360: Decide which pattern to follow. First needs a new TF Version.
return (getattr(seq, 'use_sequence_api', False)
or set(dir(Sequence())).issubset(set(dir(seq) + ['use_sequence_api'])))
2 changes: 2 additions & 0 deletions keras/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,8 @@ def __getitem__(self, idx):
```
"""

use_sequence_api = True

@abstractmethod
def __getitem__(self, index):
"""Gets batch at position `index`.
Expand Down
35 changes: 35 additions & 0 deletions tests/integration_tests/test_image_data_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
import pytest

from keras.preprocessing.image import ImageDataGenerator
from keras.utils.test_utils import get_test_data
from keras.models import Sequential
from keras import layers
Expand Down Expand Up @@ -41,5 +42,39 @@ def test_image_classification():
model = Sequential.from_config(config)


def test_image_data_generator_training():
np.random.seed(1337)
img_gen = ImageDataGenerator(rescale=1.) # Dummy ImageDataGenerator
input_shape = (16, 16, 3)
(x_train, y_train), (x_test, y_test) = get_test_data(num_train=500,
num_test=200,
input_shape=input_shape,
classification=True,
num_classes=4)
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)

model = Sequential([
layers.Conv2D(filters=8, kernel_size=3,
activation='relu',
input_shape=input_shape),
layers.MaxPooling2D(pool_size=2),
layers.Conv2D(filters=4, kernel_size=(3, 3),
activation='relu', padding='same'),
layers.GlobalAveragePooling2D(),
layers.Dense(y_test.shape[-1], activation='softmax')
])
model.compile(loss='categorical_crossentropy',
optimizer='rmsprop',
metrics=['accuracy'])
history = model.fit_generator(img_gen.flow(x_train, y_train, batch_size=16),
epochs=10,
validation_data=img_gen.flow(x_test, y_test,
batch_size=16),
verbose=0)
assert history.history['val_acc'][-1] > 0.75
model.evaluate_generator(img_gen.flow(x_train, y_train, batch_size=16))


if __name__ == '__main__':
pytest.main([__file__])
2 changes: 1 addition & 1 deletion tests/keras/utils/data_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from keras import backend as K

pytestmark = pytest.mark.skipif(
K.backend() == 'tensorflow',
K.backend() == 'tensorflow' and 'TRAVIS_PYTHON_VERSION' in os.environ,
reason='Temporarily disabled until the use_multiprocessing problem is solved')

if sys.version_info < (3,):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from keras import backend as K

pytestmark = pytest.mark.skipif(
K.backend() == 'tensorflow',
K.backend() == 'tensorflow' and 'TRAVIS_PYTHON_VERSION' in os.environ,
reason='Temporarily disabled until the use_multiprocessing problem is solved')

STEPS_PER_EPOCH = 100
Expand Down