取三个集合的公共元素

3

我有三个numpy数组。

[40  9  0 12 49  1  3  4 18 13 34 47]

[40  0 28 39 29 27 50  9 42 41]

[40  0  9 48 46  1 38 45 15 27 31 36  3 12 16 41 30 33 22 37 28  4  2  6 50
 29 32 49 35  7 11 23 44 42 14 13]

现在我想获取所有同时出现在两个或三个集合中的元素。就像上面的例子,前三个元素是所有三个集合共有的,所以它们将被保留。然后你会看到,12出现在第一组和第三组中,因此即使不在第二组中,也应该保留它。50出现在第二组和第三组中,因此即使不在第一组中,也应该保留它。
因此,基本上任何成对或全部共同点都应该被保留。
我做了类似于这样的事情,但很明显,这会保留所有来自三个集合的不同元素。
set(list(shortlistvar_rf)) & set(list(shortlistvar_f)) & set(list(shortlistvar_rl))

这些输入数组中是否可能存在重复项? - Divakar
3个回答

4
Numpy有许多1D数组的集合操作可供使用。在编写任何代码之前,请注意您所需的一般公式:
(a & b) | (b & c) | (c & a)

可以使用布尔代数来化简为:

(b & (a | c)) | (a & c)

这需要4个操作而不是5个。

考虑到这一点,您可以简单地执行以下操作:

>>> np.union1d(np.intersect1d(b, np.union1d(a, c)), np.intersect1d(a, c))
array([ 0,  1,  3,  4,  9, 12, 13, 27, 28, 29, 40, 41, 42, 49, 50])

3
>>> a = [40,  9,  0, 12 ,49  ,1  ,3  ,4 ,18 ,13 ,34 ,47]
>>> b = [40  ,0 ,28 ,39 ,29 ,27 ,50  ,9 ,42 ,41]
>>> c = [40  ,0  ,9 ,48 ,46  ,1 ,38 ,45 ,15 ,27 ,31 ,36  ,3 ,12 ,16 ,41 ,30 ,33 ,22 ,37 ,28  ,4  ,2  ,6 ,50,29 ,32 ,49 ,35  ,7 ,11 ,23 ,44 ,42 ,14 ,13]
>>> (set(a) & set(b)) | (set(a) & set(c)) | (set(b) & set(c))
{0, 1, 3, 4, 40, 9, 42, 41, 12, 13, 49, 50, 27, 28, 29}

这些是numpy数组。所以我应该先将它们转换为列表。因为集合操作不能在numpy数组上工作。 - Baktaawar
如果你可以用list()将它们转换为列表,那么你可能可以直接使用set()将它们转换为集合,而不必先将它们转换为列表。我之所以在这里使用了list(),是因为我的机器没有安装numpy。 - TigerhawkT3
当我执行set(numpy.array)时,它会报错,说集合操作未定义。 - Baktaawar
尝试使用类似于set(x for x in a)set(iter(a))的语法,这将为set()提供一个生成器表达式。我假设numpy数组的行为类似于可迭代对象,对吗? - TigerhawkT3

1
你可以将这三个输入数组的唯一元素版本连接成一个单独的数组。然后,对相同元素进行排序并找出其运行长度。与运行长度大于1相对应的元素将是至少在这三个原始输入数组中两个数组中的元素。
以下是实现代码 -
import numpy as np

# Get unique elements versions of input arrays
unqA = np.unique(A)
unqB = np.unique(B)
unqC = np.unique(C)

# Combine them into one single array and then sort it
comb_sorted = np.sort(np.hstack((unqA,unqB,unqC)))

# Find indices where group changes, where a group means a run of idential elements.
# These identical elements basically represent those common elements between inputs.
idx  = np.where(np.diff(comb_sorted))[0]
grp_change = np.hstack([ [-1],idx,[comb_sorted.size-1] ])+1

# Finally, get the runlengths of each group, detect those runlength > 1 and,
# get the corresponding elements from the combined array
common_ele = comb_sorted[grp_change[np.diff(grp_change)>1]]

基准测试

本节列出了一些运行时间测试,比较提议的方法与使用unionintersect进行numpy数组操作的其他方法,这些方法在 @Jaime's solution中。

案例#1:对于已经具有唯一元素的输入数组 -

设置输入数组:

A = np.random.randint(0,1000,[1,1000000])
B = np.random.randint(0,1000,[1,1000000])
C = np.random.randint(0,1000,[1,1000000])

A = A.ravel()
B = B.ravel()
C = C.ravel()

_, idx1 = np.unique(A, return_index=True)
A = A[np.sort(idx1)]

_, idx2 = np.unique(B, return_index=True)
B = B[np.sort(idx2)]

_, idx3 = np.unique(C, return_index=True)
C = C[np.sort(idx3)]

运行时:

In [6]: %timeit concat(A,B,C)
10000 loops, best of 3: 136 µs per loop

In [7]: %timeit union_intersect(A,B,C)
1000 loops, best of 3: 315 µs per loop

案例 #2:针对可能包含重复项的通用输入数组 -

设置输入数组:

A = np.random.randint(0,1000,[1,1000000])
B = np.random.randint(0,1000,[1,1000000])
C = np.random.randint(0,1000,[1,1000000])

A = A.ravel()
B = B.ravel()
C = C.ravel()

运行时:

In [24]: %timeit concat(A,B,C)
10 loops, best of 3: 102 ms per loop

In [25]: %timeit union_intersect(A,B,C)
10 loops, best of 3: 172 ms per loop

不错!你正在完成 np.intersect1d 的功能,并且泛化到了超过2个数组,这比我的解决方案要好得多。如果你能够重新设计它以使用布尔索引而不是整数索引,即摆脱 where 调用,请参见链接的代码示例,你应该可以大大提高性能。 - Jaime
@Jaime 我确实尝试使用“布尔索引”,但是由于我需要找到运行长度,所以我想我需要找到那些移动的索引,然后执行“diff”,因此在这里可能更喜欢使用np.where。很高兴看到这些解决方案,似乎本质上我最终得到的解决方案都是基于所有这些numpy内置函数的源代码! :) 感谢您分享这些链接! - Divakar
现在不能测试,但是 idx = np.concatenate(([True], comb_sorted[:-1] != comb_sorted[1:])); common_ele = comb_sorted[idx] 这个行不行? - Jaime

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接