高效计算笛卡尔积中总和超过特定数字的集合数量

3

我有以下有效的Python 3代码:

import itertools

loops = 10
results = [4, 2.75, 2.75, 1.5, 1.5, 1.5, 0]
threshold = loops * 2
cartesian_product = itertools.product(results, repeat=loops)

good, bad = 0, 0

for e in cartesian_product:
    if (sum(e) >= threshold):
        good += 1
    else:
        bad += 1

print('Ratio of good vs total is {0:.3f}%'.format(100 * good / (good + bad)))

如果我将循环增加到一个较大的数字(>15),程序需要太长时间才能完成。

有没有一种更有效率的方法/算法来计算比率?


在迭代中有很多重复的列表。是否已经考虑到了? - bzimor
是的,可能会有重复的列表。 - Panayotis
你可以使用列表推导式来获取一个求和列表,将其转换为numpy数组,使用numpy where来获取大于阈值的索引数组,最后使用len()来获取高于/低于阈值的求和数量。(在手机上打字...) - user7345804
我的电脑运行你的代码太慢了,因为有cartesian_product = itertools...; 这篇帖子中的答案似乎会有所帮助。 - user7345804
跨站发布:https://cs.stackexchange.com/q/77321/755,https://dev59.com/EaLia4cB1Zd3GeqPrfgo。请不要在多个网站上发布相同的问题。每个社区都应该有一个诚实的机会回答,而不浪费任何人的时间。 - D.W.
1个回答

4
这里有一个解决方案。思路是通过n次循环计算得出可以获得的所有值的可能总和,计算不同可能总和的数量,并将大于阈值的总和相加。
接下来,我们可以通过将我们的值添加到之前的总和中生成n+1次循环的所有可能总和。我们希望不同可能总和的数量不会增长太多,因为我们多次添加相同的值,并将所有大于阈值的总和进行重组。
from collections import Counter

def all_sums(values, threshold, previous_sums = None):
    """
    values must be sorted
    previous_sums is a Counter of previously obtained possible sums

    Returns a Counter of all possible sums of values and the previous sums
    """
    if not previous_sums:
        previous_sums = Counter({0:1})

    new = Counter()
    for existing_sum, ex_sum_count in sorted(previous_sums.items()):
        for index, val in enumerate(values):
            total = existing_sum + val
            if total < threshold:
                # With the current value, we have found ex_sum_count
                # ways to obtain that total
                new.update({total: ex_sum_count})
            else:
                # We don't need the exact sum, as anything we could
                # later add to it will be over the threshold.
                # We count them under the value = threshold
                # As 'values' is sorted, all subsequent values will also give 
                # a sum over the threshold
                values_left = len(values) - index
                new.update({threshold: values_left * ex_sum_count})
                break
    return new


def count_sums(values, threshold, repeat):
    """
    values must be sorted!

    Recursively calculates the possible sums of 'repeat' values,
    counting together all values over 'threshold'
    """
    if repeat == 1:
        return all_sums(values, threshold, previous_sums=None)
    else:
        return all_sums(values, threshold, previous_sums=count_sums(values, threshold, repeat=repeat-1))

让我们以您的例子来尝试一下:

loops = 10
results = [4, 2.75, 2.75, 1.5, 1.5, 1.5, 0]
threshold = loops * 2

values = sorted(results)

sums = count_sums(values, threshold, repeat=loops)
print(sums)
# Counter({20: 137401794, 19.75: 16737840, 18.25: 14016240, 18.5: 13034520, 19.5: 12904920,
# 17.0: 12349260, 15.75: 8573040, 17.25: 8048160, 15.5: 6509160, 16.75: 6395760, 14.25: 5171040,
# 18.0: 5037480, 14.5: 4461480, 16: 3739980, 18.75: 3283020, 19.25: 3220800, 13.0: 3061800, 
# 14.0: 2069550, 12.75: 1927800, 15.25: 1708560, 13.25: 1574640, 17.5: 1391670, 11.5: 1326780,
# 11.75: 1224720, 14.75: 1182660, 16.5: 1109640, 10.25: 612360, 17.75: 569520, 11.25: 453600, 
# 16.25: 444060, 12.5: 400680, 10.0: 374220, 12: 295365, 13.75: 265104, 10.5: 262440, 19.0: 229950,
# 13.5: 204390, 8.75: 204120, 15.0: 192609, 9.0: 153090, 8.5: 68040, 9.75: 65520, 7.5: 61236, 
# 7.25: 45360, 11.0: 44940, 12.25: 21840, 6.0: 17010, 7.0: 7560, 5.75: 6480, 8.25: 5280, 4.5: 3240,
# 9.5: 2520, 10.75: 720, 4.25: 540, 5.5: 450, 3.0: 405, 6.75: 180, 8: 45, 1.5: 30, 2.75: 20, 4: 10, 0: 1})
number_of_sums = len(results) ** loops
# 282475249
good = sums[threshold]
# 137401794
bad = number_of_sums - good
# 145073455

我测试了一下,在我的相对陈旧的机器上需要大约9毫秒。另外还有一些数据:10个不同的值,20次循环:
loops = 20
results = [4, 2.75, 2.45, 1.5, 1.3, 0.73, 0.12, 1.4, 1.5, 0]
threshold = loops * 2
values = sorted(results)

sums = count_sums(values, threshold, repeat=loops)
number_of_sums = len(results) ** loops
good = sums[threshold]
bad = number_of_sums - good
print(good)
print(bad)
# 5440943363190360728
# 94559056636809639272

which I obtain in less than 12 seconds.


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接