forked from The-Pocket/PocketFlow
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathviz.mdc
More file actions
282 lines (232 loc) · 9.2 KB
/
viz.mdc
File metadata and controls
282 lines (232 loc) · 9.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
---
description:
globs:
alwaysApply: false
---
---
description: Guidelines for using PocketFlow, Utility Function, Viz and Debug
globs:
alwaysApply: false
---
# Visualization and Debugging
Similar to LLM wrappers, we **don't** provide built-in visualization and debugging. Here, we recommend some *minimal* (and incomplete) implementations These examples can serve as a starting point for your own tooling.
## 1. Visualization with Mermaid
This code recursively traverses the nested graph, assigns unique IDs to each node, and treats Flow nodes as subgraphs to generate Mermaid syntax for a hierarchical visualization.
{% raw %}
```python
def build_mermaid(start):
ids, visited, lines = {}, set(), ["graph LR"]
ctr = 1
def get_id(n):
nonlocal ctr
return ids[n] if n in ids else (ids.setdefault(n, f"N{ctr}"), (ctr := ctr + 1))[0]
def link(a, b):
lines.append(f" {a} --> {b}")
def walk(node, parent=None):
if node in visited:
return parent and link(parent, get_id(node))
visited.add(node)
if isinstance(node, Flow):
node.start_node and parent and link(parent, get_id(node.start_node))
lines.append(f"\n subgraph sub_flow_{get_id(node)}[{type(node).__name__}]")
node.start_node and walk(node.start_node)
for nxt in node.successors.values():
node.start_node and walk(nxt, get_id(node.start_node)) or (parent and link(parent, get_id(nxt))) or walk(nxt)
lines.append(" end\n")
else:
lines.append(f" {(nid := get_id(node))}['{type(node).__name__}']")
parent and link(parent, nid)
[walk(nxt, nid) for nxt in node.successors.values()]
walk(start)
return "\n".join(lines)
```
{% endraw %}
For example, suppose we have a complex Flow for data science:
```python
class DataPrepBatchNode(BatchNode):
def prep(self,shared): return []
class ValidateDataNode(Node): pass
class FeatureExtractionNode(Node): pass
class TrainModelNode(Node): pass
class EvaluateModelNode(Node): pass
class ModelFlow(Flow): pass
class DataScienceFlow(Flow):pass
feature_node = FeatureExtractionNode()
train_node = TrainModelNode()
evaluate_node = EvaluateModelNode()
feature_node >> train_node >> evaluate_node
model_flow = ModelFlow(start=feature_node)
data_prep_node = DataPrepBatchNode()
validate_node = ValidateDataNode()
data_prep_node >> validate_node >> model_flow
data_science_flow = DataScienceFlow(start=data_prep_node)
result = build_mermaid(start=data_science_flow)
```
The code generates a Mermaid diagram:
```mermaid
graph LR
subgraph sub_flow_N1[DataScienceFlow]
N2['DataPrepBatchNode']
N3['ValidateDataNode']
N2 --> N3
N3 --> N4
subgraph sub_flow_N5[ModelFlow]
N4['FeatureExtractionNode']
N6['TrainModelNode']
N4 --> N6
N7['EvaluateModelNode']
N6 --> N7
end
end
```
## 2. Interactive D3.js Visualization
For more complex flows, a static diagram may not be sufficient. We provide a D3.js-based interactive visualization that allows for dragging nodes, showing group boundaries for flows, and connecting flows at their boundaries.
### Converting Flow to JSON
First, we convert the PocketFlow graph to JSON format suitable for D3.js:
```python
def flow_to_json(start):
"""Convert a flow to JSON format suitable for D3.js visualization.
This function walks through the flow graph and builds a structure with:
- nodes: All non-Flow nodes with their group memberships
- links: Connections between nodes within the same group
- group_links: Connections between different groups (for inter-flow connections)
- flows: Flow information for group labeling
"""
nodes = []
links = []
group_links = [] # For connections between groups (Flow to Flow)
ids = {}
node_types = {}
flow_nodes = {} # Keep track of flow nodes
ctr = 1
# Implementation details...
# Post-processing: Generate group links based on node connections between different groups
node_groups = {n["id"]: n["group"] for n in nodes}
filtered_links = []
# Filter out direct node-to-node connections between different groups
for link in links:
source_id = link["source"]
target_id = link["target"]
source_group = node_groups.get(source_id, 0)
target_group = node_groups.get(target_id, 0)
if source_group != target_group and source_group > 0 and target_group > 0:
# Create group-to-group links instead of node-to-node links across groups
if not any(gl["source"] == source_group and gl["target"] == target_group
for gl in group_links):
group_links.append({
"source": source_group,
"target": target_group,
"action": link["action"]
})
# Skip adding this link to filtered_links - we don't want direct node connections across groups
else:
# Keep links within the same group
filtered_links.append(link)
return {
"nodes": nodes,
"links": filtered_links,
"group_links": group_links,
"flows": {str(k): v.__class__.__name__ for k, v in flow_nodes.items()},
}
```
### Creating the Visualization
Then, we generate an HTML file with D3.js visualization:
```python
def create_d3_visualization(json_data, output_dir="./viz", filename="flow_viz"):
"""Create a D3.js visualization from JSON data."""
# Create output directory
os.makedirs(output_dir, exist_ok=True)
# Save JSON data to file
json_path = os.path.join(output_dir, f"{filename}.json")
with open(json_path, "w") as f:
json.dump(json_data, f, indent=2)
# Generate HTML with D3.js visualization
# ...HTML template with D3.js code...
# Key features implemented in the visualization:
# 1. Nodes can be dragged to reorganize the layout
# 2. Flows are shown as dashed rectangles (groups)
# 3. Inter-group connections shown as dashed lines connecting at group boundaries
# 4. Edge labels show transition actions
# Write HTML to file
html_path = os.path.join(output_dir, f"{filename}.html")
with open(html_path, "w") as f:
f.write(html_content)
print(f"Visualization created at {html_path}")
return html_path
```
### Convenience Function
A convenience function to visualize flows:
```python
def visualize_flow(flow, flow_name):
"""Helper function to visualize a flow with both mermaid and D3.js"""
print(f"\n--- {flow_name} Mermaid Diagram ---")
print(build_mermaid(start=flow))
print(f"\n--- {flow_name} D3.js Visualization ---")
json_data = flow_to_json(flow)
create_d3_visualization(
json_data, filename=f"{flow_name.lower().replace(' ', '_')}"
)
```
### Usage Example
```python
from visualize import visualize_flow
# Create a complex flow with nested subflows
# ...flow definition...
# Generate visualization
visualize_flow(data_science_flow, "Data Science Flow")
```
### Customizing the Visualization
You can customize the visualization by adjusting the force simulation parameters:
```javascript
const simulation = d3.forceSimulation(data.nodes)
// Controls the distance between connected nodes
.force("link", d3.forceLink(data.links).id(d => d.id).distance(100))
// Controls how nodes repel each other - lower values bring nodes closer
.force("charge", d3.forceManyBody().strength(-30))
// Centers the entire graph in the SVG
.force("center", d3.forceCenter(width / 2, height / 2))
// Prevents nodes from overlapping - acts like a minimum distance
.force("collide", d3.forceCollide().radius(50));
```
## 3. Call Stack Debugging
It would be useful to print the Node call stacks for debugging. This can be achieved by inspecting the runtime call stack:
```python
import inspect
def get_node_call_stack():
stack = inspect.stack()
node_names = []
seen_ids = set()
for frame_info in stack[1:]:
local_vars = frame_info.frame.f_locals
if 'self' in local_vars:
caller_self = local_vars['self']
if isinstance(caller_self, BaseNode) and id(caller_self) not in seen_ids:
seen_ids.add(id(caller_self))
node_names.append(type(caller_self).__name__)
return node_names
```
For example, suppose we have a complex Flow for data science:
```python
class DataPrepBatchNode(BatchNode):
def prep(self, shared): return []
class ValidateDataNode(Node): pass
class FeatureExtractionNode(Node): pass
class TrainModelNode(Node): pass
class EvaluateModelNode(Node):
def prep(self, shared):
stack = get_node_call_stack()
print("Call stack:", stack)
class ModelFlow(Flow): pass
class DataScienceFlow(Flow):pass
feature_node = FeatureExtractionNode()
train_node = TrainModelNode()
evaluate_node = EvaluateModelNode()
feature_node >> train_node >> evaluate_node
model_flow = ModelFlow(start=feature_node)
data_prep_node = DataPrepBatchNode()
validate_node = ValidateDataNode()
data_prep_node >> validate_node >> model_flow
data_science_flow = DataScienceFlow(start=data_prep_node)
data_science_flow.run({})
```
The output would be: `Call stack: ['EvaluateModelNode', 'ModelFlow', 'DataScienceFlow']`