在Python中实现Dijkstra算法

26

我正在尝试使用数组在Python中实现Dijkstra算法。这是我的实现。

def extract(Q, w):
    m=0
    minimum=w[0]
    for i in range(len(w)):
        if w[i]<minimum:
            minimum=w[i]
            m=i
    return m, Q[m]

def dijkstra(G, s, t='B'):
    Q=[s]
    p={s:None}
    w=[0]
    d={}
    for i in G:
        d[i]=float('inf')
        Q.append(i)
        w.append(d[i])
    d[s]=0
    S=[]
    n=len(Q)
    while Q:
        u=extract(Q,w)[1]
        S.append(u)
        #w.remove(extract(Q, d, w)[0])
        Q.remove(u)
        for v in G[u]:
            if d[v]>=d[u]+G[u][v]:
                d[v]=d[u]+G[u][v]
                p[v]=u
    return d, p

B='B'
A='A'
D='D'
G='G'
E='E'
C='C'
F='F'
G={B:{A:5, D:1, G:2}, A:{B:5, D:3, E:12, F:5}, D:{B:1, G:1, E:1, A:3}, G:{B:2, D:1, C:2}, C:{G:2, E:1, F:16}, E:{A:12, D:1, C:1, F:2}, F:{A:5, E:2, C:16}}
print "Assuming the start vertex to be B:"
print "Shortest distances", dijkstra(G, B)[0]
print "Parents", dijkstra(G, B)[1]

我期望你的答案是:

Assuming the start vertex to be B:
Shortest distances {'A': 4, 'C': 4, 'B': 0, 'E': 2, 'D': 1, 'G': 2, 'F': 4}
Parents {'A': 'D', 'C': 'G', 'B': None, 'E': 'D', 'D': 'B', 'G': 'D', 'F': 'E'}

然而,我得到的答案是:
Assuming the start vertex to be B:
Shortest distances {'A': 4, 'C': 4, 'B': 0, 'E': 2, 'D': 1, 'G': 2, 'F': 10}
Parents {'A': 'D', 'C': 'G', 'B': None, 'E': 'D', 'D': 'B', 'G': 'D', 'F': 'A'}.

对于节点 F,程序给出了错误的答案。请问有人能告诉我原因吗?


21
请使用有意义的变量名称,这将帮助人们更好地理解你的代码。 - shaktimaan
3
你的代码非常混乱:有两个不同名字的变量G,还有一个没用到的变量S等等。我猜你的代码只能找到不超过两条边的路径,因为你从未向队列中添加任何东西(在Dijkstra算法中应该这样做),但由于代码难以阅读,我无法确定。 - Natalya Ginzburg
13个回答

43

正如其他人所指出的那样,由于没有使用易懂的变量名,因此几乎不可能调试您的代码。

按照维基百科关于Dijkstra算法的文章,可以沿着这些线路(以及其他无数种方式)实现它:

nodes = ('A', 'B', 'C', 'D', 'E', 'F', 'G')
distances = {
    'B': {'A': 5, 'D': 1, 'G': 2},
    'A': {'B': 5, 'D': 3, 'E': 12, 'F' :5},
    'D': {'B': 1, 'G': 1, 'E': 1, 'A': 3},
    'G': {'B': 2, 'D': 1, 'C': 2},
    'C': {'G': 2, 'E': 1, 'F': 16},
    'E': {'A': 12, 'D': 1, 'C': 1, 'F': 2},
    'F': {'A': 5, 'E': 2, 'C': 16}}

unvisited = {node: None for node in nodes} #using None as +inf
visited = {}
current = 'B'
currentDistance = 0
unvisited[current] = currentDistance

while True:
    for neighbour, distance in distances[current].items():
        if neighbour not in unvisited: continue
        newDistance = currentDistance + distance
        if unvisited[neighbour] is None or unvisited[neighbour] > newDistance:
            unvisited[neighbour] = newDistance
    visited[current] = currentDistance
    del unvisited[current]
    if not unvisited: break
    candidates = [node for node in unvisited.items() if node[1]]
    current, currentDistance = sorted(candidates, key = lambda x: x[1])[0]

print(visited)

这段代码比必要的更冗长,我希望通过将您的代码与我的进行比较,您可以发现一些差异。

结果是:

