两个列表的最大路径和

8

我的问题是关于 Codewars 上这个 kata:this。该函数接受两个已排序且元素不重复的列表作为参数。这些列表可能有共同的项,也可能没有。任务是找到最大路径和。在寻找和的过程中,如果有任何相同的项,您可以选择将您的路径更改为另一个列表。

给定的示例如下:

list1 = [0, 2, 3, 7, 10, 12]
list2 = [1, 5, 7, 8]
0->2->3->7->10->12 => 34
0->2->3->7->8      => 20
1->5->7->8         => 21
1->5->7->10->12    => 35 (maximum path)

我完成了这道题目,但我的代码未能达到性能要求,因此我遇到了执行超时的问题。我该怎么办?
这是我的解决方案:
def max_sum_path(l1:list, l2:list):
    common_items = list(set(l1).intersection(l2))
    if not common_items:
        return max(sum(l1), sum(l2))
    common_items.sort()
    s = 0
    new_start1 = 0
    new_start2 = 0
    s1 = 0
    s2 = 0
    for item in common_items:
        s1 = sum(itertools.islice(l1, new_start1, l1.index(item)))
        s2 = sum(itertools.islice(l2, new_start2, l2.index(item)))
        new_start1 = l1.index(item)
        new_start2 = l2.index(item)
        s += max(s1, s2)
    s1 = sum(itertools.islice(l1, new_start1, len(l1)))
    s2 = sum(itertools.islice(l2, new_start2, len(l2)))
    s += max(s1, s2)
    return s

我会首先通过几次调用 time.time 来找出哪一部分占用了大部分时间。 - TheEagle
类似这样的:start = time.time(); #执行需要很长时间的任务; print(time.time() - start) - 你将会得到两次 time.time 调用之间的毫秒数。 - TheEagle
当列表长度为619和1352时,第二段代码的for循环花费了0.01697850227355957毫秒。当我尝试对大型列表使用它们时,我注意到第一个返回错误的总和。但第二个可以正常工作。但我仍然不知道如何减少长输入的时间。 - Ümit Kara
0 <= len(l1), len(l2) <= 70000。可能是因为输入非常大,但有些人已经完成了。 - Ümit Kara
我在讨论区找到了这个:你的解决方案的最大时间复杂度可以是O(len(l1) + len(l2))。不过我不知道如何获取你代码的时间复杂度...需要有其他人来做。 - TheEagle
显示剩余4条评论
5个回答

5

你的算法实际上很快,只是实现得比较慢。

导致它总共需要O(n²)的时间的两个因素是:

  • l1.index(item) 总是从列表的开头开始搜索。应该是 l1.index(item, new_start1)
  • itertools.islice(l1, new_start1, ...) 创建了一个l1的迭代器,并且在到达所需元素之前遍历了前面的new_start1个元素。所以应该使用普通的列表切片。

然后排序只需要O(n log n),而其他部分只需要O(n)。此外,排序的O(n log n)速度很快,可能甚至比允许的输入及更大的输入中的O(n)部分花费更少的时间。

以下是重写版本,接受时间约为6秒,与其他答案中的解决方案一样。

def max_sum_path(l1:list, l2:list):
    common_items = list(set(l1).intersection(l2))
    if not common_items:
        return max(sum(l1), sum(l2))
    common_items.sort()
    s = 0
    new_start1 = 0
    new_start2 = 0
    s1 = 0
    s2 = 0
    for item in common_items:
        next_start1 = l1.index(item, new_start1)  # changed
        next_start2 = l2.index(item, new_start2)  # changed
        s1 = sum(l1[new_start1 : next_start1])    # changed
        s2 = sum(l2[new_start2 : next_start2])    # changed
        new_start1 = next_start1                  # changed
        new_start2 = next_start2                  # changed
        s += max(s1, s2)
    s1 = sum(l1[new_start1:])                     # changed
    s2 = sum(l2[new_start2:])                     # changed
    s += max(s1, s2)
    return s

你可以使用迭代器而不是索引。下面是重新编写的解决方案,使用迭代器,大约在6秒内也被接受:

def max_sum_path(l1:list, l2:list):
    common_items = sorted(set(l1) & set(l2))
    s = 0
    it1 = iter(l1)
    it2 = iter(l2)
    for item in common_items:
        s1 = sum(iter(it1.__next__, item))
        s2 = sum(iter(it2.__next__, item))
        s += max(s1, s2) + item
    s1 = sum(it1)
    s2 = sum(it2)
    s += max(s1, s2)
    return s

