在C++中实现大数模乘

8
我想计算三个整数A,B(小于10^12)和C(小于10^15)的(A * B) % C。 我知道:
(A * B) % C = ((A % C) * (B % C)) % C

但是如果A = B = 10^11,那么上面的表达式将导致整数溢出。对于上述情况是否有简单的解决方案,还是我必须使用快速乘法算法。

如果我必须使用快速乘法算法,那么我应该使用哪个算法。

编辑:我在C ++中尝试了以上问题(没有引起溢出,不确定为什么),但答案不应该是吗?

先行致谢。


只有当C足够大时,RHS才会溢出(这就是余数的奇妙之处)。 - user529758
C++中的算术溢出通常是默默发生的 - 没有错误提示,它们只是发生了。当您期望输出为0时,实际上却看到了712049423024128,此时您才会意识到该问题。 - Bernhard Barker
如果你想要快速的东西,恐怕它必须是特定于平台的。你对哪些平台感兴趣? - Marc Glisse
5个回答

17
你可以使用Schrage's 方法来解决此问题。这使您能够在不生成大于某个中间数的情况下,将两个有符号数字az与特定模数m相乘。
它基于模数m的近似因式分解。
m = aq + r 
即。
q = [m / a]

r = m mod a

其中[]表示整数部分。如果r < q0 < z < m − 1,则a(z mod q)r[z / q]都处于0,...,m − 1的范围内。

az mod m = a(z mod q) − r[z / q]

如果这是负数,则加上m

[这种技术经常在线性同余随机数生成器中使用]。


4
此外,您可以将此作为递归算法使用,以包括所有 r >= q。针对乘积 r*[z/q] 重复上述算法,因此新值变为: a2 = r z2 = [r/q] 这保证了 a 的值会减小: a > r = a2 -> a2 < a 最终,当 a <= sqrt(m) 时,则有: a*a <= m -> q >= a > r -> r<q 然后算法停止。 - CrepeGoat

6

考虑到您的公式以及以下变化:

(A + B) mod C = ((A mod C) + (B mod C)) mod C 

您可以使用分治法来开发一种既简单又快速的算法:
#include <iostream>

long bigMod(long  a, long  b, long c) {
    if (a == 0 || b == 0) {
        return 0;
    }
    if (a == 1) {
        return b;
    }
    if (b == 1) {
        return a;
    } 

    // Returns: (a * b/2) mod c
    long a2 = bigMod(a, b / 2, c);

    // Even factor
    if ((b & 1) == 0) {
        // [((a * b/2) mod c) + ((a * b/2) mod c)] mod c
        return (a2 + a2) % c;
    } else {
        // Odd exponent
        // [(a mod c) + ((a * b/2) mod c) + ((a * b/2) mod c)] mod c
        return ((a % c) + (a2 + a2)) % c;
    }
}

int main() { 
    // Use the min(a, b) as the second parameter
    // This prints: 27
    std::cout << bigMod(64545, 58971, 144) << std::endl;
    return 0;
}

这是一个 O(log N) 的算法。


这是进行指数运算,但问题是要进行乘法。你可以在代码中将乘法更改为加法,然后它应该能够正常工作。 - cyon
你也需要使用if (b==0) return 0; - cyon
是的!你说得对,感谢你注意到了(并且没有将答案投票降低尽管它应该这样做)。我已经适当地更新了。 - higuaro
+1 我认为这个算法版本比被接受的答案更易读(虽然更长),就我所知,它们做的是同样的事情。 - cyon
它的功能是正确的,但算法非常慢 :( - user1209304
"a2 + a2" 可能会绝对溢出,而 "(a % c) + (a2 + a2)" 更容易发生溢出。 - BrainStone

4

更新:修正了a % c的高位被设置时出现的错误。(致谢:Kevin Hopps)

如果你注重简单而非快速,那么可以使用以下方法:

typedef unsigned long long u64;

u64 multiplyModulo(u64 a, u64 b, u64 c)
{
    u64 result = 0;
    a %= c;
    b %= c;
    while(b) {
        if(b & 0x1) {
            result += a;
            result %= c;
        }
        b >>= 1;
        if(a < c - a) {
            a <<= 1;
        } else {
            a -= (c - a);
        }
    }
    return result;
}

当“a”设置了高位时,这将产生错误的结果。请参见我下面的帖子。 - Kevin Hopps
第一个加法也可能会溢出,但代码的其余部分是正确的(下面关于左移的评论者在当前版本的代码中是不正确的):如果(result < c - a) result = result + a; 否则 result -= (c - a); - Gregory Morse

1
抱歉,当变量"a"持有一个高位设置的值时,godel9算法将产生一个不正确的结果。这是因为"a <<= 1"会丢失信息。以下是一个已纠正的算法,适用于任何整数类型,有符号或无符号。

template <typename IntType>
IntType add(IntType a, IntType b, IntType c)
    {
    assert(c > 0 && 0 <= a && a < c && 0 <= b && b < c);
    IntType room = (c - 1) - a;
    if (b <= room)
        a += b;
    else
        a = b - room - 1;
    return a;
    }

template <typename IntType>
IntType mod(IntType a, IntType c)
    {
    assert(c > 0);
    IntType q = a / c; // q may be negative
    a -= q * c; // now -c < a && a < c
    if (a < 0)
        a += c;
    return a;
    }

template <typename IntType>
IntType multiplyModulo(IntType a, IntType b, IntType c)
    {
    IntType result = 0;
    a = mod(a, c);
    b = mod(b, c);
    if (b > a)
        std::swap(a, b);
    while (b)
        {
        if (b & 0x1)
            result = add(result, a, c);
        a = add(a, a, c);
        b >>= 1;
        }
    return result;
    }


有什么理由要有加法函数“mod”?为什么不使用运算符“%”? - Hsilgos
@Hsilgos,使用mod而不是operator%的原因是确保结果为非负数。 - Kevin Hopps

0
在这种情况下,A和B是40位数字,C是50位数字,在64位模式下使用64位乘法产生128位结果(实际上是80位),之后将128位被除数除以50位除数以产生50位余数(模)是没有问题的,如果您有内置函数或可以编写汇编代码来使用64位乘法产生128位结果。
根据处理器的不同,通过乘以81位(或更少)常数来实现除以50位常数可能会更快。再次假设64位处理器,需要进行4次乘法和一些加法,然后移位4个乘积的上位比特以产生商。然后使用商乘以50位模数并从80位乘积中减去(得到50位余数)。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接