大数情况下快速计算组合数n选k模p的方法?

49
我所说的“大n”是指数百万级别。p是质数。
我尝试过http://apps.topcoder.com/wiki/display/tc/SRM+467,但这个函数似乎不正确(我用144选6模5测试它时,它给了我0,而应该给我2)。
我尝试过http://online-judge.uva.es/board/viewtopic.php?f=22&t=42690,但我没有完全理解它。
我还编写了一个记忆化递归函数,使用逻辑(combinations(n-1, k-1, p)%p + combinations(n-1, k, p)%p),但由于n很大,它会导致堆栈溢出问题。
我尝试过Lucas定理,但它似乎要么很慢,要么不准确。
我所尝试做的就是创建一个快速/准确的用于大n的n选k mod p。如果有人能帮我展示一个好的实现,我将非常感激。谢谢。
如请求所示,以下是会因n太大而导致堆栈溢出的记忆化版本:
std::map<std::pair<long long, long long>, long long> memo;

long long combinations(long long n, long long k, long long p){
   if (n  < k) return 0;
   if (0 == n) return 0;
   if (0 == k) return 1;
   if (n == k) return 1;
   if (1 == k) return n;

   map<std::pair<long long, long long>, long long>::iterator it;

   if((it = memo.find(std::make_pair(n, k))) != memo.end()) {
        return it->second;
   }
   else
   {
        long long value = (combinations(n-1, k-1,p)%p + combinations(n-1, k,p)%p)%p;
        memo.insert(std::make_pair(std::make_pair(n, k), value));
        return value;
   }  
}

2
你需要知道确切的余数还是只需知道该数字是否可以被p整除?(n个中选k个模p等于0) - vidstige
组合函数返回什么(为什么需要三个参数)? - Ivaylo Strandjev
1
combinations函数需要三个参数,因为它正在查找(n choose k) mod p。 - John Smith
所以你需要计算组合数(n, k)%p吗? - Ivaylo Strandjev
TopCoder上的解决方案适用于p > n。 - Sameer
显示剩余3条评论
3个回答

61

那么,这就是如何解决您的问题。

当然,您知道公式:

comb(n,k) = n!/(k!*(n-k)!) = (n*(n-1)*...(n-k+1))/k! 

