Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 26 additions & 5 deletions keras/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,8 +568,10 @@ def _run(self):
for i in sequence:
if self.stop_signal.is_set():
return
future = executor.apply_async(get_index, (self.uid, i))
future.idx = i
self.queue.put(
executor.apply_async(get_index, (self.uid, i)), block=True)
future, block=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the line needs to be broken.


# Done with the current epoch, waiting for the final batches
self._wait_queue()
Expand All @@ -594,8 +596,17 @@ def get(self):
"""
try:
while self.is_running():
inputs = self.queue.get(block=True).get()
self.queue.task_done()
try:
future = self.queue.get(block=True)
inputs = future.get(timeout=30)
self.queue.task_done()
except mp.TimeoutError:
idx = future.idx
warnings.warn(
"The input {} could not be retrieved."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: use ' as quote character for consistency.

Same below.

" It could be because a worker has died.".format(idx),
UserWarning)
continue
if inputs is not None:
yield inputs
except Exception:
Expand Down Expand Up @@ -684,8 +695,18 @@ def get(self):
"""
try:
while self.is_running():
inputs = self.queue.get(block=True).get()
self.queue.task_done()
try:
future = self.queue.get(block=True)
inputs = future.get(timeout=30)
self.queue.task_done()
except mp.TimeoutError:
warnings.warn(
"An input could not be retrieved."
" It could be because a worker has died."
"We do not have any information on the lost sample."
.format(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't believe the format is useful.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, maybe you intended to display the index in the warning?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, but we cannot know the index in a generator.

UserWarning)
continue
if inputs is not None:
yield inputs
except StopIteration:
Expand Down
37 changes: 37 additions & 0 deletions tests/keras/utils/data_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import six
from six.moves.urllib.parse import urljoin
from six.moves.urllib.request import pathname2url
import time
import warnings

from keras.utils import GeneratorEnqueuer
from keras.utils import OrderedEnqueuer
Expand All @@ -33,6 +35,7 @@ def next(x):

def use_spawn(func):
"""Decorator to test both Unix (fork) and Windows (spawn)"""

@six.wraps(func)
def wrapper(*args, **kwargs):
out = func(*args, **kwargs)
Expand All @@ -41,6 +44,7 @@ def wrapper(*args, **kwargs):
func(*args, **kwargs)
mp.set_start_method('fork', force=True)
return out

return wrapper


Expand Down Expand Up @@ -386,5 +390,38 @@ def test_finite_generator_enqueuer_processes():
enqueuer.stop()


@pytest.mark.skipif('TRAVIS_PYTHON_VERSION' in os.environ,
reason='Takes 150s to run')
def test_missing_inputs():
missing_idx = 10

class TimeOutSequence(DummySequence):
def __getitem__(self, item):
if item == missing_idx:
time.sleep(120)
return super(TimeOutSequence, self).__getitem__(item)

enqueuer = GeneratorEnqueuer(create_finite_generator_from_sequence_pcs(
TimeOutSequence([3, 2, 2, 3])), use_multiprocessing=True)
enqueuer.start(3, 10)
gen_output = enqueuer.get()
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
for _ in range(4 * missing_idx):
next(gen_output)
assert 'An input could not be retrieved.' in str(w[-1].message)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.


enqueuer = OrderedEnqueuer(TimeOutSequence([3, 2, 2, 3]),
use_multiprocessing=True)
enqueuer.start(3, 10)
gen_output = enqueuer.get()
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
for _ in range(11):
next(gen_output)
assert "The input {} could not be retrieved.".format(
missing_idx) in str(w[-1].message)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test will fail if other warnings are being emitted. I would recommend using pytest's utility to check that warnings are being triggered correctly: https://docs.pytest.org/en/latest/warnings.html#warns



if __name__ == '__main__':
pytest.main([__file__])