在Python中找到两个大数组(矩阵)之间的差集

20

我有两个大的二维数组,想要找到它们的差集,将它们的行看作元素。在Matlab中,可以使用以下代码实现: setdiff(A,B,'rows')。这两个数组足够大,所以我想到的明显的循环方法耗时太长。


“Set difference” 是什么意思? - reptilicus
@user1443118 我猜他的意思是“在A中而不在B中的值”,参见http://www.mathworks.com/help/techdoc/ref/setdiff.html。 - Hooked
"set difference" 的意思是指集合论中的差集操作? - Pablo Santa Cruz
你的二维数组长什么样?是一个列表的列表吗? - Pablo Santa Cruz
这些数组的维度是否相同? - reptilicus
显示剩余2条评论
3个回答

19

这个应该能工作,但在1.6.1版本中由于创建视图的不可用mergesort而无法使用。在即将发布的1.7.0版本中可以正常工作。这应该是最快的方式,因为视图不需要复制任何内存:

>>> import numpy as np
>>> a1 = np.array([[1,2,3],[4,5,6],[7,8,9]])
>>> a2 = np.array([[4,5,6],[7,8,9],[1,1,1]])
>>> a1_rows = a1.view([('', a1.dtype)] * a1.shape[1])
>>> a2_rows = a2.view([('', a2.dtype)] * a2.shape[1])
>>> np.setdiff1d(a1_rows, a2_rows).view(a1.dtype).reshape(-1, a1.shape[1])
array([[1, 2, 3]])

你可以用Python做到这一点,但可能会很慢:

>>> import numpy as np
>>> a1 = np.array([[1,2,3],[4,5,6],[7,8,9]])
>>> a2 = np.array([[4,5,6],[7,8,9],[1,1,1]])
>>> a1_rows = set(map(tuple, a1))
>>> a2_rows = set(map(tuple, a2))
>>> a1_rows.difference(a2_rows)
set([(1, 2, 3)])

谢谢。底部的方法最终崩溃了,但是一旦我弄清楚如何安装新版本的numpy,我会尝试顶部的方法。 - zss

8
这里有一个不错的纯 numpy 解决方案,适用于 1.6.1 版本。它会创建一个中间数组,所以这可能或可能不是你所需要的。它也不依赖于任何来自排序数组的加速(可能像 setdiff 那样)。
from numpy import *
# Create some sample arrays
A =random.randint(0,5,(10,3))
B =random.randint(0,5,(10,3))

作为一个例子,这是我得到的内容 - 请注意有一个共同的元素:
>>> A
array([[1, 0, 3],
       [0, 4, 2],
       [0, 3, 4],
       [4, 4, 2],
       [2, 0, 2],
       [4, 0, 0],
       [3, 2, 2],
       [4, 2, 3],
       [0, 2, 1],
       [2, 0, 2]])
>>> B
array([[4, 1, 3],
       [4, 3, 0],
       [0, 3, 3],
       [3, 0, 3],
       [3, 4, 0],
       [3, 2, 3],
       [3, 1, 2],
       [4, 1, 2],
       [0, 4, 2],
       [0, 0, 3]])

我们寻找行之间的(L1)距离为零时,这给了我们一个矩阵,在它为零的点上,这些是两个列表共同拥有的项目:

idx = where(abs((A[:,newaxis,:] - B)).sum(axis=2)==0)

作为检查:
>>> A[idx[0]]
array([[0, 4, 2]])
>>> B[idx[1]]
array([[0, 4, 2]])

能否请下投票者解释一下?我欢迎任何批评或意见,以改进我的工作。 - Hooked
感谢您提供的巧妙代码(我会记住新轴公式)。不幸的是,当我尝试时,出现了错误:“ValueError:数组太大。” - zss
当你运行A.size()B.size()时,数组有多大? - Hooked

-1

我不确定你想要什么,但这将为您提供一个布尔数组,其中包含两个数组不相等的位置,并且速度非常快:


import numpy as np
a = np.random.randn(5, 5)
b = np.random.randn(5, 5)
a[0,0] = 10.0
b[0,0] = 10.0 
a[1,1] = 5.0
b[1,1] = 5.0
c = ~(a-b==0)
print c

[[假 真 真 真 真] [真 假 真 真 真] [真 真 真 真 真] [真 真 真 真 真] [真 真 真 真 真]]


1
这不正确,它在比较元素。 OP 寻找的是行的集合差异。 - Hooked
确实,“a[0, c[0]]给出了a中第0行不在b中的元素”,但我理解问题的方式并不是找到每一行A和B中相同的元素,而是找到匹配的A行和B行。 - Hooked
然而,从匹配矩阵中,您可以轻松地使用np.all(match_matrix, axis=0)转到给定行匹配的数组。 - Okarin

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接