Skip to content

Commit 1213216

Browse files
authored
Add initial ActivityExecutor (cadence-workflow#20)
1 parent 1110e1c commit 1213216

File tree

13 files changed

+397
-21
lines changed

13 files changed

+397
-21
lines changed
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
2+
3+
from ._activity_executor import (
4+
ActivityExecutor
5+
)
6+
7+
__all__ = [
8+
"ActivityExecutor",
9+
]
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import asyncio
2+
import inspect
3+
from concurrent.futures import ThreadPoolExecutor
4+
from logging import getLogger
5+
from traceback import format_exception
6+
from typing import Any, Callable
7+
8+
from cadence._internal.type_utils import get_fn_parameters
9+
from cadence.api.v1.common_pb2 import Failure
10+
from cadence.api.v1.service_worker_pb2 import PollForActivityTaskResponse, RespondActivityTaskFailedRequest, \
11+
RespondActivityTaskCompletedRequest
12+
from cadence.client import Client
13+
14+
_logger = getLogger(__name__)
15+
16+
class ActivityExecutor:
17+
def __init__(self, client: Client, task_list: str, identity: str, max_workers: int, registry: Callable[[str], Callable]):
18+
self._client = client
19+
self._data_converter = client.data_converter
20+
self._registry = registry
21+
self._identity = identity
22+
self._thread_pool = ThreadPoolExecutor(max_workers=max_workers,
23+
thread_name_prefix=f'{task_list}-activity-')
24+
25+
async def execute(self, task: PollForActivityTaskResponse):
26+
activity_type = task.activity_type.name
27+
try:
28+
activity_fn = self._registry(activity_type)
29+
except KeyError as e:
30+
_logger.error("Activity type not found.", extra={'activity_type': activity_type})
31+
await self._report_failure(task, e)
32+
return
33+
34+
await self._execute_fn(activity_fn, task)
35+
36+
async def _execute_fn(self, activity_fn: Callable[[Any], Any], task: PollForActivityTaskResponse):
37+
try:
38+
type_hints = get_fn_parameters(activity_fn)
39+
params = await self._client.data_converter.from_data(task.input, type_hints)
40+
if inspect.iscoroutinefunction(activity_fn):
41+
result = await activity_fn(*params)
42+
else:
43+
result = await self._invoke_sync_activity(activity_fn, params)
44+
await self._report_success(task, result)
45+
except Exception as e:
46+
await self._report_failure(task, e)
47+
48+
async def _invoke_sync_activity(self, activity_fn: Callable[[Any], Any], params: list[Any]) -> Any:
49+
loop = asyncio.get_running_loop()
50+
return await loop.run_in_executor(self._thread_pool, activity_fn, *params)
51+
52+
async def _report_failure(self, task: PollForActivityTaskResponse, error: Exception):
53+
try:
54+
await self._client.worker_stub.RespondActivityTaskFailed(RespondActivityTaskFailedRequest(
55+
task_token=task.task_token,
56+
failure=_to_failure(error),
57+
identity=self._identity,
58+
))
59+
except Exception:
60+
_logger.exception('Exception reporting activity failure')
61+
62+
async def _report_success(self, task: PollForActivityTaskResponse, result: Any):
63+
as_payload = await self._data_converter.to_data(result)
64+
65+
try:
66+
await self._client.worker_stub.RespondActivityTaskCompleted(RespondActivityTaskCompletedRequest(
67+
task_token=task.task_token,
68+
result=as_payload,
69+
identity=self._identity,
70+
))
71+
except Exception:
72+
_logger.exception('Exception reporting activity complete')
73+
74+
75+
def _to_failure(exception: Exception) -> Failure:
76+
stacktrace = "".join(format_exception(exception))
77+
78+
return Failure(
79+
reason=type(exception).__name__,
80+
details=stacktrace.encode(),
81+
)

cadence/_internal/type_utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from inspect import signature, Parameter
2+
from typing import Callable, List, Type, get_type_hints
3+
4+
def get_fn_parameters(fn: Callable) -> List[Type | None]:
5+
args = signature(fn).parameters
6+
hints = get_type_hints(fn)
7+
result = []
8+
for name, param in args.items():
9+
if param.kind in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD):
10+
type_hint = hints.get(name, None)
11+
result.append(type_hint)
12+
13+
return result
14+
15+
def validate_fn_parameters(fn: Callable) -> None:
16+
args = signature(fn).parameters
17+
for name, param in args.items():
18+
if param.kind not in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD):
19+
raise ValueError(f"Parameters must be positional. {name} is {param.kind}, and not valid")