{'E': 2, 'D': 1, 'G': 2, 'F': 4, 'A': 4, 'C': 3, 'B': 0}

2
为什么要使用None作为+inf,而不是使用float('inf')呢? - fferri
4
在这里使用sorted会大大降低时间复杂度,你需要使用一个优先队列。 - trincot
@Hyperboreus 如果路径长度相同,你会如何选择下一个节点? - BrockenDuck
2
如果您需要得到一个列表中的最小项,请不要使用排序(复杂度:O(N * log(N)))。请使用内置的min方法(复杂度:O(N))。在这里,推荐使用: min(candidates, key=lambda x: x[1]) 而不是 sorted(candidates, key = lambda x: x[1])[0] - MarAja
Dijkstra算法能否轻易地被修改以返回最短路径 - jbuddy_13
显示剩余2条评论

20

这个实现仅使用数组和堆数据结构。

import heapq as hq
import math

def dijkstra(G, s):
    n = len(G)
    visited = [False]*n
    weights = [math.inf]*n
    path = [None]*n
    queue = []
    weights[s] = 0
    hq.heappush(queue, (0, s))
    while len(queue) > 0:
        g, u = hq.heappop(queue)
        visited[u] = True
        for v, w in G[u]:
            if not visited[v]:
                f = g + w
                if f < weights[v]:
                    weights[v] = f
                    path[v] = u
                    hq.heappush(queue, (f, v))
    return path, weights

G = [[(1, 6), (3, 7)],
     [(2, 5), (3, 8), (4, -4)],
     [(1, -2), (4, 7)],
     [(2, -3), (4, 9)],
     [(0, 2)]]

print(dijkstra(G, 0))

我希望这能够帮助到某些人,虽然有点晚。


4
非常简洁的解决方案。但是如何将 Dijkstra 算法应用于有负权图? - Jianyu
1
我正在尝试理解这个实现。似乎由hq.heappush(queue, (f, v))产生的冗余副本(因为heappush不会删除具有更高权重的旧v)并不重要,仅仅因为当v再次弹出时,它的所有邻居都已经具有更小的权重,所以额外的副本浪费了一些时间但不会改变结果。这是正确的吗? - Max M
3
Jianyu Dijkstra算法无法处理负值。建议使用Bellman-Ford算法代替。 - Ben J.
gus是什么? - Nic Scozzaro
visited数组是必要的吗?根据优先队列的实现方式,似乎是必要的。我们需要一种方法来更新hq中已存在的条目,而不是为相同的顶点添加新的条目。 - E. Kaufman
显示剩余2条评论

12

我以更冗长的形式写下来,以便初学者读者更容易理解:

def get_parent(pos):
    return (pos + 1) // 2 - 1


def get_children(pos):
    right = (pos + 1) * 2
    left = right - 1
    return left, right


def swap(array, a, b):
    array[a], array[b] = array[b], array[a]


class Heap:

    def __init__(self):
        self._array = []

    def peek(self):
        return self._array[0] if self._array else None

    def _get_smallest_child(self, parent):
        return min([
            it
            for it in get_children(parent)
            if it < len(self._array)
        ], key=lambda it: self._array[it], default=-1)

    def _sift_down(self):
        parent = 0
        smallest = self._get_smallest_child(parent)
        while smallest != -1 and self._array[smallest] < self._array[parent]:
            swap(self._array, smallest, parent)
            parent, smallest = smallest, self._get_smallest_child(smallest)

    def pop(self):
        if not self._array:
            return None
        swap(self._array, 0, len(self._array) - 1)
        node = self._array.pop()
        self._sift_down()
        return node

    def _sift_up(self):
        index = len(self._array) - 1
        parent = get_parent(index)
        while parent != -1 and self._array[index] < self._array[parent]:
            swap(self._array, index, parent)
            index, parent = parent, get_parent(parent)

    def add(self, item):
        self._array.append(item)
        self._sift_up()

    def __bool__(self):
        return bool(self._array)


def backtrack(best_parents, start, end):
    if end not in best_parents:
        return None
    cursor = end
    path = [cursor]
    while cursor in best_parents:
        cursor = best_parents[cursor]
        path.append(cursor)
        if cursor == start:
            return list(reversed(path))
    return None


