模数乘法函数:在模数下计算两个整数相乘

3

我在一个米勒-拉宾素性测试的代码中发现了这个模乘函数。 这个函数用于消除计算 ( a * b ) % m 时发生的整数溢出。

我需要一些帮助来理解这里发生了什么。为什么这样做有效?那个数字字面量 0x8000000000000000ULL 的意义是什么?

unsigned long long mul_mod(unsigned long long a, unsigned long long b, unsigned long long m) {
    unsigned long long d = 0, mp2 = m >> 1;
    if (a >= m) a %= m;
    if (b >= m) b %= m;
    for (int i = 0; i < 64; i++)
    {
        d = (d > mp2) ? (d << 1) - m : d << 1;
        if (a & 0x8000000000000000ULL)
            d += b;
        if (d >= m) d -= m;
        a <<= 1;
    }
    return d;
}

1
小注:在0x8000000000000000ULL中的LL是不必要的。 - chux - Reinstate Monica
1
我们将位向左移动,直到高位变为1-然后我们知道不能再移动了,否则我们会移走我们需要的位。 - Jerry Jeremiah
2
作者加上一些注释就会很有帮助的经典例子。 - chux - Reinstate Monica
虽然这很有趣且可移植,但如果您的工具链提供了更大的整数类型,您可能会选择使用它。 - Bob__
1个回答

4

这段代码目前出现在模运算维基百科页面上,仅适用于最多63位的参数--请参见底部。

概述

计算一个普通乘法a * b的一种方法是添加左移的b的副本--每个1位对应一个。这类似于我们在学校做长乘法的方式,但更简化:由于我们只需要将每个b的副本乘以1或0,所以我们只需要执行两个操作:添加左移的b的副本(当相应的a的位为1时),或者不执行任何操作(当它为0时)。

这段代码实现了类似的功能。然而,为了避免溢出(大多数情况下;见下文),它不是将每个 b 的副本进行移位并加入总和,而是将一个未移位的 b 加入总和,并依赖稍后在总和上执行的左移来将其移位到正确的位置。可以将这些移位视为作用于迄今为止添加到总和的所有求和项上。例如,第一次循环迭代检查 a 的最高位,即第 63 位,是否为 1(这就是 a & 0x8000000000000000ULL 的作用),如果是,则向总和添加一个未移位的 b;到循环完成时,前一行代码将总和 d 左移了 63 次。
这种方法的主要优点在于我们总是在两个已知小于m的数(即b和d)之间相加,因此处理模余环绕非常便宜: 我们知道 b + d < 2 * m,所以为了确保我们到目前为止的总和仍然小于m,只需检查是否 b + d < m,如果不是,则减去m。如果我们改用移位后相加的方法,那么每一位都需要进行 % 取模运算,这与除法一样昂贵 - 通常比减法昂贵得多。
模算术的一个特性是,每当我们想对某个数m执行一系列算术运算时,将它们全部使用普通算术操作后,最后取模m的余数,总是产生与每个中间结果取模m的余数相同的结果(前提是没有溢出发生)。

代码

在循环体的第一行之前,我们有不变式 d < m 和 b < m.
该行
d = (d > mp2) ? (d << 1) - m : d << 1;

这是一种小心翼翼的方法,将总数d左移1位,同时保持其在范围0 .. m内并避免溢出。我们不是先进行移位,然后测试结果是否大于或等于m,而是测试它当前是否严格高于RoundDown(m/2) -- 因为如果是这样,在加倍后,它肯定会严格高于2 * RoundDown(m/2) >= m - 1,因此需要减去m才能回到范围内。请注意,即使(d << 1) - m中的(d << 1)可能会溢出并且失去d的最高位,但这并不会对减法结果的最低64位产生影响,而这些是我们感兴趣的唯一位。 (还要注意,如果d == m/2恰好,那么之后我们得到d == m,它略微超出范围--但将测试从d > mp2更改为d >= mp2以修复此问题会破坏m为奇数且d == RoundDown(m/2)的情况,因此我们必须接受这种情况。这无关紧要,因为它将在下面得到修复。)

为什么不直接写d <<= 1; if (d >= m) d -= m;呢?假设在无限精度算术中,d << 1 >= m,因此我们应该执行减法,但是d的高位为1,d << 1的其余部分小于m:在这种情况下,初始移位将丢失高位,if将无法执行。

限制输入为63位或更少

上述边界情况仅在 d 的最高位为 1 时发生,而这只会在 m 的最高位也为 1 时发生(因为我们保持不变式 d < m)。因此,代码似乎很小心地确保即使在非常大的 m 值下也能正常工作。不幸的是,它还是可能在其他地方溢出,导致某些设置了最高位的输入产生错误答案。例如,当 a = 3b = 0x7FFFFFFFFFFFFFFFULLm = 0xFFFFFFFFFFFFFFFFULL 时,正确答案应该是 0x7FFFFFFFFFFFFFFEULL,但代码将返回 0x7FFFFFFFFFFFFFFDULL(查看正确答案的简单方法是使用交换后的 ab 值重新运行)。具体来说,每当行 d += b 溢出并留下被截断的 d 小于 m 时,就会发生这种行为,导致错误地跳过减法运算。

如果这种行为已经有文档记录(就像维基百科页面上一样),那么这只是一种限制,而不是错误。

移除限制

如果我们替换以下这些代码行

    if (a & 0x8000000000000000ULL)
        d += b;
    if (d >= m) d -= m;

使用

    unsigned long long x = -(a >> 63) & b;
    if (d >= m - x) d -= m;
    d += x;

这段代码将适用于所有输入, 包括那些具有设置顶部位的输入。第一行晦涩的代码只是一种无条件的(因此通常更快)编写方式。

    unsigned long long x = (a & 0x8000000000000000ULL) ? b : 0;

测试 d >= m - x 在修改 d 之前执行 -- 它就像旧的 d >= m 测试,但是从两边减去了 b (当 a 的最高位为1时) 或者 0 (否则)。这个测试检查了一下在添加 x 后,d 是否会变成 m 或更大。我们知道右侧的 m - x 永远不会下溢,因为 x 最大只能是 b,而我们已经在函数顶部确定了 b < m


维基百科指出,该算法适用于“不大于63位的无符号整数”,在这种情况下,高位不会被设置,因此错误不会发生。 - interjay

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接