使用一个数组的索引获取另一个数组元素的索引

4
假设我有一个形状为(2, 2, 2)的数组a:
a = np.array([[[7, 9],
               [19, 18]],
              [[24, 5],
               [18, 11]]])

有一个数组a,和一个指定轴为-1的max函数b=a.max(-1)(按行计算)。

b = np.array([[9, 19],
              [24, 18]])

我想使用a.reshape(-1)a展开后,利用对应的索引获取b中元素的索引。

array([ 7,  9, 19, 18, 24,  5, 18, 11])

结果应该是一个数组,该数组与展平后的 a 中的索引相同,这些索引对应于 b 数组中的元素所在的位置。
array([[1, 2],
       [4, 6]])

这基本上是使用pytorch中return_indices=True时maxpool2d的结果,但我正在寻找numpy中的实现。我已经使用了where, 但似乎不起作用, 是否有可能在一次操作中同时找到最大值和索引以提高效率?感谢任何帮助!


由于数字18a中重复出现了两次,应该返回哪个索引?最后一次出现的18的索引? - Omar AlSuwaidi
3个回答

1
我有一个类似于Andras的解决方案,基于np.argmax和np.arange。与“索引索引”的方法不同,我建议在np.argmax的结果上添加分段偏移量
import numpy as np
a = np.array([[[7, 9],
               [19, 18]],
              [[24, 5],
               [18, 11]]])
off = np.arange(0, a.size, a.shape[2]).reshape(a.shape[0], a.shape[1])

>>> off
array([[0, 2],
       [4, 6]])

这将导致:
>>> a.argmax(-1) + off
array([[1, 2],
       [4, 6]])

或者作为一行代码:
>>> a.argmax(-1) + np.arange(0, a.size, a.shape[2]).reshape(a.shape[0], a.shape[1])
array([[1, 2],
       [4, 6]])

哇,谢谢@Per Joachims,这个同样有效!由于您是“新贡献者”,我认为接受您的答案会更好笑。您能否解释一下偏移量?我不太明白...谢谢!! - Sam-gege
我想我明白了:基本上 off 给出每行第一个元素的索引,然后加上 argmax 将给出该行中最大元素的索引。这很简单,再次感谢。 - Sam-gege

1

目前我能想到的唯一解决方案是生成一个二维(或三维,见下文)range,对扁平数组进行索引,并使用定义b的最大索引(即a.argmax(-1))进行索引:

import numpy as np

a = np.array([[[ 7,  9],
               [19, 18]],
              [[24,  5],
               [18, 11]]])
multi_inds = a.argmax(-1)
b_shape = a.shape[:-1]
b_size = np.prod(b_shape)
flat_inds = np.arange(a.size).reshape(b_size, -1)
flat_max_inds = flat_inds[range(b_size), multi_inds.ravel()]
max_inds = flat_max_inds.reshape(b_shape)

我用一些有意义的变量名分隔了步骤,这应该能解释正在发生什么。 multi_inds 告诉您在 a 中每个“行”中选择哪个“列”以获得最大值:
>>> multi_inds
array([[1, 0],
       [0, 0]])
flat_inds是一个索引列表,每行需选择一个值:
>>> flat_inds
array([[0, 1],
       [2, 3],
       [4, 5],
       [6, 7]])

这是根据每行的最大索引精确索引的。 flat_max_inds 是您要查找的值,但它们在一个平面数组中:
>>> flat_max_inds
array([1, 2, 4, 6])

所以我们需要重新塑造它,使其与b.shape匹配:
>>> max_inds
array([[1, 2],
       [4, 6]])

一种稍微有些难懂但更加优雅的解决方案是使用一个三维索引数组,并对其进行广播索引:

import numpy as np

a = np.array([[[ 7,  9],
               [19, 18]],
              [[24,  5],
               [18, 11]]])
multi_inds = a.argmax(-1)
i, j = np.indices(a.shape[:-1])
max_inds = np.arange(a.size).reshape(a.shape)[i, j, multi_inds]

这样做可以避免将数据中间压缩为2D数组。
同样的方法也可以通过最后一部分从multi_inds获取b,即无需再次调用*max函数。
b = a[i, j, multi_inds]

哇,谢谢Andras。这看起来相当复杂。你知道任何numpy函数可以返回最大值和索引吗?argmax基本上是再次找到最大值,对吧? - Sam-gege
1
谢谢你的帮助!我需要试一下这个。 - Sam-gege
感谢您的更新答案。我认为我需要在maxpool中至少调用一次max,但我也想获取它的索引,以便在反向传播期间,我可以将梯度放回最大位置。我想知道,由于numpy数组是连续的,max()函数应该很容易根据扁平输入返回一个索引,但可惜它没有。再次感谢您的帮助! - Sam-gege
1
是的,你说得对,我忘记了,所以我只需要调用一次argmax。在我有了这些索引之后,我可以简单地调用a.reshape(-1)[np.array([[1, 2],[4, 6]])]来获取这些最大值。基本上,我不必调用max,然后再调用argmax。 - Sam-gege
@Sam-gege,我可能没有时间做这件事,特别是没有提供 [mcve] 的情况下。无论如何,我注意到在 argmax 文档中,它告诉你从 argmax 的结果中获取最大值的惯用方法:np.take_along_axis(a, multi_inds[..., None], axis=-1)(未经测试)。 - Andras Deak -- Слава Україні
显示剩余3条评论

0

这是一个很长的单行代码

new = np.array([np.where(a.reshape(-1)==x)[0][0] for x in a.max(-1).reshape(-1)]).reshape(2,2)

print(new)
array([[1, 2],
       [4, 3]])

然而,数字18重复出现了两次;那么目标是哪个索引。


谢谢你的回答,达米尔。然而,这是不正确的,因为应该返回第二个18的索引,因为那个18是其行中的最大值(18,11),而不是在行(19,18)中的第一个18。 - Sam-gege

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接