如何从numpy数组的每一行中仅获取第一个True值?

8
我有一个4x3的布尔类型numpy数组,我想返回一个相同大小的数组,其中除了原始数组每行的第一个True值所在位置外,其余都为False。因此,如果我有一个起始数组:
all_bools = np.array([[False, True, True],[True, True, True],[False, False, True],[False,False,False]])
all_bools
array([[False,  True,  True], # First true value = index 1
       [ True,  True,  True], # First true value = index 0
       [False, False,  True], # First true value = index 2
       [False, False, False]]) # No True Values

then I'd like to return

[[False, True, False],
 [True, False, False],
 [False, False, True],
 [False, False, False]]

因此,第一行、第二行和第三行的索引1、0和2已经被设置为True,除此之外没有其他值。基本上,原始方法中任何True值(除了每行的第一个True值)都被设置为False。

我一直在使用np.where和np.argmax进行尝试,但还没有找到一个好的解决方案 - 非常感谢任何帮助。需要运行多次,所以我想避免迭代。

2个回答

15
你可以使用 cumsum 函数,通过将结果与 1 进行比较来找到第一个布尔值。
all_bools.cumsum(axis=1).cumsum(axis=1) == 1 
array([[False,  True, False],
       [ True, False, False],
       [False, False,  True],
       [False, False, False]])

这也解释了@a_guest提出的问题。第二个cumsum调用是必需的,以避免匹配第一个和第二个True值之间的所有False值。


如果性能很重要,请使用argmax并设置值:

y = np.zeros_like(all_bools, dtype=bool)
idx = np.arange(len(x)), x.argmax(axis=1)
y[idx] = x[idx]

y
array([[False,  True, False],
       [ True, False, False],
       [False, False,  True],
       [False, False, False]])

Perfplot性能计时
我将利用这个机会展示perfplot及其一些时间数据,因为看到不同大小的输入会如何影响我们的解决方案是很好的。

import numpy as np
import perfplot

def cs1(x):
    return  x.cumsum(axis=1).cumsum(axis=1) == 1 

def cs2(x):
    y = np.zeros_like(x, dtype=bool)
    idx = np.arange(len(x)), x.argmax(axis=1)
    y[idx] = x[idx]
    return y

def a_guest(x):
    b = np.zeros_like(x, dtype=bool)
    i = np.argmax(x, axis=1)
    b[np.arange(i.size), i] = np.logical_or.reduce(x, axis=1)
    return b

perfplot.show(
    setup=lambda n: np.random.randint(0, 2, size=(n, n)).astype(bool),
    kernels=[cs1, cs2, a_guest],
    labels=['cs1', 'cs2', 'a_guest'],
    n_range=[2**k for k in range(1, 8)],
    xlabel='N'
)

enter image description here

这种趋势在更大的N中也持续存在。 cumsum 很昂贵,而我的第二个解决方案与 @a_guest 的之间存在恒定时间差异。


4
您可以使用以下方法,使用np.argmax和与np.logical_or.reduce的乘积处理全部为False的行:
b = np.zeros_like(a, dtype=bool)
i = np.argmax(a, axis=1)
b[np.arange(i.size), i] = np.logical_or.reduce(a, axis=1)

时间结果

不同版本按性能递增顺序排列,即最快的方法最后出现:

In [1]: import numpy as np

In [2]: def f(a):
   ...:     return a.cumsum(axis=1).cumsum(axis=1) == 1
   ...: 
   ...: 

In [3]: def g(a):
   ...:     b = np.zeros_like(a, dtype=bool)
   ...:     i = np.argmax(a, axis=1)
   ...:     b[np.arange(i.size), i] = np.logical_or.reduce(a, axis=1)
   ...:     return b
   ...: 
   ...: 

In [4]: x = np.random.randint(0, 2, size=(1000, 1000)).astype(bool)

In [5]: %timeit f(x)
10.4 ms ± 155 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [6]: %timeit g(x)
120 µs ± 184 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

In [7]: def h(a):
   ...:     y = np.zeros_like(x)
   ...:     idx = np.arange(len(x)), x.argmax(axis=1)
   ...:     y[idx] += x[idx]
   ...:     return y
   ...: 
   ...: 

In [8]: %timeit h(x)
92.1 µs ± 3.51 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

In [9]: def h2(a):
    ...:     y = np.zeros_like(x)
    ...:     idx = np.arange(len(x)), x.argmax(axis=1)
    ...:     y[idx] = x[idx]
    ...:     return y
    ...: 
    ...: 

In [10]: %timeit h2(x)
78.5 µs ± 353 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接