1 # -*- coding: utf-8 -*-
2 import math
3 import numpy
4 import matplotlib.pyplot as plt
5 from pylab import *
6
7
8
9 # 地图
10 tm = [
11 '############################################################',
12 '#..........................................................#',
13 '#.............................#............................#',
14 '#.............................#............................#',
15 '#.............................#............................#',
16 '#.......S....#.................#...........................#',
17 '#.............................#............................#',
18 '#.............................#............................#',
19 '#.............................#............................#',
20 '#........................#....##...........................#',
21 '#.............................#............................#',
22 '#.............................#............................#',
23 '#.............................#............................#',
24 '#######.#######################################............#',
25 '#....#........#............................................#',
26 '#....#........#............................................#',
27 '#....##########............................................#',
28 '#..........................................................#',
29 '#..........................................................#',
30 '#..........................................................#',
31 '#..........................................................#',
32 '#..........................................................#',
33 '#...............................##############.............#',
34 '#...............................#........E...#.............#',
35 '#...............................#............#.............#',
36 '#...............................#............#.............#',
37 '#...............................#............#.............#',
38 '#...............................###########..#.............#',
39 '#..........................................................#',
40 '#..........................................................#',
41 '############################################################']
42
43 # 因为python里string不能直接改变某一元素,所以用test_map来存储搜索时的地图
44 test_map = []
45
46
47 #########################################################
48 class Node_Elem:
49 """
50 开放列表和关闭列表的元素类型,parent用来在成功的时候回溯路径
51 """
52
53 def __init__(self, parent, x, y, dist):
54 self.parent = parent
55 self.x = x
56 self.y = y
57 self.dist = dist
58
59
60 class A_Star:
61 """
62 A星算法实现类
63 """
64
65 # 注意w,h两个参数,如果你修改了地图,需要传入一个正确值或者修改这里的默认参数
66 def __init__(self, s_x, s_y, e_x, e_y, w=60, h=30):
67 self.s_x = s_x
68 self.s_y = s_y
69 self.e_x = e_x
70 self.e_y = e_y
71
72 self.width = w
73 self.height = h
74
75 self.open = []
76 self.close = []
77 self.path = []
78
79 # 查找路径的入口函数
80 def find_path(self):
81 # 构建开始节点
82 p = Node_Elem(None, self.s_x, self.s_y, 0.0)
83 while True:
84 # 扩展F值最小的节点
85 self.extend_round(p)
86 # 如果开放列表为空,则不存在路径,返回
87 if not self.open:
88 return
89 # 获取F值最小的节点
90 idx, p = self.get_best()
91 # 找到路径,生成路径,返回
92 if self.is_target(p):
93 self.make_path(p)
94 return
95 # 把此节点压入关闭列表,并从开放列表里删除
96 self.close.append(p)
97 del self.open[idx]
98
99 def make_path(self, p):
100 # 从结束点回溯到开始点,开始点的parent == None
101 while p:
102 self.path.append((p.x, p.y))
103 p = p.parent
104
105 def is_target(self, i):
106 return i.x == self.e_x and i.y == self.e_y
107
108 def get_best(self):
109 best = None
110 bv = 1000000 # 如果你修改的地图很大,可能需要修改这个值
111 bi = -1
112 for idx, i in enumerate(self.open):
113 value = self.get_dist(i) # 获取F值
114 if value < bv: # 比以前的更好,即F值更小
115 best = i
116 bv = value
117 bi = idx
118 return bi, best
119
120 def get_dist(self, i):
121 # F = G + H
122 # G 为已经走过的路径长度, H为估计还要走多远
123 # 这个公式就是A*算法的精华了。
124 return i.dist + math.sqrt(
125 (self.e_x - i.x) * (self.e_x - i.x)
126 + (self.e_y - i.y) * (self.e_y - i.y)) * 1.2
127
128 def extend_round(self, p):
129 # 可以从8个方向走
130 xs = (-1, 0, 1, -1, 1, -1, 0, 1)
131 ys = (-1, -1, -1, 0, 0, 1, 1, 1)
132
133 for x, y in zip(xs, ys):
134 new_x, new_y = x + p.x, y + p.y
135 # 无效或者不可行走区域,则勿略
136 if not self.is_valid_coord(new_x, new_y):
137 continue
138 # 构造新的节点
139 node = Node_Elem(p, new_x, new_y, p.dist + self.get_cost(
140 p.x, p.y, new_x, new_y))
141 # 新节点在关闭列表,则忽略
142 if self.node_in_close(node):
143 continue
144 i = self.node_in_open(node)
145 if i != -1:
146 # 新节点在开放列表
147 if self.open[i].dist > node.dist:
148 # 现在的路径到比以前到这个节点的路径更好~
149 # 则使用现在的路径
150 self.open[i].parent = p
151 self.open[i].dist = node.dist
152 continue
153 self.open.append(node)
154
155 def get_cost(self, x1, y1, x2, y2):
156 """
157 上下左右直走,代价为1.0,斜走,代价为1.4
158 """
159 if x1 == x2 or y1 == y2:
160 return 1.0
161 return 1.4
162
163 def node_in_close(self, node):
164 for i in self.close:
165 if node.x == i.x and node.y == i.y:
166 return True
167 return False
168
169 def node_in_open(self, node):
170 for i, n in enumerate(self.open):
171 if node.x == n.x and node.y == n.y:
172 return i
173 return -1
174
175 def is_valid_coord(self, x, y):
176 if x < 0 or x >= self.width or y < 0 or y >= self.height:
177 return False
178 return test_map[y][x] != '#'
179
180 def get_searched(self):
181 l = []
182 for i in self.open:
183 l.append((i.x, i.y))
184 for i in self.close:
185 l.append((i.x, i.y))
186 return l
187
188
189 #########################################################
190 def print_test_map():
191 """
192 打印搜索后的地图
193 """
194 for line in test_map:
195 print(''.join(line))
196
197
198 def get_start_XY():
199 return get_symbol_XY('S')
200
201
202 def get_end_XY():
203 return get_symbol_XY('E')
204
205
206 def get_symbol_XY(s):
207 for y, line in enumerate(test_map):
208 try:
209 x = line.index(s)
210 except:
211 continue
212 else:
213 break
214 return x, y
215
216
217 #########################################################
218 def mark_path(l):
219 mark_symbol(l, '*')
220
221
222 def mark_searched(l):
223 mark_symbol(l, ' ')
224
225
226 def mark_symbol(l, s):
227 for x, y in l:
228 test_map[y][x] = s
229
230
231 def mark_start_end(s_x, s_y, e_x, e_y):
232 test_map[s_y][s_x] = 'S'
233 test_map[e_y][e_x] = 'E'
234
235
236 def tm_to_test_map():
237 for line in tm:
238 test_map.append(list(line))
239
240
241 def find_path():
242 s_x, s_y = get_start_XY()
243 e_x, e_y = get_end_XY()
244 a_star = A_Star(s_x, s_y, e_x, e_y)
245 a_star.find_path()
246 searched = a_star.get_searched()
247 path = a_star.path
248 # 标记已搜索区域
249 mark_searched(searched)
250 # 标记路径
251 mark_path(path)
252 print("path length is %d" % (len(path)))
253 print("searched squares count is %d" % (len(searched)))
254 # 标记开始、结束点
255 mark_start_end(s_x, s_y, e_x, e_y)
256
257
258 if __name__ == "__main__":
259 # 把字符串转成列表
260 tm_to_test_map()
261 find_path()
262 print_test_map()