我会将最后四行代码合并为一行,保持原样以便于比较。


4

这可以在单次遍历中以O(n)的时间复杂度和O(1)的空间复杂度完成。你只需要两个指针并行遍历两个数组并使用两个路径值。

你要增加指向较小元素的指针,并将其值添加到它的路径中。当你找到一个公共元素时,你会将其添加到两个路径中,然后将两个路径设置为最大值。

def max_sum_path(l1, l2):
    path1 = 0
    path2 = 0
    i = 0
    j = 0
    while i < len(l1) and j < len(l2):
        if l1[i] < l2[j]:
            path1 += l1[i]
            i += 1
        elif l2[j] < l1[i]:
            path2 += l2[j]
            j += 1
        else:
            # Same element in both paths
            path1 += l1[i]
            path2 += l1[i]
            path1 = max(path1, path2)
            path2 = path1
            i += 1
            j += 1
    while i < len(l1):
        path1 += l1[i]
        i += 1
    while j < len(l2):
        path2 += l2[j]
        j += 1
    return max(path1, path2)

4
该问题要求“达到线性时间复杂度”,这是一个非常明显的提示,说明像嵌套循环这样的东西行不通(在这里,索引是嵌套的O(n)循环,而当输入列表之间有许多重复值时,sort()是O(n log(n)))。 此答案展示了如何缓存重复的.index调用,并使用来自最后一块的起始偏移量将复杂度降低。
正如链接的答案所述,itertools.islice在这里不适用,因为它从列表的开头遍历。相反,使用本地切片。这个与上面对index的修改结合起来,总体上给出了线性对数复杂度,对于大多数输入是线性的。

为了提供背景信息,这是我的方法,与你的方法并没有太大区别,尽管我缓存索引并避免排序。

我首先将问题描述为一个有向无环图,想法是搜索最大路径和:

       +---> [0, 2, 3] ---+            +---> [10, 12]
[0] ---|                  |---> [7] ---|
       +---> [1, 5] ------+            +---> [8]

我们可以把每个节点的值加起来,以便更清楚地了解:
     +---> 5 ---+          +---> 22
0 ---|          |---> 7 ---|
     +---> 6 ---+          +---> 8

上面的图表显示,在唯一性约束条件下,贪心解决方案将是最优的。例如,从根节点开始,我们只能选择 5 或 6 的路径才能到达 7。其中较大的 6 是保证在最大权重路径中的一部分,因此我们选择它。
现在,问题只是如何实现这个逻辑。回到列表,以下是一个更加实质性的输入,包含格式和注释,以帮助激发一种方法:
[1, 2, 4, 7, 8,    10,         14, 15    ]
[      4,    8, 9,     11, 12,     15, 90]
       ^     ^                      ^
       |     |                      |

这说明了链接索引如何排列。我们的目标是在链接之间迭代每个块,取两个子列表和中的较大值:
[1, 2, 4, 7, 8,    10,         14, 15    ]
[      4,    8, 9,     11, 12,     15, 90]
 ^~~^     ^     ^~~~~~~~~~~~~~~~^      ^^
  0       1             2               3  <-- chunk number

以上输入的预期结果应为 3 + 4 + 7 + 8 + 32 + 15 + 90 = 159,取所有链接值加上顶部列表子列表和块0和1以及底部列表的块2和3的总和。
以下是一个相当冗长但希望易于理解的实现; 您可以访问线程以查看更优雅的解决方案:
def max_sum_path(a, b):
    b_idxes = {k: i for i, k in enumerate(b)}
    link_to_a = {}
    link_to_b = {}
    
    for i, e in enumerate(a):
        if e in b_idxes:
            link_to_a[e] = i
            link_to_b[e] = b_idxes[e]
    
    total = 0
    start_a = 0
    start_b = 0
    
    for link in link_to_a: # dicts assumed sorted, Python 3.6+
        end_a = link_to_a[link]
        end_b = link_to_b[link]
        total += max(sum(a[start_a:end_a]), sum(b[start_b:end_b])) + link
        start_a = end_a + 1
        start_b = end_b + 1
        
    return total + max(sum(a[start_a:]), sum(b[start_b:]))

