Skip to content

Commit 0c23ebe

Browse files
authored
Spiral Spanning Tree Coverage Path Planning (AtsushiSakai#355)
* First commit of Spiral Spanning Tree Coverage * Modify followed by first code review * fix pycodestyle error * modifies following 2nd code review
1 parent 9fe14e6 commit 0c23ebe

File tree

5 files changed

+371
-0
lines changed

5 files changed

+371
-0
lines changed
156 Bytes
Loading
150 Bytes
Loading
132 Bytes
Loading
Lines changed: 313 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,313 @@
1+
"""
2+
Spiral Spanning Tree Coverage Path Planner
3+
4+
author: Todd Tang
5+
paper: Spiral-STC: An On-Line Coverage Algorithm of Grid Environments
6+
by a Mobile Robot - Gabriely et.al.
7+
link: https://ieeexplore.ieee.org/abstract/document/1013479
8+
"""
9+
10+
import os
11+
import sys
12+
import math
13+
14+
import numpy as np
15+
import matplotlib.pyplot as plt
16+
17+
do_animation = True
18+
19+
20+
class SpiralSpanningTreeCoveragePlanner:
21+
def __init__(self, occ_map):
22+
self.origin_map_height = occ_map.shape[0]
23+
self.origin_map_width = occ_map.shape[1]
24+
25+
# original map resolution must be even
26+
if self.origin_map_height % 2 == 1 or self.origin_map_width % 2 == 1:
27+
sys.exit('original map width/height must be even \
28+
in grayscale .png format')
29+
30+
self.occ_map = occ_map
31+
self.merged_map_height = self.origin_map_height // 2
32+
self.merged_map_width = self.origin_map_width // 2
33+
34+
self.edge = []
35+
36+
def plan(self, start):
37+
"""plan
38+
39+
performing Spiral Spanning Tree Coverage path planning
40+
41+
:param start: the start node of Spiral Spanning Tree Coverage
42+
"""
43+
44+
visit_times = np.zeros(
45+
(self.merged_map_height, self.merged_map_width), dtype=np.int)
46+
visit_times[start[0]][start[1]] = 1
47+
48+
# generate route by
49+
# recusively call perform_spanning_tree_coverage() from start node
50+
route = []
51+
self.perform_spanning_tree_coverage(start, visit_times, route)
52+
53+
path = []
54+
# generate path from route
55+
for idx in range(len(route)-1):
56+
dp = abs(route[idx][0] - route[idx+1][0]) + \
57+
abs(route[idx][1] - route[idx+1][1])
58+
if dp == 0:
59+
# special handle for round-trip path
60+
path.append(self.get_round_trip_path(route[idx-1], route[idx]))
61+
elif dp == 1:
62+
path.append(self.move(route[idx], route[idx+1]))
63+
elif dp == 2:
64+
# special handle for non-adjacent route nodes
65+
mid_node = self.get_intermediate_node(route[idx], route[idx+1])
66+
path.append(self.move(route[idx], mid_node))
67+
path.append(self.move(mid_node, route[idx+1]))
68+
else:
69+
sys.exit('adjacent path node distance larger than 2')
70+
71+
return self.edge, route, path
72+
73+
def perform_spanning_tree_coverage(self, current_node, visit_times, route):
74+
"""perform_spanning_tree_coverage
75+
76+
recursive function for function <plan>
77+
78+
:param current_node: current node
79+
"""
80+
81+
def is_valid_node(i, j):
82+
is_i_valid_bounded = 0 <= i < self.merged_map_height
83+
is_j_valid_bounded = 0 <= j < self.merged_map_width
84+
if is_i_valid_bounded and is_j_valid_bounded:
85+
# free only when the 4 sub-cells are all free
86+
return bool(
87+
self.occ_map[2*i][2*j]
88+
and self.occ_map[2*i+1][2*j]
89+
and self.occ_map[2*i][2*j+1]
90+
and self.occ_map[2*i+1][2*j+1])
91+
92+
return False
93+
94+
# counter-clockwise neighbor finding order
95+
order = [[1, 0], [0, 1], [-1, 0], [0, -1]]
96+
97+
found = False
98+
route.append(current_node)
99+
for inc in order:
100+
ni, nj = current_node[0] + inc[0], current_node[1] + inc[1]
101+
if is_valid_node(ni, nj) and visit_times[ni][nj] == 0:
102+
neighbor_node = (ni, nj)
103+
self.edge.append((current_node, neighbor_node))
104+
found = True
105+
visit_times[ni][nj] += 1
106+
self.perform_spanning_tree_coverage(
107+
neighbor_node, visit_times, route)
108+
109+
# backtrace route from node with neighbors all visited
110+
# to first node with unvisited neighbor
111+
if not found:
112+
has_node_with_unvisited_ngb = False
113+
for node in reversed(route):
114+
# drop nodes that have been visited twice
115+
if visit_times[node[0]][node[1]] == 2:
116+
continue
117+
118+
visit_times[node[0]][node[1]] += 1
119+
route.append(node)
120+
121+
for inc in order:
122+
ni, nj = node[0] + inc[0], node[1] + inc[1]
123+
if is_valid_node(ni, nj) and visit_times[ni][nj] == 0:
124+
has_node_with_unvisited_ngb = True
125+
break
126+
127+
if has_node_with_unvisited_ngb:
128+
break
129+
130+
return route
131+
132+
def move(self, p, q):
133+
direction = self.get_vector_direction(p, q)
134+
# move east
135+
if direction == 'E':
136+
p = self.get_sub_node(p, 'SE')
137+
q = self.get_sub_node(q, 'SW')
138+
# move west
139+
elif direction == 'W':
140+
p = self.get_sub_node(p, 'NW')
141+
q = self.get_sub_node(q, 'NE')
142+
# move south
143+
elif direction == 'S':
144+
p = self.get_sub_node(p, 'SW')
145+
q = self.get_sub_node(q, 'NW')
146+
# move north
147+
elif direction == 'N':
148+
p = self.get_sub_node(p, 'NE')
149+
q = self.get_sub_node(q, 'SE')
150+
else:
151+
sys.exit('move direction error...')
152+
return [p, q]
153+
154+
def get_round_trip_path(self, last, pivot):
155+
direction = self.get_vector_direction(last, pivot)
156+
if direction == 'E':
157+
return [self.get_sub_node(pivot, 'SE'),
158+
self.get_sub_node(pivot, 'NE')]
159+
elif direction == 'S':
160+
return [self.get_sub_node(pivot, 'SW'),
161+
self.get_sub_node(pivot, 'SE')]
162+
elif direction == 'W':
163+
return [self.get_sub_node(pivot, 'NW'),
164+
self.get_sub_node(pivot, 'SW')]
165+
elif direction == 'N':
166+
return [self.get_sub_node(pivot, 'NE'),
167+
self.get_sub_node(pivot, 'NW')]
168+
else:
169+
sys.exit('get_round_trip_path: last->pivot direction error.')
170+
171+
def get_vector_direction(self, p, q):
172+
# east
173+
if p[0] == q[0] and p[1] < q[1]:
174+
return 'E'
175+
# west
176+
elif p[0] == q[0] and p[1] > q[1]:
177+
return 'W'
178+
# south
179+
elif p[0] < q[0] and p[1] == q[1]:
180+
return 'S'
181+
# north
182+
elif p[0] > q[0] and p[1] == q[1]:
183+
return 'N'
184+
else:
185+
sys.exit('get_vector_direction: Only E/W/S/N direction supported.')
186+
187+
def get_sub_node(self, node, direction):
188+
if direction == 'SE':
189+
return [2*node[0]+1, 2*node[1]+1]
190+
elif direction == 'SW':
191+
return [2*node[0]+1, 2*node[1]]
192+
elif direction == 'NE':
193+
return [2*node[0], 2*node[1]+1]
194+
elif direction == 'NW':
195+
return [2*node[0], 2*node[1]]
196+
else:
197+
sys.exit('get_sub_node: sub-node direction error.')
198+
199+
def get_interpolated_path(self, p, q):
200+
# direction p->q: southwest / northeast
201+
if (p[0] < q[0]) ^ (p[1] < q[1]):
202+
ipx = [p[0], p[0], q[0]]
203+
ipy = [p[1], q[1], q[1]]
204+
# direction p->q: southeast / northwest
205+
else:
206+
ipx = [p[0], q[0], q[0]]
207+
ipy = [p[1], p[1], q[1]]
208+
return ipx, ipy
209+
210+
def get_intermediate_node(self, p, q):
211+
p_ngb, q_ngb = set(), set()
212+
213+
for m, n in self.edge:
214+
if m == p:
215+
p_ngb.add(n)
216+
if n == p:
217+
p_ngb.add(m)
218+
if m == q:
219+
q_ngb.add(n)
220+
if n == q:
221+
q_ngb.add(m)
222+
223+
itsc = p_ngb.intersection(q_ngb)
224+
if len(itsc) == 0:
225+
sys.exit('get_intermediate_node: \
226+
no intermediate node between', p, q)
227+
elif len(itsc) == 1:
228+
return list(itsc)[0]
229+
else:
230+
sys.exit('get_intermediate_node: \
231+
more than 1 intermediate node between', p, q)
232+
233+
def visualize_path(self, edge, path, start):
234+
def coord_transform(p):
235+
return [2*p[1] + 0.5, 2*p[0] + 0.5]
236+
237+
if do_animation:
238+
last = path[0][0]
239+
trajectory = [[last[1]], [last[0]]]
240+
for p, q in path:
241+
distance = math.hypot(p[0]-last[0], p[1]-last[1])
242+
if distance <= 1.0:
243+
trajectory[0].append(p[1])
244+
trajectory[1].append(p[0])
245+
else:
246+
ipx, ipy = self.get_interpolated_path(last, p)
247+
trajectory[0].extend(ipy)
248+
trajectory[1].extend(ipx)
249+
250+
last = q
251+
252+
trajectory[0].append(last[1])
253+
trajectory[1].append(last[0])
254+
255+
for idx, state in enumerate(np.transpose(trajectory)):
256+
plt.cla()
257+
# for stopping simulation with the esc key.
258+
plt.gcf().canvas.mpl_connect(
259+
'key_release_event',
260+
lambda event: [exit(0) if event.key == 'escape' else None])
261+
262+
# draw spanning tree
263+
plt.imshow(self.occ_map, 'gray')
264+
for p, q in edge:
265+
p = coord_transform(p)
266+
q = coord_transform(q)
267+
plt.plot([p[0], q[0]], [p[1], q[1]], '-oc')
268+
sx, sy = coord_transform(start)
269+
plt.plot([sx], [sy], 'pr', markersize=10)
270+
271+
# draw move path
272+
plt.plot(trajectory[0][:idx+1], trajectory[1][:idx+1], '-k')
273+
plt.plot(state[0], state[1], 'or')
274+
plt.axis('equal')
275+
plt.grid(True)
276+
plt.pause(0.01)
277+
278+
else:
279+
# draw spanning tree
280+
plt.imshow(self.occ_map, 'gray')
281+
for p, q in edge:
282+
p = coord_transform(p)
283+
q = coord_transform(q)
284+
plt.plot([p[0], q[0]], [p[1], q[1]], '-oc')
285+
sx, sy = coord_transform(start)
286+
plt.plot([sx], [sy], 'pr', markersize=10)
287+
288+
# draw move path
289+
last = path[0][0]
290+
for p, q in path:
291+
distance = math.hypot(p[0]-last[0], p[1]-last[1])
292+
if distance == 1.0:
293+
plt.plot([last[1], p[1]], [last[0], p[0]], '-k')
294+
else:
295+
ipx, ipy = self.get_interpolated_path(last, p)
296+
plt.plot(ipy, ipx, '-k')
297+
plt.arrow(p[1], p[0], q[1]-p[1], q[0]-p[0], head_width=0.2)
298+
last = q
299+
300+
plt.show()
301+
302+
303+
def main():
304+
dir_path = os.path.dirname(os.path.realpath(__file__))
305+
img = plt.imread(os.path.join(dir_path, 'map', 'test_2.png'))
306+
STC_planner = SpiralSpanningTreeCoveragePlanner(img)
307+
start = (10, 0)
308+
edge, route, path = STC_planner.plan(start)
309+
STC_planner.visualize_path(edge, path, start)
310+
311+
312+
if __name__ == "__main__":
313+
main()
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import os
2+
import sys
3+
import matplotlib.pyplot as plt
4+
from unittest import TestCase
5+
6+
sys.path.append(os.path.dirname(
7+
os.path.abspath(__file__)) + "/../PathPlanning/SpiralSpanningTreeCPP")
8+
try:
9+
import spiral_spanning_tree_coverage_path_planner
10+
except ImportError:
11+
raise
12+
13+
spiral_spanning_tree_coverage_path_planner.do_animation = True
14+
15+
16+
class TestPlanning(TestCase):
17+
def spiral_stc_cpp(self, img, start):
18+
num_free = 0
19+
for i in range(img.shape[0]):
20+
for j in range(img.shape[1]):
21+
num_free += img[i][j]
22+
23+
STC_planner = spiral_spanning_tree_coverage_path_planner.\
24+
SpiralSpanningTreeCoveragePlanner(img)
25+
26+
edge, route, path = STC_planner.plan(start)
27+
28+
covered_nodes = set()
29+
for p, q in edge:
30+
covered_nodes.add(p)
31+
covered_nodes.add(q)
32+
33+
# assert complete coverage
34+
self.assertEqual(len(covered_nodes), num_free / 4)
35+
36+
def test_spiral_stc_cpp_1(self):
37+
img_dir = os.path.dirname(
38+
os.path.abspath(__file__)) + \
39+
"/../PathPlanning/SpiralSpanningTreeCPP"
40+
img = plt.imread(os.path.join(img_dir, 'map', 'test.png'))
41+
start = (0, 0)
42+
self.spiral_stc_cpp(img, start)
43+
44+
def test_spiral_stc_cpp_2(self):
45+
img_dir = os.path.dirname(
46+
os.path.abspath(__file__)) + \
47+
"/../PathPlanning/SpiralSpanningTreeCPP"
48+
img = plt.imread(os.path.join(img_dir, 'map', 'test_2.png'))
49+
start = (10, 0)
50+
self.spiral_stc_cpp(img, start)
51+
52+
def test_spiral_stc_cpp_3(self):
53+
img_dir = os.path.dirname(
54+
os.path.abspath(__file__)) + \
55+
"/../PathPlanning/SpiralSpanningTreeCPP"
56+
img = plt.imread(os.path.join(img_dir, 'map', 'test_3.png'))
57+
start = (0, 0)
58+
self.spiral_stc_cpp(img, start)

0 commit comments

Comments
 (0)