交集两个集合,高效地保留所有(最多)三部分。

4
如果你有两个集合ab,并且对它们求交集,那么会出现三个有趣的部分(可能为空):属于a但不属于b的h(头)元素、既属于a又属于b的i(交)元素以及属于b但不属于a的t(尾)元素。
例如:{1, 2, 3} & {2, 3, 4} -> h:{1}, i:{2, 3}, t:{4} (不是实际的Python代码,显然)
在 Python 中编写此操作的一种简单方式是:
h, i, t = a - b, a & b, b - a

我认为这可能会更加高效:

h, t = a - (i := a & b), b - i

因为该方法首先计算交集,然后仅从 ab 中减去该交集。如果 i 很小且 ab 很大,则可以帮助 - 虽然我认为它取决于减法的实现是否真正更快。据我所知,它不太可能更糟。

我无法找到这样的运算符或函数,但是由于我可以想象出执行将 ab 的三分拆分为 hit 的高效实现,所以我是否错过了已经存在的类似功能?

from magical_set_stuff import hit

h, i, t = hit(a, b)

1
我进行了一个快速测试:h,t = a -(i:= a&b),b-ih,i,t = a-b,a&b,b-a 快大约1.5-2倍。令人惊讶的是,h,i,t = a -(a&b),a&b,b -(a&b) 稍微但一致地比第二种方法更快。通过在a上迭代并在b中进行成员检查来创建hi的循环方法可预测地表现不佳(2-6倍),比第一种方法慢。 - Pranav Hosangadi
1
这真是令人惊讶 - 可能是因为进行了一些水下优化,因为我想不出为什么两次访问和使用 i 会比重新评估 (a&b) 慢两倍。 - Grismar
仅澄清一下:通过减去a - bb - a的方法略慢于通过a - (a&b)b - (a&b)的方法,这令人惊讶,因为需要计算两次交集。 a-i,b-i始终比两者都快。 - Pranav Hosangadi
1
啊,谢谢 - 因为我刚刚自己运行了一个测试,那些解决方案大致比较为 4.3:2.4:3.6,所以我同意 a - i, b - i 的方法要快得多。在 Python 中编写循环解决方案,我可以达到约 4.6,但我无法击败上述任何一种方法。当然,使用 C 函数可能会更好。 - Grismar
1
我猜我们也可以只使用 set.differencet = b - (i:= a - (h:= a - b))。当交集很大时,这个似乎执行得更快。 - user7864386
3个回答

2

这段内容与Python无关,也没有在任何第三方库中看到这样的东西。

以下是一个或许出乎意料的方法,它对于集合大小的不同以及输入中可能存在的重叠程度几乎是不敏感的。我梦想着解决一个相关问题时想出了这个方法:假设你有3个输入集合,并希望推导出7个有趣的重叠集合(仅在A中、仅在B中、仅在C中、既在A和B中、既在A和C中、既在B和C中、还是在所有3个中)。这个版本将其简化为2个输入的情况。通常,为每个输入分配一个唯一的2的次幂,并将其用作位标志:

def hit(a, b):
    x2flags = defaultdict(int)
    for x in a:
        x2flags[x] = 1
    for x in b:
        x2flags[x] |= 2
    result = [None, set(), set(), set()]
    for x, flag in x2flags.items():
        result[flag].add(x)
    return result[1], result[3], result[2]

2

除非没有人能够超越我的解决方案或任何优秀而简洁的Python解决方案,否则我不会接受自己的答案。

但对于任何有兴趣了解一些数字的人:

from random import randint
from timeit import timeit


def grismar(a: set, b: set):
    h, i, t = set(), set(), b.copy()
    for x in a:
        if x in t:
            i.add(x)
            t.remove(x)
        else:
            h.add(x)
    return h, i, t


def good(a: set, b: set):
    return a - b, a & b, b - a


def better(a: set, b: set):
    h, t = a - (i := a & b), b - i
    return h, i, t


def ok(a: set, b: set):
    return a - (a & b), a & b, b - (a & b)


from collections import defaultdict
def tim(a, b):
    x2flags = defaultdict(int)
    for x in a:
        x2flags[x] = 1
    for x in b:
        x2flags[x] |= 2
    result = [None, set(), set(), set()]
    for x, flag in x2flags.items():
        result[flag].add(x)
    return result[1], result[3], result[2]


def pychopath(a, b):
    h, t = set(), b.copy()
    h_add = h.add
    t_remove = t.remove
    i = {x for x in a
         if x in t and not t_remove(x) or h_add(x)}
    return h, i, t


def enke(a, b):
    t = b - (i := a - (h := a - b))
    return h, i, t


xs = set(randint(0, 10000) for _ in range(10000))
ys = set(randint(0, 10000) for _ in range(10000))

# validation
g = (f(xs, ys) for f in (grismar, good, better, ok, tim, enke))
l = set(tuple(tuple(sorted(s)) for s in t) for t in g)
assert len(l) == 1, 'functions are equivalent'

# warmup, not competing
timeit(lambda: grismar(xs, ys), number=500)

# competition
print('a - b, a & b, b - a ', timeit(lambda: good(xs, ys), number=10000))
print('a - (i := a & b), b - i ', timeit(lambda: better(xs, ys), number=10000))
print('a - (a & b), a & b, b - (a & b) ', timeit(lambda: ok(xs, ys), number=10000))
print('tim ', timeit(lambda: tim(xs, ys), number=10000))
print('grismar ', timeit(lambda: grismar(xs, ys), number=10000))
print('pychopath ', timeit(lambda: pychopath(xs, ys), number=10000))
print('b - (i := a - (h := a - b)) ', timeit(lambda: enke(xs, ys), number=10000))

结果:

a - b, a & b, b - a  5.6963334
a - (i := a & b), b - i  5.3934624
a - (a & b), a & b, b - (a & b)  9.7732018
tim  16.3080373
grismar  7.709292500000004
pychopath  6.76331460000074
b - (i := a - (h := a - b))  5.197220600000001

到目前为止,@enke在评论中提出的优化方案似乎是最优的:
t = b - (i := a - (h := a - b))
return h, i, t

编辑:添加了@Pychopath的结果,它确实比我的结果快得多,尽管@enke的结果仍然是最优秀的(也很可能不只是用Python)。如果@enke发布他们自己的答案,我会很乐意接受它作为答案。


1

你的优化版本,在你的基准测试中似乎比你的版本快了约20%:

def hit(a, b):
    h, t = set(), b.copy()
    h_add = h.add
    t_remove = t.remove
    i = {x for x in a
         if x in t and not t_remove(x) or h_add(x)}
    return h, i, t

如果两个集合的大小差异较大,你可能希望在开始时执行此操作:

    if len(a) > len(b):
        return hit(b, a)[::-1]

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接