NumPy性能:uint8与float,乘法与除法的比较?

19

我刚刚注意到,我的一个脚本的执行时间只需将乘法更改为除法,就能将时间减少近一半。

为了调查这个问题,我编写了一个小例子:

import numpy as np                                                                                                                                                                                
import timeit

# uint8 array
arr1 = np.random.randint(0, high=256, size=(100, 100), dtype=np.uint8)

# float32 array
arr2 = np.random.rand(100, 100).astype(np.float32)
arr2 *= 255.0


def arrmult(a):
    """ 
    mult, read-write iterator
    """
    b = a.copy()
    for item in np.nditer(b, op_flags=["readwrite"]):
        item[...] = (item + 5) * 0.5

def arrmult2(a):
    """ 
    mult, index iterator
    """
    b = a.copy()
    for i, j in np.ndindex(b.shape):
        b[i, j] = (b[i, j] + 5) * 0.5

def arrmult3(a):
    """
    mult, vectorized
    """
    b = a.copy()
    b = (b + 5) * 0.5

def arrdiv(a):
    """ 
    div, read-write iterator 
    """
    b = a.copy()
    for item in np.nditer(b, op_flags=["readwrite"]):
        item[...] = (item + 5) / 2

def arrdiv2(a):
    """ 
    div, index iterator
    """
    b = a.copy()
    for i, j in np.ndindex(b.shape):
           b[i, j] = (b[i, j] + 5)  / 2                                                                                 

def arrdiv3(a):                                                                                                     
    """                                                                                                             
    div, vectorized                                                                                                 
    """                                                                                                             
    b = a.copy()                                                                                                    
    b = (b + 5) / 2                                                                                               




def print_time(name, t):                                                                                            
    print("{: <10}: {: >6.4f}s".format(name, t))                                                                    

timeit_iterations = 100                                                                                             

print("uint8 arrays")                                                                                               
print_time("arrmult", timeit.timeit("arrmult(arr1)", "from __main__ import arrmult, arr1", number=timeit_iterations))
print_time("arrmult2", timeit.timeit("arrmult2(arr1)", "from __main__ import arrmult2, arr1", number=timeit_iterations))
print_time("arrmult3", timeit.timeit("arrmult3(arr1)", "from __main__ import arrmult3, arr1", number=timeit_iterations))
print_time("arrdiv", timeit.timeit("arrdiv(arr1)", "from __main__ import arrdiv, arr1", number=timeit_iterations))  
print_time("arrdiv2", timeit.timeit("arrdiv2(arr1)", "from __main__ import arrdiv2, arr1", number=timeit_iterations))
print_time("arrdiv3", timeit.timeit("arrdiv3(arr1)", "from __main__ import arrdiv3, arr1", number=timeit_iterations))

print("\nfloat32 arrays")                                                                                           
print_time("arrmult", timeit.timeit("arrmult(arr2)", "from __main__ import arrmult, arr2", number=timeit_iterations))
print_time("arrmult2", timeit.timeit("arrmult2(arr2)", "from __main__ import arrmult2, arr2", number=timeit_iterations))
print_time("arrmult3", timeit.timeit("arrmult3(arr2)", "from __main__ import arrmult3, arr2", number=timeit_iterations))
print_time("arrdiv", timeit.timeit("arrdiv(arr2)", "from __main__ import arrdiv, arr2", number=timeit_iterations))  
print_time("arrdiv2", timeit.timeit("arrdiv2(arr2)", "from __main__ import arrdiv2, arr2", number=timeit_iterations))
print_time("arrdiv3", timeit.timeit("arrdiv3(arr2)", "from __main__ import arrdiv3, arr2", number=timeit_iterations))

这将打印以下时间:

uint8 arrays
arrmult   : 2.2004s
arrmult2  : 3.0589s
arrmult3  : 0.0014s
arrdiv    : 1.1540s
arrdiv2   : 2.0780s
arrdiv3   : 0.0027s

float32 arrays
arrmult   : 1.2708s
arrmult2  : 2.4120s
arrmult3  : 0.0009s
arrdiv    : 1.5771s
arrdiv2   : 2.3843s
arrdiv3   : 0.0009s

