我正在尝试找到获取numpy的“where”语句在2D numpy数组上功能的最快方法;即检索满足条件的索引。与我使用过的其他语言(例如IDL,Matlab)相比,它明显要慢得多。
我已经将函数cython化,通过嵌套for循环遍历数组。速度提高了近一个数量级,但如果可能的话,我想进一步提高性能。
我已经将函数cython化,通过嵌套for循环遍历数组。速度提高了近一个数量级,但如果可能的话,我想进一步提高性能。
TEST.py:
from cython_where import *
import time
import numpy as np
data = np.zeros((2600,5200))
data[100:200,100:200] = 10
t0 = time.time()
inds,ct = cython_where(data,'EQ',10)
print time.time() - t0
t1 = time.time()
tmp = np.where(data == 10)
print time.time() - t1
我的cython_where.pyx程序:
from __future__ import division
import numpy as np
cimport numpy as np
cimport cython
DTYPE1 = np.float
ctypedef np.float_t DTYPE1_t
DTYPE2 = np.int
ctypedef np.int_t DTYPE2_t
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.nonecheck(False)
def cython_where(np.ndarray[DTYPE1_t, ndim=2] data, oper, DTYPE1_t val):
assert data.dtype == DTYPE1
cdef int xmax = data.shape[0]
cdef int ymax = data.shape[1]
cdef unsigned int x, y
cdef int count = 0
cdef np.ndarray[DTYPE2_t, ndim=1] xind = np.zeros(100000,dtype=int)
cdef np.ndarray[DTYPE2_t, ndim=1] yind = np.zeros(100000,dtype=int)
if(oper == 'EQ' or oper == 'eq'): #I didn't want to include GT, GE, LT, LE here
for x in xrange(xmax):
for y in xrange(ymax):
if(data[x,y] == val):
xind[count] = x
yind[count] = y
count += 1
return tuple([xind[0:count],yind[0:count]]),count
TEST.py的输出为:
cython_test]$ python TEST.py
0.0139019489288
0.0982608795166
我也尝试过numpy的argwhere
,它的速度与where
差不多。我对numpy和cython都很新,如果你有任何其他增加性能的想法,我非常乐意听取!
np.nonzero
代码(这是where
使用的)使用np.count_nonzero
来分配结果数组。因此,它最终会遍历两次数组,但计数迭代非常快。 - hpauljhttps://github.com/numpy/numpy/blob/c0e48cfbbdef9cca954b0c4edd0052e1ec8a30aa/numpy/core/src/multiarray/item_selection.c
中的PyArray_Nonzero
是np.nonzero
的源代码。 - hpaulj