如何计算一个非常大的整数的第n个根

33

我需要一种在Python中计算长整数的第n次方根的方法。

我尝试了pow(m, 1.0/n),但是它不起作用:

OverflowError: long int too large to convert to float

有什么想法吗?

所谓的长整数是指像下面这样非常大的整数:

11968003966030964356885611480383408833172346450467339251 196093144141045683463085291115677488411620264826942334897996389 485046262847265769280883237649461122479734279424416861834396522 819159219215308460065265520143082728303864638821979329804885526 557893649662037092457130509980883789368448042961108430809620626 059287437887495827369474189818588006905358793385574832590121472 680866521970802708379837148646191567765584039175249171110593159 305029014037881475265618958103073425958633163441030267478942720 703134493880117805010891574606323700178176718412858948243785754 898788359757528163558061136758276299059029113119763557411729353 915848889261125855717014320045292143759177464380434854573300054 940683350937992500211758727939459249163046465047204851616590276 724564411037216844005877918224201569391107769029955591465502737 961776799311859881060956465198859727495735498887960494256488224 613682478900505821893815926193600121890632


正如David所暗示的,pow(n, 1/3)将给出n的立方根(即3次方根)。 - Brian
5
不会,因为在 Python 小于 3 的版本中,1/3 等于 0。 - Matthew Schinckel
(但这也不是原帖作者想要的。) - Matthew Schinckel
Py3没有整数限制...它们可以无限增长,直到内存耗尽。我在我的安装上进行了测试。这是一个解决方案。 - user3917838
12个回答

29

如果这是一个非常大的数字,你可以使用二分搜索。

def find_invpow(x,n):
    """Finds the integer component of the n'th root of x,
    an integer such that y ** n <= x < (y + 1) ** n.
    """
    high = 1
    while high ** n <= x:
        high *= 2
    low = high/2
    while low < high:
        mid = (low + high) // 2
        if low < mid and mid**n < x:
            low = mid
        elif high > mid and mid**n > x:
            high = mid
        else:
            return mid
    return mid + 1
例如:
>>> x = 237734537465873465
>>> n = 5
>>> y = find_invpow(x,n)
>>> y
2986
>>> y**n <= x <= (y+1)**n
True
>>>
>>> x = 119680039660309643568856114803834088331723464504673392511960931441>
>>> n = 45
>>> y = find_invpow(x,n)
>>> y
227661383982863143360L
>>> y**n <= x < (y+1)**n
True
>>> find_invpow(y**n,n) == y
True
>>>

3
如果数字大于约10的1000次方,它仍然会失败。将mid = (low + high) // 2更改为mid = int((low + high) // 2) + 1可以解决这个问题。 - Attila O.
3
因为存在错误所以被踩了。我尝试使用find_invpow(64, 3),结果得到的是3,然而正确的答案应该是4。 - Elias Zamaria
1
如果你想要最终得到一个整数,那么low = high/2这一行应该改为low = high // 2 - drhagen
二分搜索似乎是一种相当慢的方法。更喜欢牛顿迭代,每次都会将精确数字的数量翻倍。为了找到非常准确的起始值,请使用浮点数。 - user1196549

17

Gmpy是一个 C 语言编写的 Python 扩展模块,它封装了 GMP 库,为 Python 代码提供快速多精度算术(整数、有理数和浮点数)、随机数生成、高级数字理论函数等功能。

其中包括一个 root 函数:

x.root(n):返回一个由两个元素组成的元组 (y, m),其中 y 是 x 的第 n 次根(可能被截断);m 是普通的 Python int 类型,如果根是精确的(x==y**n),则 m 为 1,否则为 0。n 必须是普通的 Python int 类型,且大于等于 0。

例如,求 20 次方根:

>>> import gmpy
>>> i0=11968003966030964356885611480383408833172346450467339251 
>>> m0=gmpy.mpz(i0)
>>> m0
mpz(11968003966030964356885611480383408833172346450467339251L)
>>> m0.root(20)
(mpz(567), 0)

