一种高效的算法来计算任意大整数的整数平方根(isqrt)

6

注意

如果需要 Erlang 或者 C / C++ 的解决方案,请参考下面的 Trial 4


维基百科文章

整数平方根

  • 这里可以找到“整数平方根”的定义。

计算平方根的方法

  • 这里有一种使用“位运算”的算法。

[试用1:使用库函数]

代码

isqrt(N) when erlang:is_integer(N), N >= 0 ->
    erlang:trunc(math:sqrt(N)).

问题

这个实现使用了C语言库的sqrt()函数,因此不能处理任意大的整数(注意返回结果与输入不匹配。正确答案应该是12345678901234567890):

Erlang R16B03 (erts-5.10.4) [source] [64-bit] [smp:8:8] [async-threads:10] [hipe] [kernel-poll:false]

Eshell V5.10.4  (abort with ^G)
1> erlang:trunc(math:sqrt(12345678901234567890 * 12345678901234567890)).
12345678901234567168
2> 

[试验2:只使用Bigint +]

代码

isqrt2(N) when erlang:is_integer(N), N >= 0 ->
    isqrt2(N, 0, 3, 0).

isqrt2(N, I, _, Result) when I >= N ->
    Result;

isqrt2(N, I, Times, Result) ->
    isqrt2(N, I + Times, Times + 2, Result + 1).

描述

该实现基于以下观察:

isqrt(0) = 0   # <--- One 0
isqrt(1) = 1   # <-+
isqrt(2) = 1   #   |- Three 1's
isqrt(3) = 1   # <-+
isqrt(4) = 2   # <-+
isqrt(5) = 2   #   |
isqrt(6) = 2   #   |- Five 2's
isqrt(7) = 2   #   |
isqrt(8) = 2   # <-+
isqrt(9) = 3   # <-+
isqrt(10) = 3  #   |
isqrt(11) = 3  #   |
isqrt(12) = 3  #   |- Seven 3's
isqrt(13) = 3  #   |
isqrt(14) = 3  #   |
isqrt(15) = 3  # <-+
isqrt(16) = 4  # <--- Nine 4's
...

问题

这个实现仅涉及大整数加法,因此我希望它能运行得很快。然而,当我输入 1111111111111111111111111111111111111111 * 1111111111111111111111111111111111111111 时,在我的(非常快的)机器上似乎要运行很久。


[ 第三次尝试:仅使用大整数的二分查找 +1-1div 2 ]

代码

变体 1(我的原始实现)

isqrt3(N) when erlang:is_integer(N), N >= 0 ->
    isqrt3(N, 1, N).

isqrt3(_N, Low, High) when High =:= Low + 1 ->
    Low;

isqrt3(N, Low, High) ->
    Mid = (Low + High) div 2,
    MidSqr = Mid * Mid,
    if
        %% This also catches N = 0 or 1
        MidSqr =:= N ->
            Mid;
        MidSqr < N ->
            isqrt3(N, Mid, High);
        MidSqr > N ->
            isqrt3(N, Low, Mid)
    end.

