高效计算组合和排列

40

我有一些代码用于计算排列和组合,现在我尝试让它可以处理更大的数字。

我已经找到了一个更好的排列算法,可以避免大量中间结果,但对于组合我仍然认为还有改进空间。

目前,我已经使用特殊情况来反映nCr的对称性,但我仍然希望找到一个更好的算法来避免调用factorial(r),因为这会产生一个不必要的大中间结果。如果没有此优化,最后一次测试需要花费很长时间才能计算出factorial(99000)。

是否有人可以建议更有效的方法来计算组合?

from math import factorial

def product(iterable):
    prod = 1
    for n in iterable:
        prod *= n
    return prod

def npr(n, r):
    """
    Calculate the number of ordered permutations of r items taken from a
    population of size n.

    >>> npr(3, 2)
    6
    >>> npr(100, 20)
    1303995018204712451095685346159820800000
    """
    assert 0 <= r <= n
    return product(range(n - r + 1, n + 1))

def ncr(n, r):
    """
    Calculate the number of unordered combinations of r items taken from a
    population of size n.

    >>> ncr(3, 2)
    3
    >>> ncr(100, 20)
    535983370403809682970
    >>> ncr(100000, 1000) == ncr(100000, 99000)
    True
    """
    assert 0 <= r <= n
    if r > n // 2:
        r = n - r
    return npr(n, r) // factorial(r)

这个问题很久以前就被问过了,但是无论如何...我设计了一个算法,可以计算C(n, m) = n! / (m! (n-m)!),只要结果适合于整数(这可以很容易地成为长整数)。我用Java写的,但是将其翻译成Python或任何其他过程性语言应该很容易:https://dev59.com/T1UL5IYBdhLWcg3wfn_S#50439854(查找`combinations(int n, int m)`) - Walter Tross
14个回答

27

如果 n 与 r 不相差太远,则使用组合的递归定义可能更好,因为 xC0 == 1,所以您只需要进行几次迭代:

适用于此处的递归定义是:

nCr = (n-1)C(r-1) * n/r

可以使用以下列表和尾递归很好地计算:

[(n - r, 0), (n - r + 1, 1), (n - r + 2, 2), ..., (n - 1, r - 1), (n, r)]

在Python中可以轻松生成它(我们省略第一个条目,因为 nC0 = 1):izip(xrange(n - r + 1, n+1), xrange(1, r+1))请注意,这假定 r <= n,您需要检查它们并在必要时交换。此外,为了优化,如果 r < n/2,则设置 r = n - r。

现在,我们只需要使用 reduce 和尾递归应用递归步骤即可。由于 nC0 为 1,所以我们从 1 开始,然后将当前值与列表中的下一个条目相乘如下:

from itertools import izip

reduce(lambda x, y: x * y[0] / y[1], izip(xrange(n - r + 1, n+1), xrange(1, r+1)), 1)

1
对于单个nCr,这种方法更好,但是当你有多个nCr(数量级为N)时,动态规划方法更好,即使它需要长时间的设置,因为它不会在没有必要时溢出到“大数”中。 - JPvdMerwe

20

两个相当简单的建议:

  1. 为了避免溢出,一切都要在对数空间中进行。利用对数函数的性质:log(a * b) = log(a) + log(b), 和 log(a / b) = log(a) - log(b),这使得我们可以轻松处理非常大的阶乘:log(n! / m!) = log(n!) - log(m!)等。

  2. 使用伽玛函数代替阶乘。在 scipy.stats.loggamma中可以找到一个伽玛函数。它是计算对数阶乘比直接求和更高效的方法。loggamma(n) == log(factorial(n - 1)),同样地,gamma(n) == factorial(n - 1)


做事情时使用对数空间是个好建议。不过,“为了精度”你所说的意思我不太确定。使用对数浮点数会导致大数的舍入误差,这不是吗? - Christian Oudard
1
请注意,由于浮点数的精度有限,因此此操作无法得到完全准确的结果。 - starblue
计算小n并不困难。关键是要准确地计算大n,我已经使用了任意精度,因为我正在使用Python长整型。 - Christian Oudard
你如何使用gamma和loggamma函数?两者都不会返回整数,而是返回一个scipy.stats._distn_infrastructure.rv_frozen对象。 - PTTHomps
math.gamma和math.lgamma会产生整数结果。对于scipy.stats函数的作用仍不清楚。 - PTTHomps
显示剩余2条评论

