NumPy匹配索引维度

5

问题

我有两个numpy数组,Aindices

A的维度为m x n x 10000。 indices的维度为m x n x 5(从argpartition(A, 5)[:,:,:5]输出)。 我希望得到一个m x n x 5的数组,其中包含与indices相应的A元素。

尝试

indices = np.array([[[5,4,3,2,1],[1,1,1,1,1],[1,1,1,1,1]],
    [500,400,300,200,100],[100,100,100,100,100],[100,100,100,100,100]])
A = np.reshape(range(2 * 3 * 10000), (2,3,10000))

A[...,indices] # gives an array of size (2,3,2,3,5). I want a subset of these values
np.take(A, indices) # shape is right, but it flattens the array first
np.choose(indices, A) # fails because of shape mismatch. 

动机

我正在尝试使用np.argpartition获取按排序顺序排列的每个i<mj<n的前5个最大值A[i,j],因为数组可能会变得非常大。

2个回答

5
您可以使用高级索引进行操作,详情请参考advanced-indexing
m,n = A.shape[:2]
out = A[np.arange(m)[:,None,None],np.arange(n)[:,None],indices]

样例运行 -

In [330]: A
Out[330]: 
array([[[38, 21, 61, 74, 35, 29, 44, 46, 43, 38],
        [22, 44, 89, 48, 97, 75, 50, 16, 28, 78],
        [72, 90, 48, 88, 64, 30, 62, 89, 46, 20]],

       [[81, 57, 18, 71, 43, 40, 57, 14, 89, 15],
        [93, 47, 17, 24, 22, 87, 34, 29, 66, 20],
        [95, 27, 76, 85, 52, 89, 69, 92, 14, 13]]])

In [331]: indices
Out[331]: 
array([[[7, 8, 1],
        [7, 4, 7],
        [4, 8, 4]],

       [[0, 7, 4],
        [5, 3, 1],
        [1, 4, 0]]])

In [332]: m,n = A.shape[:2]

In [333]: A[np.arange(m)[:,None,None],np.arange(n)[:,None],indices]
Out[333]: 
array([[[46, 43, 21],
        [16, 97, 16],
        [64, 46, 64]],

       [[81, 14, 43],
        [87, 24, 47],
        [27, 52, 95]]])

要获取沿着最后一个轴的前5个最大元素对应的索引,我们可以使用argpartition,如下所示 -

indices = np.argpartition(-A,5,axis=-1)[...,:5]

为了保持从最高到最低的顺序,请使用range(5)而不是5

1
为了后代,以下使用Divakar的答案来实现原始目标,即以排序顺序返回所有中前5个值:
m, n = np.shape(A)[:2]

# get the largest 5 indices for all m, n
top_unsorted_indices = np.argpartition(A, -5, axis=2)[...,-5:]

# get the values corresponding to top_unsorted_indices
top_values = A[np.arange(m)[:,None,None], np.arange(n)[:,None], top_unsorted_indices]

# sort the top 5 values
top_sorted_indices = top_unsorted_indices[np.arange(m)[:,None,None], np.arange(n)[:,None], np.argsort(-top_values)]

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接