模幂运算函数中的整数溢出问题

3

我正在编写一个 powMod 函数,这个函数我需要频繁使用。起点是一个自定义的 pow 函数:

// Compute power using multiplication and square.
// pow (*) (^2) 1 x n = x^n
let pow mul sq one x n =   
    let rec loop x' n' acc =
       match n' with
       | 0 -> acc
       | _ -> let q = n'/2
              let r = n'%2
              let x2 = sq x'
              if r = 0 then
                 loop x2 q acc
              else
                 loop x2 q (mul x' acc)
    loop x n one

检查了我的输入范围后,我选择了int64,因为它足够大以表示输出,并且我可以避免使用bigint进行昂贵的计算:

let mulMod m a b = (a*b)%m
let squareMod m a = mulMod m a a
let powMod m = pow (mulMod m) (squareMod m) 1L

我假设模数(m)大于乘数(ab),并且函数仅适用于非负数。 对于大多数情况,powMod函数是正确的;但是,mulMod函数存在问题,其中a*b可能超出int64范围,但(a*b)%m不会。 下面的示例演示了溢出问题:
let a = (pown 2L 40) - 1L
let b = (pown 2L 32) - 1L
let p = powMod a b 2 // p = -8589934591L -- wrong

有没有办法避免 int64 溢出而不使用 bigint 类型?


尝试在进行乘法运算之前找到一种对每个数进行取模的方法?或者,这可能行不通,但如果你能够通过一个公共除数进行除法运算,然后再进行乘法运算,那么你可能可以做到。不幸的是,无论你做什么都会影响性能。 - James Black
这并没有帮助。我已经假设乘数比模数小了。 - pad
你可以使用具有相当大精度的十进制(decimal)类型,但它只比bigint略快 :| - David Grenier
5个回答

2
根据维基百科,以下公式是等价的。您的代码使用的是第一个,改为使用第二个应该可以解决溢出问题。
c = (a x b) mod(m)
c = (a x (b mod(m))) mod(m) 

希望这能有所帮助。
根据您下方的评论 - 如果a <= m和b <= m,并且m > sqrt(maxint64),那么我不确定在不使用更大的存储空间的情况下是否可能有解决方案。对于大值的m,b mod m将返回b,因此使用上述等效性公式没有任何好处。
好消息是,您应该能够将更改限制在单行代码上,并将值重新输入到64位之内[因为我们知道(a*b)%c不应溢出],然后继续进行计算。这将把执行性能的开销限制在尽可能小的代码部分内。

如果 b >= m,这可能有助于解决问题。但它肯定不能在一般情况下解决问题。 - svick
我假设用户可以添加条件来查找特殊情况 - 如果a> b,则交换a和b,以使b始终> a等。 - Gerald P. Wright
不是的。我假设 a<=mb<=m(请参见我的问题),但仍会发生溢出。在我的情况下,m相当大,我不能假设 m < sqrt(maxint64) - pad

2
我了解f#的知识非常有限,但我认为您可以应用以下事实:
如果b是奇数且存在n满足 = 2n + 1
a * b mod(m) = 2 * a * n + a mod(m)
             = 2 * (a*n mod(m)) + a mod(m)

如果b是偶数,同样地,您可以在an上重复此操作多次,直到得到适合于int64的乘积。如果m > maxint64/2,仍可能发生溢出。


+1 对于好的建议。至少它减少了溢出的可能性。 - pad

2
您遇到的问题是所有中间计算都隐含mod 264,而通常事实并非如此。您正在计算的是:

a·b mod m = (a·b mod 264) mod m

我无法想到一种简单的方法来使用仅64位数字进行正确的计算,但您不必一直升级到大整数; 如果ab最多有64位,则它们的完整乘积最多有128位,因此可以使用两个64位整数(这里作为自定义结构捆绑)跟踪乘积:
// bit width of a uint64, needed for mod calculation
let width = 
    let rec loop w = function
    | 0uL -> w
    | n -> loop (w+1) (n >>> 1)
    loop 0

[<Struct; CustomComparison; CustomEquality>]
type UInt128 =
    val hi : uint64
    val lo : uint64
    new (hi,lo) = { lo = lo; hi = hi }
    new (lo) = { lo = lo; hi = 0uL }
    static member (+)(x:UInt128, y:UInt128) =
        if x.lo > 0xffffffffuL - y.lo then
            UInt128(x.hi + y.hi + 1uL, x.lo + y.lo)
        else
            UInt128(x.hi + y.hi, x.lo + y.lo)
    static member (-)(x:UInt128, y:UInt128) =
        if y.lo > x.lo then
            UInt128(x.hi - y.hi - 1uL, x.lo - y.lo)
        else
            UInt128(x.hi - y.hi, x.lo - y.lo)

    static member ( * )(x:UInt128, y:UInt128) =
        let a1 = ((x.lo &&& 0xffffffffuL) * (y.lo &&& 0xffffffffuL)) >>> 32
        let a2 =  (x.lo &&& 0xffffffffuL) * (y.lo >>> 32)
        let a3 =  (x.lo >>> 32) * (y.lo &&& 0xffffffffuL)
        let sum = ((a1 + a2 + a3) >>> 32) + (x.lo >>> 32) * (y.lo >>> 32)
        let sum =
            if a2 > 0xffffffffffffffffuL - a1 || a1 + a2 > 0xffffffffffffffffuL - a3 then
                0x100000000uL + sum
            else
                sum
        UInt128(x.hi * y.lo + x.lo * y.hi + sum, x.lo * y.lo)

    static member (>>>)(x:UInt128, n) =
        UInt128(x.hi >>> n, x.lo >>> n)

    static member (<<<)(x:UInt128, n) =
        UInt128((x.hi <<< n) + (x.lo >>> (64 - n)), x.lo <<< n)

    interface System.IComparable with
        member x.CompareTo(y) =
            match y with
            | :? UInt128 as y ->
                match x.hi.CompareTo(y.hi) with
                | 0 -> x.lo.CompareTo(y.lo)
                | n -> n

    override x.Equals(y) = 
        match y with
        | :? UInt128 as y -> x.hi = y.hi && x.lo = y.lo
        | _ -> false

    override x.GetHashCode() = x.hi.GetHashCode() + x.lo.GetHashCode() * 7

    (* calculate mod via long-division *)
    static member (%)(x:UInt128, d) =
        let rec reduce (r:UInt128) d' =
            if r.hi = 0uL then r.lo % d
            else
                let r' = if r < d' then r else r - d'
                reduce r' (d' >>> 1)
        let shift = width x.hi + (64 - width d)
        reduce x (UInt128(0uL,d) <<< shift)

