diff --git a/keras/utils/data_utils.py b/keras/utils/data_utils.py index 6bcc5cf70ff5..82eb431b8081 100644 --- a/keras/utils/data_utils.py +++ b/keras/utils/data_utils.py @@ -568,8 +568,9 @@ def _run(self): for i in sequence: if self.stop_signal.is_set(): return - self.queue.put( - executor.apply_async(get_index, (self.uid, i)), block=True) + future = executor.apply_async(get_index, (self.uid, i)) + future.idx = i + self.queue.put(future, block=True) # Done with the current epoch, waiting for the final batches self._wait_queue() @@ -594,8 +595,17 @@ def get(self): """ try: while self.is_running(): - inputs = self.queue.get(block=True).get() - self.queue.task_done() + try: + future = self.queue.get(block=True) + inputs = future.get(timeout=30) + self.queue.task_done() + except mp.TimeoutError: + idx = future.idx + warnings.warn( + 'The input {} could not be retrieved.' + ' It could be because a worker has died.'.format(idx), + UserWarning) + inputs = self.sequence[idx] if inputs is not None: yield inputs except Exception: @@ -684,8 +694,17 @@ def get(self): """ try: while self.is_running(): - inputs = self.queue.get(block=True).get() - self.queue.task_done() + try: + future = self.queue.get(block=True) + inputs = future.get(timeout=30) + self.queue.task_done() + except mp.TimeoutError: + warnings.warn( + 'An input could not be retrieved.' + ' It could be because a worker has died.' + 'We do not have any information on the lost sample.', + UserWarning) + continue if inputs is not None: yield inputs except StopIteration: diff --git a/tests/keras/engine/test_training.py b/tests/keras/engine/test_training.py index 5a510998fe74..d24ff0a77ce8 100644 --- a/tests/keras/engine/test_training.py +++ b/tests/keras/engine/test_training.py @@ -466,10 +466,11 @@ def gen_data(): epochs=5, initial_epoch=0, validation_data=val_seq, - callbacks=[tracker_cb]) + callbacks=[tracker_cb], + max_queue_size=1) assert trained_epochs == [0, 1, 2, 3, 4] assert trained_batches == list(range(12)) * 5 - assert len(val_seq.logs) == 12 * 5 + assert 12 * 5 <= len(val_seq.logs) <= (12 * 5) + 2 # the queue may be full. # test for workers = 0 trained_epochs = [] diff --git a/tests/keras/legacy/interface_test.py b/tests/keras/legacy/interface_test.py index 0e3ba412d250..f6a98f90c39d 100644 --- a/tests/keras/legacy/interface_test.py +++ b/tests/keras/legacy/interface_test.py @@ -1,7 +1,9 @@ import pytest import json import keras +import keras.backend as K import numpy as np +import os def test_dense_legacy_interface(): @@ -806,6 +808,9 @@ def test_cropping3d_legacy_interface(): assert json.dumps(old_layer.get_config()) == json.dumps(new_layer.get_config()) +@pytest.mark.skipif(K.backend() in {'tensorflow', 'cntk'} + and 'TRAVIS_PYTHON_VERSION' in os.environ, + reason='Generators cannot use `spawn`.') def test_generator_methods_interface(): def train_generator(): x = np.random.randn(2, 2) diff --git a/tests/keras/utils/data_utils_test.py b/tests/keras/utils/data_utils_test.py index 0d0949ce6eff..ff5b32250007 100644 --- a/tests/keras/utils/data_utils_test.py +++ b/tests/keras/utils/data_utils_test.py @@ -13,6 +13,8 @@ import six from six.moves.urllib.parse import urljoin from six.moves.urllib.request import pathname2url +import warnings + from flaky import flaky from keras.utils import GeneratorEnqueuer @@ -24,9 +26,13 @@ from keras import backend as K pytestmark = pytest.mark.skipif( - K.backend() in {'tensorflow', 'cntk'} and 'TRAVIS_PYTHON_VERSION' in os.environ, + six.PY2 and 'TRAVIS_PYTHON_VERSION' in os.environ, reason='Temporarily disabled until the use_multiprocessing problem is solved') +skip_generators = pytest.mark.skipif(K.backend() in {'tensorflow', 'cntk'} and + 'TRAVIS_PYTHON_VERSION' in os.environ, + reason='Generators do not work with `spawn`.') + if sys.version_info < (3,): def next(x): return x.next() @@ -34,14 +40,17 @@ def next(x): def use_spawn(func): """Decorator to test both Unix (fork) and Windows (spawn)""" + @six.wraps(func) def wrapper(*args, **kwargs): - out = func(*args, **kwargs) if sys.version_info > (3, 4): mp.set_start_method('spawn', force=True) - func(*args, **kwargs) + out = func(*args, **kwargs) mp.set_start_method('fork', force=True) + else: + out = func(*args, **kwargs) return out + return wrapper @@ -188,6 +197,7 @@ def test_generator_enqueuer_threads(): enqueuer.stop() +@skip_generators def test_generator_enqueuer_processes(): enqueuer = GeneratorEnqueuer(create_generator_from_sequence_pcs( DummySequence([3, 10, 10, 3])), use_multiprocessing=True) @@ -223,6 +233,7 @@ def test_generator_enqueuer_fail_threads(): next(gen_output) +@skip_generators def test_generator_enqueuer_fail_processes(): enqueuer = GeneratorEnqueuer(create_generator_from_sequence_pcs( FaultSequence()), use_multiprocessing=True) @@ -378,6 +389,7 @@ def test_finite_generator_enqueuer_threads(): enqueuer.stop() +@skip_generators def test_finite_generator_enqueuer_processes(): enqueuer = GeneratorEnqueuer(create_finite_generator_from_sequence_pcs( DummySequence([3, 10, 10, 3])), use_multiprocessing=True) @@ -391,5 +403,34 @@ def test_finite_generator_enqueuer_processes(): enqueuer.stop() +@pytest.mark.skipif('TRAVIS_PYTHON_VERSION' in os.environ, + reason='Takes 150s to run') +def test_missing_inputs(): + missing_idx = 10 + + class TimeOutSequence(DummySequence): + def __getitem__(self, item): + if item == missing_idx: + time.sleep(120) + return super(TimeOutSequence, self).__getitem__(item) + + enqueuer = GeneratorEnqueuer(create_finite_generator_from_sequence_pcs( + TimeOutSequence([3, 2, 2, 3])), use_multiprocessing=True) + enqueuer.start(3, 10) + gen_output = enqueuer.get() + with pytest.warns(UserWarning, match='An input could not be retrieved.'): + for _ in range(4 * missing_idx): + next(gen_output) + + enqueuer = OrderedEnqueuer(TimeOutSequence([3, 2, 2, 3]), + use_multiprocessing=True) + enqueuer.start(3, 10) + gen_output = enqueuer.get() + warning_msg = "The input {} could not be retrieved.".format(missing_idx) + with pytest.warns(UserWarning, match=warning_msg): + for _ in range(11): + next(gen_output) + + if __name__ == '__main__': pytest.main([__file__])