forked from The-Pocket/PocketFlow
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_batch_node.py
More file actions
160 lines (126 loc) · 4.71 KB
/
test_batch_node.py
File metadata and controls
160 lines (126 loc) · 4.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import unittest
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from pocketflow import Node, BatchNode, Flow
class ArrayChunkNode(BatchNode):
def __init__(self, chunk_size=10):
super().__init__()
self.chunk_size = chunk_size
def prep(self, shared_storage):
# Get array from shared storage and split into chunks
array = shared_storage.get('input_array', [])
chunks = []
for start in range(0, len(array), self.chunk_size):
end = min(start + self.chunk_size, len(array))
chunks.append(array[start: end])
return chunks
def exec(self, chunk):
# Process the chunk and return its sum
chunk_sum = sum(chunk)
return chunk_sum
def post(self, shared_storage, prep_result, proc_result):
# Store chunk results in shared storage
shared_storage['chunk_results'] = proc_result
return "default"
class SumReduceNode(Node):
def prep(self, shared_storage):
# Get chunk results from shared storage and sum them
chunk_results = shared_storage.get('chunk_results', [])
total = sum(chunk_results)
shared_storage['total'] = total
class TestBatchNode(unittest.TestCase):
def test_array_chunking(self):
"""
Test that the array is correctly split into chunks
"""
shared_storage = {
'input_array': list(range(25)) # [0,1,2,...,24]
}
chunk_node = ArrayChunkNode(chunk_size=10)
chunk_node.run(shared_storage)
results = shared_storage['chunk_results']
self.assertEqual(results, [45, 145, 110])
def test_map_reduce_sum(self):
"""
Test a complete map-reduce pipeline that sums a large array:
1. Map: Split array into chunks and sum each chunk
2. Reduce: Sum all the chunk sums
"""
# Create test array: [0,1,2,...,99]
array = list(range(100))
expected_sum = sum(array) # 4950
shared_storage = {
'input_array': array
}
# Create nodes
chunk_node = ArrayChunkNode(chunk_size=10)
reduce_node = SumReduceNode()
# Connect nodes
chunk_node >> reduce_node
# Create and run pipeline
pipeline = Flow(start=chunk_node)
pipeline.run(shared_storage)
self.assertEqual(shared_storage['total'], expected_sum)
def test_uneven_chunks(self):
"""
Test that the map-reduce works correctly with array lengths
that don't divide evenly by chunk_size
"""
array = list(range(25))
expected_sum = sum(array) # 300
shared_storage = {
'input_array': array
}
chunk_node = ArrayChunkNode(chunk_size=10)
reduce_node = SumReduceNode()
chunk_node >> reduce_node
pipeline = Flow(start=chunk_node)
pipeline.run(shared_storage)
self.assertEqual(shared_storage['total'], expected_sum)
def test_custom_chunk_size(self):
"""
Test that the map-reduce works with different chunk sizes
"""
array = list(range(100))
expected_sum = sum(array)
shared_storage = {
'input_array': array
}
# Use chunk_size=15 instead of default 10
chunk_node = ArrayChunkNode(chunk_size=15)
reduce_node = SumReduceNode()
chunk_node >> reduce_node
pipeline = Flow(start=chunk_node)
pipeline.run(shared_storage)
self.assertEqual(shared_storage['total'], expected_sum)
def test_single_element_chunks(self):
"""
Test extreme case where chunk_size=1
"""
array = list(range(5))
expected_sum = sum(array)
shared_storage = {
'input_array': array
}
chunk_node = ArrayChunkNode(chunk_size=1)
reduce_node = SumReduceNode()
chunk_node >> reduce_node
pipeline = Flow(start=chunk_node)
pipeline.run(shared_storage)
self.assertEqual(shared_storage['total'], expected_sum)
def test_empty_array(self):
"""
Test edge case of empty input array
"""
shared_storage = {
'input_array': []
}
chunk_node = ArrayChunkNode(chunk_size=10)
reduce_node = SumReduceNode()
chunk_node >> reduce_node
pipeline = Flow(start=chunk_node)
pipeline.run(shared_storage)
self.assertEqual(shared_storage['total'], 0)
if __name__ == '__main__':
unittest.main()