(请见http://zh.wikipedia.org/wiki/二项式系数#计算二项式系数的值)

你知道如何计算分子:

long long res = 1;
for (long long i = n; i > n- k; --i) {
  res = (res * i) % p;
}

现在,由于p是质数,与p互质的每个整数的倒数都是明确定义的,即可以找到a-1。可以使用费马小定理ap-1=1(mod p) => a*ap-2=1(mod p)来完成这个过程,因此a-1=ap-2。现在,您需要做的就是实现快速幂(例如使用二进制方法):

long long degree(long long a, long long k, long long p) {
  long long res = 1;
  long long cur = a;

  while (k) {
    if (k % 2) {
      res = (res * cur) % p;
    }
    k /= 2;
    cur = (cur * cur) % p;
  }
  return res;
}

现在,您可以将分母加到我们的结果中:

long long res = 1;
for (long long i = 1; i <= k; ++i) {
  res = (res * degree(i, p- 2)) % p;
}

请注意我在所有地方都使用long long以避免类型溢出。当然,您不需要执行k次指数运算-您可以计算k!(mod p),然后仅进行一次除法:

long long denom = 1;
for (long long i = 1; i <= k; ++i) {
  denom = (denom * i) % p;
}
res = (res * degree(denom, p- 2)) % p;

编辑:根据@dbaupp的评论,如果k≥p,则k!模p将等于0,并且(k!)^-1将不被定义。为了避免这种情况,请首先计算n * (n-1) ... (n-k + 1)和k!中p的次数并比较它们的值:

int get_degree(long long n, long long p) { // returns the degree with which p is in n!
  int degree_num = 0;
  long long u = p;
  long long temp = n;

  while (u <= temp) {
    degree_num += temp / u;
    u *= p;
  }
  return degree_num;
}

long long combinations(int n, int k, long long p) {
  int num_degree = get_degree(n, p) - get_degree(n - k, p);
  int den_degree = get_degree(k, p);

  if (num_degree > den_degree) {
    return 0;
  }
  long long res = 1;
  for (long long i = n; i > n - k; --i) {
    long long ti = i;
    while(ti % p == 0) {
      ti /= p;
    }
    res = (res * ti) % p;
  }
  for (long long i = 1; i <= k; ++i) {
    long long ti = i;
    while(ti % p == 0) {
      ti /= p;
    }
    res = (res * degree(ti, p-2, p)) % p;
  }
  return res;
}

编辑:上述解决方案还可以添加一项优化——我们可以计算k!(mod p)而不是计算k!中每个倍数的倒数,然后再计算该数字的倒数。因此,我们只需要支付对指数进行对数运算的代价一次即可。当然,我们必须丢弃每个倍数的p约数。我们只需使用以下代码更改最后的循环:

long long denom = 1;
for (long long i = 1; i <= k; ++i) {
  long long ti = i;
  while(ti % p == 0) {
    ti /= p;
  }
  denom = (denom * ti) % p;
}
res = (res * degree(denom, p-2, p)) % p;

你只是计算 n*(n-1)*...*(n-k+1) * (k!)^-1 吗?只有当 k < p 时,这才有定义,否则 k! == 0,不存在逆元。 - huon
我觉得“计算p的度并进行约分”的部分并不是微不足道的。至少,要高效地完成并不容易。 - huon
这似乎与我在第一个链接中展示的实现类似(如何使144选择6 mod 5无效等)。 - John Smith
我已经更新了我的帖子,请再次阅读。对于错误我感到抱歉。 - Ivaylo Strandjev
@IvayloStrandjev 我刚刚用这个解决了一个组合问题。最后的优化实际上很有用,解决了TLE问题。:) 非常感谢! - Kshitij Banerjee
显示剩余4条评论

15
对于大的k,我们可以通过利用两个基本事实来显著减少工作量:
  1. 如果p是一个质数,则n!的质因数分解中p的指数由(n - s_p(n)) / (p-1)给出,其中s_p(n)是在以p为底的表示下n的各位数字之和(所以对于p = 2,它是popcount)。因此,在choose(n,k)的质因数分解中p的指数是(s_p(k) + s_p(n-k) - s_p(n)) / (p-1),特别地,当且仅当在以p为底时加法k + (n-k)没有进位时,它为零(指数是进位的数量)。

  2. 威尔逊定理:如果p是一个质数,则(p-1)! ≡ (-1) (mod p)

n!的质因数分解中p的指数通常通过以下方式计算:

long long factorial_exponent(long long n, long long p)
{
    long long ex = 0;
    do
    {
        n /= p;
        ex += n;
    }while(n > 0);
    return ex;
}

检查choose(n,k)是否可以被p整除并不是必须的,但最好先这样做,因为这通常是情况,这样做会更省事:

long long choose_mod(long long n, long long k, long long p)
{
    // We deal with the trivial cases first
    if (k < 0 || n < k) return 0;
    if (k == 0 || k == n) return 1;
    // Now check whether choose(n,k) is divisible by p
    if (factorial_exponent(n) > factorial_exponent(k) + factorial_exponent(n-k)) return 0;
    // If it's not divisible, do the generic work
    return choose_mod_one(n,k,p);
}

现在让我们更仔细地看一下n!。我们将≤ n的数字分为p的倍数和与p互质的数字。有:
n = q*p + r, 0 ≤ r < p

p的倍数贡献p^q * q!。与p互质的数字为(j*p + k), 1 ≤ k < p,其中0 ≤ j < q,以及(q*p + k), 1 ≤ k ≤ r的乘积。

对于与p互质的数字,我们只关心它们在模p下的贡献。每个完整的序列j*p + k, 1 ≤ k < p在模p下都等于(p-1)!,因此它们总共产生了一个贡献(-1)^q在模p下。最后(可能)不完整的序列在模p下等于r!

