Python:交集索引numpy数组

44

我如何获得两个NumPy数组之间交点的索引?我可以使用intersect1d获取交集值:

import numpy as np

a = np.array(xrange(11))
b = np.array([2, 7, 10])
inter = np.intersect1d(a, b)
# inter == array([ 2,  7, 10])

我该如何获取inter中的值在a中的索引?

6个回答

47

您可以使用in1d生成的布尔数组来索引一个arange。将a翻转,使得索引与值不同:

>>> a[::-1]
array([10,  9,  8,  7,  6,  5,  4,  3,  2,  1,  0])
>>> a = a[::-1]

intersect1d 仍然返回相同的值...

>>> numpy.intersect1d(a, b)
array([ 2,  7, 10])

但是in1d返回一个布尔类型的数组:

>>> numpy.in1d(a, b)
array([ True, False, False,  True, False, False, False, False,  True,
       False, False], dtype=bool)

哪个可以用来索引一定范围:

>>> numpy.arange(a.shape[0])[numpy.in1d(a, b)]
array([0, 3, 8])
>>> indices = numpy.arange(a.shape[0])[numpy.in1d(a, b)]
>>> a[indices]
array([10,  7,  2])
为了简化上述问题,你可以使用nonzero函数,因为它返回一组均匀列表的XY等坐标的元组,这可能是最正确的方法。
>>> numpy.nonzero(numpy.in1d(a, b))
(array([0, 3, 8]),)

或者,等价地:

>>> numpy.in1d(a, b).nonzero()
(array([0, 3, 8]),)

该结果可无障碍地用作与a形状相同的数组的索引。

>>> a[numpy.nonzero(numpy.in1d(a, b))]
array([10,  7,  2])

需要注意的是,在许多情况下,直接使用布尔数组本身就足够了,而不必将其转换为一组非布尔索引。

最后,您还可以将布尔数组传递给argwhere,它会生成一个略微不同形状的结果,不太适合用作索引,但可能对其他用途有用。

>>> numpy.argwhere(numpy.in1d(a, b))
array([[0],
       [3],
       [8]])

2
很粗糙,但它能用 :) 在Octave中更容易: [inter indexA indexB] = intersect(A,b) - invis
in1d和intersect1d不是相同的。intersect1d提供唯一值,而in1d提供所有交集,因此这个答案并不总是适用。 - Rik
1
@Rik,我想我不同意。 in1d确实不会筛选重复项,但它不应该这样做。它返回的是索引,从一组重复项中仅返回一个索引将是混乱的行为。问题没有指定哪种行为,因此此答案正好符合要求:“获取两个numpy数组之间交点的索引”。如果您不想有重复项,则必须事先将其筛除,这是合理且可以预期的。 - senderle
1
我理解你的意思,你可以先使用np.unique或者在return_index=true的情况下使用intersect1d来给出索引。根据他发布的代码,我认为他想要唯一值,但不确定。 - Rik
无论如何,我在下面发布了代码,以防有人需要唯一值。 - Rik
@Rik,说得有道理。我也是这么想的,但实际上我得出结论输入已经是唯一的。你的回答是一个很好的补充,谢谢。 - senderle

2

如果您需要获取由intersect1d给出的唯一值:

import numpy as np

a = np.array([range(11,21), range(11,21)]).reshape(20)
b = np.array([12, 17, 20])
print(np.intersect1d(a,b))
#unique values

inter = np.in1d(a, b)
print(a[inter])
#you can see these values are not unique

indices=np.array(range(len(a)))[inter]
#These are the non-unique indices

_,unique=np.unique(a[inter], return_index=True)

uniqueIndices=indices[unique]
#this grabs the unique indices

print(uniqueIndices)
print(a[uniqueIndices])
#now they are unique as you would get from np.intersect1d()

输出:

[12 17 20]
[12 17 20 12 17 20]
[1 6 9]
[12 17 20]

2
indices = np.argwhere(np.in1d(a,b))

1

对于Python版本>=3.5,有另一种解决方案

其他解决方案

我们逐步进行以下步骤。

基于问题中的原始代码

import numpy as np

a = np.array(range(11))
b = np.array([2, 7, 10])
inter = np.intersect1d(a, b)

首先,我们用零创建一个numpy数组。
c = np.zeros(len(a))
print (c)

output

>>> [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.]

其次,使用交集索引更改c数组的值。因此,我们有:
c[inter] = 1
print (c)

输出

>>>[ 0.  0.  1.  0.  0.  0.  0.  1.  0.  0.  1.]

最后一步,使用np.nonzero()的特性,它将返回您想要的非零项的索引
inter_with_idx = np.nonzero(c)
print (inter_with_idx)

最终输出
array([ 2, 7, 10])

参考资料

[1] numpy.nonzero


如果有需要改进的地方,请告诉我。 - WY Hsu
有人可以解释一下为什么要踩我吗?感激不尽 :) - WY Hsu

1

从numpy版本1.15.0开始,intersect1d函数有一个return_indices选项:

numpy.intersect1d(ar1, ar2, assume_unique=False, return_indices=False)

0

这是一篇非常老的帖子,但是 numpy.intersect1d() 函数有一个 return_indices 标志

common, inda, indb = numpy.intersect1d(a,b, return_indices=True) 

该函数会返回具有相同值的a的索引/位置,用inda表示;b则用indb表示。

然而,它返回的是第一个交点。例如,如果a不是唯一的且有4个相似的值,则返回a的索引为第一个相交点。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接