forked from The-Pocket/PocketFlow
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_flow_composition.py
More file actions
151 lines (133 loc) · 5.85 KB
/
test_flow_composition.py
File metadata and controls
151 lines (133 loc) · 5.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
# tests/test_flow_composition.py
import unittest
import asyncio # Keep import, might be needed if other tests use it indirectly
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from pocketflow import Node, Flow
# --- Existing Nodes ---
class NumberNode(Node):
def __init__(self, number):
super().__init__()
self.number = number
def prep(self, shared_storage):
shared_storage['current'] = self.number
# post implicitly returns None
class AddNode(Node):
def __init__(self, number):
super().__init__()
self.number = number
def prep(self, shared_storage):
shared_storage['current'] += self.number
# post implicitly returns None
class MultiplyNode(Node):
def __init__(self, number):
super().__init__()
self.number = number
def prep(self, shared_storage):
shared_storage['current'] *= self.number
# post implicitly returns None
# --- New Nodes for Action Propagation Test ---
class SignalNode(Node):
"""A node that returns a specific signal string from its post method."""
def __init__(self, signal="default_signal"):
super().__init__()
self.signal = signal
# No prep needed usually if just signaling
def post(self, shared_storage, prep_result, exec_result):
# Store the signal in shared storage for verification
shared_storage['last_signal_emitted'] = self.signal
return self.signal # Return the specific action string
class PathNode(Node):
"""A node to indicate which path was taken in the outer flow."""
def __init__(self, path_id):
super().__init__()
self.path_id = path_id
def prep(self, shared_storage):
shared_storage['path_taken'] = self.path_id
# post implicitly returns None
# --- Test Class ---
class TestFlowComposition(unittest.TestCase):
# --- Existing Tests (Unchanged) ---
def test_flow_as_node(self):
"""
1) Create a Flow (f1) starting with NumberNode(5), then AddNode(10), then MultiplyNode(2).
2) Create a second Flow (f2) whose start is f1.
3) Create a wrapper Flow (f3) that contains f2 to ensure proper execution.
Expected final result in shared_storage['current']: (5 + 10) * 2 = 30.
"""
shared_storage = {}
f1 = Flow(start=NumberNode(5))
f1 >> AddNode(10) >> MultiplyNode(2)
f2 = Flow(start=f1)
f3 = Flow(start=f2)
f3.run(shared_storage)
self.assertEqual(shared_storage['current'], 30)
def test_nested_flow(self):
"""
Demonstrates nested flows with proper wrapping:
inner_flow: NumberNode(5) -> AddNode(3)
middle_flow: starts with inner_flow -> MultiplyNode(4)
wrapper_flow: contains middle_flow to ensure proper execution
Expected final result: (5 + 3) * 4 = 32.
"""
shared_storage = {}
inner_flow = Flow(start=NumberNode(5))
inner_flow >> AddNode(3)
middle_flow = Flow(start=inner_flow)
middle_flow >> MultiplyNode(4)
wrapper_flow = Flow(start=middle_flow)
wrapper_flow.run(shared_storage)
self.assertEqual(shared_storage['current'], 32)
def test_flow_chaining_flows(self):
"""
Demonstrates chaining two flows with proper wrapping:
flow1: NumberNode(10) -> AddNode(10) # final = 20
flow2: MultiplyNode(2) # final = 40
wrapper_flow: contains both flow1 and flow2 to ensure proper execution
Expected final result: (10 + 10) * 2 = 40.
"""
shared_storage = {}
numbernode = NumberNode(10)
numbernode >> AddNode(10)
flow1 = Flow(start=numbernode)
flow2 = Flow(start=MultiplyNode(2))
flow1 >> flow2 # Default transition based on flow1 returning None
wrapper_flow = Flow(start=flow1)
wrapper_flow.run(shared_storage)
self.assertEqual(shared_storage['current'], 40)
def test_composition_with_action_propagation(self):
"""
Test that an outer flow can branch based on the action returned
by the last node's post() within an inner flow.
"""
shared_storage = {}
# 1. Define an inner flow that ends with a node returning a specific action
inner_start_node = NumberNode(100) # current = 100, post -> None
inner_end_node = SignalNode("inner_done") # post -> "inner_done"
inner_start_node >> inner_end_node
# Inner flow will execute start->end, and the Flow's execution will return "inner_done"
inner_flow = Flow(start=inner_start_node)
# 2. Define target nodes for the outer flow branches
path_a_node = PathNode("A") # post -> None
path_b_node = PathNode("B") # post -> None
# 3. Define the outer flow starting with the inner flow
outer_flow = Flow()
outer_flow.start(inner_flow) # Use the start() method
# 4. Define branches FROM the inner_flow object based on its returned action
inner_flow - "inner_done" >> path_b_node # This path should be taken
inner_flow - "other_action" >> path_a_node # This path should NOT be taken
# 5. Run the outer flow and capture the last action
# Execution: inner_start -> inner_end -> path_b
last_action_outer = outer_flow.run(shared_storage)
# 6. Assert the results
# Check state after inner flow execution
self.assertEqual(shared_storage.get('current'), 100)
self.assertEqual(shared_storage.get('last_signal_emitted'), "inner_done")
# Check that the correct outer path was taken
self.assertEqual(shared_storage.get('path_taken'), "B")
# Check the action returned by the outer flow. The last node executed was
# path_b_node, which returns None from its post method.
self.assertIsNone(last_action_outer)
if __name__ == '__main__':
unittest.main()