forked from The-Pocket/PocketFlow
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_async_flow.py
More file actions
228 lines (185 loc) · 8.52 KB
/
test_async_flow.py
File metadata and controls
228 lines (185 loc) · 8.52 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
import unittest
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from pocketflow import Node, AsyncNode, AsyncFlow
class AsyncNumberNode(AsyncNode):
"""
Simple async node that sets 'current' to a given number.
Demonstrates overriding .process() (sync) and using
post_async() for the async portion.
"""
def __init__(self, number):
super().__init__()
self.number = number
async def prep_async(self, shared_storage):
# Synchronous work is allowed inside an AsyncNode,
# but final 'condition' is determined by post_async().
shared_storage['current'] = self.number
return "set_number"
async def post_async(self, shared_storage, prep_result, proc_result):
# Possibly do asynchronous tasks here
await asyncio.sleep(0.01)
# Return a condition for the flow
return "number_set"
class AsyncIncrementNode(AsyncNode):
"""
Demonstrates incrementing the 'current' value asynchronously.
"""
async def prep_async(self, shared_storage):
shared_storage['current'] = shared_storage.get('current', 0) + 1
return "incremented"
async def post_async(self, shared_storage, prep_result, proc_result):
await asyncio.sleep(0.01) # simulate async I/O
return "done"
class AsyncSignalNode(AsyncNode):
""" An async node that returns a specific signal string from post_async. """
def __init__(self, signal="default_async_signal"):
super().__init__()
self.signal = signal
# No prep needed usually if just signaling
async def prep_async(self, shared_storage):
await asyncio.sleep(0.01) # Simulate async work
async def post_async(self, shared_storage, prep_result, exec_result):
# Store the signal in shared storage for verification
shared_storage['last_async_signal_emitted'] = self.signal
await asyncio.sleep(0.01) # Simulate async work
print(self.signal)
return self.signal # Return the specific action string
class AsyncPathNode(AsyncNode):
""" An async node to indicate which path was taken in the outer flow. """
def __init__(self, path_id):
super().__init__()
self.path_id = path_id
async def prep_async(self, shared_storage):
await asyncio.sleep(0.01) # Simulate async work
shared_storage['async_path_taken'] = self.path_id
# post_async implicitly returns None (for default transition out if needed)
async def post_async(self, shared_storage, prep_result, exec_result):
await asyncio.sleep(0.01)
# Return None by default
class TestAsyncNode(unittest.TestCase):
"""
Test the AsyncNode (and descendants) in isolation (not in a flow).
"""
def test_async_number_node_direct_call(self):
"""
Even though AsyncNumberNode is designed for an async flow,
we can still test it directly by calling run_async().
"""
async def run_node():
node = AsyncNumberNode(42)
shared_storage = {}
condition = await node.run_async(shared_storage)
return shared_storage, condition
shared_storage, condition = asyncio.run(run_node())
self.assertEqual(shared_storage['current'], 42)
self.assertEqual(condition, "number_set")
def test_async_increment_node_direct_call(self):
async def run_node():
node = AsyncIncrementNode()
shared_storage = {'current': 10}
condition = await node.run_async(shared_storage)
return shared_storage, condition
shared_storage, condition = asyncio.run(run_node())
self.assertEqual(shared_storage['current'], 11)
self.assertEqual(condition, "done")
class TestAsyncFlow(unittest.TestCase):
"""
Test how AsyncFlow orchestrates multiple async nodes.
"""
def test_simple_async_flow(self):
"""
Flow:
1) AsyncNumberNode(5) -> sets 'current' to 5
2) AsyncIncrementNode() -> increments 'current' to 6
"""
# Create our nodes
start = AsyncNumberNode(5)
inc_node = AsyncIncrementNode()
# Chain them: start >> inc_node
start - "number_set" >> inc_node
# Create an AsyncFlow with start
flow = AsyncFlow(start)
# We'll run the flow synchronously (which under the hood is asyncio.run())
shared_storage = {}
asyncio.run(flow.run_async(shared_storage))
self.assertEqual(shared_storage['current'], 6)
def test_async_flow_branching(self):
"""
Demonstrate a branching scenario where we return different
conditions. For example, you could have an async node that
returns "go_left" or "go_right" in post_async, but here
we'll keep it simpler for demonstration.
"""
class BranchingAsyncNode(AsyncNode):
def exec(self, data):
value = shared_storage.get("value", 0)
shared_storage["value"] = value
# We'll decide branch based on whether 'value' is positive
return None
async def post_async(self, shared_storage, prep_result, proc_result):
await asyncio.sleep(0.01)
if shared_storage["value"] >= 0:
return "positive_branch"
else:
return "negative_branch"
class PositiveNode(Node):
def exec(self, data):
shared_storage["path"] = "positive"
return None
class NegativeNode(Node):
def exec(self, data):
shared_storage["path"] = "negative"
return None
shared_storage = {"value": 10}
start = BranchingAsyncNode()
positive_node = PositiveNode()
negative_node = NegativeNode()
# Condition-based chaining
start - "positive_branch" >> positive_node
start - "negative_branch" >> negative_node
flow = AsyncFlow(start)
asyncio.run(flow.run_async(shared_storage))
self.assertEqual(shared_storage["path"], "positive",
"Should have taken the positive branch")
def test_async_composition_with_action_propagation(self):
"""
Test AsyncFlow branches based on action from nested AsyncFlow's last node.
"""
async def run_test():
shared_storage = {}
# 1. Define an inner async flow ending with AsyncSignalNode
# Use existing AsyncNumberNode which should return None from post_async implicitly
inner_start_node = AsyncNumberNode(200)
inner_end_node = AsyncSignalNode("async_inner_done") # post_async -> "async_inner_done"
inner_start_node - "number_set" >> inner_end_node
# Inner flow will execute start->end, Flow exec returns "async_inner_done"
inner_flow = AsyncFlow(start=inner_start_node)
# 2. Define target async nodes for the outer flow branches
path_a_node = AsyncPathNode("AsyncA") # post_async -> None
path_b_node = AsyncPathNode("AsyncB") # post_async -> None
# 3. Define the outer async flow starting with the inner async flow
outer_flow = AsyncFlow(start=inner_flow)
# 4. Define branches FROM the inner_flow object based on its returned action
inner_flow - "async_inner_done" >> path_b_node # This path should be taken
inner_flow - "other_action" >> path_a_node # This path should NOT be taken
# 5. Run the outer async flow and capture the last action
# Execution: inner_start -> inner_end -> path_b
last_action_outer = await outer_flow.run_async(shared_storage)
# 6. Return results for assertion
return shared_storage, last_action_outer
# Run the async test function
shared_storage, last_action_outer = asyncio.run(run_test())
# 7. Assert the results
# Check state after inner flow execution
self.assertEqual(shared_storage.get('current'), 200) # From AsyncNumberNode
self.assertEqual(shared_storage.get('last_async_signal_emitted'), "async_inner_done")
# Check that the correct outer path was taken
self.assertEqual(shared_storage.get('async_path_taken'), "AsyncB")
# Check the action returned by the outer flow. The last node executed was
# path_b_node, which returns None from its post_async method.
self.assertIsNone(last_action_outer)
if __name__ == '__main__':
unittest.main()