在numpy的一个三维数组中,如何提取第三维中最大元素的索引?

4

例子输入3D数组的形状(2,2,2):

[[[ 1, 2],
  [ 4, 3]],
 [[ 5, 6],
  [ 8, 7]]]

我的三维数组的形状为(N,N,N),在上面的例子中N = 2。

我需要获取所有索引,使得第三个维度的索引属于第三个维度中的最大元素,以上3D数组的输出:

[[0, 0, 1],  # for element 2
 [0, 1, 0],  # for element 4
 [1, 0, 1],  # for element 6
 [1, 1, 0]]  # for element 8

如果我可以使用argmaxargwhere函数就太好了。我希望避免迭代,并看看是否可能使用numpy函数来完成此操作。


6如何映射到[1, 0, 1]?我不理解你想要的输出。 - YXD
@YXD 在这个三维数组中,“6”位于第一个二维数组中,然后在该第一个二维数组中,“6”位于位置(0,1):因此输出为(1, 0, 1)。 - bits
1个回答

1

这是一种方法,使用 np.meshgrid 来获取第一和第二轴上的所有索引,然后使用 np.column_stack 将它们与第三轴上的最大索引堆叠在一起 -

d = a.argmax(-1)
m,n = a.shape[:2]
c,r = np.mgrid[:m,:n]
out = np.column_stack((c.ravel(),r.ravel(),d.ravel()))

样例运行 -

In [96]: a
Out[96]: 
array([[[38, 49, 15, 61, 29],
        [31, 88, 45, 88, 20],
        [17, 97, 58, 61, 14],
        [43, 77, 56, 92, 89]],

       [[48, 91, 49, 35, 58],
        [53, 34, 58, 92, 52],
        [20, 35, 70, 41, 81],
        [60, 42, 85, 82, 41]],

       [[45, 41, 32, 41, 25],
        [59, 32, 90, 18, 47],
        [24, 93, 29, 89, 12],
        [80, 27, 12, 51, 33]]])

In [97]: out
Out[97]: 
array([[0, 0, 3],
       [0, 1, 1],
       [0, 2, 1],
       [0, 3, 3],
       [1, 0, 1],
       [1, 1, 3],
       [1, 2, 4],
       [1, 3, 2],
       [2, 0, 0],
       [2, 1, 2],
       [2, 2, 1],
       [2, 3, 0]])

或者,由于这些索引基本上是重复的,我们可以使用np.repeatnp.tile来获取这些索引数组,然后像之前一样使用np.column_stack,如下所示 -

d0 = np.arange(m).repeat(n)
d1 = np.tile(np.arange(n),m)
out = np.column_stack((d0,d1,d.ravel()))

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接