因此,如果我们写成:

n   = a*p + A
k   = b*p + B
n-k = c*p + C

我们获得了。
choose(n,k) = p^a * a!/ (p^b * b! * p^c * c!) * cop(a,A) / (cop(b,B) * cop(c,C))

其中cop(m,r)是所有与p互质且≤ m*p + r的数字的乘积。

有两种可能性,a = b + cA = B + C,或者a = b + c + 1A = B + C - p

在我们的计算中,我们事先消除了第二种可能性,但这并非必需。

在第一种情况下,p的显式幂会被消除,我们剩下的是

choose(n,k) = a! / (b! * c!) * cop(a,A) / (cop(b,B) * cop(c,C))
            = choose(a,b) * cop(a,A) / (cop(b,B) * cop(c,C))

任何能够整除 choose(n,k)p 的幂都来自于 choose(a,b),在我们的情况下,由于我们已经在之前排除了这些情况,所以不会有这种情况出现。虽然 cop(a,A) / (cop(b,B) * cop(c,C)) 不一定是整数(例如考虑 choose(19,9) (mod 5)),但是当考虑模 p 的表达式时,cop(m,r) 可以简化为 (-1)^m * r!,因此,由于 a = b + c(-1) 相互抵消,我们只剩下
choose(n,k) ≡ choose(a,b) * choose(A,B) (mod p)

在第二种情况中,我们发现。
choose(n,k) = choose(a,b) * p * cop(a,A)/ (cop(b,B) * cop(c,C))

因为 a = b + c + 1。最后一位产生的进位表示 A < B,所以对于模数 p

p * cop(a,A) / (cop(b,B) * cop(c,C)) ≡ 0 = choose(A,B)

(我们可以用模反元素替换除法,或者将其视为有理数的同余关系,这意味着分子可以被p整除)。无论如何,我们再次发现

choose(n,k) ≡ choose(a,b) * choose(A,B) (mod p)

现在我们可以对choose(a,b)部分进行递归。 示例:
choose(144,6) (mod 5)
144 = 28 * 5 + 4
  6 =  1 * 5 + 1
choose(144,6) ≡ choose(28,1) * choose(4,1) (mod 5)
              ≡ choose(3,1) * choose(4,1) (mod 5)
              ≡ 3 * 4 = 12 ≡ 2 (mod 5)

