-
Notifications
You must be signed in to change notification settings - Fork 19.6k
[Don't Merge][Need discussion] First-step toward a working mp on Travis #11521
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
452d92b
e906b5d
24cadfc
a57bf5a
3d5e837
dc05b52
e2a257a
1bc99f6
c8fd39c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -568,8 +568,10 @@ def _run(self): | |
| for i in sequence: | ||
| if self.stop_signal.is_set(): | ||
| return | ||
| future = executor.apply_async(get_index, (self.uid, i)) | ||
| future.idx = i | ||
| self.queue.put( | ||
| executor.apply_async(get_index, (self.uid, i)), block=True) | ||
| future, block=True) | ||
|
|
||
| # Done with the current epoch, waiting for the final batches | ||
| self._wait_queue() | ||
|
|
@@ -594,8 +596,17 @@ def get(self): | |
| """ | ||
| try: | ||
| while self.is_running(): | ||
| inputs = self.queue.get(block=True).get() | ||
| self.queue.task_done() | ||
| try: | ||
| future = self.queue.get(block=True) | ||
| inputs = future.get(timeout=30) | ||
| self.queue.task_done() | ||
| except mp.TimeoutError: | ||
| idx = future.idx | ||
| warnings.warn( | ||
| "The input {} could not be retrieved." | ||
|
||
| " It could be because a worker has died.".format(idx), | ||
| UserWarning) | ||
| continue | ||
| if inputs is not None: | ||
| yield inputs | ||
| except Exception: | ||
|
|
@@ -684,8 +695,18 @@ def get(self): | |
| """ | ||
| try: | ||
| while self.is_running(): | ||
| inputs = self.queue.get(block=True).get() | ||
| self.queue.task_done() | ||
| try: | ||
| future = self.queue.get(block=True) | ||
| inputs = future.get(timeout=30) | ||
| self.queue.task_done() | ||
| except mp.TimeoutError: | ||
| warnings.warn( | ||
| "An input could not be retrieved." | ||
| " It could be because a worker has died." | ||
| "We do not have any information on the lost sample." | ||
| .format(), | ||
|
||
| UserWarning) | ||
| continue | ||
| if inputs is not None: | ||
| yield inputs | ||
| except StopIteration: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,6 +13,8 @@ | |
| import six | ||
| from six.moves.urllib.parse import urljoin | ||
| from six.moves.urllib.request import pathname2url | ||
| import time | ||
| import warnings | ||
|
|
||
| from keras.utils import GeneratorEnqueuer | ||
| from keras.utils import OrderedEnqueuer | ||
|
|
@@ -33,6 +35,7 @@ def next(x): | |
|
|
||
| def use_spawn(func): | ||
| """Decorator to test both Unix (fork) and Windows (spawn)""" | ||
|
|
||
| @six.wraps(func) | ||
| def wrapper(*args, **kwargs): | ||
| out = func(*args, **kwargs) | ||
|
|
@@ -41,6 +44,7 @@ def wrapper(*args, **kwargs): | |
| func(*args, **kwargs) | ||
| mp.set_start_method('fork', force=True) | ||
| return out | ||
|
|
||
| return wrapper | ||
|
|
||
|
|
||
|
|
@@ -386,5 +390,38 @@ def test_finite_generator_enqueuer_processes(): | |
| enqueuer.stop() | ||
|
|
||
|
|
||
| @pytest.mark.skipif('TRAVIS_PYTHON_VERSION' in os.environ, | ||
| reason='Takes 150s to run') | ||
| def test_missing_inputs(): | ||
| missing_idx = 10 | ||
|
|
||
| class TimeOutSequence(DummySequence): | ||
| def __getitem__(self, item): | ||
| if item == missing_idx: | ||
| time.sleep(120) | ||
| return super(TimeOutSequence, self).__getitem__(item) | ||
|
|
||
| enqueuer = GeneratorEnqueuer(create_finite_generator_from_sequence_pcs( | ||
| TimeOutSequence([3, 2, 2, 3])), use_multiprocessing=True) | ||
| enqueuer.start(3, 10) | ||
| gen_output = enqueuer.get() | ||
| with warnings.catch_warnings(record=True) as w: | ||
| warnings.simplefilter("always") | ||
| for _ in range(4 * missing_idx): | ||
| next(gen_output) | ||
| assert 'An input could not be retrieved.' in str(w[-1].message) | ||
|
||
|
|
||
| enqueuer = OrderedEnqueuer(TimeOutSequence([3, 2, 2, 3]), | ||
| use_multiprocessing=True) | ||
| enqueuer.start(3, 10) | ||
| gen_output = enqueuer.get() | ||
| with warnings.catch_warnings(record=True) as w: | ||
| warnings.simplefilter("always") | ||
| for _ in range(11): | ||
| next(gen_output) | ||
| assert "The input {} could not be retrieved.".format( | ||
| missing_idx) in str(w[-1].message) | ||
|
||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| pytest.main([__file__]) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think the line needs to be broken.