Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Test with spawn
  • Loading branch information
Dref360 committed Nov 27, 2018
commit 3d5e837d444761629e42f8964e65d1f3959f8351
16 changes: 7 additions & 9 deletions keras/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,8 +570,7 @@ def _run(self):
return
future = executor.apply_async(get_index, (self.uid, i))
future.idx = i
self.queue.put(
future, block=True)
self.queue.put(future, block=True)

# Done with the current epoch, waiting for the final batches
self._wait_queue()
Expand Down Expand Up @@ -603,10 +602,10 @@ def get(self):
except mp.TimeoutError:
idx = future.idx
warnings.warn(
"The input {} could not be retrieved."
" It could be because a worker has died.".format(idx),
'The input {} could not be retrieved.'
' It could be because a worker has died.'.format(idx),
UserWarning)
continue
inputs = self.sequence[idx]
if inputs is not None:
yield inputs
except Exception:
Expand Down Expand Up @@ -701,10 +700,9 @@ def get(self):
self.queue.task_done()
except mp.TimeoutError:
warnings.warn(
"An input could not be retrieved."
" It could be because a worker has died."
"We do not have any information on the lost sample."
.format(),
'An input could not be retrieved.'
' It could be because a worker has died.'
'We do not have any information on the lost sample.',
UserWarning)
continue
if inputs is not None:
Expand Down
14 changes: 11 additions & 3 deletions tests/keras/utils/data_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,13 @@
from keras import backend as K

pytestmark = pytest.mark.skipif(
K.backend() == 'tensorflow' and 'TRAVIS_PYTHON_VERSION' in os.environ,
six.PY2 and 'TRAVIS_PYTHON_VERSION' in os.environ,
reason='Temporarily disabled until the use_multiprocessing problem is solved')

skip_generators = pytest.mark.skipif(K.backend() in ['tensorflow', 'cntk'] and
'TRAVIS_PYTHON_VERSION' in os.environ,
reason='Generators do not work with `spawn`.')

if sys.version_info < (3,):
def next(x):
return x.next()
Expand All @@ -38,11 +42,12 @@ def use_spawn(func):

@six.wraps(func)
def wrapper(*args, **kwargs):
out = func(*args, **kwargs)
if sys.version_info > (3, 4):
mp.set_start_method('spawn', force=True)
func(*args, **kwargs)
out = func(*args, **kwargs)
mp.set_start_method('fork', force=True)
else:
out = func(*args, **kwargs)
return out

return wrapper
Expand Down Expand Up @@ -191,6 +196,7 @@ def test_generator_enqueuer_threads():
enqueuer.stop()


@skip_generators
def test_generator_enqueuer_processes():
enqueuer = GeneratorEnqueuer(create_generator_from_sequence_pcs(
DummySequence([3, 10, 10, 3])), use_multiprocessing=True)
Expand Down Expand Up @@ -224,6 +230,7 @@ def test_generator_enqueuer_fail_threads():
next(gen_output)


@skip_generators
def test_generator_enqueuer_fail_processes():
enqueuer = GeneratorEnqueuer(create_generator_from_sequence_pcs(
FaultSequence()), use_multiprocessing=True)
Expand Down Expand Up @@ -377,6 +384,7 @@ def test_finite_generator_enqueuer_threads():
enqueuer.stop()


@skip_generators
def test_finite_generator_enqueuer_processes():
enqueuer = GeneratorEnqueuer(create_finite_generator_from_sequence_pcs(
DummySequence([3, 10, 10, 3])), use_multiprocessing=True)
Expand Down
6 changes: 3 additions & 3 deletions tests/test_multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
from keras.utils import Sequence
from keras import backend as K

pytestmark = pytest.mark.skipif(
K.backend() == 'tensorflow' and 'TRAVIS_PYTHON_VERSION' in os.environ,
reason='Temporarily disabled until the use_multiprocessing problem is solved')
pytestmark = pytest.mark.skipif('TRAVIS_PYTHON_VERSION' in os.environ,
reason='Temporarily disabled until'
'the use_multiprocessing problem is solved')

STEPS_PER_EPOCH = 100
STEPS = 100
Expand Down