问题
我有两个numpy数组,A
和indices
。
A
的维度为m x n x 10000。
indices
的维度为m x n x 5(从argpartition(A, 5)[:,:,:5]
输出)。
我希望得到一个m x n x 5的数组,其中包含与indices
相应的A
元素。
尝试
indices = np.array([[[5,4,3,2,1],[1,1,1,1,1],[1,1,1,1,1]],
[500,400,300,200,100],[100,100,100,100,100],[100,100,100,100,100]])
A = np.reshape(range(2 * 3 * 10000), (2,3,10000))
A[...,indices] # gives an array of size (2,3,2,3,5). I want a subset of these values
np.take(A, indices) # shape is right, but it flattens the array first
np.choose(indices, A) # fails because of shape mismatch.
动机
我正在尝试使用np.argpartition
获取按排序顺序排列的每个i<m
,j<n
的前5个最大值A[i,j]
,因为数组可能会变得非常大。