算法 多源最短路径 floyd算法
算法核心:顶点i到顶点j之间插入顶点k,看是否能够缩短i和j之间的距离
数据结构:
设置地图的带权邻接矩阵为G.Edge[][],如果顶点i到顶点j有边,则G.Edge[i][j]=<i,j>边的权值,否则G.Edge[i][j]=∞。
采用两个辅助数组;dist[i][j]记录i到j的最短路径长度, p[i][j]记录i到j顶点最短路径上j顶点的前驱
初始化:
初始化dist[i][j]=G.Edge[i][j],如果顶点i到顶点j有边,初始化p[i][j]=i,否则p[i][j]=-1
插点:
如果dist[i][j]>dist[i][k]+dist[k][j]则dist[i][j]=dist[i][k]+dist[k][j],记录顶点j的前驱为:p[i][j]=p[k][j]
实现代码:
#!/usr/bin/env python # -*- coding:utf-8 -*- graphNodeList = [] #图存储节点信息 graphAdjMatrix = [] #有向权重图邻接矩阵 dist = [] #存储最短路径开销,二维数组 p = [] #每个顶点的直接前驱的索引,二维数组 MAX = 100000 #str切分后给列表赋值格式 def acquireNode(): node_list = input("请输入途中所有的点,以空格分隔:") for nodeInfo in node_list.strip().split(" "): graphNodeList.append(nodeInfo) #根据输入信息填写邻接矩阵 def acquireSideUndig(): print("请输入图的所有边信息,格式为边的两个顶点,使用空格分隔。 eg:a b,如果全部信息输入完毕请输入end") while True: tempSide = input(">:") if tempSide.strip().lower() == "end": print("输入结束") break tempNodeList = tempSide.strip().split(" ") if len(tempNodeList) == 2: if tempNodeList[0] in graphNodeList and tempNodeList[1] in graphNodeList: createUndigraphAdjMatrix(tempNodeList[0], tempNodeList[1]) else: print("边信息输入有误,请重新输入") continue else: print("输入有误请重新输入") continue def acquireSideDig(): print("请输入图的所有边信息,格式为边的两个顶点,使用空格分隔。 eg:a b,如果全部信息输入完毕请输入end") while True: tempSide = input(">:") if tempSide.strip().lower() == "end": print("输入结束") break tempNodeList = tempSide.strip().split(" ") if len(tempNodeList) == 2: if tempNodeList[0] in graphNodeList and tempNodeList[1] in graphNodeList: createDigraphAdjMatrix(tempNodeList[0], tempNodeList[1]) else: print("边信息输入有误,请重新输入") continue else: print("输入有误请重新输入") continue def acquireSideDigWeight(): print("请输入图的所有边信息,格式为边的两个顶点和权重,使用空格分隔。 eg:a b weight,如果全部信息输入完毕请输入end") while True: tempSide = input(">:") if tempSide.strip().lower() == "end": print("输入结束") break tempNodeList = tempSide.strip().split(" ") if len(tempNodeList) == 3: if tempNodeList[0] in graphNodeList and tempNodeList[1] in graphNodeList: createDigraphAdjMatrixWeight(tempNodeList[0], tempNodeList[1], int(tempNodeList[2])) else: print("边信息输入有误,请重新输入") continue else: print("输入有误请重新输入") continue #初始化邻接矩阵;注意多维数组初始化格式以及数据的浅拷贝坑 def initGraphAdjMatrixWeight(nodeNum): for row in range(nodeNum): tempList = [] for column in range(nodeNum): if row == column: tempList.append(0) else: tempList.append(MAX) graphAdjMatrix.append(tempList) #根据输入顶点信息完成邻接表 def createUndigraphAdjMatrix(node0, node1): tempIndex1 = graphNodeList.index(node0) tempIndex2 = graphNodeList.index(node1) graphAdjMatrix[tempIndex1][tempIndex2] = 1 graphAdjMatrix[tempIndex2][tempIndex1] = 1 def createDigraphAdjMatrix(node0, node1): tempIndex1 = graphNodeList.index(node0) tempIndex2 = graphNodeList.index(node1) graphAdjMatrix[tempIndex1][tempIndex2] = 1 def createDigraphAdjMatrixWeight(node0, node1, weight): tempIndex1 = graphNodeList.index(node0) tempIndex2 = graphNodeList.index(node1) graphAdjMatrix[tempIndex1][tempIndex2] = weight def printAdjMatrix(nodeNum): for row in range(nodeNum): for column in range(nodeNum): print(graphAdjMatrix[row][column], end=" ") print("") def maindigAdjMatWeight(): acquireNode() initGraphAdjMatrixWeight(len(graphNodeList)) acquireSideDigWeight() """ def init_dist(nodenum=len(graphNodeList)): for row in range(nodenum): dist.append([]) for column in range(nodeum): dist[row].append(graphAdjMatrix[row][column]) """ def init_dist(nodenum): print(len(graphNodeList)) for row in range(nodenum): dist.append([]) for column in range(nodenum): dist[row].append(graphAdjMatrix[row][column]) def init_p(nodenum): for row in range(nodenum): p.append([]) for column in range(nodenum): if 0 < graphAdjMatrix[row][column] < MAX: p[row].append(row) else: p[row].append(-1) def mainfloyd(): init_dist(len(graphNodeList)) init_p(len(graphNodeList)) #遍历所有顶点,每个顶点都对dist进行全部遍历 for index, nodeinfo in enumerate(graphNodeList): for row in range(len(graphNodeList)): for column in range(len(graphNodeList)): #每一个点对dist中每一个值进行插点操作,符合条件的更新对应dist和p列表的值;循环完毕即求解到所有最短路径 if dist[row][column] > dist[row][index] + dist[index][column]: dist[row][column] = dist[row][index] + dist[index][column] p[row][column] = p[index][column] def printpath(): #遍历所有顶点组合,对符合条件的两个顶点打印最短路径值和最短路径经过的点 for row in range(len(graphNodeList)): for column in range(len(graphNodeList)): #判断两个顶点之间存在路径可达 if 0 < dist[row][column] < MAX: #输出两个点之间最短路径的权重 print("顶点%s到顶点%s之间最短路径长度为%d。" % (graphNodeList[row], graphNodeList[column], dist[row][column])) #输出最短路径经过的点 print("顶点%s到顶点%s之间最短路径经过的点为:" % (graphNodeList[row], graphNodeList[column]), end=" ") temppathlist = [] temppathlist.append(graphNodeList[row]) temppathlist.append(graphNodeList[column]) while p[row][column] != row: temppathlist.insert(1, p[row][column]) column = p[row][column] for pathnode in temppathlist: print(pathnode, end=" ") print() if __name__ == "__main__": maindigAdjMatWeight() mainfloyd() printpath()
浙公网安备 33010602011771号