使用Cython运行距离计算比NumPy慢

3

我正在尝试学习Cython; 但是,我一定做错了什么。这个小测试代码的运行速度比我的矢量化numpy版本慢了约50倍。请问有人能告诉我为什么我的Cython比Python慢吗?谢谢。

该代码计算了在R^3中一个点loc与一个点数组points之间的距离。

import numpy as np
cimport numpy as np
import cython
cimport cython

DTYPE = np.float64
ctypedef np.float64_t DTYPE_t
@cython.boundscheck(False) # turn of bounds-checking for entire function
@cython.wraparound(False)
@cython.nonecheck(False)
def distMeasureCython(np.ndarray[DTYPE_t, ndim=2] points, np.ndarray[DTYPE_t, ndim=1] loc):
    cdef unsigned int i
    cdef unsigned int L = points.shape[0]
    cdef np.ndarray[DTYPE_t, ndim=1] d = np.zeros(L)
    for i in xrange(0,L):
        d[i] = np.sqrt((points[i,0] - loc[0])**2 + (points[i,1] - loc[1])**2 + (points[i,2]  - loc[2])**2)
    return d

这是与之进行比较的numpy代码。
from numpy import *
N = 1e6
points = random.uniform(0,1,(N,3))
loc = random.uniform(0,1,(3))

def distMeasureNumpy(points,loc):
    d = points - loc
    d = sqrt(sum(d*d,axis=1))
    return d

使用numpy/python版本大约需要44毫秒,而cython版本需要大约2秒钟。我在Mac OSX上运行Python 2.7。我使用IPython的%timeit命令来计算这两个函数的时间。


我没有看到Cython版本有任何明显的问题(而且我很惊讶它的速度如此之慢)。然而,你不会用Cython打败一个正确向量化的numpy表达式。Cython最适合(并且通常非常好)于无法向量化的操作。另外,通过使用d = np.hypot(*d.T),你可能会稍微加快你的numpy版本的速度。 - Joe Kington
你运行过 cython -a your_code.pyx 并查看了 your_code.html 吗?这是一种检查Cython生成的C代码并找出有多少被转换为C,以及还有多少在Python级别工作的便捷方式。 - Warren Weckesser
2个回答

6

调用Python函数np.sqrt会影响程序性能,因为它计算的是标量浮点数的平方根,所以应该使用C语言数学库中的sqrt函数。下面是修改后的代码:

import numpy as np
cimport numpy as np
import cython
cimport cython

from libc.math cimport sqrt

DTYPE = np.float64
ctypedef np.float64_t DTYPE_t
@cython.boundscheck(False) # turn of bounds-checking for entire function
@cython.wraparound(False)
@cython.nonecheck(False)
def distMeasureCython(np.ndarray[DTYPE_t, ndim=2] points,
                      np.ndarray[DTYPE_t, ndim=1] loc):
    cdef unsigned int i
    cdef unsigned int L = points.shape[0]
    cdef np.ndarray[DTYPE_t, ndim=1] d = np.zeros(L)
    for i in xrange(0,L):
        d[i] = sqrt((points[i,0] - loc[0])**2 +
                    (points[i,1] - loc[1])**2 +
                    (points[i,2] - loc[2])**2)
    return d

以下是性能改进的演示。您原始的代码在模块check_speed_original中,修改后的版本在check_speed中:
In [11]: import check_speed_original

In [12]: import check_speed

设置测试数据:

In [13]: N = 10**6

In [14]: points = random.uniform(0,1,(N,3))

In [15]: loc = random.uniform(0,1,(3,))

原始版本在我的电脑上需要1.26秒:

In [16]: %timeit check_speed_original.distMeasureCython(points, loc)
1 loops, best of 3: 1.26 s per loop

修改后的版本需要4.47毫秒:
In [17]: %timeit check_speed.distMeasureCython(points, loc)
100 loops, best of 3: 4.47 ms per loop

如果有人担心结果可能不同:

In [18]: d1 = check_speed.distMeasureCython(points, loc)

In [19]: d2 = check_speed_original.distMeasureCython(points, loc)

In [20]: np.all(d1 == d2)
Out[20]: True

它有效了!谢谢。我也得到了你上面提到的运行时间。关于HTML的提示,谢谢。这是我的第一个问题,现在已经解决了,我应该关闭它吗?感谢WW的帮助。 - plancherel

3

如前所述,代码中的numpy.sqrt调用是问题所在。然而,我认为不需要使用cdef extern,因为Cython已经提供了这些基本的C/C++库(请参阅文档)。因此,您只需像这样导入它:

    from libc.math cimport sqrt

只是为了摆脱额外的开销。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接