let mulMod m a b =
    UInt128(a) * UInt128(b) % m

(* squareMod, powMod basically as before: *)
let squareMod m a = mulMod m a a  
let powMod m = pow (mulMod m) (squareMod m) 1uL  

let a = (pown 2uL 40) - 1uL  
let b = (pown 2uL 32) - 1uL  
let p = powMod a b 2

话虽如此,既然bigint可以给你正确的答案,为什么不直接使用bigint进行中间计算,最后再转换为long(考虑到m的范围,这是一种无损转换)?我猜测相对于维护自己的数学例程而言,使用bigint的性能惩罚应该是可以接受的。


+1 你说得对。在mulMod中检测溢出情况并使用bigint处理这些情况是正确的方法。顺便说一下,你自定义的UInt128很不错。 - pad

0

我对 F# 并不是很熟悉,但类似于以下的伪代码:

pmod a b n // a^b % n
pmod a 0 n = 1
pmod a 1 n = a%n
pmod a b n = match b%2
   | 0 -> ((pmod (a) (b/2) n) ^ 2) % n
   | 1 -> ((pmod (a) (b-1) n) * a ) % n

pmod 还不行,但应该作为辅助函数使用。

PowMod a b n = pmod (a%n) b n

你可以看到,只要结果的平方大于uint64,那么这个结果就是错误的,因此n必须适合于uint32。


这是另一种编写powMod的方式,但问题仍然存在。 - pad

0

这不是对你的问题的回答,但是以这种方式编写函数将使其更通用和方便使用,而且似乎更有效率:

let inline pow x n =
    let zero = LanguagePrimitives.GenericZero
    let rec loop x acc = function
        | n when n = zero -> acc
        | n ->
            let q = n >>> 1
            let acc = if n = (q <<< 1) then acc else x * acc
            loop (x * x) acc q
    loop x LanguagePrimitives.GenericOne n;;

for x = 0 to 1000000 do
    pow 3UL 31UL |> ignore

另外,我想无符号长整型可能不够用?

编辑:对于大整数,在执行更少的乘法运算时,以下算法比上面的算法快3倍-这可能有助于您选择支持大整数的方案:

let inline pow2 x n =
    let zero = LanguagePrimitives.GenericZero
    let one = LanguagePrimitives.GenericOne
    let rec loop x data = function
        | c when c <<< 1 <= n ->
            let c = c <<< 1
            let x = x * x
            loop x (Map.add -c x data) c
        | c -> reduce x data (n - c)
    and reduce acc data = function
        | c when c = zero -> acc
        | c ->
            let next, value = data |> Seq.pick (fun (KeyValue (n, v)) -> if -n <= c then Some (-n, v) else None)
            reduce (acc * value) data (c - next)
    loop x (Map [-1, x]) one;;

for x = 10000 downto 9000 do
    pow2 7I x |> ignore

在模式匹配中要小心,我所做的第一件事是从你的代码中移除了 ',结果出现了一个错误,因为最后一个情况是 _ 而不是 n,并且没有隐藏顶层定义。 - David Grenier

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接