Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 25 additions & 6 deletions keras/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,8 +568,9 @@ def _run(self):
for i in sequence:
if self.stop_signal.is_set():
return
self.queue.put(
executor.apply_async(get_index, (self.uid, i)), block=True)
future = executor.apply_async(get_index, (self.uid, i))
future.idx = i
self.queue.put(future, block=True)

# Done with the current epoch, waiting for the final batches
self._wait_queue()
Expand All @@ -594,8 +595,17 @@ def get(self):
"""
try:
while self.is_running():
inputs = self.queue.get(block=True).get()
self.queue.task_done()
try:
future = self.queue.get(block=True)
inputs = future.get(timeout=30)
self.queue.task_done()
except mp.TimeoutError:
idx = future.idx
warnings.warn(
'The input {} could not be retrieved.'
' It could be because a worker has died.'.format(idx),
UserWarning)
inputs = self.sequence[idx]
if inputs is not None:
yield inputs
except Exception:
Expand Down Expand Up @@ -684,8 +694,17 @@ def get(self):
"""
try:
while self.is_running():
inputs = self.queue.get(block=True).get()
self.queue.task_done()
try:
future = self.queue.get(block=True)
inputs = future.get(timeout=30)
self.queue.task_done()
except mp.TimeoutError:
warnings.warn(
'An input could not be retrieved.'
' It could be because a worker has died.'
'We do not have any information on the lost sample.',
UserWarning)
continue
if inputs is not None:
yield inputs
except StopIteration:
Expand Down
5 changes: 3 additions & 2 deletions tests/keras/engine/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,10 +466,11 @@ def gen_data():
epochs=5,
initial_epoch=0,
validation_data=val_seq,
callbacks=[tracker_cb])
callbacks=[tracker_cb],
max_queue_size=1)
assert trained_epochs == [0, 1, 2, 3, 4]
assert trained_batches == list(range(12)) * 5
assert len(val_seq.logs) == 12 * 5
assert 12 * 5 <= len(val_seq.logs) <= (12 * 5) + 2 # the queue may be full.

# test for workers = 0
trained_epochs = []
Expand Down
5 changes: 5 additions & 0 deletions tests/keras/legacy/interface_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import pytest
import json
import keras
import keras.backend as K
import numpy as np
import os


def test_dense_legacy_interface():
Expand Down Expand Up @@ -806,6 +808,9 @@ def test_cropping3d_legacy_interface():
assert json.dumps(old_layer.get_config()) == json.dumps(new_layer.get_config())


@pytest.mark.skipif(K.backend() in {'tensorflow', 'cntk'}
and 'TRAVIS_PYTHON_VERSION' in os.environ,
reason='Generators cannot use `spawn`.')
def test_generator_methods_interface():
def train_generator():
x = np.random.randn(2, 2)
Expand Down
47 changes: 44 additions & 3 deletions tests/keras/utils/data_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import six
from six.moves.urllib.parse import urljoin
from six.moves.urllib.request import pathname2url
import warnings

from flaky import flaky

from keras.utils import GeneratorEnqueuer
Expand All @@ -24,24 +26,31 @@
from keras import backend as K

pytestmark = pytest.mark.skipif(
K.backend() in {'tensorflow', 'cntk'} and 'TRAVIS_PYTHON_VERSION' in os.environ,
six.PY2 and 'TRAVIS_PYTHON_VERSION' in os.environ,
reason='Temporarily disabled until the use_multiprocessing problem is solved')

skip_generators = pytest.mark.skipif(K.backend() in {'tensorflow', 'cntk'} and
'TRAVIS_PYTHON_VERSION' in os.environ,
reason='Generators do not work with `spawn`.')

if sys.version_info < (3,):
def next(x):
return x.next()


def use_spawn(func):
"""Decorator to test both Unix (fork) and Windows (spawn)"""

@six.wraps(func)
def wrapper(*args, **kwargs):
out = func(*args, **kwargs)
if sys.version_info > (3, 4):
mp.set_start_method('spawn', force=True)
func(*args, **kwargs)
out = func(*args, **kwargs)
mp.set_start_method('fork', force=True)
else:
out = func(*args, **kwargs)
return out

return wrapper


Expand Down Expand Up @@ -188,6 +197,7 @@ def test_generator_enqueuer_threads():
enqueuer.stop()


@skip_generators
def test_generator_enqueuer_processes():
enqueuer = GeneratorEnqueuer(create_generator_from_sequence_pcs(
DummySequence([3, 10, 10, 3])), use_multiprocessing=True)
Expand Down Expand Up @@ -223,6 +233,7 @@ def test_generator_enqueuer_fail_threads():
next(gen_output)


@skip_generators
def test_generator_enqueuer_fail_processes():
enqueuer = GeneratorEnqueuer(create_generator_from_sequence_pcs(
FaultSequence()), use_multiprocessing=True)
Expand Down Expand Up @@ -378,6 +389,7 @@ def test_finite_generator_enqueuer_threads():
enqueuer.stop()


@skip_generators
def test_finite_generator_enqueuer_processes():
enqueuer = GeneratorEnqueuer(create_finite_generator_from_sequence_pcs(
DummySequence([3, 10, 10, 3])), use_multiprocessing=True)
Expand All @@ -391,5 +403,34 @@ def test_finite_generator_enqueuer_processes():
enqueuer.stop()


@pytest.mark.skipif('TRAVIS_PYTHON_VERSION' in os.environ,
reason='Takes 150s to run')
def test_missing_inputs():
missing_idx = 10

class TimeOutSequence(DummySequence):
def __getitem__(self, item):
if item == missing_idx:
time.sleep(120)
return super(TimeOutSequence, self).__getitem__(item)

enqueuer = GeneratorEnqueuer(create_finite_generator_from_sequence_pcs(
TimeOutSequence([3, 2, 2, 3])), use_multiprocessing=True)
enqueuer.start(3, 10)
gen_output = enqueuer.get()
with pytest.warns(UserWarning, match='An input could not be retrieved.'):
for _ in range(4 * missing_idx):
next(gen_output)

enqueuer = OrderedEnqueuer(TimeOutSequence([3, 2, 2, 3]),
use_multiprocessing=True)
enqueuer.start(3, 10)
gen_output = enqueuer.get()
warning_msg = "The input {} could not be retrieved.".format(missing_idx)
with pytest.warns(UserWarning, match=warning_msg):
for _ in range(11):
next(gen_output)


if __name__ == '__main__':
pytest.main([__file__])