为什么循环在这里比索引更好?

10

几年前,有人在Active State Recipes发表了一篇文章,为了比较三个python/NumPy函数的功能。这些函数都接受相同的参数并返回相同的结果:一个距离矩阵

其中两个函数来自已发表的资源;它们都是numpy代码的惯用写法,或者至少在我看来是这样的。创建距离矩阵所需的重复计算是由numpy优雅的索引语法驱动的。下面是其中之一的代码:

from numpy.matlib import repmat, repeat

def calcDistanceMatrixFastEuclidean(points):
  numPoints = len(points)
  distMat = sqrt(sum((repmat(points, numPoints, 1) - 
             repeat(points, numPoints, axis=0))**2, axis=1))
  return distMat.reshape((numPoints,numPoints))

第三个创建了距离矩阵的函数使用了单循环(显然是大量循环,因为一个只有1,000个2D点的距离矩阵就有一百万个条目)。乍一看,这个函数看起来像我学习NumPy时编写的代码,我会先编写Python代码,然后逐行转换。

在Active State发布几个月后,对比这三个函数的性能测试结果被发布并在NumPy邮件列表的线程中讨论。

事实上,带有循环的函数明显优于其他两个函数:

from numpy import mat, zeros, newaxis

def calcDistanceMatrixFastEuclidean2(nDimPoints):
  nDimPoints = array(nDimPoints)
  n,m = nDimPoints.shape
  delta = zeros((n,n),'d')
  for d in xrange(m):
    data = nDimPoints[:,d]
    delta += (data - data[:,newaxis])**2
  return sqrt(delta)

一个帖子的参与者(Keir Mierle)提供了一个可能为真实的原因:

我怀疑这样做更快的原因在于它具有更好的局部性,在转移到下一个较小的工作集之前,完全完成对一个相对较小的工作集的计算。一行代码需要重复地将潜在的大型MxN数组拉入处理器。

根据这位发帖者自己的说法,他的评论只是一种怀疑,并且似乎没有进一步讨论。

对于如何解释这些结果,是否有其他想法?

特别是,是否可以从此示例中提取出有用的规则,关于何时循环和何时索引,作为编写NumPy代码的指导?

对于那些不熟悉NumPy或还没有查看过代码的人来说,这种比较并不是基于一种边缘情况--如果是这样,它肯定不会吸引我的注意。相反,这个比较涉及到矩阵计算中的常见任务(即,在给定两个前置条件的情况下创建结果数组),而且每个函数又由最常见的NumPy内置函数组成。

2个回答

11

简短版:第二段代码仅在点的维度上循环(对于3D点,通过for循环只执行3次),因此循环并不多。第二个代码的真正加速在于更好地利用了Numpy的能力,避免在查找点之间的差异时创建一些额外的矩阵。这减少了内存使用和计算量。

详细版: 我认为calcDistanceMatrixFastEuclidean2函数可能会误导您关于其循环的问题。它仅在点的维度上循环。对于1D点,循环只执行一次,对于2D,执行两次,对于3D,执行三次。这实际上并没有太多的循环。

让我们稍微分析一下代码,看看其中一个比另一个快的原因。 calcDistanceMatrixFastEuclidean 我将其称为fast1,而calcDistanceMatrixFastEuclidean2则为fast2

fast1基于Matlab的方式,如repmap函数所示。在这种情况下,repmap函数创建了一个重复的数据数组。然而,如果您查看该函数的代码,它非常低效。它使用了许多Numpy函数(3个reshape和2个repeat)来执行此操作。 repeat函数还用于创建一个包含每个数据项重复多次的原始数据的数组。如果我们的输入数据为[1,2,3],则我们要从[1,2,3,1,2,3,1,2,3]中减去[1,1,1,2,2,2,3,3,3]。Numpy不得不在运行Numpy的C代码之间创建许多额外的矩阵,这是可以避免的。

fast2使用了更多的Numpy的重量级功能,在Numpy调用之间创建了更少的矩阵。fast2沿着点的每个维度循环,进行减法,并保持每个维度之间平方差的运行总和。仅在最后才进行平方根。到目前为止,这可能听起来不像fast1那么高效,但fast2通过使用Numpy的索引避免了做repmat的事情。以简单的1D情况为例。 fast2创建了一个数据的1D数组,并将其从一个2D(N x 1)数据数组中减去。这将在不使用repmatrepeat的情况下创建每个点与所有其他点之间的差异矩阵,从而避免了创建许多额外的数组。这是我的观点中真正的速度差异所在。 fast1在查找点之间的差异时创建了许多额外的中间矩阵(它们的计算成本也很高),而fast2更好地利用了Numpy的能力来避免这些问题。

顺便说一句,这里有一个稍微更快的版本的fast2

