高效地查找两个二维 numpy 数组的行交集

6
我将尝试找到一种有效的方法来查找两个np.arrays的行交集。
这两个数组具有相同的形状,并且每行中不能出现重复的值。
例如:
import numpy as np

a = np.array([[2,5,6],
              [8,2,3],
              [4,1,5],
              [1,7,9]])

b = np.array([[2,3,4],  # one element(2) in common with a[0] -> 1
              [7,4,3],  # one element(3) in common with a[1] -> 1
              [5,4,1],  # three elements(5,4,1) in common with a[2] -> 3
              [7,6,9]]) # two element(9,7) in common with a[3] -> 2

我的期望输出是:np.array([1,1,3,2])

使用循环很容易实现:

def get_intersect1ds(a, b):
    result = np.empty(a.shape[0], dtype=np.int)
    for i in xrange(a.shape[0]):
        result[i] = (len(np.intersect1d(a[i], b[i])))
    return result

结果:

>>> get_intersect1ds(a, b)
array([1, 1, 3, 2])

但是有没有更有效的方法来完成它呢?

哦,ab 在每一行中可以有重复的值吗? - YXD
@MrE 说得好,重复不可能发生。谢谢。 - Akavall
你期望输入的数组有多大? - Warren Weckesser
@WarrenWeckesser,400万除以25,我可能会经常执行这个操作。 - Akavall
3个回答

7
如果一行内没有重复项,你可以尝试复制np.intersect1d在底层所做的事情(请参见源代码这里):
>>> c = np.hstack((a, b))
>>> c
array([[2, 5, 6, 2, 3, 4],
       [8, 2, 3, 7, 4, 3],
       [4, 1, 5, 5, 4, 1],
       [1, 7, 9, 7, 6, 9]])
>>> c.sort(axis=1)
>>> c
array([[2, 2, 3, 4, 5, 6],
       [2, 3, 3, 4, 7, 8],
       [1, 1, 4, 4, 5, 5],
       [1, 6, 7, 7, 9, 9]])
>>> c[:, 1:] == c[:, :-1]
array([[ True, False, False, False, False],
       [False,  True, False, False, False],
       [ True, False,  True, False,  True],
       [False, False,  True, False,  True]], dtype=bool)
>>> np.sum(c[:, 1:] == c[:, :-1], axis=1)
array([1, 1, 3, 2])

你能解释一下 c[:, 1:] == c[:, :-1] 这行代码背后的算法吗? - Elad Maimoni

2

这个答案可能不可行,因为如果输入的形状为(N,M),它会生成一个大小为(N,M,M)的中间数组,但使用广播总是很有趣:

In [43]: a
Out[43]: 
array([[2, 5, 6],
       [8, 2, 3],
       [4, 1, 5],
       [1, 7, 9]])

In [44]: b
Out[44]: 
array([[2, 3, 4],
       [7, 4, 3],
       [5, 4, 1],
       [7, 6, 9]])

In [45]: (np.expand_dims(a, -1) == np.expand_dims(b, 1)).sum(axis=-1).sum(axis=-1)
Out[45]: array([1, 1, 3, 2])

对于大型数组,可以通过分批处理来使方法更加内存友好。


1

我想不出一个纯粹的numpy解决方案,但以下建议应该可以加速,潜在地显著提高速度:

  1. 使用numba。只需在您的get_intersect1ds函数上添加@autojit即可。
  2. 在调用intersect1d时传递assume_unique = True

很遗憾,我没有访问numba的权限,但我在考虑使用cython。我认为这也应该可以工作。感谢您的建议。 - Akavall

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接