使用 gmpy2 会导致出现以下错误: 'mpz' 对象没有 'root' 属性。 - Zelphir Kaltstahl
1
gmpy2使用基于MPFR库的新mpfr类型。 gmpy2.root(x,n) -> mpfr返回x的n次根。结果始终为 'mpfr'。 - gimel
2
@Zelphir gmpy2有gmpy2.iroot函数来计算整数根。 - jiakai

9

如果你想要一些标准的、快速编写并且高精度的东西,我会使用十进制,并将精度调整到至少与x的长度相同。(getcontext().prec)

代码(Python 3.0)

from decimal import *

x =   '11968003966030964356885611480383408833172346450467339251\
196093144141045683463085291115677488411620264826942334897996389\
485046262847265769280883237649461122479734279424416861834396522\
819159219215308460065265520143082728303864638821979329804885526\
557893649662037092457130509980883789368448042961108430809620626\
059287437887495827369474189818588006905358793385574832590121472\
680866521970802708379837148646191567765584039175249171110593159\
305029014037881475265618958103073425958633163441030267478942720\
703134493880117805010891574606323700178176718412858948243785754\
898788359757528163558061136758276299059029113119763557411729353\
915848889261125855717014320045292143759177464380434854573300054\
940683350937992500211758727939459249163046465047204851616590276\
724564411037216844005877918224201569391107769029955591465502737\
961776799311859881060956465198859727495735498887960494256488224\
613682478900505821893815926193600121890632'

minprec = 27
if len(x) > minprec: getcontext().prec = len(x)
else:                getcontext().prec = minprec

x = Decimal(x)
power = Decimal(1)/Decimal(3)

answer = x**power
ranswer = answer.quantize(Decimal('1.'), rounding=ROUND_UP)

diff = x - ranswer**Decimal(3)
if diff == Decimal(0):
    print("x is the cubic number of", ranswer)
else:
    print("x has a cubic root of ", answer)

答案

x是以下数字的立方:22873918786185635329056863961725521583023133411 451452349318109627653540670761962215971994403670045614485973722724603798 107719978813658857014190047742680490088532895666963698551709978502745901 704433723567548799463129652706705873694274209728785041817619032774248488 2965377218610139128882473918261696612098418


8
你可以通过避免 while 循环,取而代之地将低位设置为 10 ** (len(str(x)) / n),高位设置为低位 * 10,从而使其运行速度略有加快。更好的方法可能是用位长度替换 len(str(x)) 并使用位移。根据我的测试,我估计第一种方法可以提速 5%,第二种方法可以提速 25%。如果整数足够大,则这可能很重要(并且加速可能会有所不同)。请认真测试我的代码,不要盲目信任。我进行了一些基本测试,但可能会漏掉一些边界情况。此外,这些加速因所选数字而异。

如果你实际使用的数据比这里发布的数据要大得多,那么这种改变可能是值得的。

from timeit import Timer

def find_invpow(x,n):
    """Finds the integer component of the n'th root of x,
    an integer such that y ** n <= x < (y + 1) ** n.
    """
    high = 1
    while high ** n < x:
        high *= 2
    low = high/2
    while low < high:
        mid = (low + high) // 2
        if low < mid and mid**n < x:
            low = mid
        elif high > mid and mid**n > x:
            high = mid
        else:
            return mid
    return mid + 1

def find_invpowAlt(x,n):
    """Finds the integer component of the n'th root of x,
    an integer such that y ** n <= x < (y + 1) ** n.
    """
    low = 10 ** (len(str(x)) / n)
    high = low * 10

    while low < high:
        mid = (low + high) // 2
        if low < mid and mid**n < x:
            low = mid
        elif high > mid and mid**n > x:
            high = mid
        else:
            return mid
    return mid + 1

x = 237734537465873465
n = 5
tests = 10000

