Python - 2D Numpy数组的交集

3
我正在拼命寻找一种高效的方法来检查两个2D numpy数组是否相交。
所以我有两个数组,每个数组都有任意数量的2D数组,例如:
A=np.array([[2,3,4],[5,6,7],[8,9,10]])
B=np.array([[5,6,7],[1,3,4]])
C=np.array([[1,2,3],[6,6,7],[10,8,9]])

我只需要一个True,如果至少有一个向量与另一个数组中的另一个向量相交,否则为false。因此,它应该给出以下结果:
f(A,B)  -> True
f(A,C)  -> False

我有点新手,开始用Python列表编写程序,虽然可以运行,但效率非常低。程序需要几天才能完成,因此我现在正在研究使用numpy.array解决方案,但这些数组真的不太容易处理。
以下是我的程序和Python列表解决方案的一些背景:
我正在进行类似于三维自避随机行走http://en.wikipedia.org/wiki/Self-avoiding_walk的操作。但是,我不是进行随机行走并希望它达到所需长度(例如,我想建立由1000个珠子组成的链),而是进行以下操作:
我创建一个“平坦”的链,长度为N:
X=[]
for i in range(0,N+1):
    X.append((i,0,0))

现在我将把这个扁平的链折叠起来:
  1. 随机选择一个元素(“枢轴元素”)
  2. 随机选择一个方向(要么全部元素朝左,要么朝右)
  3. 随机选择9种可能的空间旋转之一(3个轴* 3种可能的旋转90°、180°、270°)
  4. 旋转所选方向的所有元素到所选的旋转角度
  5. 检查所选方向的新元素是否与其他方向相交
  6. 无交点->接受新配置,否则->保持原有链。
步骤1-6必须进行大量次数(例如对于长度为1000的链,约为5000次),因此这些步骤必须有效地完成。我基于列表的解决方案如下:
def PivotFold(chain):
randPiv=random.randint(1,N)  #Chooses a random pivotelement, N is the Chainlength
Pivot=chain[randPiv]  #get that pivotelement
C=[]  #C is going to be a shifted copy of the chain
intersect=False
for j in range (0,N+1):   # Here i shift the hole chain to get the pivotelement to the origin, so i can use simple rotations around the origin
    C.append((chain[j][0]-Pivot[0],chain[j][1]-Pivot[1],chain[j][2]-Pivot[2]))
rotRand=random.randint(1,18)  # rotRand is used to choose a direction and a Rotation (2 possible direction * 9 rotations = 18 possibilitys)
#Rotations around Z-Axis
if rotRand==1:
    for j in range (randPiv,N+1):
        C[j]=(-C[j][1],C[j][0],C[j][2])
        if C[0:randPiv].__contains__(C[j])==True:
            intersect=True
            break
elif rotRand==2:
    for j in range (randPiv,N+1):
        C[j]=(C[j][1],-C[j][0],C[j][2])
        if C[0:randPiv].__contains__(C[j])==True:
            intersect=True
            break
...etc
if intersect==False: # return C if there was no intersection in C
    Shizz=C
else:
    Shizz=chain
return Shizz

函数PivotFold(chain)将在最初的平面链X上使用大量次数。它写得很朴素,也许您有一些技巧来改进它^^我认为numpy数组会很好,因为我可以有效地转移和旋转整个链而不必循环遍历所有元素...

哦,当然是2D的:D我之前想到3D是因为每个元素都包含3个数字。通过交叉,我的意思是我想检查是否有任何两个元素完全相等。所以[1,2,3]和[1,2,3]是相同的。但是例如[2,3,4]和[3,2,4]就不一样了。想象一下普通的3D向量...在空间中不应该有任何两个指向同一位置的向量。 - user3785759
您计划在何种上下文中使用它?这些数组代表什么?NumPy的效率高度依赖于批量执行操作;一次要计算多少个交集? - user2357112
1
我不确定NumPy是否适合这里。你可以使用常规Python数据结构得到相当高效的结果。例如,使用元组集而不是NumPy数组,线性时间交集只需not set1.isdisjoint(set2)。解决所有对交集问题,并找到N个数组之间的所有交集,时间大约与N个单独交集相当,而不是N^2,只要交集不太多。你能展示基于列表的解决方案吗? - user2357112
不完全相同但相关的内容:https://dev59.com/1n_aa4cB1Zd3GeqP9_EF - user2379410
我编辑了一些内容并将基于列表的解决方案添加到了我的原始帖子中。 - user3785759
显示剩余3条评论
6个回答

4
这应该就可以了:
In [11]:

def f(arrA, arrB):
    return not set(map(tuple, arrA)).isdisjoint(map(tuple, arrB))
In [12]:

f(A, B)
Out[12]:
True
In [13]:

f(A, C)
Out[13]:
False
In [14]:

f(B, C)
Out[14]:
False

