使用NumPy查找数组中N个重复项的索引

3
我有一个数组,它是使用sp.distance.cdist获得的,这个数组长这样:
 [ 0.          5.37060126  2.68530063  4.65107712  2.68530063  4.65107712
   2.04846297  7.41906423  4.11190697  6.50622284  4.11190697  6.50622284]
 [ 5.37060126  0.          4.65107712  2.68530063  4.65107712  2.68530063
   7.41906423  2.04846297  6.50622284  4.11190697  6.50622284  4.11190697]
 [ 2.68530063  4.65107712  0.          2.68530063  4.65107712  5.37060126
   4.11190697  6.50622284  2.04846297  4.11190697  6.50622284  7.41906423]
 [ 4.65107712  2.68530063  2.68530063  0.          5.37060126  4.65107712
   6.50622284  4.11190697  4.11190697  2.04846297  7.41906423  6.50622284]
 [ 2.68530063  4.65107712  4.65107712  5.37060126  0.          2.68530063
   4.11190697  6.50622284  6.50622284  7.41906423  2.04846297  4.11190697]
 [ 4.65107712  2.68530063  5.37060126  4.65107712  2.68530063  0.
   6.50622284  4.11190697  7.41906423  6.50622284  4.11190697  2.04846297]
 [ 2.04846297  7.41906423  4.11190697  6.50622284  4.11190697  6.50622284
   0.          9.4675272   4.7337636   8.19911907  4.7337636   8.19911907]
 [ 7.41906423  2.04846297  6.50622284  4.11190697  6.50622284  4.11190697
   9.4675272   0.          8.19911907  4.7337636   8.19911907  4.7337636 ]
 [ 4.11190697  6.50622284  2.04846297  4.11190697  6.50622284  7.41906423
   4.7337636   8.19911907  0.          4.7337636   8.19911907  9.4675272 ]
 [ 6.50622284  4.11190697  4.11190697  2.04846297  7.41906423  6.50622284
   8.19911907  4.7337636   4.7337636   0.          9.4675272   8.19911907]
 [ 4.11190697  6.50622284  6.50622284  7.41906423  2.04846297  4.11190697
   4.7337636   8.19911907  8.19911907  9.4675272   0.          4.7337636 ]
 [ 6.50622284  4.11190697  7.41906423  6.50622284  4.11190697  2.04846297
   8.19911907  4.7337636   9.4675272   8.19911907  4.7337636   0.        ]]

使用numpy,我想要搜索一些值,例如在2.72.3之间,并且在数组的行中找到它们时,同时返回它们的索引。我已经阅读了很多资料,例如.argmin()部分实现了我想要的功能(但它只显示零或低于零的值的位置,并且只有一个匹配项)。在.argmin的文档中,我找不到任何关于如何找到不为零的最小值并且不在第一个匹配项后停止的相关信息。我需要在这个区间内进行操作。为了更好地解释,这就是我期望得到的:

e.g.:

[row (0), index (2), index (4)]
[row (1), index (3), index (5)]
[row (2), index (0), index (3)]

什么是最好的方法来做到这一点?与此同时,我将继续尝试,如果我找到解决方案,我会在这里发布。

谢谢。


你想用数组索引做什么?你最好只是使用 (v > 2.3) & (v < 2.7) 的结果(一个布尔类型的数组),而不是使用一个索引数组。 - Sven Marnach
2个回答

2
你需要的是 np.argwhere 函数,它会告诉你在一个数组中哪些位置满足某个条件。
v = np.array([[ 0.     ,     5.37060126,  2.68530063 , 4.65107712 , 2.5 ],
              [ 5.37060126 ,  4.65107712 , 2.68530063 ,.11190697,1 ]])


np.argwhere((v > 2.3) & (v < 2.7))

array([[0, 2],
        [0, 4],
         [1, 2]])

1
您需要使用的是 numpy.where,它返回一个元组,其中包含对于一个 numpy.ndarray 的值,在某些条件下为 True 的每个维度的索引。以下是使用您的数据的示例:
i, j = np.where(((a > 2.3) & (a < 2.7)))
#(array([ 0,  0,  2,  2,  4,  4,  6,  6,  8,  8, 10, 10], dtype=int64),
# array([2, 4, 3, 5, 0, 3, 1, 2, 0, 5, 1, 4], dtype=int64))

然后,您可以使用 groupby 将输出放入所需的格式中:
from itertools import groupby
for k,g in itertools.groupby(zip(i, j), lambda x: x[0]):
    print k, [tmp[1] for tmp in zip(*g)]
#0 [0, 4]
#2 [2, 5]
#4 [4, 3]
#6 [6, 2]
#8 [8, 5]
#10 [10, 4]

我尝试使用您的解决方案,但总是出现错误。我还在阅读有关如何使用groupby的内容,但目前为止没有什么收获。for k,g in groupby(zip(i,j), lambda x: x[0]): NameError: name 'i' is not defined - muammar
@muammar i, j 是由 np.where 返回的索引... 我忘记放答案了... 现在我已经编辑过了,应该可以工作了! - Saullo G. P. Castro
argwherewhere都使用nonzero - hpaulj

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接