Numpy:dot(a,b)和(a*b)。sum() 的区别

16

对于一维numpy数组,这两个表达式在理论上应该产生相同的结果:

(a*b).sum()/a.sum()
dot(a, b)/a.sum()

后者使用dot(),速度更快。但哪一个更准确?为什么?

以下是一些背景。

我想使用numpy计算样本的加权方差。 我在另一个答案中找到了dot()表达式,并有评论指出它应该更准确。然而,在那里没有任何解释。


1
这是加权平均吧?你可能只想使用 np.average - user2357112
我认为“数值精确”部分是指从值中减去平均值,而不是使用“点积”。 - user2357112
1个回答

10

Numpy dot是调用BLAS库的例程之一,该库在编译时链接(或构建自己的库)。这个重要性在于,BLAS库可以利用乘加操作(通常是融合乘加运算),从而限制计算执行的舍入次数。

接下来看这个例子:

>>> a=np.ones(1000,dtype=np.float128)+1E-14 
>>> (a*a).sum()  
1000.0000000000199948
>>> np.dot(a,a)
1000.0000000000199948

不是完全准确,但足够接近。

>>> a=np.ones(1000,dtype=np.float64)+1E-14
>>> np.dot(a,a)
1000.0000000000176  #off by 2.3948e-12
>>> (a*a).sum()
1000.0000000000059  #off by 1.40948e-11

np.dot(a, a)使用的浮点数舍入次数大约是(a*a).sum()的一半,因此前者更加精确。

Nvidia的一本书提供了以下四位精度的示例。 rn代表最接近4个数字的四舍五入:

x = 1.0008
x2 = 1.00160064                    #    true value
rn(x2 − 1) = 1.6006 × 10−4         #    fused multiply-add
rn(rn(x2) − 1) = 1.6000 × 10−4     #    multiply, then add

当然,在十进制的情况下,浮点数不会舍入到小数点后第16位,但你明白我的意思。

np.dot(a,a)放在上述标记中,加上一些伪代码:

out=0
for x in a:
    out=rn(x*x+out)   #Fused multiply add

(a*a).sum() 时:

arr=np.zeros(a.shape[0])   
for x in range(len(arr)):
    arr[x]=rn(a[x]*a[x])

out=0
for x in arr:
    out=rn(x+out)

从这个例子中可以看出,与np.dot(a,a)相比,使用(a*a).sum()将使数字舍入两次。这些微小的差异加起来可能会略微改变答案。可以在这里找到更多示例。


5
如果numpy正在使用用户机器上优化的BLAS,并且他的处理器具有FMA,则不应基于“blas can do that…”做出太多假设。 - user1220978
是的,即使对于英特尔常春藤桥架构,a+b*c也会编译成mulss后跟addss - Vladimir F Героям слава
你是否添加了足够的编译器选项?(例如使用gcc的-march=core-avx2 -mavx -mfma等选项) - user1220978
哼,那很有趣,我本以为 FMA 比它更普遍 - 然而市场似乎正在朝着这个方向发展。 - Daniel

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接