我正在尝试学习Cython; 但是,我一定做错了什么。这个小测试代码的运行速度比我的矢量化numpy版本慢了约50倍。请问有人能告诉我为什么我的Cython比Python慢吗?谢谢。
该代码计算了在R^3中一个点loc与一个点数组points之间的距离。
import numpy as np
cimport numpy as np
import cython
cimport cython
DTYPE = np.float64
ctypedef np.float64_t DTYPE_t
@cython.boundscheck(False) # turn of bounds-checking for entire function
@cython.wraparound(False)
@cython.nonecheck(False)
def distMeasureCython(np.ndarray[DTYPE_t, ndim=2] points, np.ndarray[DTYPE_t, ndim=1] loc):
cdef unsigned int i
cdef unsigned int L = points.shape[0]
cdef np.ndarray[DTYPE_t, ndim=1] d = np.zeros(L)
for i in xrange(0,L):
d[i] = np.sqrt((points[i,0] - loc[0])**2 + (points[i,1] - loc[1])**2 + (points[i,2] - loc[2])**2)
return d
这是与之进行比较的numpy代码。
from numpy import *
N = 1e6
points = random.uniform(0,1,(N,3))
loc = random.uniform(0,1,(3))
def distMeasureNumpy(points,loc):
d = points - loc
d = sqrt(sum(d*d,axis=1))
return d
使用numpy/python版本大约需要44毫秒,而cython版本需要大约2秒钟。我在Mac OSX上运行Python 2.7。我使用IPython的%timeit命令来计算这两个函数的时间。
d = np.hypot(*d.T)
,你可能会稍微加快你的numpy版本的速度。 - Joe Kingtoncython -a your_code.pyx
并查看了your_code.html
吗?这是一种检查Cython生成的C代码并找出有多少被转换为C,以及还有多少在Python级别工作的便捷方式。 - Warren Weckesser