3636from ....cluster .api import ClusterAPI
3737from ....lifecycle .api import LifecycleAPI
3838from ....meta .api import MetaAPI
39- from ....subtask import SubtaskGraph
39+ from ....subtask import Subtask , SubtaskGraph
4040from ....subtask .utils import iter_input_data_keys , iter_output_data
4141from ..api import TaskExecutor , ExecutionChunkResult , register_executor_cls
42+ from .context import RayExecutionContext
4243
4344ray = lazy_import ("ray" )
4445logger = logging .getLogger (__name__ )
@@ -55,7 +56,7 @@ def execute_subtask(
5556 ensure_coverage ()
5657 subtask_chunk_graph = deserialize (* subtask_chunk_graph )
5758 # inputs = [i[1] for i in inputs]
58- context = dict (zip (input_keys , inputs ))
59+ context = RayExecutionContext (zip (input_keys , inputs ))
5960 # optimize chunk graph.
6061 subtask_chunk_graph = optimize (subtask_chunk_graph )
6162 # from data_key to results
@@ -117,7 +118,7 @@ async def create(
117118 session_id : str ,
118119 address : str ,
119120 task ,
120- tile_context ,
121+ tile_context : TileContext ,
121122 ** kwargs
122123 ) -> "TaskExecutor" :
123124 ray_executor = ray .remote (execute_subtask )
@@ -159,14 +160,10 @@ async def execute_subtask_graph(
159160 result_keys = {chunk .key for chunk in chunk_graph .result_chunks }
160161 for subtask in subtask_graph .topological_iter ():
161162 subtask_chunk_graph = subtask .chunk_graph
162- chunk_key_to_data_keys = get_chunk_key_to_data_keys (subtask_chunk_graph )
163- key_to_input = {
164- key : context [key ]
165- for key , _ in iter_input_data_keys (
166- subtask , subtask_chunk_graph , chunk_key_to_data_keys
167- )
168- }
169- output_keys = self ._get_output_keys (subtask_chunk_graph )
163+ key_to_input = await self ._load_subtask_inputs (
164+ stage_id , subtask , subtask_chunk_graph , context
165+ )
166+ output_keys = self ._get_subtask_output_keys (subtask_chunk_graph )
170167 output_meta_keys = result_keys & output_keys
171168 output_count = len (output_keys ) + bool (output_meta_keys )
172169 output_object_refs = self ._ray_executor .options (
@@ -250,8 +247,44 @@ async def get_progress(self) -> float:
250247 async def cancel (self ):
251248 """Cancel execution."""
252249
250+ async def _load_subtask_inputs (
251+ self , stage_id : str , subtask : Subtask , chunk_graph : ChunkGraph , context : Dict
252+ ):
253+ """
254+ Load a dict of input key to object ref of subtask from context.
255+
256+ It updates the context if the input object refs are fetched from
257+ the meta service.
258+ """
259+ key_to_input = {}
260+ key_to_get_meta = {}
261+ chunk_key_to_data_keys = get_chunk_key_to_data_keys (chunk_graph )
262+ for key , _ in iter_input_data_keys (
263+ subtask , chunk_graph , chunk_key_to_data_keys
264+ ):
265+ if key in context :
266+ key_to_input [key ] = context [key ]
267+ else :
268+ key_to_get_meta [key ] = self ._meta_api .get_chunk_meta .delay (
269+ key , fields = ["object_refs" ]
270+ )
271+ if key_to_get_meta :
272+ logger .info (
273+ "Fetch %s metas and update context of stage %s." ,
274+ len (key_to_get_meta ),
275+ stage_id ,
276+ )
277+ meta_list = await self ._meta_api .get_chunk_meta .batch (
278+ * key_to_get_meta .values ()
279+ )
280+ for key , meta in zip (key_to_get_meta .keys (), meta_list ):
281+ object_ref = meta ["object_refs" ][0 ]
282+ key_to_input [key ] = object_ref
283+ context [key ] = object_ref
284+ return key_to_input
285+
253286 @staticmethod
254- def _get_output_keys (chunk_graph ):
287+ def _get_subtask_output_keys (chunk_graph : ChunkGraph ):
255288 output_keys = {}
256289 for chunk in chunk_graph .results :
257290 if isinstance (chunk .op , VirtualOperand ):
0 commit comments