cadence/client.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,13 @@
55
from cadence.api.v1.service_worker_pb2_grpc import WorkerAPIStub
66
from grpc.aio import Channel
77

8+
from cadence.data_converter import DataConverter
9+
810

911
class ClientOptions(TypedDict, total=False):
1012
domain: str
1113
identity: str
14+
data_converter: DataConverter
1215

1316
class Client:
1417
def __init__(self, channel: Channel, options: ClientOptions) -> None:
@@ -17,6 +20,9 @@ def __init__(self, channel: Channel, options: ClientOptions) -> None:
1720
self._options = options
1821
self._identity = options["identity"] if "identity" in options else f"{os.getpid()}@{socket.gethostname()}"
1922

23+
@property
24+
def data_converter(self) -> DataConverter:
25+
return self._options["data_converter"]
2026

2127
@property
2228
def domain(self) -> str:

cadence/data_converter.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
class DataConverter(Protocol):
1010

1111
@abstractmethod
12-
async def from_data(self, payload: Payload, type_hints: List[Type]) -> List[Any]:
12+
async def from_data(self, payload: Payload, type_hints: List[Type | None]) -> List[Any]:
1313
raise NotImplementedError()
1414

1515
@abstractmethod
@@ -23,7 +23,10 @@ def __init__(self) -> None:
2323
self._fallback_decoder = JSONDecoder(strict=False)
2424

2525

26-
async def from_data(self, payload: Payload, type_hints: List[Type]) -> List[Any]:
26+
async def from_data(self, payload: Payload, type_hints: List[Type | None]) -> List[Any]:
27+
if not payload.data:
28+
return DefaultDataConverter._convert_into([], type_hints)
29+
2730
if len(type_hints) > 1:
2831
payload_str = payload.data.decode()
2932
# Handle payloads from the Go client, which are a series of json objects rather than a json array
@@ -37,7 +40,7 @@ async def from_data(self, payload: Payload, type_hints: List[Type]) -> List[Any]
3740
return DefaultDataConverter._convert_into([as_value], type_hints)
3841

3942

40-
def _decode_whitespace_delimited(self, payload: str, type_hints: List[Type]) -> List[Any]:
43+
def _decode_whitespace_delimited(self, payload: str, type_hints: List[Type | None]) -> List[Any]:
4144
results: List[Any] = []
4245
start, end = 0, len(payload)
4346
while start < end and len(results) < len(type_hints):
@@ -49,10 +52,12 @@ def _decode_whitespace_delimited(self, payload: str, type_hints: List[Type]) ->
4952
return DefaultDataConverter._convert_into(results, type_hints)
5053

5154
@staticmethod
52-
def _convert_into(values: List[Any], type_hints: List[Type]) -> List[Any]:
55+
def _convert_into(values: List[Any], type_hints: List[Type | None]) -> List[Any]:
5356
results: List[Any] = []
5457
for i, type_hint in enumerate(type_hints):
55-
if i < len(values):
58+
if not type_hint:
59+
value = values[i]
60+
elif i < len(values):
5661
value = convert(values[i], type_hint)
5762
else:
5863
value = DefaultDataConverter._get_default(type_hint)

cadence/sample/client_example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from cadence.client import Client, ClientOptions
66
from cadence._internal.rpc.metadata import MetadataInterceptor
7-
from cadence.worker import Worker
7+
from cadence.worker import Worker, Registry
88

99

1010
async def main():
@@ -15,7 +15,7 @@ async def main():
1515
metadata["rpc-caller"] = "nate"
1616
async with insecure_channel("localhost:7833", interceptors=[MetadataInterceptor(metadata)]) as channel:
1717
client = Client(channel, ClientOptions(domain="foo"))
18-
worker = Worker(client, "task_list")
18+
worker = Worker(client, "task_list", Registry())
1919
await worker.run()
2020