我一直认为乘法比除法计算更便宜。然而,对于uint8来说,除法似乎几乎是两倍有效的。这是否与* 0.5需要在浮点数中计算乘法,然后将结果强制转换回整数有关?
至少对于浮点数,乘法似乎比除法快。这通常成立吗?
为什么uint8中的乘法比float32中的乘法更昂贵?我认为8位无符号整数应该比32位浮点数计算速度快得多?
能否解释一下这个问题?
编辑:为了获得更多数据,我已经包括了向量化函数(如建议),并添加了索引迭代器。向量化函数要快得多,因此不能真正进行比较。但是,如果将timeit_iterations设置得更高一些,针对向量化函数,结果表明对于uint8float32,乘法都更快。我想这会更加混淆?!
也许乘法实际上总是比除法更快,但是for循环中的主要性能泄漏不是算术操作,而是循环本身。尽管这并不能解释为什么循环对不同的操作表现不同。
编辑2:正如@jotasi所述,我们正在寻找关于divisionmultiplication以及int(或uint8)与float(或float32)的全面解释。此外,解释向量化方法和迭代器的不同趋势也很有趣,因为在向量化情况下,除法似乎较慢,而在迭代器情况下则更快。

如果我没有忽略什么愚蠢的东西,它会变得更奇怪。用 b = (b+5) * 0.5b = (b+5) / 2 替换 for 循环会导致除法变得更慢。 - jotasi
你的代码中有一个小错别字。在arrdiv3函数中,你应该将b = (b + 5) / 0.5改为b = (b + 5) / 2 - jotasi
你说得对!谢谢,我已经修复了。 - daniel451
我还调整了时间。奇怪的是,float32div3 的时间没有改变,而 uint8 则增加到了 2.7E-3。我猜向量化版本的时间已经太低,无法提供精确的测量结果了?! - daniel451
可能是这样。但我也用更大的数组检查了它们,并得到了相同的趋势(除法较慢)。也许你可以强调一下,提供关于向量化与迭代器、整型与浮点型的全面解释会很好。 - jotasi
实际上,那不是我想表达的意思,但是迭代器的除法速度更快,向量化方法的速度较慢,这是趋势上的差异。 - jotasi
4个回答

13
问题在于你的假设,即你测量除法或乘法所需的时间,而这并不正确。你正在测量除法或乘法所需的开销。
要解释每个影响的确切代码都需要查看代码,因为其可能会因版本而异。这个答案只能给出一个想法,告诉你需要考虑什么。
问题在于在Python中,一个简单的整数(int)实际上并不简单:它是一个真正的对象,必须在垃圾回收器中注册,并且随着其值的增长而增加大小-对此你必须付出代价:例如,对于8位整数,需要24字节的内存! Python浮点数也是如此。
另一方面,NumPy数组由没有开销的简单c语言风格的整数/浮点数组成,你可以节省大量内存,但是在访问NumPy数组元素时需要付出代价。 a [i] 的意思是:必须构造Python整数,将其注册到垃圾回收器中,然后才能使用它——有很多开销。
考虑以下代码:
li1=[x%256 for x in xrange(10**4)]
arr1=np.array(li1, np.uint8)

def arrmult(a):    
    for i in xrange(len(a)):
        a[i]*=5;

arrmult(li1)arrmult(arr1) 快 25%,因为列表中的整数已经是 Python-ints,不需要创建它们! 计算时间的绝大部分都用于对象的创建 - 其他所有内容几乎可以被忽略。


让我们来看看您的代码,首先是乘法:

def arrmult2(a):
    ...
    b[i, j] = (b[i, j] + 5) * 0.5

对于uint8,必须执行以下操作(为简单起见,不考虑+5):

  1. 创建一个python-int
  2. 将其转换为float(创建python-float),以便能够进行浮点数乘法运算
  3. 然后将其转换回python-int或/和uint8

对于float32,需要处理的内容更少(乘法的成本不高): 1. 创建一个python-float 2. 将其转换为float32。

因此,使用float版本应该更快,实际也确实如此。


