在numba nopython函数中计算阶乘的最快方法

4

我有一个函数,想要使用 numba 进行编译,但是需要在该函数内部计算阶乘。不幸的是,numba 不支持 math.factorial

import math
import numba as nb

@nb.njit
def factorial1(x):
    return math.factorial(x)

factorial1(10)
# UntypedAttributeError: Failed at nopython (nopython frontend)

我发现它支持math.gamma(可用于计算阶乘),但与实际的math.gamma函数相反,它不返回表示“整数值”的浮点数:

@nb.njit
def factorial2(x):
    return math.gamma(x+1)

factorial2(10)
# 3628799.9999999995  <-- not exact

math.gamma(11)
# 3628800.0  <-- exact

相比于math.factorial,它的速度较慢:

%timeit factorial2(10)
# 1.12 µs ± 11.3 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
%timeit math.factorial(10)
# 321 ns ± 6.12 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

所以我决定定义自己的函数:

@nb.njit
def factorial3(x):
    n = 1
    for i in range(2, x+1):
        n *= i
    return n

factorial3(10)
# 3628800

%timeit factorial3(10)
# 821 ns ± 12.2 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

虽然它的速度比math.factorial慢,但比基于math.gamma的numba函数快,并且值是“精确”的。

所以我正在寻找一种最快的方法,在一个nopython numba函数内计算正整数(<=20;为避免溢出)的factorial


2
如果你只关心整数0..20的阶乘,那么查找表可能值得检查以提高速度。 - High Performance Mark
啊,我之前的评论中写成了“your”,实际应该是“you're”。或者,“如果你唯一关心的是…”。 - High Performance Mark
你可以尝试在numba中重新实现Python的方法-它会进行一些额外的步骤以特定的方式排序乘法-https://github.com/python/cpython/blob/3.6/Modules/mathmodule.c#L1275 - chrisb
1个回答

3

对于小于等于20的值,Python使用了一个查找表,正如评论中建议的那样。 https://github.com/python/cpython/blob/3.6/Modules/mathmodule.c#L1452

LOOKUP_TABLE = np.array([
    1, 1, 2, 6, 24, 120, 720, 5040, 40320,
    362880, 3628800, 39916800, 479001600,
    6227020800, 87178291200, 1307674368000,
    20922789888000, 355687428096000, 6402373705728000,
    121645100408832000, 2432902008176640000], dtype='int64')

@nb.jit
def fast_factorial(n):
    if n > 20:
        raise ValueError
    return LOOKUP_TABLE[n]

如果使用 Python 调用,由于 Numba 的调度开销,速度会略慢于 Python 版本。

In [58]: %timeit math.factorial(10)
10000000 loops, best of 3: 79.4 ns per loop

In [59]: %timeit fast_factorial(10)
10000000 loops, best of 3: 173 ns per loop

但是在另一个Numba函数内调用可以更快。

def loop_python():
    for i in range(10000):
        for n in range(21):
            math.factorial(n)

@nb.njit
def loop_numba():
    for i in range(10000):
        for n in range(21):
            fast_factorial(n)

In [65]: %timeit loop_python()
10 loops, best of 3: 36.7 ms per loop

In [66]: %timeit loop_numba()
10000000 loops, best of 3: 73.6 ns per loop

1
Numba会进行激进的循环优化,因此如果您不保存fast_factorial的结果,它甚至不会执行循环。 - MSeifert

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接