Skip to content

Commit 121c36c

Browse files
authored
Add ActivityContext (cadence-workflow#26)
1 parent a5a257c commit 121c36c

File tree

4 files changed

+225
-34
lines changed

4 files changed

+225
-34
lines changed

cadence/_internal/activity/_activity_executor.py

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1-
import asyncio
21
import inspect
32
from concurrent.futures import ThreadPoolExecutor
43
from logging import getLogger
54
from traceback import format_exception
65
from typing import Any, Callable
6+
from google.protobuf.duration import to_timedelta
7+
from google.protobuf.timestamp import to_datetime
78

8-
from cadence._internal.type_utils import get_fn_parameters
9+
from cadence._internal.activity._context import _Context, _SyncContext
10+
from cadence.activity import ActivityInfo
911
from cadence.api.v1.common_pb2 import Failure
1012
from cadence.api.v1.service_worker_pb2 import PollForActivityTaskResponse, RespondActivityTaskFailedRequest, \
1113
RespondActivityTaskCompletedRequest
@@ -19,35 +21,31 @@ def __init__(self, client: Client, task_list: str, identity: str, max_workers: i
1921
self._data_converter = client.data_converter
2022
self._registry = registry
2123
self._identity = identity
24+
self._task_list = task_list
2225
self._thread_pool = ThreadPoolExecutor(max_workers=max_workers,
2326
thread_name_prefix=f'{task_list}-activity-')
2427

2528
async def execute(self, task: PollForActivityTaskResponse):
26-
activity_type = task.activity_type.name
27-
try:
28-
activity_fn = self._registry(activity_type)
29-
except KeyError as e:
30-
_logger.error("Activity type not found.", extra={'activity_type': activity_type})
31-
await self._report_failure(task, e)
32-
return
33-
34-
await self._execute_fn(activity_fn, task)
35-
36-
async def _execute_fn(self, activity_fn: Callable[[Any], Any], task: PollForActivityTaskResponse):
3729
try:
38-
type_hints = get_fn_parameters(activity_fn)
39-
params = await self._client.data_converter.from_data(task.input, type_hints)
40-
if inspect.iscoroutinefunction(activity_fn):
41-
result = await activity_fn(*params)
42-
else:
43-
result = await self._invoke_sync_activity(activity_fn, params)
30+
context = self._create_context(task)
31+
result = await context.execute(task.input)
4432
await self._report_success(task, result)
4533
except Exception as e:
4634
await self._report_failure(task, e)
4735

48-
async def _invoke_sync_activity(self, activity_fn: Callable[[Any], Any], params: list[Any]) -> Any:
49-
loop = asyncio.get_running_loop()
50-
return await loop.run_in_executor(self._thread_pool, activity_fn, *params)
36+
def _create_context(self, task: PollForActivityTaskResponse) -> _Context:
37+
activity_type = task.activity_type.name
38+
try:
39+
activity_fn = self._registry(activity_type)
40+
except KeyError:
41+
raise KeyError(f"Activity type not found: {activity_type}") from None
42+
43+
info = self._create_info(task)
44+
45+
if inspect.iscoroutinefunction(activity_fn):
46+
return _Context(self._client, info, activity_fn)
47+
else:
48+
return _SyncContext(self._client, info, activity_fn, self._thread_pool)
5149

5250
async def _report_failure(self, task: PollForActivityTaskResponse, error: Exception):
5351
try:
@@ -71,6 +69,23 @@ async def _report_success(self, task: PollForActivityTaskResponse, result: Any):
7169
except Exception:
7270
_logger.exception('Exception reporting activity complete')
7371

72+
def _create_info(self, task: PollForActivityTaskResponse) -> ActivityInfo:
73+
return ActivityInfo(
74+
task_token=task.task_token,
75+
workflow_type=task.workflow_type.name,
76+
workflow_domain=task.workflow_domain,
77+
workflow_id=task.workflow_execution.workflow_id,
78+
workflow_run_id=task.workflow_execution.run_id,
79+
activity_id=task.activity_id,
80+
activity_type=task.activity_type.name,
81+
task_list=self._task_list,
82+
heartbeat_timeout=to_timedelta(task.heartbeat_timeout),
83+
scheduled_timestamp=to_datetime(task.scheduled_time),
84+
started_timestamp=to_datetime(task.started_time),
85+
start_to_close_timeout=to_timedelta(task.start_to_close_timeout),
86+
attempt=task.attempt,
87+
)
88+
7489

7590
def _to_failure(exception: Exception) -> Failure:
7691
stacktrace = "".join(format_exception(exception))
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import asyncio
2+
from concurrent.futures.thread import ThreadPoolExecutor
3+
from typing import Callable, Any
4+
5+
from cadence import Client
6+
from cadence._internal.type_utils import get_fn_parameters
7+
from cadence.activity import ActivityInfo, ActivityContext
8+
from cadence.api.v1.common_pb2 import Payload
9+
10+
11+
class _Context(ActivityContext):
12+
def __init__(self, client: Client, info: ActivityInfo, activity_fn: Callable[[Any], Any]):
13+
self._client = client
14+
self._info = info
15+
self._activity_fn = activity_fn
16+
17+
async def execute(self, payload: Payload) -> Any:
18+
params = await self._to_params(payload)
19+
with self._activate():
20+
return await self._activity_fn(*params)
21+
22+
async def _to_params(self, payload: Payload) -> list[Any]:
23+
type_hints = get_fn_parameters(self._activity_fn)
24+
return await self._client.data_converter.from_data(payload, type_hints)
25+
26+
def client(self) -> Client:
27+
return self._client
28+
29+
def info(self) -> ActivityInfo:
30+
return self._info
31+
32+
class _SyncContext(_Context):
33+
def __init__(self, client: Client, info: ActivityInfo, activity_fn: Callable[[Any], Any], executor: ThreadPoolExecutor):
34+
super().__init__(client, info, activity_fn)
35+
self._executor = executor
36+
37+
async def execute(self, payload: Payload) -> Any:
38+
params = await self._to_params(payload)
39+
loop = asyncio.get_running_loop()
40+
return await loop.run_in_executor(self._executor, self._run, params)
41+
42+
def _run(self, args: list[Any]) -> Any:
43+
with self._activate():
44+
return self._activity_fn(*args)
45+
46+
def client(self) -> Client:
47+
raise RuntimeError("client is only supported in async activities")
48+

cadence/activity.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from abc import ABC, abstractmethod
2+
from contextlib import contextmanager
3+
from contextvars import ContextVar
4+
from dataclasses import dataclass
5+
from datetime import timedelta, datetime
6+
from typing import Iterator
7+
8+
from cadence import Client
9+
10+
11+
@dataclass(frozen=True)
12+
class ActivityInfo:
13+
task_token: bytes
14+
workflow_type: str
15+
workflow_domain: str
16+
workflow_id: str
17+
workflow_run_id: str
18+
activity_id: str
19+
activity_type: str
20+
task_list: str
21+
heartbeat_timeout: timedelta
22+
scheduled_timestamp: datetime
23+
started_timestamp: datetime
24+
start_to_close_timeout: timedelta
25+
attempt: int
26+
27+
def client() -> Client:
28+
return ActivityContext.get().client()
29+
30+
def in_activity() -> bool:
31+
return ActivityContext.is_set()
32+
33+
def info() -> ActivityInfo:
34+
return ActivityContext.get().info()
35+
36+
37+
38+
class ActivityContext(ABC):
39+
_var: ContextVar['ActivityContext'] = ContextVar("activity")
40+
41+
@abstractmethod
42+
def info(self) -> ActivityInfo:
43+
...
44+
45+
@abstractmethod
46+
def client(self) -> Client:
47+
...
48+
49+
@contextmanager
50+
def _activate(self) -> Iterator[None]:
51+
token = ActivityContext._var.set(self)
52+
yield None
53+
ActivityContext._var.reset(token)
54+
55+
@staticmethod
56+
def is_set() -> bool:
57+
return ActivityContext._var.get(None) is not None
58+
59+
@staticmethod
60+
def get() -> 'ActivityContext':
61+
return ActivityContext._var.get()

tests/cadence/_internal/activity/test_activity_executor.py

Lines changed: 79 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
import asyncio
2+
from datetime import timedelta, datetime
23
from unittest.mock import Mock, AsyncMock, PropertyMock
34

45
import pytest
6+
from google.protobuf.timestamp_pb2 import Timestamp
7+
from google.protobuf.duration import from_timedelta
58

6-
from cadence import Client
9+
from cadence import activity, Client
710
from cadence._internal.activity import ActivityExecutor
8-
from cadence.api.v1.common_pb2 import WorkflowExecution, ActivityType, Payload, Failure
11+
from cadence.activity import ActivityInfo
12+
from cadence.api.v1.common_pb2 import WorkflowExecution, ActivityType, Payload, Failure, WorkflowType
913
from cadence.api.v1.service_worker_pb2 import RespondActivityTaskCompletedResponse, PollForActivityTaskResponse, \
1014
RespondActivityTaskCompletedRequest, RespondActivityTaskFailedResponse, RespondActivityTaskFailedRequest
1115
from cadence.data_converter import DefaultDataConverter
@@ -19,7 +23,6 @@ def client() -> Client:
1923
return client
2024

2125

22-
@pytest.mark.asyncio
2326
async def test_activity_async_success(client):
2427
worker_stub = client.worker_stub
2528
worker_stub.RespondActivityTaskCompleted = AsyncMock(return_value=RespondActivityTaskCompletedResponse())
@@ -37,7 +40,6 @@ async def activity_fn():
3740
identity='identity',
3841
))
3942