def dijkstra(weighted_graph, start, end):
    """
    Calculate the shortest path for a directed weighted graph.

    Node can be virtually any hashable datatype.

    :param start: starting node
    :param end: ending node
    :param weighted_graph: {"node1": {"node2": weight, ...}, ...}
    :return: ["START", ... nodes between ..., "END"] or None, if there is no
            path
    """
    distances = {i: float("inf") for i in weighted_graph}
    best_parents = {i: None for i in weighted_graph}

    to_visit = Heap()
    to_visit.add((0, start))
    distances[start] = 0

    visited = set()

    while to_visit:
        src_distance, source = to_visit.pop()
        if src_distance > distances[source]:
            continue
        if source == end:
            break
        visited.add(source)
        for target, distance in weighted_graph[source].items():
            if target in visited:
                continue
            new_dist = distances[source] + weighted_graph[source][target]
            if distances[target] > new_dist:
                distances[target] = new_dist
                best_parents[target] = source
                to_visit.add((new_dist, target))

    return backtrack(best_parents, start, end)

2
你可以用float('inf')替换Infinity()。 - michaelwayman
不需要“visited_nodes”。 - derek
2
Derek,在有向无环图中是正确的。在存在循环的图中,我们需要防止陷入循环。 - Archibald
或者我误解了你?你能提出修改建议吗? - Archibald
2
这并没有回答原问题。而且,为什么不使用 https://docs.python.org/3/library/heapq.html ? - Hugues

10

我需要一个可以返回路径的解决方案,因此我编写了一个简单的类,使用了多个问题/答案中有关Dijkstra的思想:

class Dijkstra:

    def __init__(self, vertices, graph):
        self.vertices = vertices  # ("A", "B", "C" ...)
        self.graph = graph  # {"A": {"B": 1}, "B": {"A": 3, "C": 5} ...}

    def find_route(self, start, end):
        unvisited = {n: float("inf") for n in self.vertices}
        unvisited[start] = 0  # set start vertex to 0
        visited = {}  # list of all visited nodes
        parents = {}  # predecessors
        while unvisited:
            min_vertex = min(unvisited, key=unvisited.get)  # get smallest distance
            for neighbour, _ in self.graph.get(min_vertex, {}).items():
                if neighbour in visited:
                    continue
                new_distance = unvisited[min_vertex] + self.graph[min_vertex].get(neighbour, float("inf"))
                if new_distance < unvisited[neighbour]:
                    unvisited[neighbour] = new_distance
                    parents[neighbour] = min_vertex
            visited[min_vertex] = unvisited[min_vertex]
            unvisited.pop(min_vertex)
            if min_vertex == end:
                break
        return parents, visited

    @staticmethod
    def generate_path(parents, start, end):
        path = [end]
        while True:
            key = parents[path[0]]
            path.insert(0, key)
            if key == start:
                break
        return path

示例图和用法(通过这个巧妙的工具生成绘制): 在此输入图片描述

input_vertices = ("A", "B", "C", "D", "E", "F", "G")
input_graph = {
    "A": {"B": 5, "D": 3, "E": 12, "F": 5},
    "B": {"A": 5, "D": 1, "G": 2},
    "C": {"E": 1, "F": 16, "G": 2},
    "D": {"A": 3, "B": 1, "E": 1, "G": 1},
    "E": {"A": 12, "C": 1, "D": 1, "F": 2},
    "F": {"A": 5, "C": 16, "E": 2},
    "G": {"B": 2, "C": 2, "D": 1}
}
start_vertex = "B"
end_vertex= "C"
dijkstra = Dijkstra(input_vertices, input_graph)
p, v = dijkstra.find_route(start_vertex, end_vertex)
print("Distance from %s to %s is: %.2f" % (start_vertex, end_vertex, v[end_vertex]))
se = dijkstra.generate_path(p, start_vertex, end_vertex)
print("Path from %s to %s is: %s" % (start_vertex, end_vertex, " -> ".join(se)))

输出

Distance from B to C is: 3.00
Path from B to C is: B -> D -> E -> C

这个图表很有用,但不幸的是在线工具选择的颜色和形状并不太好看。 - qwr
您可以编辑所有项或仅选定项的顶点样式(颜色、形状、文本位置、大小)和边缘样式(颜色、线条样式、宽度、文本)(请参见“设置”菜单)。 - Ionut Ticus