现在让我们看一下除法:

def arrdiv2(a):
    ...
    b[i, j] = (b[i, j] + 5)  / 2 
这里的陷阱是所有操作都是整数操作。因此,与乘法相比,无需将其转换为Python浮点数,因此在乘法的情况下我们有较少的开销。对于 unint8,除法比乘法“更快”。
但是,对于 float32,除法和乘法的速度/慢度相同,因为在这种情况下几乎没有改变 - 我们仍然需要创建一个 Python 浮点数。
现在来看矢量化版本:它们使用c风格的“原始”float32s / uint8s,无需转换(及其成本!)到对应的Python对象。为了获得有意义的结果,您应该增加迭代次数(现在运行时间太短,无法确定)。
  1. 对于float32,除法和乘法的运行时间可能相同,因为我希望numpy通过将除以2替换为乘以0.5(但要确定必须查看代码)。
  2. uint8的乘法应该更慢,因为每个uint8整数在乘以0.5之前必须转换为浮点数,然后在转换回uint8之后再进行乘法。
  3. 对于uint8情况,numpy无法通过将除以2替换为乘以0.5,因为它是一种整数除法。整数除法在很多架构上比浮点乘法慢,这是最慢的矢量化操作。
PS:我不会过多地考虑乘法与除法的成本 - 有太多其他因素可能对性能产生更大的影响。例如,创建不必要的临时对象或者如果numpy数组很大并且无法适应缓存,那么内存访问将是瓶颈 - 你根本看不出乘法和除法之间的区别。

5

这篇回答仅涉及矢量操作,因为其他操作缓慢的原因已由ead回答。

很多“优化”都基于旧硬件。这些优化适用于旧硬件的假设在新硬件上不再成立。

流水线和除法

除法很慢。除法操作由数个单元组成,每个单元必须依次执行一次计算。这就是使除法变慢的原因。

但是,在一个浮点处理单元(FPU)[大多数现代CPU上都有] 中,有专门为除法指令设计的单元排列成“流水线”。一旦一个单元完成,该单元在其余操作中就不再需要。如果您有多个除法操作,则可以让这些无事可做的单元开始下一个除法操作。因此,尽管每个操作很慢,但FPU实际上可以实现高吞吐量的除法操作。流水线与向量化不同,但结果大多相同——当您有大量相同的操作要执行时,吞吐量更高。

将流水线理解为交通流量。将三条车道以30英里/小时的速度运动与一条车道以90英里/小时的速度运动进行比较。慢速交通肯定是单独更慢,但三车道仍具有相同的吞吐量。


-1

这是因为你将一个整数乘以一个浮点数并将结果存储为整数。 尝试使用不同的整数或浮点数值进行arr_mult和arr_div测试。特别是,比较乘以“2”和乘以“2.0”的结果。


在你的最后一行,你是不是想说“除以'2'和除以'2.0'”? - mtrw
这仍然没有回答为什么向量化版本不显示相同的行为... - jotasi
@mtrw 我确实是指“乘法”,尽管除法也可以看到这种效果。而且我应该写成“2.0”,而不是“2.”。 - Daniel Sk
@jotasi 我不确定,但考虑到整数向量化除法实际上似乎比整数向量化除以浮点数要慢,我猜测硬件会在内部进行类型转换和浮点运算,然后巧妙地转换回来,确保得到与常规整数算术相同的结果。 - Daniel Sk

-2

这通常是在“热身”之前需要更长时间的第一个操作(例如,内存分配、缓存)。

使用除法和乘法的相反顺序可以看到相同的效果:

>>> print_time("arrdiv", timeit.timeit("arrdiv(arr2)", "from __main__ import arrdiv, arr2", number=timeit_iterations))
>>> print_time("arrmult", timeit.timeit("arrmult(arr2)", "from __main__ import arrmult, arr2", number=timeit_iterations))

arrdiv:  3.2630s
arrmult:  2.5873s

实际上,这只是解释了 np.float32 数组的微小差异,其中除法已经稍微慢了一些。如果您尝试使用 np.uint8 数组进行相同的操作,则除法仍然快两倍。 - jotasi

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接