Skip to content

Commit 5cff118

Browse files
authored
Add execution API to enable custimization of Mars Task Service (mars-project#2894)
1 parent dc93f88 commit 5cff118

File tree

12 files changed

+804
-592
lines changed

12 files changed

+804
-592
lines changed

mars/services/subtask/worker/tests/test_subtask.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ async def actor_pool():
7878
await mo.create_actor(
7979
TaskConfigurationActor,
8080
dict(),
81+
dict(),
8182
uid=TaskConfigurationActor.default_uid(),
8283
address=pool.external_address,
8384
)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright 1999-2021 Alibaba Group Holding Ltd.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .mars import *
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Copyright 1999-2021 Alibaba Group Holding Ltd.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from abc import ABC, abstractmethod
16+
from dataclasses import dataclass
17+
from typing import List, Dict, Any, Type
18+
19+
from ....core import ChunkGraph
20+
from ....typing import BandType
21+
from ...subtask import SubtaskGraph, SubtaskResult
22+
23+
24+
@dataclass
25+
class ExecutionChunkResult:
26+
key: str # The chunk key for fetching the result.
27+
meta: Dict # The chunk meta for iterative tiling.
28+
context: Any # The context info, e.g. ray.ObjectRef.
29+
30+
31+
class TaskExecutor(ABC):
32+
name = None
33+
34+
@classmethod
35+
@abstractmethod
36+
async def create(
37+
cls, config: Dict, *, session_id: str, address: str, task, **kwargs
38+
) -> "TaskExecutor":
39+
name = config.get("backend", "mars")
40+
backend_config = config.get(name, {})
41+
executor_cls = _name_to_task_executor_cls[name]
42+
if executor_cls.create.__func__ is TaskExecutor.create.__func__:
43+
raise NotImplementedError(
44+
f"The {executor_cls} should implement the abstract classmethod `create`."
45+
)
46+
return await executor_cls.create(
47+
backend_config, session_id=session_id, address=address, task=task, **kwargs
48+
)
49+
50+
async def __aenter__(self):
51+
"""Called when begin to execute the task."""
52+
53+
@abstractmethod
54+
async def execute_subtask_graph(
55+
self,
56+
stage_id: str,
57+
subtask_graph: SubtaskGraph,
58+
chunk_graph: ChunkGraph,
59+
context: Any = None,
60+
) -> List[ExecutionChunkResult]:
61+
"""Execute a subtask graph and returns result."""
62+
63+
async def __aexit__(self, exc_type, exc_val, exc_tb):
64+
"""Called when finish the task."""
65+
66+
@abstractmethod
67+
async def get_available_band_slots(self) -> Dict[BandType, int]:
68+
"""Get available band slots."""
69+
70+
@abstractmethod
71+
async def get_progress(self) -> float:
72+
"""Get the execution progress."""
73+
74+
@abstractmethod
75+
async def cancel(self):
76+
"""Cancel execution."""
77+
78+
# The following APIs are for compatible with mars backend, they
79+
# will be removed as soon as possible.
80+
async def set_subtask_result(self, subtask_result: SubtaskResult):
81+
"""Set the subtask result."""
82+
83+
def get_stage_processors(self):
84+
"""Get stage processors."""
85+
86+
87+
_name_to_task_executor_cls: Dict[str, Type[TaskExecutor]] = {}
88+
89+
90+
def register_executor_cls(executor_cls: Type[TaskExecutor]):
91+
_name_to_task_executor_cls[executor_cls.name] = executor_cls
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright 1999-2021 Alibaba Group Holding Ltd.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .executor import MarsTaskExecutor

0 commit comments

Comments
 (0)