找出三个数相加小于给定阈值的总数量

4

我正在解决一些练习问题,但降低复杂度的过程中遇到了麻烦。给定一个不同整数数组a[]和一个阈值T,需要找到三个数i,j,k,满足条件:a[i] < a[j] < a[k]a[i] + a[j] + a[k] <= T。我已经用下面这个Python脚本将复杂度从O(n^3)降至O(n^2 log n),现在想知道能否进一步优化。

import sys
import bisect

first_line = sys.stdin.readline().strip().split(' ')
num_numbers = int(first_line[0])
threshold = int(first_line[1])
count = 0

if num_numbers < 3:
    print count
else:
    numbers = sys.stdin.readline().strip().split(' ')
    numbers = map(int, numbers)
    numbers.sort()

    for i in xrange(num_numbers - 2):
        for j in xrange(i+1, num_numbers - 1):
            k_1 = threshold - (numbers[i] + numbers[j])
            if k_1 < numbers[j]:
                break
            else:
                cross_thresh = bisect.bisect(numbers,k_1) - (j+1)
                if cross_thresh > 0:
                    count += cross_thresh

    print count

在上面的例子中,第一行输入仅提供数字数量和阈值。下一行是完整的列表。如果列表少于3个,则不存在三元组,因此我们返回0。否则,我们读取整个整数列表,对其进行排序,然后按如下方式处理它们:我们遍历每个元素和(使i < j),并计算不会破坏的最高值。然后我们找到第一个违反此条件的列表中的元素的索引,并将所有介于j和s之间的元素添加到计数中。对于包含30,000个元素的列表,这需要大约7分钟才能运行。有没有办法让它更快?

1
不必每次在整个范围上执行二分查找,因为你知道对于每个 j > lastJ 的迭代,你只需要在范围 [0, last_cross_thresh] 中进行搜索。虽然我相当确定这不会在实践中改善渐近复杂度或运行时间。 - j_random_hacker
1
此外,您可以通过在列表中查找两个最小元素(在O(n)时间内)来预先过滤列表,然后消除任何使得两个最小元素的和与该元素的和超过T的元素(也是一个O(n)扫描)。即使您的整体算法是O(n^3)或其他复杂度,这将为您提供一个较小的工作集...预先对列表进行排序(O(n lg n))可能会使其更简单... - twalberg
或者,您可以通过预计算一个大小为(T + 1)的数组来将log(n)因子转换为加法T项(使其成为伪多项式算法),该数组记录每个0 <= i <= T的元素数。这也需要O(T)空间。 - j_random_hacker
@twalberg的想法稍微加强一下就可以轻松地实现:你只需要在if cross_thresh > 0:的末尾添加一个else <break out of innermost loop>子句即可。 - j_random_hacker
2个回答

3
你正在为每个(i,j)对执行二分查找,以找到相应的k值。因此时间复杂度为O(n^2 log(n))。
我可以建议一个算法,其最坏情况时间复杂度为O(n^2)。
假设列表从左到右排序,并且元素从1到n编号。然后伪代码如下:
for i = 1 to n - 2:
    j = i + 1
    find maximal k with binary search
    while j < k:
        j = j + 1
        find maximal k with linear search to the left, starting from last k position

这个算法的最坏情况时间复杂度为O(n^2)而不是O(n^3)的原因是因为位置k单调递减。因此,即使进行线性扫描,您也不会为每个(i,j)对花费O(n)的时间。相反,您将花费O(n)的总时间来扫描每个不同的i值以查找k


3

O(n^2)版本使用Python实现(基于wookie919的答案):

def triplets(N, T):
    N = sorted(N)
    result = 0

    for i in xrange(len(N)-2):
        k = len(N)-1
        for j in xrange(i+1, len(N)-1):
            while k>=0 and N[i]+N[j]+N[k]>T:
                k-=1
            result += max(k, j)-j

    return result

import random
sample = random.sample(xrange(1000000), 30000)
print triplets(sample, 500000)

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接