在排序的二维numpy数组中沿轴查找第一个非零值

11

我正在尝试找到二维排序数组每行第一个非零值的最快方法。技术上,数组中唯一的值是零和一,并且它是“已排序”的。

例如,数组可能如下所示:

v =

0 0 0 1 1 1 1 
0 0 0 1 1 1 1 
0 0 0 0 1 1 1 
0 0 0 0 0 0 1 
0 0 0 0 0 0 1 
0 0 0 0 0 0 1 
0 0 0 0 0 0 0

我可以使用argmax函数

argmax(v, axis=1))

想要找到数组从0变为1的位置,但我认为这将在每行上执行穷举搜索。我的数组大小合理(~2000x2000)。在for循环中对每一行采用searchsorted方法进行搜索,argmax仍会比这种方法更有效吗?或者有更好的替代方案吗?

此外,数组总是这样的,即一行中第一个1的位置始终大于它上面一行中第一个1的位置(但不能保证最后几行中会有1)。我可以利用这一点,对于每一行使用一个“起始索引值”和一个for循环来查找前一行第一个1的位置,但我正确地认为numpy的argmax函数仍然比用Python编写的循环更快吗?

我只是想测试这些替代方案,但数组的边长可能会相当大(从250到10,000)。


我非常希望argmax函数能够更快。 如果性能很关键,您可以尝试用C编写扩展。 - SudoNhim
2个回答

7

使用np.where相对较快:

>>> a
array([[0, 0, 0, 1, 1, 1, 1],
       [0, 0, 0, 1, 1, 1, 1],
       [0, 0, 0, 0, 1, 1, 1],
       [0, 0, 0, 0, 0, 0, 1],
       [0, 0, 0, 0, 0, 0, 1],
       [0, 0, 0, 0, 0, 0, 1],
       [0, 0, 0, 0, 0, 0, 0]])
>>> np.where(a>0)
(array([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 3, 4, 5]), array([3, 4, 5, 6, 3, 4, 5, 6, 4, 5, 6, 6, 6, 6]))

该代码提供的是包含大于0值的元组坐标。

您还可以使用np.where测试每个子数组:

def first_true1(a):
    """ return a dict of row: index with value in row > 0 """
    di={}
    for i in range(len(a)):
        idx=np.where(a[i]>0)
        try:
            di[i]=idx[0][0]
        except IndexError:
            di[i]=None    

    return di       

输出:

{0: 3, 1: 3, 2: 4, 3: 6, 4: 6, 5: 6, 6: None}

ie, 第0行: 索引3>0; 第4行: 索引4>0; 第6行: 没有大于0的索引

正如您所怀疑的那样,argmax可能更快:

def first_true2():
    di={}
    for i in range(len(a)):
        idx=np.argmax(a[i])
        if idx>0:
            di[i]=idx
        else:
            di[i]=None    

    return di       
    # same dict is returned...

如果你能理解在所有零行中没有 None 的逻辑,这种方法会更快:

def first_true3():
    di={}
    for i, j in zip(*np.where(a>0)):
        if i in di:
            continue
        else:
            di[i]=j

    return di      

这里是使用argmax中的轴版本(根据您的评论建议):

def first_true4():
    di={}
    for i, ele in enumerate(np.argmax(a,axis=1)):
        if ele==0 and a[i][0]==0:
            di[i]=None
        else:
            di[i]=ele

    return di          

针对速度比较(以您的示例数组为例),我得出以下结果:

            rate/sec usec/pass first_true1 first_true2 first_true3 first_true4
first_true1   23,818    41.986          --      -34.5%      -63.1%      -70.0%
first_true2   36,377    27.490       52.7%          --      -43.6%      -54.1%
first_true3   64,528    15.497      170.9%       77.4%          --      -18.6%
first_true4   79,287    12.612      232.9%      118.0%       22.9%          --

如果我将其扩展到2000 X 2000的np数组,得到的结果如下:

            rate/sec  usec/pass first_true3 first_true1 first_true2 first_true4
first_true3        3 354380.107          --       -0.3%      -74.7%      -87.8%
first_true1        3 353327.084        0.3%          --      -74.6%      -87.7%
first_true2       11  89754.200      294.8%      293.7%          --      -51.7%
first_true4       23  43306.494      718.3%      715.9%      107.3%          --

实际上,argmax 的好处在于你可以指定一个轴,例如 argmax(a, axis=1),它会使用 C 语言编写的循环遍历行,因此你不必使用 Python 的 for 循环,这应该会更慢。 - user1554752
@user1554752:是的,但如果您使用argmax(a, axis=1),则在a中的行之间存在歧义,这些行为[1,x,x,x,][0,0,0,0],因为argmax(a, axis=1)将对任何一种情况返回0。您仍然需要循环遍历argmax返回的数组以测试此歧义,不是吗? - dawg
这就是我可以利用数据中的模式的地方,其中第一个1从未出现在其上方行的第一个1左侧的位置。一旦我从argmax获得数组(称其为indx),我就可以在其上运行argmin。如果它返回一个值p != 0,则从p向下的所有行都仅由零组成。 - user1554752

4

argmax() 使用C语言级别的循环,比Python循环要快得多,因此即使您在Python中编写了一个聪明的算法,也难以击败argmax()。您可以使用Cython来加速:

@cython.boundscheck(False)
@cython.wraparound(False) 
def find(int[:,:] a):
    cdef int h = a.shape[0]
    cdef int w = a.shape[1]
    cdef int i, j
    cdef int idx = 0
    cdef list r = []
    for i in range(h):
        for j in range(idx, w):
            if a[i, j] == 1:
                idx = j
                r.append(idx)
                break
        else:
            r.append(-1)
    return r

在我的电脑上,对于一个2000x2000的矩阵,运算时间为100微秒与3毫秒之间。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接