2D数组中重复的前N个最大值掩码化

9

我有一个2维numpy数组:

arr = np.array([[0.1, 0.1, 0.3, 0.4, 0.5], 
                [0.06, 0.1, 0.1, 0.1, 0.01], 
                [0.24, 0.24, 0.24, 0.24, 0.24], 
                [0.2, 0.25, 0.3, 0.12, 0.02]])
print (arr)
[[0.1  0.1  0.3  0.4  0.5 ]
 [0.06 0.1  0.1  0.1  0.01]
 [0.24 0.24 0.24 0.24 0.24]
 [0.2  0.25 0.3  0.12 0.02]]

我想筛选出前N个值,所以我使用了 argsort

N = 2
arr1 = np.argsort(-arr, kind='mergesort') < N
print (arr1)
[[False False False  True  True]
 [ True False False  True False] <- first top 2 are duplicates
 [ True  True False False False]
 [False  True  True False False]]

它表现良好,至少不会有像第二行那样的完全重复。

期望输出:

print (arr1)
[[False False False  True  True]
 [False  True  True False False]
 [ True  True False False False]
 [False  True  True False False]]

有没有更快的方法来处理它?


这不应该是 np.argsort(-arr, kind='mergesort')[:,:N] 吗? - Divakar
是的,argsort会按顺序给出前N个索引。因此需要进行切片操作。然后,使用它来创建一个掩码,我想说。 - Divakar
1个回答

5
使用切片获取前N个索引,并使用它们创建最终的掩码 -
idx = np.argsort(-arr, kind='mergesort')[:,:N]
mask = np.zeros(arr.shape, dtype=bool)
np.put_along_axis(mask, idx, True, axis=-1)

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接