假设我有一个形如numpy数组:
arr=numpy.array([[1,1,0],[1,1,0],[0,0,1],[0,0,0]])
我想找出每一列第一个非零值的索引。
因此,在这种情况下,我希望返回以下内容:
[0,0,2]
我该怎么做呢?
在非零的掩码上沿着该轴(这里是列的第零轴)使用np.argmax
,以获取首次匹配(True 值)的索引 -
(arr!=0).argmax(axis=0)
扩展以涵盖通用轴说明符和对于元素沿该轴未找到任何非零值的情况,我们将有以下实现方式 -
def first_nonzero(arr, axis, invalid_val=-1):
mask = arr!=0
return np.where(mask.any(axis=axis), mask.argmax(axis=axis), invalid_val)
False
值上使用argmax()
将返回0
,因此如果所需的invalid_val
是0
,则我们可以直接使用mask.argmax(axis=axis)
获取最终输出。In [296]: arr # Different from given sample for variety
Out[296]:
array([[1, 0, 0],
[1, 1, 0],
[0, 1, 0],
[0, 0, 0]])
In [297]: first_nonzero(arr, axis=0, invalid_val=-1)
Out[297]: array([ 0, 1, -1])
In [298]: first_nonzero(arr, axis=1, invalid_val=-1)
Out[298]: array([ 0, 0, 1, -1])
扩展以涵盖所有比较操作
要找到第一个zeros
,只需将arr == 0
用作函数中的mask
。对于第一个等于某个值val
的ones
,使用arr == val
,对于此处可能出现的所有比较
情况都是如此。
要找到与某个比较条件相匹配的最后一个元素,我们需要沿着该轴翻转并使用与使用argmax
相同的思路,然后通过从轴长度进行偏移来补偿翻转,如下所示 -
def last_nonzero(arr, axis, invalid_val=-1):
mask = arr!=0
val = arr.shape[axis] - np.flip(mask, axis=axis).argmax(axis=axis) - 1
return np.where(mask.any(axis=axis), val, invalid_val)
样例运行 -
In [320]: arr
Out[320]:
array([[1, 0, 0],
[1, 1, 0],
[0, 1, 0],
[0, 0, 0]])
In [321]: last_nonzero(arr, axis=0, invalid_val=-1)
Out[321]: array([ 1, 2, -1])
In [322]: last_nonzero(arr, axis=1, invalid_val=-1)
Out[322]: array([ 0, 1, 1, -1])
再次强调,使用相应的比较器获取mask
,然后在列出的函数中使用,即可涵盖所有可能的comparisons
情况。
arr = np.array([[1,1,0],[1,1,0],[0,0,1],[0,0,0]])
def first_nonzero_index(array):
"""Return the index of the first non-zero element of array. If all elements are zero, return -1."""
fnzi = -1 # first non-zero index
indices = np.flatnonzero(array)
if (len(indices) > 0):
fnzi = indices[0]
return fnzi
np.apply_along_axis(first_nonzero_index, axis=1, arr=arr)
# result
array([ 0, 0, 2, -1])
解释
np.flatnonzero(array) 方法(如 Henrik Koberg 在评论中所建议的)返回“数组扁平化后非零索引”。该函数计算这些索引并返回第一个(如果所有元素都为零则返回-1)。
apply_along_axis 方法沿指定轴应用函数到1-D切片。在这里,由于 axis 为 1,因此该函数应用于行。
如果我们可以假设输入数组的所有行至少都包含一个非零元素,则可以将解决方案写成一行代码:
np.apply_along_axis(lambda a: np.flatnonzero(a)[0], axis=1, arr=arr)
可能的变化
原始答案
这里是另一种使用numpy.argwhere
的方法,它返回数组的非零元素的索引:
array = np.array([0,0,0,1,2,3,0,0])
nonzero_indx = np.argwhere(array).squeeze()
start, end = (nonzero_indx[0], nonzero_indx[-1])
print(array[start], array[end])
提供:
1 3
argwhere
这个名称相当不直观。 - aksg87flatnonzero(array)
而不是 argwhere(array).squeeze()
。 - Henrik Kobergapply_along_axis
是一个非向量化的便捷函数。在这种情况下,我宁愿选择被接受的答案。 - Henrik Koberg
argmax
仍然会不必要地浏览数组的其余部分(可能很大),假设可能在那里找到更大的值(不知道mask
没有比1
更大的值)。是否可以轻松避免这种情况,而不需要“手动”实现板块处理? - root