40-
@pytest.mark.asyncio
4143
async def test_activity_async_failure(client):
4244
worker_stub = client.worker_stub
4345
worker_stub.RespondActivityTaskFailed = AsyncMock(return_value=RespondActivityTaskFailedResponse())
@@ -64,7 +66,6 @@ async def activity_fn():
6466
identity='identity',
6567
)
6668

67-
@pytest.mark.asyncio
6869
async def test_activity_args(client):
6970
worker_stub = client.worker_stub
7071
worker_stub.RespondActivityTaskCompleted = AsyncMock(return_value=RespondActivityTaskCompletedResponse())
@@ -82,8 +83,6 @@ async def activity_fn(first: str, second: str):
8283
identity='identity',
8384
))
8485

85-
86-
@pytest.mark.asyncio
8786
async def test_activity_sync_success(client):
8887
worker_stub = client.worker_stub
8988
worker_stub.RespondActivityTaskCompleted = AsyncMock(return_value=RespondActivityTaskCompletedResponse())
@@ -105,7 +104,6 @@ def activity_fn():
105104
identity='identity',
106105
))
107106

108-
@pytest.mark.asyncio
109107
async def test_activity_sync_failure(client):
110108
worker_stub = client.worker_stub
111109
worker_stub.RespondActivityTaskFailed = AsyncMock(return_value=RespondActivityTaskFailedResponse())
@@ -132,7 +130,6 @@ def activity_fn():
132130
identity='identity',
133131
)
134132

