在2D numpy数组中获取大于阈值的元素的索引

48

我有一个二维的numpy数组:

x = np.array([
 [  1.92043482e-04,   0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
    0.00000000e+00,   0.00000000e+00,   2.41005634e-03,   0.00000000e+00,
    7.19330120e-04,   0.00000000e+00,   0.00000000e+00,   1.42886875e-04,
    0.00000000e+00,   0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
    0.00000000e+00,   9.79279411e-05,   7.88888657e-04,   0.00000000e+00,
    0.00000000e+00,   1.40425916e-01,   0.00000000e+00,   1.13955893e-02,
    7.36868947e-03,   3.67091988e-04,   0.00000000e+00,   0.00000000e+00,
    0.00000000e+00,   0.00000000e+00,   1.72037105e-03,   1.72377961e-03,
    0.00000000e+00,   0.00000000e+00,   1.19532061e-01,   0.00000000e+00,
    0.00000000e+00,   0.00000000e+00,   0.00000000e+00,   3.37249481e-04,
    0.00000000e+00,   0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
    0.00000000e+00,   0.00000000e+00,   1.75111492e-03,   0.00000000e+00,
    0.00000000e+00,   1.12639313e-02],
 [  0.00000000e+00,   0.00000000e+00,   1.10271735e-04,   5.98736562e-04,
    6.77961628e-04,   7.49569659e-04,   0.00000000e+00,   0.00000000e+00,
    2.91697850e-03,   0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
    0.00000000e+00,   0.00000000e+00,   3.30257021e-04,   2.46629275e-04,
    0.00000000e+00,   1.87586441e-02,   6.49103144e-04,   0.00000000e+00,
    1.19046355e-04,   0.00000000e+00,   0.00000000e+00,   2.69499898e-03,
    1.48525386e-02,   0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
    0.00000000e+00,   0.00000000e+00,   0.00000000e+00,   1.18803119e-03,
    3.93100829e-04,   0.00000000e+00,   3.76245304e-04,   2.79537738e-02,
    0.00000000e+00,   1.20738457e-03,   9.74669064e-06,   7.18680093e-04,
    1.61546793e-02,   3.49360861e-04,   0.00000000e+00,   0.00000000e+00,
    0.00000000e+00,   0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
    0.00000000e+00,   0.00000000e+00]])

我如何获得大于 0.01 的元素的索引?

目前,我正在执行 t = np.argmax(x, axis=1),以获取每个最大值的索引,其结果为:[21 35]。我该如何实现上述要求?

2个回答

66

你可以使用np.argwhere来返回符合布尔条件的数组中所有条目的索引:

>>> x = np.array([[0,0.2,0.5],[0.05,0.01,0]])

>>> np.argwhere(x > 0.01)
array([[0, 1],
       [0, 2],
       [1, 0]])    

嗯,我觉得我不太清楚如何解释np数组。在(array([0, 1]), array([1, 0]))中,为什么是array[1,0]而不是[0,1]? - Arman
因为它们是numpy对象。 - maxymoo
我不确定那是否给了我正确的结果。我进行了t = np.where(x > 0.01,得到的输出是:(array([0, 0, 0, 0, 1, 1, 1, 1]), array([21, 23, 34, 49, 17, 24, 35, 40]))。但是第一个元素并没有包含21,而这个数字在使用argmax时被返回。 - Arman
哦,抱歉实际上第一个数组是x坐标,第二个数组是y坐标,你需要将它们压缩在一起以获得x-y坐标,我会编辑我的答案。 - maxymoo
实际上你可以使用 argwhere - maxymoo

1

我们也可以使用np.nonzero()来获取一个元组,其中包含每个维度的数组,这些数组包含条件为True的索引。

x_indices, y_indices = np.nonzero(x > 0.01)
# (array([0, 0, 0, 0, 1, 1, 1, 1], dtype=int64), array([21, 23, 34, 49, 17, 24, 35, 40], dtype=int64))

它的一个好处是可以立即用于索引数组。例如,如果我们想要过滤大于0.01的元素,则

x[np.nonzero(x>0.01)]

nonzero 按维度分组索引,而 argwhere 按元素分组(这只是从不同的角度看同一件事),因此以下语句为真:

(np.argwhere(x>0.01).T == np.nonzero(x>0.01)).all()   # True

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接