def calcDistanceMatrixFastEuclidean3(nDimPoints):
  nDimPoints = array(nDimPoints)
  n,m = nDimPoints.shape
  data = nDimPoints[:,0]
  delta = (data - data[:,newaxis])**2
  for d in xrange(1,m):
    data = nDimPoints[:,d]
    delta += (data - data[:,newaxis])**2
  return sqrt(delta)

区别在于我们不再将delta创建为一个零矩阵。


1

dis 仅供娱乐:

dis.dis(calcDistanceMatrixFastEuclidean)

  2           0 LOAD_GLOBAL              0 (len)
              3 LOAD_FAST                0 (points)
              6 CALL_FUNCTION            1
              9 STORE_FAST               1 (numPoints)

  3          12 LOAD_GLOBAL              1 (sqrt)
             15 LOAD_GLOBAL              2 (sum)
             18 LOAD_GLOBAL              3 (repmat)
             21 LOAD_FAST                0 (points)
             24 LOAD_FAST                1 (numPoints)
             27 LOAD_CONST               1 (1)
             30 CALL_FUNCTION            3

  4          33 LOAD_GLOBAL              4 (repeat)
             36 LOAD_FAST                0 (points)
             39 LOAD_FAST                1 (numPoints)
             42 LOAD_CONST               2 ('axis')
             45 LOAD_CONST               3 (0)
             48 CALL_FUNCTION          258
             51 BINARY_SUBTRACT
             52 LOAD_CONST               4 (2)
             55 BINARY_POWER
             56 LOAD_CONST               2 ('axis')
             59 LOAD_CONST               1 (1)
             62 CALL_FUNCTION          257
             65 CALL_FUNCTION            1
             68 STORE_FAST               2 (distMat)

  5          71 LOAD_FAST                2 (distMat)
             74 LOAD_ATTR                5 (reshape)
             77 LOAD_FAST                1 (numPoints)
             80 LOAD_FAST                1 (numPoints)
             83 BUILD_TUPLE              2
             86 CALL_FUNCTION            1
             89 RETURN_VALUE

dis.dis(calcDistanceMatrixFastEuclidean2)

  2           0 LOAD_GLOBAL              0 (array)
              3 LOAD_FAST                0 (nDimPoints)
              6 CALL_FUNCTION            1
              9 STORE_FAST               0 (nDimPoints)

  3          12 LOAD_FAST                0 (nDimPoints)
             15 LOAD_ATTR                1 (shape)
             18 UNPACK_SEQUENCE          2
             21 STORE_FAST               1 (n)
             24 STORE_FAST               2 (m)

  4          27 LOAD_GLOBAL              2 (zeros)
             30 LOAD_FAST                1 (n)
             33 LOAD_FAST                1 (n)
             36 BUILD_TUPLE              2
             39 LOAD_CONST               1 ('d')
             42 CALL_FUNCTION            2
             45 STORE_FAST               3 (delta)

  5          48 SETUP_LOOP              76 (to 127)
             51 LOAD_GLOBAL              3 (xrange)
             54 LOAD_FAST                2 (m)
             57 CALL_FUNCTION            1
             60 GET_ITER
        >>   61 FOR_ITER                62 (to 126)
             64 STORE_FAST               4 (d)

  6          67 LOAD_FAST                0 (nDimPoints)
             70 LOAD_CONST               0 (None)
             73 LOAD_CONST               0 (None)
             76 BUILD_SLICE              2
             79 LOAD_FAST                4 (d)
             82 BUILD_TUPLE              2
             85 BINARY_SUBSCR
             86 STORE_FAST               5 (data)

  7          89 LOAD_FAST                3 (delta)
             92 LOAD_FAST                5 (data)
             95 LOAD_FAST                5 (data)
             98 LOAD_CONST               0 (None)
            101 LOAD_CONST               0 (None)
            104 BUILD_SLICE              2
            107 LOAD_GLOBAL              4 (newaxis)
            110 BUILD_TUPLE              2
            113 BINARY_SUBSCR
            114 BINARY_SUBTRACT
            115 LOAD_CONST               2 (2)
            118 BINARY_POWER
            119 INPLACE_ADD
            120 STORE_FAST               3 (delta)
            123 JUMP_ABSOLUTE           61
        >>  126 POP_BLOCK

  8     >>  127 LOAD_GLOBAL              5 (sqrt)
            130 LOAD_FAST                3 (delta)
            133 CALL_FUNCTION            1
            136 RETURN_VALUE

我不是dis的专家,但看起来你需要更多地关注第一个函数调用的功能,才能知道它们为什么需要花费一些时间。Python也有一个性能分析工具,cProfile


1
如果你正在使用cProfile,我建议使用RunSnakeRun来查看结果。 - detly
我注意到Python优化的诀窍似乎通常是让Python解释器尽可能少地执行Python指令。 - Omnifarious

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接