在numpy数组中查找条件索引的最快方式

5
我正在尝试找到获取numpy的“where”语句在2D numpy数组上功能的最快方法;即检索满足条件的索引。与我使用过的其他语言(例如IDL,Matlab)相比,它明显要慢得多。
我已经将函数cython化,通过嵌套for循环遍历数组。速度提高了近一个数量级,但如果可能的话,我想进一步提高性能。

TEST.py:

from cython_where import *
import time
import numpy as np

data = np.zeros((2600,5200))
data[100:200,100:200] = 10

t0 = time.time()
inds,ct = cython_where(data,'EQ',10)
print time.time() - t0

t1 = time.time()
tmp = np.where(data == 10)
print time.time() - t1

我的cython_where.pyx程序:

from __future__ import division
import numpy as np
cimport numpy as np
cimport cython

DTYPE1 = np.float
ctypedef np.float_t DTYPE1_t
DTYPE2 = np.int
ctypedef np.int_t DTYPE2_t

@cython.boundscheck(False)
@cython.wraparound(False)
@cython.nonecheck(False)

def cython_where(np.ndarray[DTYPE1_t, ndim=2] data, oper, DTYPE1_t val):
  assert data.dtype == DTYPE1

  cdef int xmax = data.shape[0]
  cdef int ymax = data.shape[1]
  cdef unsigned int x, y
  cdef int count = 0
  cdef np.ndarray[DTYPE2_t, ndim=1] xind = np.zeros(100000,dtype=int)
  cdef np.ndarray[DTYPE2_t, ndim=1] yind = np.zeros(100000,dtype=int)
  if(oper == 'EQ' or oper == 'eq'): #I didn't want to include GT, GE, LT, LE here
    for x in xrange(xmax):
    for y in xrange(ymax):
      if(data[x,y] == val):
        xind[count] = x
        yind[count] = y
        count += 1

 return tuple([xind[0:count],yind[0:count]]),count

TEST.py的输出为: cython_test]$ python TEST.py 0.0139019489288 0.0982608795166

我也尝试过numpy的argwhere,它的速度与where差不多。我对numpy和cython都很新,如果你有任何其他增加性能的想法,我非常乐意听取!


正如标题所述,我想要找到一个二维数组中满足条件(例如 arr == 2)的索引的最快方法。我已经通过Cython优化改进了numpy where语句,就像我上面解释的那样。 - weather guy
你提到了numpy.where,但是numpy.where文档中给出的示例是:ix = np.in1d(x.ravel(), goodvalues).reshape(x.shape),用于检索索引。你试过了吗?这个方法更好吗?还是使用(a == 10).nonzero()? - P. Brunet
@P.Brunet,我尝试过了,它比常规的np.where(x==val)慢一点。除非你要测试多个值,否则我不确定为什么要使用这种方法。 - weather guy
编译后的 np.nonzero 代码(这是 where 使用的)使用 np.count_nonzero 来分配结果数组。因此,它最终会遍历两次数组,但计数迭代非常快。 - hpaulj
https://github.com/numpy/numpy/blob/c0e48cfbbdef9cca954b0c4edd0052e1ec8a30aa/numpy/core/src/multiarray/item_selection.c 中的 PyArray_Nonzeronp.nonzero 的源代码。 - hpaulj
1个回答

3

贡献:

  • Numpy can be speed up on flattened array for a 4x gain:

    %timeit np.where(data==10)
    1 loops, best of 3: 105 ms per loop
    
    %timeit np.unravel_index(np.where(data.ravel()==10),data.shape)
    10 loops, best of 3: 26.0 ms per loop
    
我认为您可以优化您的Cython代码,避免为每个单元格计算"k=i*ncol+j"。
  • Numba give a simple alternative :

    from numba import jit
    dtype=data.dtype
    @jit(nopython=True)
    def numbaeq(flatdata,x,nrow,ncol):
      size=ncol*nrow
      ix=np.empty(size,dtype=dtype)
      jx=np.empty(size,dtype=dtype)
      count=0
      k=0
      while k<size:
        if flatdata[k]==x :
          ix[count]=k//ncol
          jx[count]=k%ncol
          count+=1
        k+=1          
      return ix[:count],jx[:count]
    
    def whereequal(data,x): return numbaeq(data.ravel(),x,*data.shape)
    

这将提供:

    %timeit whereequal(data,10)
    10 loops, best of 3: 20.2 ms per loop

对于这种问题,numba在cython性能下的优化不够好。

  • k//ncolk%ncol可以通过优化的divmod操作同时计算。
  • 最终的步骤是汇编语言和并行化,但这是其他领域的事情。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接