python | 算法-网络延迟时间-dijikstra算法应用

写在前面:
我自己用python练习算法与数据结构的典型算法汇总在这里:汇总-算法与数据结构-python版,欢迎翻阅!

1️⃣ 参考链接https://github.com/algorithmzuo/algorithmbasic2020/blob/master/src/class16/Code06_NetworkDelayTime.java

2️⃣ 所用例子

代码详情

# leetcode 743题, 用这道题来练习Dijikstra算法
# 参考: https://github.com/algorithmzuo/algorithmbasic2020/blob/master/src/class16/Code06_NetworkDelayTime.java
# python中优先级队列的实现参考:https://geek-docs.com/python/python-examples/python-priority-queue.html
from queue import PriorityQueue

class NetworkDelayTime:
    # 方法1:普通堆+屏蔽已经计算过的点
    def net_delay_time1(self, times, n, k):
        # 记录每个节点指向的节点和对应的延迟
        nexts = {}
        for record in times:
            nexts[record[0]] = []
        for record in times:
            nexts[record[0]].append([record[1], record[2]])
        # 采用优先级队列来充当小顶堆的角色
        # 有关queue.PriorityQueue的知识,
        # 参考官方文档:https://docs.python.org/zh-cn/3.7/library/queue.html?highlight=priorityqueue#queue.PriorityQueue
        heap = PriorityQueue()
        heap.put((0, [k, 0]))
        # used 记录已经被计算过的节点
        used = []
        # result 记录所有最短距离的最大值,即题目所求
        result = 0
        while not heap.empty() and len(used) < n:
            item = heap.get()
            cur = item[1][0]
            delay = item[1][1]
            if cur in used: continue
            used.append(cur)
            result = max(delay, result)
            if cur in nexts.keys():
                for next in nexts[cur]:
                    new_delay = delay + next[1]
                    heap.put((new_delay, [next[0], new_delay]))
        return -1 if len(used) < n else result

    # 方法2:加强堆的解法
    def net_delay_time2(self, times, n, k):
        nexts = {}
        for i in times:
            nexts[i[0]] = []
        for i in times:
            nexts[i[0]].append([i[1], i[2]])

        heap = Heap()
        heap.add(k, 0)
        num = 0
        max_delay = 0
        while not heap.empty():
            out = heap.pop()
            node = out[0]
            delay = out[1]
            num += 1
            max_delay = max(max_delay, delay)
            if node in nexts.keys():
                for next_record in nexts[node]:
                    next_node = next_record[0]
                    next_delay = next_record[1]
                    heap.add(next_node, delay + next_delay)

        return -1 if num < n else max_delay

class Heap:
        def __init__(self):
            self.heap = []
            self.index = {}
            self.used = []
            self.size = 0

        def empty(self):
            return self.size == 0

        def add(self, node, delay):
            if node in self.used: return
            if node not in self.index.keys():
                self.index[node] = self.size
                self.heap.append([node, delay])
                self.heap_insert(self.size)
                self.size += 1

        def heap_insert(self, index):
            parent = int((index - 1) / 2)
            while self.heap[index][1] < self.heap[parent][1]:
                self.swap(index, parent)
                index = parent
                parent = int((index - 1) / 2)

        def swap(self, index1, index2):
            record1 = self.heap[index1]
            record2 = self.heap[index2]
            self.heap[index1] = record2
            self.heap[index2] = record1
            self.index[record1[0]] = index2
            self.index[record2[0]] = index1

        def pop(self):
            out = self.heap[0]
            self.size -= 1
            self.swap(0, self.size)
            self.heap.pop()
            self.heapify(0)
            return out

        def heapify(self, index):
            left = index * 2 + 1
            while left < self.size:
                right = left + 1
                smallest = right if right < self.size and self.heap[right][1] < self.heap[left][1] \
                    else left
                smallest = left if self.heap[left][1] < self.heap[index][1] else index
                if index == smallest: break
                self.swap(index, smallest)
                index = smallest
                left = index * 2 + 1


# 测试
times1 = [[2, 1, 1], [2, 3, 1], [3, 4, 1]]
n1, k1 = 4, 2
times2 = [[1, 2, 1]]
n2, k2 = 2, 1
times3 = [[1, 2, 1]]
n3, k3 = 2, 2
solution = NetworkDelayTime()
# 测试方法1
result1 = solution.net_delay_time1(times1, n1, k1)
result2 = solution.net_delay_time1(times2, n2, k2)
result3 = solution.net_delay_time1(times3, n3, k3)
print(result1==2 and result2==1 and result3==-1)
# True
# 测试方法2
result1 = solution.net_delay_time2(times1, n1, k1)
result2 = solution.net_delay_time2(times2, n2, k2)
result3 = solution.net_delay_time2(times3, n3, k3)
print(result1==2 and result2==1 and result3==-1)
# True
posted @ 2022-10-29 16:55  万国码aaa  阅读(121)  评论(0编辑  收藏  举报