为什么 x 的三次方比 x 的平方乘以 x 慢?

28
在NumPy中,x*x*x比x**3甚至np.power(x, 3)快一个数量级。
x = np.random.rand(1e6)
%timeit x**3
100 loops, best of 3: 7.07 ms per loop

%timeit x*x*x
10000 loops, best of 3: 163 µs per loop

%timeit np.power(x, 3)
100 loops, best of 3: 7.15 ms per loop

有什么想法可以解释这种行为吗?据我所知,这三种方式产生的输出都是相同的(通过 np.allclose 进行验证)。


整数与浮点数的计算,也许是这个问题吗? - Martijn Pieters
1
@RohitJain 我认为那个链接并不是特别有用。那个问题的被接受答案是“使用numpy”,而且那个问题是关于纯Python代码,而不是NumPy的。 - user395760
1
@delnam 不要看已被接受的答案,看排名最高的答案。 - cmd
@cmd 最高评分的答案基本上是错误的。取幂大致为 O(1),因为x**y被重写为 2**(y*log x)。在现代处理器上,2**alog a都是单浮点指令。 - Jeffrey Sax
6个回答

36
根据这个答案,指数运算的实现存在一些开销,而乘法则没有。然而,朴素的乘法会随着指数的增加越来越慢。以下是一个经验性的演示:
 In [3]: x = np.random.rand(1e6)

 In [15]: %timeit x**2
 100 loops, best of 3: 11.9 ms per loop

 In [16]: %timeit x*x
 100 loops, best of 3: 12.7 ms per loop

 In [17]: %timeit x**3
 10 loops, best of 3: 132 ms per loop

 In [18]: %timeit x*x*x
 10 loops, best of 3: 27.2 ms per loop

 In [19]: %timeit x**4
 10 loops, best of 3: 132 ms per loop

 In [20]: %timeit x*x*x*x
 10 loops, best of 3: 42.4 ms per loop

 In [21]: %timeit x**10
 10 loops, best of 3: 132 ms per loop

 In [22]: %timeit x*x*x*x*x*x*x*x*x*x
 10 loops, best of 3: 137 ms per loop

 In [24]: %timeit x**15
 10 loops, best of 3: 132 ms per loop

 In [25]: %timeit x*x*x*x*x*x*x*x*x*x*x*x*x*x*x
 1 loops, best of 3: 212 ms per loop

请注意,指数计算时间基本保持不变,除了x ** 2这种情况,我怀疑它被特殊处理了,而乘法越来越慢。似乎可以利用这一点来加速整数幂运算...例如:

In [26]: %timeit x**16
10 loops, best of 3: 132 ms per loop

In [27]: %timeit x*x*x*x*x*x*x*x*x*x*x*x*x*x*x*x
1 loops, best of 3: 225 ms per loop

In [28]: def tosixteenth(x):
   ....:     x2 = x*x
   ....:     x4 = x2*x2
   ....:     x8 = x4*x4
   ....:     x16 = x8*x8
   ....:     return x16
   ....:

In [29]: %timeit tosixteenth(x)
10 loops, best of 3: 49.5 ms per loop

似乎您可以将此技术应用于任何整数,通过将其拆分为二的幂的和,针对每个二的幂进行如上计算,并相加:

In [93]: %paste
def smartintexp(x, exp):
    result = np.ones(len(x))
    curexp = np.array(x)
    while True:
        if exp%2 == 1:
            result *= curexp
        exp >>= 1
        if not exp: break
        curexp *= curexp
    return result
## -- End pasted text --

In [94]: x
Out[94]:
array([ 0.0163407 ,  0.57694587,  0.47336487, ...,  0.70255032,
        0.62043303,  0.0796748 ])

In [99]: x**21
Out[99]:
array([  3.01080670e-38,   9.63466181e-06,   1.51048544e-07, ...,
         6.02873388e-04,   4.43193256e-05,   8.46721060e-24])

In [100]: smartintexp(x, 21)
Out[100]:
array([  3.01080670e-38,   9.63466181e-06,   1.51048544e-07, ...,
         6.02873388e-04,   4.43193256e-05,   8.46721060e-24])

In [101]: %timeit x**21
10 loops, best of 3: 132 ms per loop

In [102]: %timeit smartintexp(x, 21)
10 loops, best of 3: 70.7 ms per loop

对于2的小次幂,它速度很快:

In [106]: %timeit x**32
10 loops, best of 3: 131 ms per loop

