forked from The-Pocket/PocketFlow
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_flow_basic.py
More file actions
224 lines (190 loc) · 8.8 KB
/
test_flow_basic.py
File metadata and controls
224 lines (190 loc) · 8.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
# tests/test_flow_basic.py
import unittest
import sys
from pathlib import Path
import warnings
sys.path.insert(0, str(Path(__file__).parent.parent))
from pocketflow import Node, Flow
# --- Node Definitions ---
# Nodes intended for default transitions (>>) should NOT return a specific
# action string from post. Let it return None by default.
# Nodes intended for conditional transitions (-) MUST return the action string.
class NumberNode(Node):
def __init__(self, number):
super().__init__()
self.number = number
def prep(self, shared_storage):
shared_storage['current'] = self.number
# post implicitly returns None - used for default transition
class AddNode(Node):
def __init__(self, number):
super().__init__()
self.number = number
def prep(self, shared_storage):
shared_storage['current'] += self.number
# post implicitly returns None - used for default transition
class MultiplyNode(Node):
def __init__(self, number):
super().__init__()
self.number = number
def prep(self, shared_storage):
shared_storage['current'] *= self.number
# post implicitly returns None - used for default transition
class CheckPositiveNode(Node):
# This node IS designed for conditional branching
def prep(self, shared_storage):
pass
def post(self, shared_storage, prep_result, proc_result):
# MUST return the specific action string for branching
if shared_storage['current'] >= 0:
return 'positive'
else:
return 'negative'
class NoOpNode(Node):
# Just a placeholder node
pass # post implicitly returns None
class EndSignalNode(Node):
# A node specifically to return a value when it's the end
def __init__(self, signal="finished"):
super().__init__()
self.signal = signal
def post(self, shared_storage, prep_result, exec_result):
return self.signal # Return a specific signal
# --- Test Class ---
class TestFlowBasic(unittest.TestCase):
def test_start_method_initialization(self):
"""Test initializing flow with start() after creation."""
shared_storage = {}
n1 = NumberNode(5)
pipeline = Flow()
pipeline.start(n1)
last_action = pipeline.run(shared_storage)
self.assertEqual(shared_storage['current'], 5)
# NumberNode.post returns None (default)
self.assertIsNone(last_action)
def test_start_method_chaining(self):
"""Test fluent chaining using start().next()..."""
shared_storage = {}
pipeline = Flow()
# Chain: NumberNode -> AddNode -> MultiplyNode
# All use default transitions (post returns None)
pipeline.start(NumberNode(5)).next(AddNode(3)).next(MultiplyNode(2))
last_action = pipeline.run(shared_storage)
self.assertEqual(shared_storage['current'], 16)
# Last node (MultiplyNode) post returns None
self.assertIsNone(last_action)
def test_sequence_with_rshift(self):
"""Test a simple linear pipeline using >>"""
shared_storage = {}
n1 = NumberNode(5)
n2 = AddNode(3)
n3 = MultiplyNode(2)
pipeline = Flow()
# All default transitions (post returns None)
pipeline.start(n1) >> n2 >> n3
last_action = pipeline.run(shared_storage)
self.assertEqual(shared_storage['current'], 16)
# Last node (n3: MultiplyNode) post returns None
self.assertIsNone(last_action)
def test_branching_positive(self):
"""Test positive branch: CheckPositiveNode returns 'positive'"""
shared_storage = {}
start_node = NumberNode(5) # post -> None
check_node = CheckPositiveNode() # post -> 'positive' or 'negative'
add_if_positive = AddNode(10) # post -> None
add_if_negative = AddNode(-20) # post -> None (won't run)
pipeline = Flow()
# start -> check (default); check branches on 'positive'/'negative'
pipeline.start(start_node) >> check_node
check_node - "positive" >> add_if_positive
check_node - "negative" >> add_if_negative
# Execution: start_node -> check_node -> add_if_positive
last_action = pipeline.run(shared_storage)
self.assertEqual(shared_storage['current'], 15) # 5 + 10
# Last node executed was add_if_positive, its post returns None
self.assertIsNone(last_action)
def test_branching_negative(self):
"""Test negative branch: CheckPositiveNode returns 'negative'"""
shared_storage = {}
start_node = NumberNode(-5) # post -> None
check_node = CheckPositiveNode() # post -> 'positive' or 'negative'
add_if_positive = AddNode(10) # post -> None (won't run)
add_if_negative = AddNode(-20) # post -> None
pipeline = Flow()
pipeline.start(start_node) >> check_node
check_node - "positive" >> add_if_positive
check_node - "negative" >> add_if_negative
# Execution: start_node -> check_node -> add_if_negative
last_action = pipeline.run(shared_storage)
self.assertEqual(shared_storage['current'], -25) # -5 + -20
# Last node executed was add_if_negative, its post returns None
self.assertIsNone(last_action)
def test_cycle_until_negative_ends_with_signal(self):
"""Test cycle, ending on a node that returns a signal"""
shared_storage = {}
n1 = NumberNode(10) # post -> None
check = CheckPositiveNode() # post -> 'positive' or 'negative'
subtract3 = AddNode(-3) # post -> None
end_node = EndSignalNode("cycle_done") # post -> "cycle_done"
pipeline = Flow()
pipeline.start(n1) >> check
# Branching from CheckPositiveNode
check - 'positive' >> subtract3
check - 'negative' >> end_node # End on negative branch
# After subtracting, go back to check (default transition)
subtract3 >> check
# Execution: n1->check->sub3->check->sub3->check->sub3->check->sub3->check->end_node
last_action = pipeline.run(shared_storage)
self.assertEqual(shared_storage['current'], -2) # 10 -> 7 -> 4 -> 1 -> -2
# Last node executed was end_node, its post returns "cycle_done"
self.assertEqual(last_action, "cycle_done")
def test_flow_ends_warning_default_missing(self):
"""Test warning when default transition is needed but not found"""
shared_storage = {}
# Node that returns a specific action from post
class ActionNode(Node):
def post(self, *args): return "specific_action"
start_node = ActionNode()
next_node = NoOpNode()
pipeline = Flow()
pipeline.start(start_node)
# Define successor only for the specific action
start_node - "specific_action" >> next_node
# Make start_node return None instead, triggering default search
start_node.post = lambda *args: None
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
# Run flow. start_node runs, post returns None.
# Flow looks for "default", but only "specific_action" exists.
last_action = pipeline.run(shared_storage)
self.assertEqual(len(w), 1)
self.assertTrue(issubclass(w[-1].category, UserWarning))
# Warning message should indicate "default" wasn't found
self.assertIn("Flow ends: 'None' not found in ['specific_action']", str(w[-1].message))
# Last action is from start_node's post
self.assertIsNone(last_action)
def test_flow_ends_warning_specific_missing(self):
"""Test warning when specific action is returned but not found"""
shared_storage = {}
# Node that returns a specific action from post
class ActionNode(Node):
def post(self, *args): return "specific_action"
start_node = ActionNode()
next_node = NoOpNode()
pipeline = Flow()
pipeline.start(start_node)
# Define successor only for "default"
start_node >> next_node # same as start_node.next(next_node, "default")
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
# Run flow. start_node runs, post returns "specific_action".
# Flow looks for "specific_action", but only "default" exists.
last_action = pipeline.run(shared_storage)
self.assertEqual(len(w), 1)
self.assertTrue(issubclass(w[-1].category, UserWarning))
# Warning message should indicate "specific_action" wasn't found
self.assertIn("Flow ends: 'specific_action' not found in ['default']", str(w[-1].message))
# Last action is from start_node's post
self.assertEqual(last_action, "specific_action")
if __name__ == '__main__':
unittest.main()