1
当然,复杂度和Codewars解决方案实际测试的内容是有区别的。再次强调,如果看起来很复杂但CW测试套件并没有真正将代码推向像我提到的那样的边缘情况,那么我们就算自己很幸运。但如果超时并且没有任何信息,最安全的做法是假设这些情况正在被测试,并尝试确保它确实是线性的。 再次查看OP的代码,它并不像我想象的那么糟糕(例如“sums / slices”是可以接受的),因此我会修改我的帖子。谢谢。 - ggorlen
1
他们的复杂性之所以“不是问题”,是因为CW测试的设计方式——实际上是运气好,他们并没有真正推动代码。我可以轻松地设计出大型测试,其中包含高密度的重复项,以确保真正的线性解决方案。 - ggorlen
1
好的,我似乎无法引起任何性能问题,即使输入大量重复项以尝试获得最坏情况。基本上,当我看到超时时,我没有烦恼OP的代码,但事实证明缓存.index(),使用偏移量.index()和使用本地切片基本上提供了合理的解决方案。赞扬你没有把孩子和洗澡水一起扔掉 - 我高估了排序的影响。答案已更新以涵盖所有这些内容。 - ggorlen
1
我想知道它们都是不同的数字是否会使排序运行更快。此外,这是本地代码,所以不应该是一个大问题。有时线性对数法可能会成为真正的瓶颈,但显然不是在这里。 - ggorlen
1
我正在试图以分支预测和Timsort的实现细节的角度来思考——其中一件事是,如果你在提交时使用print(common_items),你会看到它经常是排序好的,或者接近排序好的。一般来说,当接近排序好时(例如插入排序),排序是线性的。来自Wikipedia的描述是:“...它被设计成在切片接近排序好或由两个或多个排序序列连接在一起的情况下非常快。”因此,我认为我关于避免排序仍然具有理论上的优势。 - ggorlen
显示剩余10条评论

3

一旦您知道两个列表共享的项目,就可以分别迭代每个列表以总结共享项目之间的项目,从而构建部分总和列表。这些列表对于输入列表都具有相同的长度,因为共享项目的数量是相同的。

然后,可以通过在共享值之间的每个伸展区间中取两个列表的最大值来找到最大路径和:

def max_sum_path(l1, l2):
    shared_items = set(l1) & set(l2)
    
    def partial_sums(lst):
        result = []
        partial_sum = 0
        for item in lst:
            partial_sum += item
            if item in shared_items:
                result.append(partial_sum)
                partial_sum = 0
        result.append(partial_sum)
        return result
            
    return sum(map(max, partial_sums(l1), 
                        partial_sums(l2)))

时间复杂度:我们只需一次迭代处理每个列表(针对部分和较短的列表的迭代在这里是无关紧要的),因此此代码的时间复杂度与输入列表的长度成线性关系。然而,正如您和Kelly Bundy所指出的, 您自己的算法实际上也具有相同的时间复杂度,除了排序共同项部分,但该部分似乎对于给定的测试用例并不太相关。

因此,总的结论是,如果你的目标仅是使你的代码足够快以通过某些测试用例,那么更好的方法是对实际实现进行分析以找到时间瓶颈,而不必担心理论最坏情况。


1
也可以返回 sum(map(max, partial_sums1, partial_sums2)) - Kelly Bundy
1
我认为排序不是问题,因为它只需执行一次,而且其隐藏的常数相当快。使用或不使用sorted(shared_items),该过程在该网站上大约需要5.8秒(点击ATTEMPT按钮)。 - Kelly Bundy
2
我的解决问题的方法是这样的: 找到共同的项目。如果没有,则返回l1,l2的最大和。 如果有任何一个,则一直进行并求和,直到它们。 在共同的项目中选择最大值并添加到总和中。 完成后选择部分总和的最大值。 我认为问题在于“sum(itertools.islice(l1, new_start1, l1.index(item)))”。 相反,我应该通过线性迭代并为每个列表添加来完成它。 - Ümit Kara
感谢@KellyBundy指出了OP代码中的真正问题。我已经点赞并编辑了自己的回答以提及您的观点,并按照您的建议更改了返回语句。 - Arne
感谢@ÜmitKara提供的额外信息。这帮助我编辑了我的答案。 - Arne
显示剩余2条评论

3

基准测试

Discourse选项卡上,您可以单击“显示Kata测试用例”(一旦您解决了Kata),以查看它们的测试用例生成器。我使用这个来基准测试迄今为止发布的解决方案以及我自己的解决方案。进行了数十轮测试,因为测试用例非常随机,导致运行时间波动很大。在每一轮中,所有生成的测试用例都会提供给所有的解决方案(因此,在每一轮中,所有的解决方案都得到了相同的测试用例)。

Codewars的测试用例

同时还有Kelly Bundy对于排序公共值集合的最坏情况:

输入图像说明

代码如下。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接