135-
@pytest.mark.asyncio
136133
async def test_activity_unknown(client):
137134
worker_stub = client.worker_stub
138135
worker_stub.RespondActivityTaskFailed = AsyncMock(return_value=RespondActivityTaskFailedResponse())
@@ -148,7 +145,7 @@ def registry(name: str):
148145

149146
call = worker_stub.RespondActivityTaskFailed.call_args[0][0]
150147

151-
assert 'unknown activity: any' in call.failure.details.decode()
148+
assert 'Activity type not found: any' in call.failure.details.decode()
152149
call.failure.details = bytes()
153150
assert call == RespondActivityTaskFailedRequest(
154151
task_token=b'task_token',
@@ -158,15 +155,85 @@ def registry(name: str):
158155
identity='identity',
159156
)
160157

158+
async def test_activity_context(client):
159+
worker_stub = client.worker_stub
160+
worker_stub.RespondActivityTaskCompleted = AsyncMock(return_value=RespondActivityTaskCompletedResponse())
161+
162+
async def activity_fn():
163+
assert fake_info("activity_type") == activity.info()
164+
assert activity.in_activity()
165+
assert activity.client() is not None
166+
return "success"
167+
168+
executor = ActivityExecutor(client, 'task_list', 'identity', 1, lambda name: activity_fn)
169+
170+
await executor.execute(fake_task("activity_type", ""))
171+
172+
worker_stub.RespondActivityTaskCompleted.assert_called_once_with(RespondActivityTaskCompletedRequest(
173+
task_token=b'task_token',
174+
result=Payload(data='"success"'.encode()),
175+
identity='identity',
176+
))
177+
178+
async def test_activity_context_sync(client):
179+
worker_stub = client.worker_stub
180+
worker_stub.RespondActivityTaskCompleted = AsyncMock(return_value=RespondActivityTaskCompletedResponse())
181+
182+
def activity_fn():
183+
assert fake_info("activity_type") == activity.info()
184+
assert activity.in_activity()
185+
with pytest.raises(RuntimeError):
186+
activity.client()
187+
return "success"
188+
189+
executor = ActivityExecutor(client, 'task_list', 'identity', 1, lambda name: activity_fn)
190+
191+
await executor.execute(fake_task("activity_type", ""))
192+
193+
worker_stub.RespondActivityTaskCompleted.assert_called_once_with(RespondActivityTaskCompletedRequest(
194+
task_token=b'task_token',
195+
result=Payload(data='"success"'.encode()),
196+
identity='identity',
197+
))
198+
199+
200+
def fake_info(activity_type: str) -> ActivityInfo:
201+
return ActivityInfo(
202+
task_token=b'task_token',
203+
workflow_domain="workflow_domain",
204+
workflow_id="workflow_id",
205+
workflow_run_id="run_id",
206+
activity_id="activity_id",
207+
activity_type=activity_type,
208+
attempt=1,
209+
workflow_type="workflow_type",
210+
task_list="task_list",
211+
heartbeat_timeout=timedelta(seconds=1),
212+
scheduled_timestamp=datetime(2020, 1, 2 ,3),
213+
started_timestamp=datetime(2020, 1, 2 ,4),
214+
start_to_close_timeout=timedelta(seconds=2),
215+
)
216+
161217
def fake_task(activity_type: str, input_json: str) -> PollForActivityTaskResponse:
162218
return PollForActivityTaskResponse(
163219
task_token=b'task_token',
220+
workflow_domain="workflow_domain",
221+
workflow_type=WorkflowType(name="workflow_type"),
164222
workflow_execution=WorkflowExecution(
165223
workflow_id="workflow_id",
166224
run_id="run_id",
167225
),
168226
activity_id="activity_id",
169227
activity_type=ActivityType(name=activity_type),
170228
input=Payload(data=input_json.encode()),
171-
attempt=0,
172-
)
229+
attempt=1,
230+
heartbeat_timeout=from_timedelta(timedelta(seconds=1)),
231+
scheduled_time=from_datetime(datetime(2020, 1, 2, 3)),
232+
started_time=from_datetime(datetime(2020, 1, 2, 4)),
233+
start_to_close_timeout=from_timedelta(timedelta(seconds=2)),
234+
)
235+
236+
def from_datetime(time: datetime) -> Timestamp:
237+
t = Timestamp()
238+
t.FromDatetime(time)
239+
return t

0 commit comments

Comments
 (0)