2121
if __name__ == '__main__':

cadence/worker/_activity.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,23 @@
11
import asyncio
22
from typing import Optional
33

4-
from cadence.api.v1.common_pb2 import Failure
5-
from cadence.api.v1.service_worker_pb2 import PollForActivityTaskResponse, PollForActivityTaskRequest, \
6-
RespondActivityTaskFailedRequest
4+
from cadence._internal.activity import ActivityExecutor
5+
from cadence.api.v1.service_worker_pb2 import PollForActivityTaskResponse, PollForActivityTaskRequest
76
from cadence.api.v1.tasklist_pb2 import TaskList, TaskListKind
87
from cadence.client import Client
8+
from cadence.worker._registry import Registry
99
from cadence.worker._types import WorkerOptions, _LONG_POLL_TIMEOUT
1010
from cadence.worker._poller import Poller
1111

1212

1313
class ActivityWorker:
14-
def __init__(self, client: Client, task_list: str, options: WorkerOptions) -> None:
14+
def __init__(self, client: Client, task_list: str, registry: Registry, options: WorkerOptions) -> None:
1515
self._client = client
1616
self._task_list = task_list
1717
self._identity = options["identity"]
18-
permits = asyncio.Semaphore(options["max_concurrent_activity_execution_size"])
18+
max_concurrent = options["max_concurrent_activity_execution_size"]
19+
permits = asyncio.Semaphore(max_concurrent)
20+
self._executor = ActivityExecutor(self._client, self._task_list, options["identity"], max_concurrent, registry.get_activity)
1921
self._poller = Poller[PollForActivityTaskResponse](options["activity_task_pollers"], permits, self._poll, self._execute)
2022
# TODO: Local dispatch, local activities, actually running activities, etc
2123

@@ -35,9 +37,5 @@ async def _poll(self) -> Optional[PollForActivityTaskResponse]:
3537
return None
3638

3739
async def _execute(self, task: PollForActivityTaskResponse) -> None:
38-
await self._client.worker_stub.RespondActivityTaskFailed(RespondActivityTaskFailedRequest(
39-
task_token=task.task_token,
40-
identity=self._identity,
41-
failure=Failure(reason='error', details=b'not implemented'),
42-
))
40+
await self._executor.execute(task)
4341

cadence/worker/_registry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import logging
1010
from typing import Callable, Dict, Optional, Unpack, TypedDict
11+
from cadence._internal.type_utils import validate_fn_parameters
1112

1213

1314
logger = logging.getLogger(__name__)
@@ -107,6 +108,7 @@ def activity(
107108
options = RegisterActivityOptions(**kwargs)
108109

109110
def decorator(f: Callable) -> Callable:
111+
validate_fn_parameters(f)
110112
activity_name = options.get('name') or f.__name__
111113

112114
if activity_name in self._activities:

cadence/worker/_worker.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,22 @@
33
from typing import Unpack, cast
44

55
from cadence.client import Client
6+
from cadence.worker._registry import Registry
67
from cadence.worker._activity import ActivityWorker
78
from cadence.worker._decision import DecisionWorker
89
from cadence.worker._types import WorkerOptions, _DEFAULT_WORKER_OPTIONS
910

1011

1112
class Worker:
1213

13-
def __init__(self, client: Client, task_list: str, **kwargs: Unpack[WorkerOptions]) -> None:
14+
def __init__(self, client: Client, task_list: str, registry: Registry, **kwargs: Unpack[WorkerOptions]) -> None:
1415
self._client = client
1516
self._task_list = task_list
1617

1718
options = WorkerOptions(**kwargs)
1819
_validate_and_copy_defaults(client, task_list, options)
1920
self._options = options
20-
self._activity_worker = ActivityWorker(client, task_list, options)
21+
self._activity_worker = ActivityWorker(client, task_list, registry, options)
2122
self._decision_worker = DecisionWorker(client, task_list, options)
2223

2324

0 commit comments

Comments
 (0)