10

在 scipy 中有一个函数可以实现这个功能,但是还没有被提及: scipy.special.comb。根据一些快速的时间测试结果看起来非常高效(comb(100000, 1000, 1) == comb(100000, 99000, 1) 大约需要 0.004 秒)。

[虽然这个特定的问题似乎是关于算法的,但是问题is there a math ncr function in python 被标记为重复...]


10

在Python 3.7之前:

def prod(items, start=1):
    for item in items:
        start *= item
    return start


def perm(n, k):
    if not 0 <= k <= n:
        raise ValueError(
            'Values must be non-negative and n >= k in perm(n, k)')
    else:
        return prod(range(n - k + 1, n + 1))


def comb(n, k):
    if not 0 <= k <= n:
        raise ValueError(
            'Values must be non-negative and n >= k in comb(n, k)')
    else:
        k = k if k < n - k else n - k
        return prod(range(n - k + 1, n + 1)) // math.factorial(k)

Python 3.8+:


有趣的是,一些手动实现的组合函数可能比 math.comb() 更快:

def math_comb(n, k):
    return math.comb(n, k)


def comb_perm(n, k):
    k = k if k < n - k else n - k
    return math.perm(n, k) // math.factorial(k)


def comb(n, k):
    k = k if k < n - k else n - k
    return prod(range(n - k + 1, n + 1)) // math.factorial(k)


def comb_other(n, k):
    k = k if k > n - k else n - k
    return prod(range(n - k + 1, n + 1)) // math.factorial(k)


def comb_reduce(n, k):
    k = k if k < n - k else n - k
    return functools.reduce(
        lambda x, y: x * y[0] // y[1],
        zip(range(n - k + 1, n + 1), range(1, k + 1)),
        1)


def comb_iter(n, k):
    k = k if k < n - k else n - k
    result = 1
    for i in range(1, k + 1):
        result = result * (n - i + 1) // i
    return result


def comb_iterdiv(n, k):
    k = k if k < n - k else n - k
    result = divider = 1
    for i in range(1, k + 1):
        result *= (n - i + 1)
        divider *= i
    return result // divider


def comb_fact(n, k):
    k = k if k < n - k else n - k
    return math.factorial(n) // math.factorial(n - k) // math.factorial(k)

bm

经过基准测试表明,使用math.perm()math.factorial()实现的comb_perm()比大多数情况下使用math.comb()更快,本测试计算固定n=256并逐渐增加k(直到k = n // 2)的计算时间。

请注意,comb_reduce()非常缓慢,本质上与@wich's answer的方法相同,而comb_iter()也相对较慢,本质上与@ZXX's answer的方法相同。

部分分析在此处(由于Python版本的Colab -- 3.7 -- 不支持comb_math()comb_perm(),因此未包含在内)。


我持不同意见。无论在哪个宇宙中,一个乘法循环随着迭代次数的增加是不可能变得更快的 :-) 它的复杂度必须像 O(1/ln(n))这样才能比理论上的量子计算机更快 :-))))))你有检查过结果吗?使用大整数数值库了吗?你的prod(...)会在瞬间穿透MAX_INT => 环绕。可能会产生波动曲线。 (n-k+1)/ k保持int,不会失去任何位。当它到达穿孔点时,它会翻转为浮点数。也许Python的[*]在您穿透MAX_INT后会更快? :-) 因为它停止了? :-) - ZXX
2
@ZXX 或许图表内容不够清晰,抱歉。无论如何,我很确定我从未写过或暗示过你所说的内容。如果你指的是 comb_other() 随着更大的输入变得更快,那是因为 kn - k 被交换以显示昂贵计算发生的位置。你可以很容易地自行检查所有这些函数是否达到相同的数值,远远超过 int64 结果阈值(而且 Python 有内置的大整数支持,我认为可以安全地假设 math.comb() 给出了正确的结果)。 - norok2
Python 3.11 优化了 math.comb,现在可能会更快。 - Kelly Bundy
好主意,我会在有时间的时候添加Py3.11的基准测试。 - norok2
最好还是在基准测试中包含更大的案例。原帖中提到他们有大量数据,并且使用了ncr(100000, 1000)的doctest。而且即使是math.comb(100000, 50000)现在也只需要大约0.2秒。顺便说一下,我不记得math.comb是针对大数优化还是小数优化。无论如何,我怀疑它在大数上表现更好。 - Kelly Bundy

