3232from ...web import WebActor
3333from ...meta import MetaAPI
3434from .. import TaskAPI , TaskStatus , WebTaskAPI
35+ from ..supervisor .processor import TaskProcessor
3536from ..errors import TaskNotExist
3637
3738
@@ -45,7 +46,7 @@ async def start_pool(is_worker: bool):
4546 subprocess_start_method = "spawn" ,
4647 )
4748 else :
48- kw = dict (n_process = 0 , subprocess_start_method = "spawn" )
49+ kw = dict (n_process = 1 , subprocess_start_method = "spawn" )
4950 pool = await mo .create_actor_pool ("127.0.0.1" , ** kw )
5051 await pool .start ()
5152 return pool
@@ -57,11 +58,9 @@ async def start_pool(is_worker: bool):
5758 await asyncio .gather (sv_pool .stop (), worker_pool .stop ())
5859
5960
60- @pytest .mark .parametrize (indirect = True )
61- @pytest .fixture (params = [False , True ])
62- async def start_test_service (actor_pools , request ):
63- sv_pool , worker_pool = actor_pools
64-
61+ async def _start_services (
62+ supervisor_pool , worker_pool , request , task_processor_cls = None
63+ ):
6564 config = {
6665 "services" : [
6766 "cluster" ,
@@ -75,37 +74,49 @@ async def start_test_service(actor_pools, request):
7574 ],
7675 "cluster" : {
7776 "backend" : "fixed" ,
78- "lookup_address" : sv_pool .external_address ,
77+ "lookup_address" : supervisor_pool .external_address ,
7978 "resource" : {"numa-0" : Resource (num_cpus = 2 )},
8079 },
8180 "meta" : {"store" : "dict" },
8281 "scheduling" : {},
8382 "task" : {},
8483 }
84+ if task_processor_cls :
85+ config ["task" ]["task_processor_cls" ] = task_processor_cls
8586 if request :
8687 config ["services" ].append ("web" )
87-
88- await start_services (NodeRole .SUPERVISOR , config , address = sv_pool .external_address )
88+ await start_services (
89+ NodeRole .SUPERVISOR , config , address = supervisor_pool .external_address
90+ )
8991 await start_services (NodeRole .WORKER , config , address = worker_pool .external_address )
9092
9193 session_id = "test_session"
92- session_api = await SessionAPI .create (sv_pool .external_address )
94+ session_api = await SessionAPI .create (supervisor_pool .external_address )
9395 await session_api .create_session (session_id )
9496
9597 if not request .param :
96- task_api = await TaskAPI .create (session_id , sv_pool .external_address )
98+ task_api = await TaskAPI .create (session_id , supervisor_pool .external_address )
9799 else :
98100 web_actor = await mo .actor_ref (
99- WebActor .default_uid (), address = sv_pool .external_address
101+ WebActor .default_uid (), address = supervisor_pool .external_address
100102 )
101103 web_address = await web_actor .get_web_address ()
102104 task_api = WebTaskAPI (session_id , web_address )
103105
104106 assert await task_api .get_task_results () == []
105107
106108 # create mock meta and storage APIs
107- _ = await MetaAPI .create (session_id , sv_pool .external_address )
109+ _ = await MetaAPI .create (session_id , supervisor_pool .external_address )
108110 storage_api = await MockStorageAPI .create (session_id , worker_pool .external_address )
111+ return task_api , storage_api , config
112+
113+
114+ @pytest .mark .parametrize (indirect = True )
115+ @pytest .fixture (params = [False , True ])
116+ async def start_test_service (actor_pools , request ):
117+ sv_pool , worker_pool = actor_pools
118+
119+ task_api , storage_api , config = await _start_services (sv_pool , worker_pool , request )
109120
110121 try :
111122 yield sv_pool .external_address , task_api , storage_api
@@ -115,6 +126,66 @@ async def start_test_service(actor_pools, request):
115126 await stop_services (NodeRole .SUPERVISOR , config , sv_pool .external_address )
116127
117128
129+ class MockTaskProcessor (TaskProcessor ):
130+ @classmethod
131+ def _get_decref_stage_chunk_keys (cls , stage_processor ):
132+ import time
133+
134+ # time.sleep to block async thread
135+ time .sleep (5 )
136+ return super ()._get_decref_stage_chunk_keys (stage_processor )
137+
138+
139+ @pytest .mark .parametrize (indirect = True )
140+ @pytest .fixture (params = [True ])
141+ async def start_test_service_with_mock (actor_pools , request ):
142+ sv_pool , worker_pool = actor_pools
143+
144+ task_api , storage_api , config = await _start_services (
145+ sv_pool ,
146+ worker_pool ,
147+ request ,
148+ task_processor_cls = "mars.services.task.tests.test_service.MockTaskProcessor" ,
149+ )
150+
151+ try :
152+ yield sv_pool .external_address , task_api , storage_api
153+ finally :
154+ await MockStorageAPI .cleanup (worker_pool .external_address )
155+ await stop_services (NodeRole .WORKER , config , worker_pool .external_address )
156+ await stop_services (NodeRole .SUPERVISOR , config , sv_pool .external_address )
157+
158+
159+ @pytest .mark .asyncio
160+ async def test_task_timeout_execution (start_test_service_with_mock ):
161+ _sv_pool_address , task_api , storage_api = start_test_service_with_mock
162+
163+ def f1 ():
164+ return np .arange (5 )
165+
166+ def f2 ():
167+ return np .arange (5 , 10 )
168+
169+ def f3 (f1r , f2r ):
170+ return np .concatenate ([f1r , f2r ]).sum ()
171+
172+ r1 = mr .spawn (f1 )
173+ r2 = mr .spawn (f2 )
174+ r3 = mr .spawn (f3 , args = (r1 , r2 ))
175+
176+ graph = TileableGraph ([r3 .data ])
177+ next (TileableGraphBuilder (graph ).build ())
178+
179+ task_id = await task_api .submit_tileable_graph (graph , fuse_enabled = False )
180+ assert await task_api .get_last_idle_time () is None
181+ assert isinstance (task_id , str )
182+
183+ await task_api .wait_task (task_id , timeout = 2 )
184+ task_result = await task_api .get_task_result (task_id )
185+
186+ assert task_result .status == TaskStatus .terminated
187+
188+
118189@pytest .mark .asyncio
119190async def test_task_execution (start_test_service ):
120191 _sv_pool_address , task_api , storage_api = start_test_service
0 commit comments