choose(12349,789) ≡ choose(2469,157) * choose(4,4)
                  ≡ choose(493,31) * choose(4,2) * choose(4,4
                  ≡ choose(98,6) * choose(3,1) * choose(4,2) * choose(4,4)
                  ≡ choose(19,1) * choose(3,1) * choose(3,1) * choose(4,2) * choose(4,4)
                  ≡ 4 * 3 * 3 * 1 * 1 = 36 ≡ 1 (mod 5)

现在是实施阶段:
// Preconditions: 0 <= k <= n; p > 1 prime
long long choose_mod_one(long long n, long long k, long long p)
{
    // For small k, no recursion is necessary
    if (k < p) return choose_mod_two(n,k,p);
    long long q_n, r_n, q_k, r_k, choose;
    q_n = n / p;
    r_n = n % p;
    q_k = k / p;
    r_k = k % p;
    choose = choose_mod_two(r_n, r_k, p);
    // If the exponent of p in choose(n,k) isn't determined to be 0
    // before the calculation gets serious, short-cut here:
    /* if (choose == 0) return 0; */
    choose *= choose_mod_one(q_n, q_k, p);
    return choose % p;
}

// Preconditions: 0 <= k <= min(n,p-1); p > 1 prime
long long choose_mod_two(long long n, long long k, long long p)
{
    // reduce n modulo p
    n %= p;
    // Trivial checks
    if (n < k) return 0;
    if (k == 0 || k == n) return 1;
    // Now 0 < k < n, save a bit of work if k > n/2
    if (k > n/2) k = n-k;
    // calculate numerator and denominator modulo p
    long long num = n, den = 1;
    for(n = n-1; k > 1; --n, --k)
    {
        num = (num * n) % p;
        den = (den * k) % p;
    }
    // Invert denominator modulo p
    den = invert_mod(den,p);
    return (num * den) % p;
}

要计算模反元素,你可以使用费马(所谓的小)定理:
如果 p 是质数且 a 不被 p 整除,则 a^(p-1) ≡ 1 (mod p)。
然后计算 a^(p-2) (mod p)即可得到其逆元,或者使用适用于更广范围参数的方法,扩展欧几里得算法或连分数拓展,这些方法可以为任意互质(正)整数对提供模反元素。
long long invert_mod(long long k, long long m)
{
    if (m == 0) return (k == 1 || k == -1) ? k : 0;
    if (m < 0) m = -m;
    k %= m;
    if (k < 0) k += m;
    int neg = 1;
    long long p1 = 1, p2 = 0, k1 = k, m1 = m, q, r, temp;
    while(k1 > 0) {
        q = m1 / k1;
        r = m1 % k1;
        temp = q*p1 + p2;
        p2 = p1;
        p1 = temp;
        m1 = k1;
        k1 = r;
        neg = !neg;
    }
    return neg ? m - p2 : p2;
}

像计算 a^(p-2) (mod p) 一样,这是一个 O(log p) 算法,对于某些输入来说它会快得多(实际上是 O(min(log k, log p)),所以对于小的 k 和大的 p,速度会更快),但对于其他情况则会慢一些。

总体而言,我们需要计算最多 O(log_p k) 个二项式系数模 p,其中每个二项式系数最多需要 O(p) 次操作,从而得到总复杂度为 O(p*log_p k) 的操作次数。 当 k 明显大于 p 时,这比 O(k) 的解决方案要好得多。对于 k <= p,它会带来一些额外开销,但仍然可以归结为 O(k) 的解决方案。


你能否发布一份算法摘要?对我来说,跟随这些步骤有点困难。 - nhahtdh
你能给我一点提示,你在哪里遇到了困难吗?如果我不必完全猜测可能对于那些无法读懂我的思维的人来说哪些部分可能有问题,那么这将更容易做到。 - Daniel Fischer
看起来你正在第一部分通过Lucas定理的结果运行一个循环(伪装成递归函数),并在第二部分使用乘法逆元计算nCk mod p?(这正是我正在寻找的)。当p较小时,Lucas定理将处理好这种情况。 - nhahtdh
是的,就是这样(当我写下这个关系时,我不知道有人费心制作了一个定理,因此没有提到Lucas大师;现在我知道了,我应该添加一个参考文献)。 - Daniel Fischer

0
如果你需要计算多次,有一种更快的方法。我将发布Python代码,因为它可能是最容易转换成另一种语言的,尽管我会在最后放置C++代码。

只需计算一次

暴力破解:

def choose(n, k, m):
    ans = 1
    for i in range(k): ans *= (n-i)
    for i in range(k): ans //= i
    return ans % m

但是计算可能会涉及到非常大的数字,因此我们可以使用模算术技巧:

(a * b) mod m = (a mod m) * (b mod m) mod m

(a / (b*c)) mod m = (a mod m) / ((b mod m) * (c mod m) mod m)

(a / b) mod m = (a mod m) * (b mod m)^-1

请注意最后一个方程式末尾的^-1。这是bm的乘法逆元。它基本上意味着((b mod m) * (b mod m)^-1) mod m = 1,就像(非零)整数的a * a^-1 = a * 1/a = 1一样。

这可以通过几种方式计算,其中之一是扩展欧几里得算法:

def multinv(n, m):
    ''' Multiplicative inverse of n mod m '''
    if m == 1: return 0
    m0, y, x = m, 0, 1

    while n > 1:
        y, x = x - n//m*y, y
        m, n = n%m, m
    
    return x+m0 if x < 0 else x

请注意,另一种方法——指数运算——仅在m为质数时有效。如果是质数,您可以这样做:
def powmod(b, e, m):
    ''' b^e mod m '''
    # Note: If you use python, there's a built-in pow(b, e, m) that's probably faster
    # But that's not in C++, so you can convert this instead:
    P = 1
    while e:
        if  e&1: P = P * b % m
        e >>= 1; b = b * b % m
    return P

def multinv(n, m):
    ''' Multiplicative inverse of n mod m, only if m is prime '''
    return powmod(n, m-2, m)
    

但请注意,扩展欧几里得算法往往仍然运行更快,即使它们在技术上具有相同的时间复杂度O(log m),因为它具有较低的常数因子。

现在是完整代码:

def multinv(n, m):
    ''' Multiplicative inverse of n mod m in log(m) '''
    if m == 1: return 0
    m0, y, x = m, 0, 1

    while n > 1:
        y, x = x - n//m*y, y
        m, n = n%m, m
    
    return x+m0 if x < 0 else x


def choose(n, k, m):
    num = den = 1
    for i in range(k): num = num * (n-i) % m
    for i in range(k): den = den * i % m
    return num * multinv(den, m)

多次查询

我们可以分别计算分子和分母,然后再将它们组合起来。但是请注意,我们为分子计算的乘积是 n * (n-1) * (n-2) * (n-3) ... * (n-k+1)。如果您曾经学过一些叫做“前缀和”的东西,这太相似了。那么让我们应用它。

预先计算 fact[i] = i! mod m,其中 i 的最大值为 n,可能为 1e7(一千万)。然后,分子为 (fact[n] * fact[n-k]^-1) mod m,分母为 fact[k]。因此,我们可以计算出 choose(n, k, m) = fact[n] * multinv(fact[n-k], m) % m * multinv(fact[k], m) % m

Python 代码:

MAXN = 1000 # Increase if necessary
MOD = 10**9+7 # A common mod that's used, change if necessary

fact = [1]
for i in range(1, MAXN+1):
    fact.append(fact[-1] * i % MOD)

def multinv(n, m):
    ''' Multiplicative inverse of n mod m in log(m) '''
    if m == 1: return 0
    m0, y, x = m, 0, 1

    while n > 1:
        y, x = x - n//m*y, y
        m, n = n%m, m
    
    return x+m0 if x < 0 else x


def choose(n, k, m):
    return fact[n] * multinv(fact[n-k] * fact[k] % m, m) % m

C++ 代码:

#include <iostream>
using namespace std;

const int MAXN = 1000; // Increase if necessary
const int MOD = 1e9+7; // A common mod that's used, change if necessary

int fact[MAXN+1];

int multinv(int n, int m) {
    /* Multiplicative inverse of n mod m in log(m) */
    if (m == 1) return 0;
    int m0 = m, y = 0, x = 1, t;

    while (n > 1) {
        t = y;
        y = x - n/m*y;
        x = t;
        
        t = m;
        m = n%m;
        n = t;
    }
    
    return x<0 ? x+m0 : x;
}

int choose(int n, int k, int m) {
    return (long long) fact[n]
         * multinv((long long) fact[n-k] * fact[k] % m, m) % m;
}

int main() {
    fact[0] = 1;
    for (int i = 1; i <= MAXN; i++) {
        fact[i] = (long long) fact[i-1] * i % MOD;
    }

    cout << choose(4, 2, MOD) << '\n';
    cout << choose(1e6, 1e3, MOD) << '\n';
}

请注意,我将转换为 long long 以避免溢出。

1
谢谢!我觉得这很有用。但是在最新的Python版本中,调用multinv()函数时缺少了最后一个“m”参数。 - Christofer Ohlsson
添加C++代码对于不懂Python的人来说是很好的。 - murage kibicho
被踩了。大部分 Python 代码根本不起作用。例如,for i in range(k): ans //= i 立即产生 ZeroDivisionError。我不确定 C++ 代码是否更好。 - jcsahnwaldt Reinstate Monica

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接