9
如果您不需要纯Python的解决方案,gmpy2 可能会有所帮助(gmpy2.comb非常快)。

1
谢谢你提供的参考,这是一个非常好的实际解决方案。但对于我来说,这更像是一个学习项目,所以我更关心算法而不是实际结果。 - Christian Oudard
3
对于那些在回答发布数年后查看此答案的人,gmpy现在被称为gmpy2。 - Bill Bell

7
更高效的nCr解决方案 - 在空间和精度方面都更好。
中间值 (res) 保证始终为 int 类型,且不大于结果。空间复杂度为 O(1) (没有列表、没有压缩、没有栈),时间复杂度为 O(r) - 恰好需要 r 次乘除运算。
def ncr(n, r):
    r = min(r, n-r)
    if r == 0: return 1
    res = 1
    for k in range(1,r+1):
        res = res*(n-k+1)/k
    return res

6
from scipy import misc
misc.comb(n, k)

应该允许您计算组合


5
如果你正在计算N选K(这就是我认为你使用ncr在做的事情),那么有一种动态规划的解决方案可能会更快。这将避免阶乘,而且如果需要可以保留表格以备后用。
以下是教学链接:

http://www.csc.liv.ac.uk/~ped/teachadmin/algor/dyprog.html

抱歉,我不确定如何更好地解决你的第一个问题。

编辑:这是模拟。有一些非常滑稽的错位错误,因此它肯定需要进行更多的清理。

import sys
n = int(sys.argv[1])+2#100
k = int(sys.argv[2])+1#20
table = [[0]*(n+2)]*(n+2)

for i in range(1,n):
    table[i][i] = 1
for i in range(1,n):
    for j in range(1,n-i):
        x = i+j
        if j == 1: table[x][j] = 1
        else: table[x][j] = table[x-1][j-1] + table[x-1][j]

print table[n][k]

1
据我所见,这个实现似乎是O(n^2),而我提出的尾递归则是O(n)。 - wich
似乎使用了不同的递归定义。这里 n 选 k = n-1 选 k-1 + n-1 选 k,而我使用了 n 选 k = n-1 选 k-1 * n/k。 - wich
确实,情况就是这样。我很快会编辑这篇文章,包括一个Python的快速模拟算法。你的算法明显更快。我会把我的帖子留在这里,以防Gorgapor有一些奇特的机器,需要几个小时才能完成乘法。>.> - agorenst
1
这可能是O(N^2),但它预先计算了所有nCr组合对的组合,因此如果您将经常使用许多不同值的nCr,则这将更快,因为查找是O(1),并且不容易溢出。对于一个值,O(N)算法更好。 - JPvdMerwe

4
如果您的问题不需要知道排列或组合的精确数量,那么您可以使用斯特林公式来近似阶乘。
这将导致类似于以下代码:
import math

def stirling(n):
    # http://en.wikipedia.org/wiki/Stirling%27s_approximation
    return math.sqrt(2*math.pi*n)*(n/math.e)**n

def npr(n,r):
    return (stirling(n)/stirling(n-r) if n>20 else
            math.factorial(n)/math.factorial(n-r))

def ncr(n,r):    
    return (stirling(n)/stirling(r)/stirling(n-r) if n>20 else
            math.factorial(n)/math.factorial(r)/math.factorial(n-r))

print(npr(3,2))
# 6
print(npr(100,20))
# 1.30426670868e+39
print(ncr(3,2))
# 3
print(ncr(100,20))
# 5.38333246453e+20

阶乘的主要问题在于结果的大小,而不是计算时间。此外,这里的结果值比浮点数能够准确表示的值要大得多。 - Christian Oudard

0
您可以输入两个整数,并导入math库来找到它们的阶乘,然后应用nCr公式。
import math
n,r=[int(_)for _ in raw_input().split()]
f=math.factorial
print f(n)/f(r)/f(n-r)

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接