8
这不是我的答案——我的教授比我做得更有效率。这是他的方法,显然使用了辅助函数来处理重复的任务。
def dijkstra(graph, source):

    vertices, edges = graph
    dist = dict()
    previous = dict()

    for vertex in vertices:
        dist[vertex] = float("inf")
        previous[vertex] = None

    dist[source] = 0
    Q = set(vertices)

    while len(Q) > 0:
        u = minimum_distance(dist, Q)
        print('Currently considering', u, 'with a distance of', dist[u])
        Q.remove(u)

        if dist[u] == float('inf'):
            break

        n = get_neighbours(graph, u)
        for vertex in n:
            alt = dist[u] + dist_between(graph, u, vertex)
            if alt < dist[vertex]:
                dist[vertex] = alt
                previous[vertex] = u

    return previous

给定一个图

({'A', 'B', 'C', 'D'}, {('A', 'B', 5), ('B', 'A', 5), ('B', 'C', 10), ('B', 'D', 6), ('C', 'D', 2), ('D', 'C', 2)})

运行命令 print(dijkstra(graph, 'A') 可以得到以下结果:

当前考虑到 A,距离为 0

当前考虑到 B,距离为 5

当前考虑到 D,距离为 11

当前考虑到 C,距离为 13

即:

{'C': 'D', 'D': 'B', 'A': None, 'B': 'A'} => 随机顺序


谢谢分享这个,但是你在哪里定义目标呢?我看到它有源。 - grepit
5
函数 minimum_distance 未定义。 - Hugues
1
最难的部分是 minimum_distance,因为 Python 的 heapq 没有内置的降低优先级的方法。 - qwr

7

基于CLRS第2版24.3章实现。

d表示增量,p表示前驱节点。

import heapq

def dijkstra(g, s, t):

    q = []
    d = {k: sys.maxint for k in g.keys()}
    p = {}

    d[s] = 0 
    heapq.heappush(q, (0, s))

    while q:
        last_w, curr_v = heapq.heappop(q)
        for n, n_w in g[curr_v]:
            cand_w = last_w + n_w # equivalent to d[curr_v] + n_w 
            # print d # uncomment to see how deltas are updated
            if cand_w < d[n]:
                d[n] = cand_w
                p[n] = curr_v
                heapq.heappush(q, (cand_w, n))

    print "predecessors: ", p 
    print "delta: ", d 
    return d[t]

def test():

    og = {}
    og["s"] = [("t", 10), ("y", 5)]
    og["t"] = [("y", 2), ("x", 1)]
    og["y"] = [("t", 3), ("x", 9), ("z", 2)]
    og["z"] = [("x", 6), ("s", 7)]
    og["x"] = [("z", 4)]

    assert dijkstra(og, "s", "x") == 9 


if __name__ == "__main__":
    test()

实现假定所有节点都表示为键。例如,如果节点(例如上面的示例中的“x”)未在og中定义为键,则增量d将缺少该键,并且检查if cand_w < d[n]将无法正常工作。

1
这开始是Carlos' answer的小修改,但我最终改变了很多。
该类的构造函数应传递图形。这样可以多次调用get_shortest_path()方法,并在每次调用时使用不同的起始和结束参数,而无需每次都传递图形。
断言展示了如何在循环或非循环图中使用该类。
我个人不需要最短路径的权重,因此我不让get_shortest_path()方法返回它,但如果需要,可以将distances[node]添加到方法的返回语句中。
import heapq


def main():
    showcase_cyclic()
    showcase_acyclic()


def showcase_cyclic():
    cyclic_graph = {
        "a": { "w": 14, "x": 7, "y": 9 },
        "b": { "w": 9, "z": 6 },
        "w": { "a": 14, "b": 9, "y": 2 },
        "x": { "a": 7, "y": 10, "z": 15 },
        "y": { "a": 9, "w": 2, "x": 10, "z": 11 },
        "z": { "b": 6, "x": 15, "y": 11 }
    }

    dijkstra = Dijkstra(cyclic_graph)

    assert dijkstra.get_shortest_path("a", "a") == ['a']
    assert dijkstra.get_shortest_path("a", "b") == ['a', 'y', 'w', 'b']


def showcase_acyclic():
    acyclic_graph = {
        "a": { "b": 1, "c": 2 },
        "b": { "d": 3 },
        "c": { "d": 4 },
        "d": { "e": 5 },
        "e": {}
    }

    dijkstra = Dijkstra(acyclic_graph)

    assert dijkstra.get_shortest_path("a", "a") == ['a']
    assert dijkstra.get_shortest_path("a", "e") == ['a', 'b', 'd', 'e']


class Dijkstra:
    def __init__(self, graph):
        self.graph = graph


    def get_shortest_path(self, start, end):
        distances = { key: float("inf") for key in self.graph.keys() }
        distances[start] = 0

        unvisited = { key for key in self.graph.keys() }

        parents = {}

        node = start

        while node != end:
            unvisited.remove(node)

            for neighbor in self.graph[node].keys():
                if neighbor in unvisited:
                    tentative_distance = distances[node] + self.graph[node][neighbor]
                    recorded_distance = distances[neighbor]

                    if tentative_distance < recorded_distance:
                        distances[neighbor] = tentative_distance
                        parents[neighbor] = node

            node = self.get_closest_node(unvisited, distances)

        return self.get_path(end, parents)


    def get_closest_node(self, unvisited, distances):
        unvisited_min_heap = [(distances[node], node) for node in unvisited]

        # Great explanation of heapify: https://dev59.com/Vmcs5IYBdhLWcg3wjkwY#61446534
        heapq.heapify(unvisited_min_heap)

        _, closest_node = heapq.heappop(unvisited_min_heap)

        return closest_node


    def get_path(self, end, parents):
        path = []

        while True:
            path.append(end)

            if end not in parents:
                return list(reversed(path))

            end = parents[end]


if __name__ == "__main__":
    main()

1
使用heapq模块来实现优先队列,您可以这样做:
from collections import defaultdict
import heapq

graph = {
    'A': [('B', 2), ('C', 1)],
    'B': [('A', 2), ('C', 4), ('D', 3)],
    'C': [('A', 1), ('B', 4), ('E', 2)],
    'E': [('C', 2), ('D', 1), ('F', 4)],
    'D': [('B', 3), ('E', 1), ('F', 2)],
    'F': [('D', 2), ('E', 4)]

}


def dijkstra(graph, start: str):
    result_map = defaultdict(lambda: float('inf'))
    result_map[start] = 0

    visited = set()

    queue = [(0, start)]

    while queue:
        weight, v = heapq.heappop(queue)
        visited.add(v)

        for u, w in graph[v]:
            if u not in visited:
                result_map[u] = min(w + weight, result_map[u])
                heapq.heappush(queue, [w + weight, u])

    return result_map


print(dijkstra(graph, 'A'))

输出:

{'A': 0, 'B': 2, 'C': 1, 'E': 3, 'D': 4, 'F': 6}

1
我在我的博客rebrained.com上将维基百科的描述分解为以下伪代码:
初始状态:
1. 给节点两个属性 - node.visited和node.distance 2. 对于除了起始节点以外的所有节点,将node.distance设置为无穷大 3. 对于所有节点,将node.visited设置为false 4. 将当前节点设置为起始节点。
当前节点循环:
1. 如果当前节点=结束节点,则完成并返回current.distance和路径 2. 对于所有未访问的邻居,计算它们的暂定距离(current.distance + 边缘到邻居的距离)。 3. 如果暂定距离小于邻居的设定距离,则覆盖它。 4. 设置current.isvisited = true。 5. 将当前节点设置为剩余未访问节点中具有最小node.distance的节点。

http://rebrained.com/?p=392

import sys
def shortestpath(graph,start,end,visited=[],distances={},predecessors={}):
    """Find the shortest path btw start & end nodes in a graph"""
    # detect if first time through, set current distance to zero
    if not visited: distances[start]=0
    # if we've found our end node, find the path to it, and return
    if start==end:
        path=[]
        while end != None:
            path.append(end)
            end=predecessors.get(end,None)
        return distances[start], path[::-1]
    # process neighbors as per algorithm, keep track of predecessors
    for neighbor in graph[start]:
        if neighbor not in visited:
            neighbordist = distances.get(neighbor,sys.maxint)
            tentativedist = distances[start] + graph[start][neighbor]
            if tentativedist < neighbordist:
                distances[neighbor] = tentativedist
                predecessors[neighbor]=start
    # neighbors processed, now mark the current node as visited 
    visited.append(start)
    # finds the closest unvisited node to the start 
    unvisiteds = dict((k, distances.get(k,sys.maxint)) for k in graph if k not in visited)
    closestnode = min(unvisiteds, key=unvisiteds.get)
    # now take the closest node and recurse, making it current 
    return shortestpath(graph,closestnode,end,visited,distances,predecessors)
if __name__ == "__main__":
    graph = {'a': {'w': 14, 'x': 7, 'y': 9},
            'b': {'w': 9, 'z': 6},
            'w': {'a': 14, 'b': 9, 'y': 2},
            'x': {'a': 7, 'y': 10, 'z': 15},
            'y': {'a': 9, 'w': 2, 'x': 10, 'z': 11},
            'z': {'b': 6, 'x': 15, 'y': 11}}
    print shortestpath(graph,'a','a')
    print shortestpath(graph,'a','b')
    """
    Expected Result:
        (0, ['a']) 
        (20, ['a', 'y', 'w', 'b'])
        """

0
import sys
import heapq

class Node:

     def __init__(self, name):
        self.name = name
        self.visited = False
        self.adjacenciesList = []
        self.predecessor = None
        self.mindistance = sys.maxsize    

    def __lt__(self, other):
        return self.mindistance < other.mindistance

class Edge:

    def __init__(self, weight, startvertex, endvertex):
        self.weight = weight
        self.startvertex = startvertex
        self.endvertex = endvertex

def calculateshortestpath(vertexlist, startvertex):
    q = []
    startvertex.mindistance = 0
    heapq.heappush(q, startvertex)

    while q:
        actualnode = heapq.heappop(q)
        for edge in actualnode.adjacenciesList:
            tempdist = edge.startvertex.mindistance + edge.weight
            if tempdist < edge.endvertex.mindistance:
                edge.endvertex.mindistance = tempdist
                edge.endvertex.predecessor = edge.startvertex
                heapq.heappush(q,edge.endvertex)
def getshortestpath(targetvertex):
    print("The value of it's minimum distance is: ",targetvertex.mindistance)
    node = targetvertex
    while node:
        print(node.name)
        node = node.predecessor

node1 = Node("A");
node2 = Node("B");
node3 = Node("C");
node4 = Node("D");
node5 = Node("E");
node6 = Node("F");
node7 = Node("G");
node8 = Node("H");

edge1 = Edge(5,node1,node2);
edge2 = Edge(8,node1,node8);
edge3 = Edge(9,node1,node5);
edge4 = Edge(15,node2,node4);
edge5 = Edge(12,node2,node3);
edge6 = Edge(4,node2,node8);
edge7 = Edge(7,node8,node3);
edge8 = Edge(6,node8,node6);
edge9 = Edge(5,node5,node8);
edge10 = Edge(4,node5,node6);
edge11 = Edge(20,node5,node7);
edge12 = Edge(1,node6,node3);
edge13 = Edge(13,node6,node7);
edge14 = Edge(3,node3,node4);
edge15 = Edge(11,node3,node7);
edge16 = Edge(9,node4,node7);

node1.adjacenciesList.append(edge1);
node1.adjacenciesList.append(edge2);
node1.adjacenciesList.append(edge3);
node2.adjacenciesList.append(edge4);
node2.adjacenciesList.append(edge5);
node2.adjacenciesList.append(edge6);
node8.adjacenciesList.append(edge7);
node8.adjacenciesList.append(edge8);
node5.adjacenciesList.append(edge9);
node5.adjacenciesList.append(edge10);
node5.adjacenciesList.append(edge11);
node6.adjacenciesList.append(edge12);
node6.adjacenciesList.append(edge13);
node3.adjacenciesList.append(edge14);
node3.adjacenciesList.append(edge15);
node4.adjacenciesList.append(edge16);

vertexlist = (node1,node2,node3,node4,node5,node6,node7,node8)

calculateshortestpath(vertexlist,node1)
getshortestpath(node7)

1
仅提供代码的答案并不是良好的做法。请添加简要说明您的代码如何解决问题。(阅读有关在SO上提问的文档) - Yannis

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接