Numpy浮点数提升的不一致性

4
我了解到在 x = f() * g()中,首先执行f(),然后执行g(),然后将结果相乘,并最终赋值给x。但下面的内容似乎与此相矛盾:
import numpy as np

print(np.sqrt(2).dtype)
print((np.array([1.], dtype='float32') * np.array([.5], dtype='float64')).dtype)
print((np.array([1.], dtype='float32') * np.sqrt(2)).dtype)

>>> float64
>>> float64
>>> float32

根据我的所有先前经验,Numpy会升级到更高的dtype,但这里不是。如果我们分配单独的数组并在之后进行乘法运算,则行为相同。我想Numpy使用了除了dtype以外的一些隐藏属性,而不是覆盖Python执行。

它是如何工作的?

1个回答

3
在第二行中:
print((np.array([1.], dtype='float32') * np.array([.5], dtype='float64')).dtype)

你正在对两个数组进行乘法运算,因此它会提升 dtype(正如你所期望的那样)。

但是,在第三行中:

print((np.array([1.], dtype='float32') * np.sqrt(2)).dtype)

你正在将一个标量乘以一个数组,它会保留数组的数据类型。请注意,如果你用一个数组替换标量,像这样:
print((np.array([1.], dtype='float32') * np.sqrt([2])).dtype)

你将会再次获得你所期望的 float64 类型。


我明白了,所以数组的 dtype 始终优先于标量(0-dim)的 dtype - OverLordGoldDragon
1
是的,它确实会: "当标量和数组同时使用时,数组的类型优先"。 - OverLordGoldDragon
请注意, numba @njit 不遵循此约定。 - OverLordGoldDragon
1
@OverLordGoldDragon 谢谢您的留言。那是一个有趣的细节。 - Ehsan
事实上,Python浮点数也会强制转换为float64,即使每个其他标量和ndim数组都是float32,例如2. * x... 我更希望它遵循Numpy的规则。 - OverLordGoldDragon

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接