Skip to content

Commit f702803

Browse files
authored
Fix timeout for wait_task (mars-project#2883)
1 parent eddcc3b commit f702803

File tree

5 files changed

+126
-19
lines changed

5 files changed

+126
-19
lines changed

mars/services/task/api/web.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import asyncio
1516
import base64
1617
import json
1718
from typing import Callable, List, Optional, Union
@@ -152,8 +153,18 @@ async def wait_task(self, session_id: str, task_id: str):
152153
timeout = self.get_argument("timeout", None) or None
153154
timeout = float(timeout) if timeout is not None else None
154155
oscar_api = await self._get_oscar_task_api(session_id)
155-
res = await oscar_api.wait_task(task_id, timeout)
156-
self.write(json.dumps(_json_serial_task_result(res)))
156+
if timeout:
157+
try:
158+
res = await asyncio.wait_for(
159+
asyncio.shield(oscar_api.wait_task(task_id, timeout)),
160+
timeout=timeout,
161+
)
162+
self.write(json.dumps(_json_serial_task_result(res)))
163+
except asyncio.TimeoutError:
164+
self.write(json.dumps({}))
165+
else:
166+
res = await oscar_api.wait_task(task_id, timeout)
167+
self.write(json.dumps(_json_serial_task_result(res)))
157168

158169
@web_api("(?P<task_id>[^/]+)", method="delete")
159170
async def cancel_task(self, session_id: str, task_id: str):

mars/services/task/supervisor/manager.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from ..core import Task, new_task_id, TaskStatus
3535
from ..errors import TaskNotExist
3636
from .preprocessor import TaskPreprocessor
37-
from .processor import TaskProcessorActor
37+
from .processor import TaskProcessorActor, TaskProcessor
3838

3939
logger = logging.getLogger(__name__)
4040

@@ -43,15 +43,18 @@ class TaskConfigurationActor(mo.Actor):
4343
def __init__(
4444
self,
4545
task_conf: Dict[str, Any],
46+
task_processor_cls: Type[TaskProcessor] = None,
4647
task_preprocessor_cls: Type[TaskPreprocessor] = None,
4748
):
4849
for name, value in task_conf.items():
4950
setattr(task_options, name, value)
51+
self._task_processor_cls = task_processor_cls
5052
self._task_preprocessor_cls = task_preprocessor_cls
5153

5254
def get_config(self):
5355
return {
5456
"task_options": task_options,
57+
"task_processor_cls": self._task_processor_cls,
5558
"task_preprocessor_cls": self._task_preprocessor_cls,
5659
}
5760

@@ -77,6 +80,7 @@ def __init__(self, session_id: str):
7780
self._session_id = session_id
7881

7982
self._config = None
83+
self._task_processor_cls = None
8084
self._task_preprocessor_cls = None
8185
self._last_idle_time = None
8286

@@ -104,9 +108,10 @@ async def __post_create__(self):
104108
TaskConfigurationActor.default_uid(), address=self.address
105109
)
106110
task_conf = await configuration_ref.get_config()
107-
self._config, self._task_preprocessor_cls = (
111+
self._config, self._task_preprocessor_cls, self._task_processor_cls = (
108112
task_conf["task_options"],
109113
task_conf["task_preprocessor_cls"],
114+
task_conf["task_processor_cls"],
110115
)
111116
self._task_preprocessor_cls = self._get_task_preprocessor_cls()
112117

@@ -159,6 +164,7 @@ async def submit_tileable_graph(
159164
self._session_id,
160165
parent_task_id,
161166
task_name=task_name,
167+
task_processor_cls=self._task_processor_cls,
162168
address=self.address,
163169
uid=uid,
164170
)

mars/services/task/supervisor/processor.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import asyncio
16+
import importlib
1617
import itertools
1718
import logging
1819
import operator
@@ -525,11 +526,18 @@ class TaskProcessorActor(mo.Actor):
525526
_processed_task_ids: Set[str]
526527
_cur_processor: Optional[TaskProcessor]
527528

528-
def __init__(self, session_id: str, task_id: str, task_name: str = None):
529+
def __init__(
530+
self,
531+
session_id: str,
532+
task_id: str,
533+
task_name: str = None,
534+
task_processor_cls: Type[TaskPreprocessor] = None,
535+
):
529536
self.session_id = session_id
530537
self.task_id = task_id
531538
self.task_name = task_name
532539

540+
self._task_processor_cls = self._get_task_processor_cls(task_processor_cls)
533541
self._task_id_to_processor = dict()
534542
self._cur_processor = None
535543
self._subtask_decref_events = dict()
@@ -559,7 +567,7 @@ async def add_task(
559567
task_preprocessor = task_preprocessor_cls(
560568
task, tiled_context=tiled_context, config=config
561569
)
562-
processor = TaskProcessor(
570+
processor = self._task_processor_cls(
563571
task,
564572
task_preprocessor,
565573
self._cluster_api,
@@ -572,6 +580,15 @@ async def add_task(
572580
# tell self to start running
573581
await self.ref().start.tell()
574582

583+
@classmethod
584+
def _get_task_processor_cls(cls, task_processor_cls):
585+
if task_processor_cls is not None:
586+
assert isinstance(task_processor_cls, str)
587+
module, name = task_processor_cls.rsplit(".", 1)
588+
return getattr(importlib.import_module(module), name)
589+
else:
590+
return TaskProcessor
591+
575592
def _get_unprocessed_task_processor(self):
576593
for processor in self._task_id_to_processor.values():
577594
if processor.result.status == TaskStatus.pending:

mars/services/task/supervisor/service.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,12 @@ class TaskSupervisorService(AbstractService):
3737
async def start(self):
3838
task_config = self._config.get("task", dict())
3939
options = task_config.get("default_config", dict())
40+
task_processor_cls = task_config.get("task_processor_cls")
4041
task_preprocessor_cls = task_config.get("task_preprocessor_cls")
4142
await mo.create_actor(
4243
TaskConfigurationActor,
4344
options,
45+
task_processor_cls=task_processor_cls,
4446
task_preprocessor_cls=task_preprocessor_cls,
4547
address=self._address,
4648
uid=TaskConfigurationActor.default_uid(),

mars/services/task/tests/test_service.py

Lines changed: 84 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from ...web import WebActor
3333
from ...meta import MetaAPI
3434
from .. import TaskAPI, TaskStatus, WebTaskAPI
35+
from ..supervisor.processor import TaskProcessor
3536
from ..errors import TaskNotExist
3637

3738

@@ -45,7 +46,7 @@ async def start_pool(is_worker: bool):
4546
subprocess_start_method="spawn",
4647
)
4748
else:
48-
kw = dict(n_process=0, subprocess_start_method="spawn")
49+
kw = dict(n_process=1, subprocess_start_method="spawn")
4950
pool = await mo.create_actor_pool("127.0.0.1", **kw)
5051
await pool.start()
5152
return pool
@@ -57,11 +58,9 @@ async def start_pool(is_worker: bool):
5758
await asyncio.gather(sv_pool.stop(), worker_pool.stop())
5859

5960

60-
@pytest.mark.parametrize(indirect=True)
61-
@pytest.fixture(params=[False, True])
62-
async def start_test_service(actor_pools, request):
63-
sv_pool, worker_pool = actor_pools
64-
61+
async def _start_services(
62+
supervisor_pool, worker_pool, request, task_processor_cls=None
63+
):
6564
config = {
6665
"services": [
6766
"cluster",
@@ -75,37 +74,49 @@ async def start_test_service(actor_pools, request):
7574
],
7675
"cluster": {
7776
"backend": "fixed",
78-
"lookup_address": sv_pool.external_address,
77+
"lookup_address": supervisor_pool.external_address,
7978
"resource": {"numa-0": Resource(num_cpus=2)},
8079
},
8180
"meta": {"store": "dict"},
8281
"scheduling": {},
8382
"task": {},
8483
}
84+
if task_processor_cls:
85+
config["task"]["task_processor_cls"] = task_processor_cls
8586
if request:
8687
config["services"].append("web")
87-
88-
await start_services(NodeRole.SUPERVISOR, config, address=sv_pool.external_address)
88+
await start_services(
89+
NodeRole.SUPERVISOR, config, address=supervisor_pool.external_address
90+
)
8991
await start_services(NodeRole.WORKER, config, address=worker_pool.external_address)
9092

9193
session_id = "test_session"
92-
session_api = await SessionAPI.create(sv_pool.external_address)
94+
session_api = await SessionAPI.create(supervisor_pool.external_address)
9395
await session_api.create_session(session_id)
9496

9597
if not request.param:
96-
task_api = await TaskAPI.create(session_id, sv_pool.external_address)
98+
task_api = await TaskAPI.create(session_id, supervisor_pool.external_address)
9799
else:
98100
web_actor = await mo.actor_ref(
99-
WebActor.default_uid(), address=sv_pool.external_address
101+
WebActor.default_uid(), address=supervisor_pool.external_address
100102
)
101103
web_address = await web_actor.get_web_address()
102104
task_api = WebTaskAPI(session_id, web_address)
103105

104106
assert await task_api.get_task_results() == []
105107

106108
# create mock meta and storage APIs
107-
_ = await MetaAPI.create(session_id, sv_pool.external_address)
109+
_ = await MetaAPI.create(session_id, supervisor_pool.external_address)
108110
storage_api = await MockStorageAPI.create(session_id, worker_pool.external_address)
111+
return task_api, storage_api, config
112+
113+
114+
@pytest.mark.parametrize(indirect=True)
115+
@pytest.fixture(params=[False, True])
116+
async def start_test_service(actor_pools, request):
117+
sv_pool, worker_pool = actor_pools
118+
119+
task_api, storage_api, config = await _start_services(sv_pool, worker_pool, request)
109120

110121
try:
111122
yield sv_pool.external_address, task_api, storage_api
@@ -115,6 +126,66 @@ async def start_test_service(actor_pools, request):
115126
await stop_services(NodeRole.SUPERVISOR, config, sv_pool.external_address)
116127

117128

129+
class MockTaskProcessor(TaskProcessor):
130+
@classmethod
131+
def _get_decref_stage_chunk_keys(cls, stage_processor):
132+
import time
133+
134+
# time.sleep to block async thread
135+
time.sleep(5)
136+
return super()._get_decref_stage_chunk_keys(stage_processor)
137+
138+
139+
@pytest.mark.parametrize(indirect=True)
140+
@pytest.fixture(params=[True])
141+
async def start_test_service_with_mock(actor_pools, request):
142+
sv_pool, worker_pool = actor_pools
143+
144+
task_api, storage_api, config = await _start_services(
145+
sv_pool,
146+
worker_pool,
147+
request,
148+
task_processor_cls="mars.services.task.tests.test_service.MockTaskProcessor",
149+
)
150+
151+
try:
152+
yield sv_pool.external_address, task_api, storage_api
153+
finally:
154+
await MockStorageAPI.cleanup(worker_pool.external_address)
155+
await stop_services(NodeRole.WORKER, config, worker_pool.external_address)
156+
await stop_services(NodeRole.SUPERVISOR, config, sv_pool.external_address)
157+
158+
159+
@pytest.mark.asyncio
160+
async def test_task_timeout_execution(start_test_service_with_mock):
161+
_sv_pool_address, task_api, storage_api = start_test_service_with_mock
162+
163+
def f1():
164+
return np.arange(5)
165+
166+
def f2():
167+
return np.arange(5, 10)
168+
169+
def f3(f1r, f2r):
170+
return np.concatenate([f1r, f2r]).sum()
171+
172+
r1 = mr.spawn(f1)
173+
r2 = mr.spawn(f2)
174+
r3 = mr.spawn(f3, args=(r1, r2))
175+
176+
graph = TileableGraph([r3.data])
177+
next(TileableGraphBuilder(graph).build())
178+
179+
task_id = await task_api.submit_tileable_graph(graph, fuse_enabled=False)
180+
assert await task_api.get_last_idle_time() is None
181+
assert isinstance(task_id, str)
182+
183+
await task_api.wait_task(task_id, timeout=2)
184+
task_result = await task_api.get_task_result(task_id)
185+
186+
assert task_result.status == TaskStatus.terminated
187+
188+
118189
@pytest.mark.asyncio
119190
async def test_task_execution(start_test_service):
120191
_sv_pool_address, task_api, storage_api = start_test_service

0 commit comments

Comments
 (0)