查找两个节点之间路径数量的更快算法

3

我正在尝试在Python在线评测系统中回答一个问题,但是我的程序超出了时间限制和内存限制。这个问题要求计算从起始节点到结束节点的所有路径数量。完整的问题说明可以在这里查看。

以下是我的代码:

import sys
lines = sys.stdin.read().strip().split('\n')
n = int(lines[0])
dict1 = {}

for i in xrange(1, n+1):
    dict1[i] = []

for i in xrange(1, len(lines) - 1):
    numbers = map(int, lines[i].split())
    num1 = numbers[0]
    num2 = numbers[1]
    dict1[num2].append(num1)

def pathfinder(start, graph, count):
    new = []
    if start == []:
        return count
    for i in start:
        numList = graph[i]
        for j in numList:
            if j == 1:
                count += 1
            else:
                new.append(j)

    return pathfinder(new, graph, count)   

print pathfinder([n], dict1, 0)

代码的功能是从最后一个节点开始,通过探索所有相邻节点向上工作到顶部。我基本上制作了一种广度优先搜索算法,但它占用太多空间和时间。我该如何改进这个代码使其更有效率?我的方法是否错误,应该如何修复?
3个回答

3

由于该图是无环的,我们可以立即看到拓扑排序是 1, 2, ..., n。因此,我们可以使用动态规划来解决 最长路径问题。在列表 paths 中,元素 paths[i] 存储从 1i 有多少条路径。更新很简单——对于每条边 (i,j),其中 i 是来自我们的拓扑顺序,我们执行 paths[j] += path[i]

from collections import defaultdict

graph = defaultdict(list)
n = int(input())
while True:
    tokens = input().split()
    a, b = int(tokens[0]), int(tokens[1])
    if a == 0:
        break
    graph[a].append(b)

paths = [0] * (n+1)
paths[1] = 1
for i in range(1, n+1):
    for j in graph[i]:
        paths[j] += paths[i]
print(paths[n])

请注意,您正在实现的实际上不是BFS算法,因为您没有标记已访问的顶点,导致您的起始位置start会变得不成比例。
测试该图。
for i in range(1, n+1):
    dict1[i] = list(range(i-1, 0, -1))

如果你打印出start的大小,你会发现对于给定的n,它所得到的最大值恰好随着二项式系数(n, floor(n/2))增长,即约为4^n/sqrt(n)。还要注意的是,BFS并不是你想要的方法,因为它无法计算路径的数量。

代码可以运行,但我想确认一下我的理解是否正确。对于每个节点,计数等于它被父节点访问的次数,以及被其所有父节点访问的次数,一直到1?这需要O(顶点+边)的时间来完成吗? - Bob Marshall
@BobMarshall 确切地说! - sve

1
import sys
from collections import defaultdict

def build_matrix(filename, x):
    # A[i] stores number of paths from node x to node i.

    # O(n) to build parents_of_node
    parents_of_node = defaultdict(list)
    with open(filename) as infile:
        num_nodes = int(infile.readline())
        A = [0] * (num_nodes + 1)  # A[0] is dummy variable. Not used.
        for line in infile:
            if line == "0 0":
                break

            u, v = map(int, line.strip().split())
            parents_of_node[v].append(u)

            # Initialize all direct descendants of x to 1
            if u == x:
                A[v] = 1

    # Number of paths from x to i = sum(number of paths from x to parent of i)
    for i in xrange(1, num_nodes + 1):  # O(n)
        A[i] += sum(A[p] for p in parents_of_node[i])  # O(max fan-in of graph), assuming O(1) for accessing dict.

    # Total time complexity to build A is O(n * (max_fan-in of graph))
    return A


def main():
    filename = sys.argv[1]

    x = 1  # Find number of paths from x
    y = 4  # to y

    A = build_matrix(filename, x)
    print(A[y])

0

SO倾向于提供完整的答案,这些答案可以独立于外部资源进行评估。 - Brian Cain

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接