print "Norm", Timer('find_invpow(x,n)', 'from __main__ import find_invpow, x,n').timeit(number=tests)
print "Alt", Timer('find_invpowAlt(x,n)', 'from __main__ import find_invpowAlt, x,n').timeit(number=tests)

Norm 0.626754999161

Alt 0.566340923309


你的第二个函数find_invpowAlt,在x=118997879821732370764604711647724283139870175351576755860556891902958645241483485254092600557474860904935286687480039428945219115513349647465379580432922136155355040992635166676363150438436216219094913514982415747153956476970303302126880391024128871557664284712411567099374094385902892603751471822837746770111,n=3时给出了一个极其错误的答案。 - tzs
如果你想要最终得到一个整数,那么low = high/2这一行应该改为low = high // 2 - drhagen
我故意使用了 low = high/2,因为我直接从 Markus's Code 复制了那个答案,以便与我的改进方案进行基准测试。我要注意的是,自那时以来,Markus 已经更新了他的代码以纠正一个错误,而 tzs 回复了我的代码也有一个错误。然而,我现在不记得问题的细节和我的解决方案,所以我不再感到有能力解决这些缺陷。 - Brian

3

哦,对于那么大的数字,您可以使用decimal模块。

ns:您的数字作为字符串

ns = "11968003966030964356885611480383408833172346450467339251196093144141045683463085291115677488411620264826942334897996389485046262847265769280883237649461122479734279424416861834396522819159219215308460065265520143082728303864638821979329804885526557893649662037092457130509980883789368448042961108430809620626059287437887495827369474189818588006905358793385574832590121472680866521970802708379837148646191567765584039175249171110593159305029014037881475265618958103073425958633163441030267478942720703134493880117805010891574606323700178176718412858948243785754898788359757528163558061136758276299059029113119763557411729353915848889261125855717014320045292143759177464380434854573300054940683350937992500211758727939459249163046465047204851616590276724564411037216844005877918224201569391107769029955591465502737961776799311859881060956465198859727495735498887960494256488224613682478900505821893815926193600121890632"
from decimal import Decimal
d = Decimal(ns)
one_third = Decimal("0.3333333333333333")
print d ** one_third

答案是:2.287391878618402702753613056E+305

TZ指出这并不准确,他说的没错。以下是我的测试。

from decimal import Decimal

def nth_root(num_decimal, n_integer):
    exponent = Decimal("1.0") / Decimal(n_integer)
    return num_decimal ** exponent

def test():
    ns = "11968003966030964356885611480383408833172346450467339251196093144141045683463085291115677488411620264826942334897996389485046262847265769280883237649461122479734279424416861834396522819159219215308460065265520143082728303864638821979329804885526557893649662037092457130509980883789368448042961108430809620626059287437887495827369474189818588006905358793385574832590121472680866521970802708379837148646191567765584039175249171110593159305029014037881475265618958103073425958633163441030267478942720703134493880117805010891574606323700178176718412858948243785754898788359757528163558061136758276299059029113119763557411729353915848889261125855717014320045292143759177464380434854573300054940683350937992500211758727939459249163046465047204851616590276724564411037216844005877918224201569391107769029955591465502737961776799311859881060956465198859727495735498887960494256488224613682478900505821893815926193600121890632"
    nd = Decimal(ns)
    cube_root = nth_root(nd, 3)
    print (cube_root ** Decimal("3.0")) - nd

if __name__ == "__main__":
    test()

这个数值偏差约为10的891次方。


嗯,这可能可行,但不准确。使用您的术语,如果答案等于done_third,则(answer3 - d)应该是多少? - tzot
十进制浮点数的精度足以满足您的需求...我的0.333字符串只是为了简洁起见。 - Jim Carroll
tz...你说得对。它偏离了很多...哦,算了吧。不过上面的牛顿法确实很棒! - Jim Carroll