要找交集?好的,set听起来像是一个合理的选择。 但是numpy.arraylist不可哈希?好的,将它们转换为tuple。 这就是思路。
使用numpy的方法涉及非常难以阅读的广播:
In [34]:

(A[...,np.newaxis]==B[...,np.newaxis].T).all(1)
Out[34]:
array([[False, False],
       [ True, False],
       [False, False]], dtype=bool)
In [36]:

(A[...,np.newaxis]==B[...,np.newaxis].T).all(1).any()
Out[36]:
True

一些 timeit 的结果:

In [38]:
#Dan's method
%timeit set_comp(A,B)
10000 loops, best of 3: 34.1 µs per loop
In [39]:
#Avoiding lambda will speed things up
%timeit f(A,B)
10000 loops, best of 3: 23.8 µs per loop
In [40]:
#numpy way probably will be slow, unless the size of the array is very big (my guess)
%timeit (A[...,np.newaxis]==B[...,np.newaxis].T).all(1).any()
10000 loops, best of 3: 49.8 µs per loop

此外,numpy方法会占用大量内存,因为A[...,np.newaxis]==B[...,np.newaxis].T步骤会创建一个三维数组。


啊,抱歉@CTZhum,我在写我的回复时才看到你的回复。它更简洁,很好 :) - daniel
1
嘿@daniel, numpy 的解决方案更加简洁:(A[...,np.newaxis]==B[...,np.newaxis].T).all(1).any()。 干杯。 - CT Zhu
嘿,大家好,谢谢你们的回答!看起来不错,我明天会尝试实现它 :) - user3785759
能够实现你的set方法 :) 对于大数组似乎运行良好。非常感谢! - user3785759

3

使用与此处概念相同的思路,您可以执行以下操作:

def make_1d_view(a):
    a = np.ascontiguousarray(a)
    dt = np.dtype((np.void, a.dtype.itemsize * a.shape[1]))
    return a.view(dt).ravel()

def f(a, b):
    return len(np.intersect1d(make_1d_view(A), make_1d_view(b))) != 0

>>> f(A, B)
True
>>> f(A, C)
False

这对浮点类型无效(它不会将+0.0和-0.0视为相同的值),而np.intersect1d使用排序,因此其性能是线性对数级别的,而不是线性的。您可以通过在代码中复制np.intersect1d的源代码,并在布尔索引数组上调用np.any而不是检查返回数组的长度来挤出一些性能。

3

您还可以使用一些 np.tilenp.swapaxes 代码完成工作!

def intersect2d(X, Y):
        """
        Function to find intersection of two 2D arrays.
        Returns index of rows in X that are common to Y.
        """
        X = np.tile(X[:,:,None], (1, 1, Y.shape[0]) )
        Y = np.swapaxes(Y[:,:,None], 0, 2)
        Y = np.tile(Y, (X.shape[0], 1, 1))
        eq = np.all(np.equal(X, Y), axis = 1)
        eq = np.any(eq, axis = 1)
        return np.nonzero(eq)[0]

为了更具体地回答这个问题,您只需要检查返回的数组是否为空。

1
这应该会更快,不像for循环解决方案那样是O(n^2),但它并不完全符合numpythonic的要求。不确定如何更好地利用numpy。
def set_comp(a, b):
   sets_a = set(map(lambda x: frozenset(tuple(x)), a))
   sets_b = set(map(lambda x: frozenset(tuple(x)), b))
   return not sets_a.isdisjoint(sets_b)

0

这个问题可以使用numpy_indexed包高效地解决(免责声明:我是它的作者):

import numpy_indexed as npi
len(npi.intersection(A, B)) > 0

0
我认为如果两个数组具有子数组集,则您想要true!您可以使用以下代码:
def(A,B):
 for i in A:
  for j in B:
   if i==j
   return True
 return False 

谢谢您的回答,但我正在寻找基于numpy函数而没有任何循环的解决方案。 - user3785759
哦,这个怎么样? numpy.array_equal(a1, a2)[source]¶如果两个数组具有相同的形状和元素,则为True,否则为False。 - Mazdak
这个比较是针对数组的,在这种情况下,如果您不想使用循环,必须使用子数组。我没有任何主意! - Mazdak
是的,这个确实非常接近我想要的。但是array_equal只适用于1D数组。所以我只能用它来检查单个元素,例如np.array_equal([1,2,3],[1,2,3])会返回true...但是我必须循环遍历我的两个数组中的所有元素来检查交集,这又不是我想要的 :/ - user3785759
是的,我为您编写了代码。你必须在循环中使用它,但如果现在不想用循环,我没有任何想法!但我认为对于这种情况,我们不能不使用循环,即使在numpy中有一个函数可以做到这一点,它实际上也使用循环,因为循环的目的就是减少我们的计算量! - Mazdak
@bluebird7,在这种情况下,您需要使用(i == j).all()。只是提供信息。 - daniel

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接