变体2(修改上面的代码,使边界与Mid+1或Mid-1一起移动,参考Vikram Bhat的答案

isqrt3a(N) when erlang:is_integer(N), N >= 0 ->
    isqrt3a(N, 1, N).

isqrt3a(N, Low, High) when Low >= High ->
    HighSqr = High * High,
    if
        HighSqr > N ->
            High - 1;
        HighSqr =< N ->
            High
    end;

isqrt3a(N, Low, High) ->
    Mid = (Low + High) div 2,
    MidSqr = Mid * Mid,
    if
        %% This also catches N = 0 or 1
        MidSqr =:= N ->
            Mid;
        MidSqr < N ->
            isqrt3a(N, Mid + 1, High);
        MidSqr > N ->
            isqrt3a(N, Low, Mid - 1)
    end.

问题

现在它可以快速求解79位数字(即1111111111111111111111111111111111111111 * 1111111111111111111111111111111111111111),结果会立即显示。但是,在我的机器上解决一百万(1,000,000)个61位数字(即从10000000000000000000000000000000000000000000000000000000000001000000000000000000000000000000000000000000000000000001000000)需要60秒(+-2秒)。我希望能更快地完成。


[试验4:仅使用大整数+div的牛顿法]

代码

isqrt4(0) -> 0;

isqrt4(N) when erlang:is_integer(N), N >= 0 ->
    isqrt4(N, N).

isqrt4(N, Xk) ->
    Xk1 = (Xk + N div Xk) div 2,
    if
        Xk1 >= Xk ->
            Xk;
        Xk1 < Xk ->
            isqrt4(N, Xk1)
    end.

使用C/C++编写代码(仅供参考)

递归变体

#include <stdint.h>

uint32_t isqrt_impl(
    uint64_t const n,
    uint64_t const xk)
{
    uint64_t const xk1 = (xk + n / xk) / 2;
    return (xk1 >= xk) ? xk : isqrt_impl(n, xk1);
}

uint32_t isqrt(uint64_t const n)
{
    if (n == 0) return 0;
    if (n == 18446744073709551615ULL) return 4294967295U;
    return isqrt_impl(n, n);
}

迭代变体

#include <stdint.h>

uint32_t isqrt_iterative(uint64_t const n)
{
    uint64_t xk = n;
    if (n == 0) return 0;
    if (n == 18446744073709551615ULL) return 4294967295U;
    do
    {
        uint64_t const xk1 = (xk + n / xk) / 2;
        if (xk1 >= xk)
        {
            return xk;
        }
        else
        {
            xk = xk1;
        }
    } while (1);
}

问题

在我的机器上,Erlang代码可以在40秒内解决一百万个61位数字(+ - 1秒),因此比Trial 3更快。它能更快吗?


关于我的机器

处理器: 3.4 GHz Intel Core i7

内存: 32 GB 1600 MHz DDR3

操作系统: Mac OS X Version 10.9.1


相关问题

Python中的整数平方根

  • 用户448810的答案使用了"牛顿法"。我不确定是否可以使用"整数除法"来进行除法运算。稍后我会更新尝试。 [UPDATE(2015-01-11): 可以这样做]

  • 数学的答案涉及使用一个第三方的Python包,这对我来说并不是很好,因为我主要是想用Erlang内置的工具来解决问题。

  • DSM的答案似乎很有趣。我真的不明白正在发生什么,但它似乎涉及"位运算",所以对我来说并不是很合适。

元整数平方根中的无限递归

  • 这个问题是关于C++的,AraK(提问者)的算法看起来与上面的Trial 2的算法类似。

在Python中,// 是整数除法,因此对于 x^2 - n = 0 的牛顿迭代法是可行的。 - Blender
@Blender:我知道Python中的//是什么意思,因此我怀疑它是否有效,因为维基百科上发布的算法(以及任何与“牛顿法”有关的内容)都适用于实数(请注意浮点数的脚注)。 - Siu Ching Pong -Asuka Kenji-
使用整数除法没有任何区别。 序列 {x_n} 具有下界为isqrt(n)的递减趋势。 - Blender
@Blender:直接应用维基百科中的算法,并将除法替换为整数除法,不会得到正确的结果。由于维基百科的终止条件是abs(x [k + 1] - x [k])<1,而我们正在进行整数运算,为了满足条件,x [k + 1]必须等于x[k]。因此,当算法应用于输入3时:x0 = 3,x1 =(3 + 3 div 3)div 2 = 2,x2 =(2 + 3 div 2)div 2 = 1,x3 =(1 + 3 div 1)div 2 = 2,这导致无限循环。为了纠正这个问题,在x[k+1] >= x[k]时终止(因为你提到的原因)。 - Siu Ching Pong -Asuka Kenji-
啊,你说得对。我看了第一个答案中的代码,它使用 x_{n_1} >= x_n 作为停止条件。我认为维基百科的例子可以更新以考虑整数和浮点数除法。 - Blender
显示剩余3条评论
1个回答

2
以下是类似二分查找的算法,不需要浮点数除法,只需要整数乘法(比牛顿迭代法慢):
low = 1;

/* More efficient bound

high = pow(10,log10(target)/2+1);

*/


high = target


while(low<high) {

 mid = (low+high)/2;
 currsq = mid*mid;

 if(currsq==target) {
    return(mid);
 }

 if(currsq<target) {

      if((mid+1)*(mid+1)>target) {
             return(mid);
      }    
      low =  mid+1;
  }

 else {

     high = mid-1;
 }

}

这个程序需要 O(logN) 次迭代,因此即使对于非常大的数字也不会无限运行。
如果需要计算 targetlog10 值:
acc = target

log10 = 0;

while(acc>0) {

  log10 = log10 + 1;
  acc = acc/10;
}

注意:acc/10是整数除法。

编辑:

有效边界:sqrt(n)的位数约为n的一半,因此您可以传递high = 10^(log10(N)/2+1)low = 10^(log10(N)/2-1)以获得更紧密的边界,并且应该提供2倍的速度提升。

评估边界:

bound = 1;
acc = N;
count = 0;
while(acc>0) {

 acc = acc/10;

 if(count%2==0) {

    bound = bound*10;
 }

 count++;

}

high = bound*10;
low = bound/10;
isqrt(N,low,high);

我之前已经实现过这个。我会检查这个或牛顿法哪个更好。问题在于乘法是在大整数上进行的,而不是在机器整数上进行的,因此这是一个昂贵的操作。 - Siu Ching Pong -Asuka Kenji-
@AsukaKenji-SiuChingPong- 整数乘法大约需要O(logN^2)的时间,因此这个算法将会是O(logN^3),对于你的应用来说仍然相当快速。 - Vikram Bhat
请问一下 high 的初始值是什么?在您最初的回答中,它是 target,然后您进行了两次编辑,变成了 pow(10,log(target)/2+1)。为什么会这样呢?此外,由于 target 是一个大整数(比 long long double 能够容纳的还要大),我担心您需要像 ipow()ilog() 这样的函数来完成这个任务(这需要另一个问题),因为 C <math.h> 库无法处理它。 - Siu Ching Pong -Asuka Kenji-
好的,请让我澄清我的问题...我的意思是,你是如何得出初始值 high = pow(10,log10(N)/2+1) 的?这意味着:当 N = 10 时,high = 10^(1.5) = 31;当 N = 100 时,high = 10^(2) = 100;当 N = 1000 时,high = 10^(2.5) = 316;当 N = 10000 时,high = 10^(3) = 1000;等等。为什么? - Siu Ching Pong -Asuka Kenji-
@AsukaKenji-SiuChingPong- 所有的除法都是整数除法,因此例如当target = 10^(2N)时,high = 10^(N+1),我们知道sqrt(target) = 10^N,因此通过这个界限更紧密,因此可以减少许多乘法。 - Vikram Bhat
显示剩余3条评论

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接