2
我可以为解决您的任务提供四种方法。第一种基于二分查找。第二种基于牛顿迭代法。第三种基于移位n次方根算法。第四种是我称之为弦切法,由我在此处插入图片中描述。
许多答案中已经实现了二分查找。我在这里介绍自己对它及其实现的看法。
作为替代,我还实现了优化二分查找方法(标记为Opt)。该方法从范围[hi / 2, hi)开始,其中hi等于2^(num_bit_length / k),如果我们计算k次方根。
牛顿法是新方法,我看其他答案中没有实现。它通常被认为比二分查找更快,尽管我的代码计时结果并未显示加速。因此,这里提供这种方法只是作为参考/兴趣。
移位法比优化的二分查找方法快30-50%,如果在C ++中实现,应该会更快,因为C ++具有快速的64位算术,部分用于此方法。
弦切线法:

enter image description here

弦-切线法是由我发明的(在纸上)(见上图),它受到牛顿法的启发并进行了改进。基本上,我画出一条和一条切线,并找到与水平线y=n相交的点,这两个交点形成了根解(x0,n)的下限和上限近似值,其中n=x0 ^ k。这种方法被证明是最快的,而其他所有方法都需要超过2000次迭代,而这种方法只需要8次迭代,对于8192位数的情况。因此,这种方法比以前(按速度计算)的位移方法快200-300倍

例如,我生成了一个真正巨大的8192位随机整数。并测量使用这两种方法找到立方根的时间。

test() 函数中,您可以看到我将 k = 3 作为根的幂(立方根)传递,您可以传递任何幂而不是 3。

在线尝试!

def binary_search(begin, end, f, *, niter = [0]):
    while begin < end:
        niter[0] += 1
        mid = (begin + end) >> 1
        if f(mid):
            begin = mid + 1
        else:
            end = mid
    return begin

def binary_search_kth_root(n, k, *, verbose = False):
    # https://en.wikipedia.org/wiki/Binary_search_algorithm
    niter = [0]
    res = binary_search(0, n + 1, lambda root: root ** k < n, niter = niter)
    if verbose:
        print('Binary Search iterations:', niter[0])
    return res