In [107]: %timeit smartintexp(x, 32)
10 loops, best of 3: 57.4 ms per loop

但随着指数的增大而变得越来越慢:

In [97]: %timeit x**63
10 loops, best of 3: 133 ms per loop

In [98]: %timeit smartintexp(x, 63)
10 loops, best of 3: 110 ms per loop

对于最坏情况并不更快:

In [115]: %timeit x**511
10 loops, best of 3: 135 ms per loop

In [114]: %timeit smartintexp(x, 511)
10 loops, best of 3: 192 ms per loop

8
你刚刚发现了平方取幂法... - Jaime
1
@Jaime:确实(我已经知道这个存在了),我想知道为什么numpy不以这种方式处理整数指数,直到某个特定的大小...这似乎是一个非常容易提速的方法。 - Claudiu
1
@Claudiu 一个可能的原因是,几乎任何类型的浮点算术重新排序或重新组合都会以微妙的方式改变结果,对于相当多的用例来说,这是不可接受的。请参见https://dev59.com/Fmw15IYBdhLWcg3wu-M9 - user395760
1
@delnan:啊,也许是这样。也许pow有标准的期望,如果你想用另一种方式(比如平方指数)来实现它,你可以自己实现(就像我在这里做的那样)。 - Claudiu
Python 2.7.5 平方 5. 后的结果是偶数;由于类型强制转换,整数 5 (int) 比浮点数 5. (float) 稍微慢一些。 - Dima Tisnek

7
作为一条提示,如果您正在计算幂并担心速度问题:
x = np.random.rand(5e7)

%timeit x*x*x
1 loops, best of 3: 522 ms per loop

%timeit np.einsum('i,i,i->i',x,x,x)
1 loops, best of 3: 288 ms per loop

为什么einsum更快仍然是一个问题(我的提问)的悬而未决。虽然这可能是由于einsum能够使用SSE2,而numpy的ufuncs直到1.8版本才能使用。

就地计算甚至更快:

def calc_power(arr):
    for x in xrange(arr.shape[0]):
        arr[x]=arr[x]*arr[x]*arr[x]
numba_power = autojit(calc_power)

%timeit numba_power(x)
10 loops, best of 3: 51.5 ms per loop

%timeit np.einsum('i,i,i->i',x,x,x,out=x)
10 loops, best of 3: 111 ms per loop

%timeit np.power(x,3,out=x)
1 loops, best of 3: 609 ms per loop

这非常有帮助,谢谢! - uhoh

3

我认为这是因为x ** y必须处理通用情况,其中xy都是浮点数。 从数学上讲,我们可以写成x ** y = exp(y * log(x))。根据您的示例,我发现

x = np.random.rand(1e6)
%timeit x**3
10 loops, best of 3: 178 ms per loop

%timeit np.exp(3*np.log(x))
10 loops, best of 3: 176 ms per loop

我没有检查过实际的numpy代码,但它必须在内部执行类似以下的操作。


1
这是因为在Python中,幂运算是作为浮点操作执行的(对于使用C的numpy也是如此)。
在C语言中,pow函数提供了3种方法:

double pow (double x, double y)

long powl (long double x, long double y)

float powf (float x, float y)

这些都是浮点运算。

如果x是浮点数,那么在两种情况下都会进行浮点运算。您可以进一步解释您的答案。 - cmd

0
根据规范

两个参数的 pow(x, y) 形式等同于使用幂运算符:x**y。

参数必须具有数字类型。对于混合操作数类型,二元算术运算符的强制转换规则适用。

换句话说:由于 x 是浮点数,指数从整数转换为浮点数,并执行通用浮点幂运算。在内部,这通常被重写为:
x**y = 2**(y*lg(x))

2**alg a(以2为底的对数a)是现代处理器上的单个指令,但仍需要比几次乘法更长的时间。


-1
timeit np.multiply(np.multiply(x,x),x)

times the same as x*x*x. 我猜测 np.multiply 使用了快速的Fortran线性代数包,例如BLAS。我从另一个问题中得知,numpy.dot 在某些情况下使用BLAS。


我必须收回之前的说法。 np.dot(x,x)np.sum(x*x) 快3倍。因此,使用np.multiply并不能始终保持与BLAS的速度优势。


使用我的numpy(时间会因机器和可用库而异)

np.power(x,3.1)
np.exp(3.1*np.log(x))

需要大约相同的时间,但是

np.power(x,3)

速度是2倍。虽然不如x*x*x快,但仍比一般的幂函数快。因此它正在利用整数幂的某些优势。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接