def binary_search_opt_kth_root(n, k, *, verbose = False):
    # https://en.wikipedia.org/wiki/Binary_search_algorithm
    niter = [0]
    hi = 1 << (n.bit_length() // k - 1)
    while hi ** k <= n:
        niter[0] += 1
        hi <<= 1
    res = binary_search(hi >> 1, hi, lambda root: root ** k < n, niter = niter)
    if verbose:
        print('Binary Search Opt iterations:', niter[0])
    return res

def newton_kth_root(n, k, *, verbose = False):
    # https://en.wikipedia.org/wiki/Newton%27s_method
    f = lambda x: x ** k - n
    df = lambda x: k * x ** (k - 1)
    x, px, niter = n, 2 * n, [0]
    while abs(px - x) > 1:
        niter[0] += 1
        px = x
        x -= f(x) // df(x)
    if verbose:
        print('Newton Method iterations:', niter[0])
    mini, minv = None, None
    for i in range(-2, 3):
        v = abs(f(x + i))
        if minv is None or v < minv:
            mini, minv = i, v
    return x + mini

def shifting_kth_root(n, k, *, verbose = False):
    # https://en.wikipedia.org/wiki/Shifting_nth_root_algorithm
    B_bits = 64
    
    r, y = 0, 0
    B = 1 << B_bits
    Bk_bits = B_bits * k
    Bk_mask = (1 << Bk_bits) - 1
    niter = [0]
    
    for i in range((n.bit_length() + Bk_bits - 1) // Bk_bits - 1, -1, -1):
        alpha = (n >> (i * Bk_bits)) & Bk_mask
        B_y = y << B_bits
        Bk_yk = (y ** k) << Bk_bits
        Bk_r_alpha = (r << Bk_bits) + alpha
        Bk_yk_Bk_r_alpha = Bk_yk + Bk_r_alpha
        beta = binary_search(1, B, lambda beta: (B_y + beta) ** k <= Bk_yk_Bk_r_alpha, niter = niter) - 1
        y, r = B_y + beta, Bk_r_alpha - ((B_y + beta) ** k - Bk_yk)

    if verbose:
        print('Shifting Method iterations:', niter[0])

    return y

def chord_tangent_kth_root(n, k, *, verbose = False):
    niter = [0]
    hi = 1 << (n.bit_length() // k - 1)
    while hi ** k <= n:
        niter[0] += 1
        hi <<= 1
    f = lambda x: x ** k
    df = lambda x: k * x ** (k - 1)
    # https://istack.dev59.com/et9O0.webp
    x_begin, x_end = hi >> 1, hi
    y_begin, y_end = f(x_begin), f(x_end)
    for icycle in range(1 << 30):
        if x_end - x_begin <= 1:
            break
        niter[0] += 1
        if 0: # Do Binary Search step if needed
            x_mid = (x_begin + x_end) >> 1
            y_mid = f(x_mid)
            if y_mid > n:
                x_end, y_end = x_mid, y_mid
            else:
                x_begin, y_begin = x_mid, y_mid
        # (y_end - y_begin) / (x_end - x_begin) = (n - y_begin) / (x_n - x_begin) ->
        x_n = x_begin + (n - y_begin) * (x_end - x_begin) // (y_end - y_begin)
        y_n = f(x_n)
        tangent_x = x_n + (n - y_n) // df(x_n) + 1
        
        chord_x = x_n + (n - y_n) * (x_end - x_n) // (y_end - y_n)
        
        assert chord_x <= tangent_x, (chord_x, tangent_x)
        x_begin, x_end = chord_x, tangent_x
        y_begin, y_end = f(x_begin), f(x_end)
        assert y_begin <= n, (chord_x, y_begin, n, n - y_begin)
        assert y_end > n, (icycle, tangent_x - binary_search_kth_root(n, k), y_end, n, y_end - n)
    if verbose:
        print('Chord Tangent Method iterations:', niter[0])
    return x_begin

def test():
    import random, timeit
    
    nruns = 3
    bits = 8192
    n = random.randrange(1 << (bits - 1), 1 << bits)
    
    a = binary_search_kth_root(n, 3, verbose = True)
    b = binary_search_opt_kth_root(n, 3, verbose = True)
    c = newton_kth_root(n, 3, verbose = True)
    d = shifting_kth_root(n, 3, verbose = True)
    e = chord_tangent_kth_root(n, 3, verbose = True)
    assert abs(a - b) <= 0 and abs(a - c) <= 1 and abs(a - d) <= 1 and abs(a - e) <= 1, (a - b, a - c, a - d, a - e)

    print()
    print('Binary Search timing:', round(timeit.timeit(lambda: binary_search_kth_root(n, 3), number = nruns) / nruns, 3), 'sec')
    print('Binary Search Opt timing:', round(timeit.timeit(lambda: binary_search_opt_kth_root(n, 3), number = nruns) / nruns, 3), 'sec')
    print('Newton Method timing:', round(timeit.timeit(lambda: newton_kth_root(n, 3), number = nruns) / nruns, 3), 'sec')
    print('Shifting Method timing:', round(timeit.timeit(lambda: shifting_kth_root(n, 3), number = nruns) / nruns, 3), 'sec')
    print('Chord Tangent Method timing:', round(timeit.timeit(lambda: chord_tangent_kth_root(n, 3), number = nruns) / nruns, 3), 'sec')

if __name__ == '__main__':
    test()

输出:

Binary Search iterations: 8192
Binary Search Opt iterations: 2732
Newton Method iterations: 9348
Shifting Method iterations: 2752
Chord Tangent Method iterations: 8

Binary Search timing: 0.506 sec
Binary Search Opt timing: 0.05 sec
Newton Method timing: 2.09 sec
Shifting Method timing: 0.03 sec
Chord Tangent Method timing: 0.001 sec

你为什么要给 niter 赋一个默认值? - schuelermine

2

1
我想到了自己的答案,它借鉴了 @Mahmoud Kassem 的思路,简化了代码,并使其更具可重用性:
def cube_root(x):
    return decimal.Decimal(x) ** (decimal.Decimal(1) / decimal.Decimal(3))

我在Python 3.5.1和Python 2.7.8中进行了测试,似乎能够正常工作。
结果将具有与运行函数时指定的十进制上下文一样多的数字,默认为28位小数。根据“decimal”模块中“power”函数的文档, "结果是明确定义的,但只有“几乎总是正确舍入”的情况。 "如果您需要更精确的结果,则可以按以下方式执行:
with decimal.localcontext() as context:
    context.prec = 50
    print(cube_root(42))

0
在旧版本的Python中,1/3等于0。在Python 3.0中,1/3等于0.33333333333(而1//3等于0)。
因此,要么更改您的代码以使用1/3.0,要么切换到Python 3.0。

据我所知,没有任何迹象表明它使用的是Python 2。 - Solomon Ucko
@SolomonUckoI 这个问题和答案都是在 Python 3 发布后仅一周内发布的。 - Brian

0
你可以使用SymPy来完成这个任务。
In [29]: n = 1196800396603096435688561148038340883317234645046733925119609314414104568346308529111567
    ...: 74884116202648269423348979963894850462628472657692808832376494611224797342794244168618343965
    ...: 22819159219215308460065265520143082728303864638821979329804885526557893649662037092457130509
    ...: 98088378936844804296110843080962062605928743788749582736947418981858800690535879338557483259
    ...: 01214726808665219708027083798371486461915677655840391752491711105931593050290140378814752656
    ...: 18958103073425958633163441030267478942720703134493880117805010891574606323700178176718412858
    ...: 94824378575489878835975752816355806113675827629905902911311976355741172935391584888926112585
    ...: 57170143200452921437591774643804348545733000549406833509379925002117587279394592491630464650
    ...: 47204851616590276724564411037216844005877918224201569391107769029955591465502737961776799311
    ...: 85988106095646519885972749573549888796049425648822461368247890050582189381592619360012189063
    ...: 2

In [30]: import sympy

In [31]: r = sympy.Integer(n) ** sympy.Rational(1, 3)

In [32]: print(r)
228739187861856353290568639617255215830231334114514523493181096276535406707619622159719944036700456144859737227246037981077199788136588570141900477426804900885328956669636985517099785027459017044337235675487994631296527067058736942742097287850418176190327742484882965377218610139128882473918261696612098418

这只是使用SymPy的一般符号表达式简化。如果可能的话,上述方法将给出一个精确的结果,否则将给出一个表达式,例如:
In [33]: r2 = sympy.Integer(10) ** sympy.Rational(1, 3)

In [34]: r2
Out[34]: 
3 ____
╲╱ 10 

In [35]: r2.evalf(50)
Out[35]: 2.1544346900318837217592935665193504952593449421921

SymPy还有一个专门用于计算整数的n次根的函数:
In [38]: sympy.integer_nthroot(n, 3)
Out[38]: 
(228739187861856353290568639617255215830231334114514523493181096276535406707619622159719944036700456144859737227246037981077199788136588570141900477426804900885328956669636985517099785027459017044337235675487994631296527067058736942742097287850418176190327742484882965377218610139128882473918261696612098418,
 True)

这里的True表示结果是一个精确的立方根。当无法得到精确的根时,该函数将返回立方根的下取整值。
In [39]: sympy.integer_nthroot(10, 3)
Out[39]: (2, False)

换句话说,2是最大的整数,满足条件2**3 <= 10,但False表示2**3 != 10

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接