Python:基于交集的简单列表合并

54
考虑存在以下整数列表:

Consider there are some lists of integers as:

#--------------------------------------
0 [0,1,3]
1 [1,0,3,4,5,10,...]
2 [2,8]
3 [3,1,0,...]
...
n []
#--------------------------------------

这个问题是要合并至少有一个共同元素的列表。因此,仅针对给定部分的结果如下:

#--------------------------------------
0 [0,1,3,4,5,10,...]
2 [2,8]
#--------------------------------------

在大数据集上,如何最有效地实现此操作(元素仅为数字)? 树形结构是否值得考虑? 目前我通过将列表转换为set并迭代交集来完成工作,但速度较慢!此外,我觉得这是非常基础的!另外,我的实现缺少某些东西(未知),因为有时一些列表仍然未合并!话虽如此,如果您建议自己实现,请慷慨提供一个简单的示例代码[显然,Python 是我最喜欢的:)]或伪代码。
更新 1: 以下是我使用的代码:

#--------------------------------------
lsts = [[0,1,3],
        [1,0,3,4,5,10,11],
        [2,8],
        [3,1,0,16]];
#--------------------------------------

该函数有(漏洞!!):

#--------------------------------------
def merge(lsts):
    sts = [set(l) for l in lsts]
    i = 0
    while i < len(sts):
        j = i+1
        while j < len(sts):
            if len(sts[i].intersection(sts[j])) > 0:
                sts[i] = sts[i].union(sts[j])
                sts.pop(j)
            else: j += 1                        #---corrected
        i += 1
    lst = [list(s) for s in sts]
    return lst
#--------------------------------------
结果是:
#--------------------------------------
>>> merge(lsts)
>>> [0, 1, 3, 4, 5, 10, 11, 16], [8, 2]]
#--------------------------------------

更新2: 根据Niklas Baumstark提供的代码,对于简单情况来说,速度要稍快一些。还没有测试“Hooked”提供的方法,因为它是完全不同的方法(顺便说一句,它看起来很有趣)。 对于所有这些方法的测试程序可能非常困难或不可能确保结果。我将使用的真实数据集非常庞大且复杂,因此仅通过重复无法追踪任何错误。也就是说,在将其作为模块放入大型代码之前,我需要100%满意该方法的可靠性。因此,现在对于简单的数据集而言,Niklas的方法更快且答案当然是正确的。
但是我如何确信它对于真正的大型数据集有效呢? 因为我将无法通过视觉追踪错误!

更新3: 请注意,对于此问题,方法的可靠性比速度更重要。最终我希望能够将Python代码转换为Fortran以获得最大的性能。

更新4:
这篇文章和慷慨给出的答案、建设性的评论都有很多有趣的观点。我建议仔细阅读所有内容。感谢提出问题、给出惊人的答案、建设性的评论和讨论。


“大数据”是指许多列表或非常长的列表吗?也许聪明的多线程可以为您节省一些时间。 - Rik Poggi
@RikPoggi 几乎都是:许多列表,每个列表都可能很长。 - Developer
代码示例似乎不完整,特别是while语句。 - Janne Karila
@NiklasBaumstark,正如我之前所写,它从一开始就是完整的,但SO呈现中存在奇怪的困难。如果您尝试查看页面源代码,则可以进行检查。无论如何,我已将所有内容替换为HTML版本,现在对所有人都应该没问题了。 - Developer
听起来像是一种算法,用于给定图的各种块(例如每个节点及其边缘),您需要将连接的子图排序。 - Chris Morgan
显示剩余16条评论
19个回答

30
我的尝试:
def merge(lsts):
    sets = [set(lst) for lst in lsts if lst]
    merged = True
    while merged:
        merged = False
        results = []
        while sets:
            common, rest = sets[0], sets[1:]
            sets = []
            for x in rest:
                if x.isdisjoint(common):
                    sets.append(x)
                else:
                    merged = True
                    common |= x
            results.append(common)
        sets = results
    return sets

lst = [[65, 17, 5, 30, 79, 56, 48, 62],
       [6, 97, 32, 93, 55, 14, 70, 32],
       [75, 37, 83, 34, 9, 19, 14, 64],
       [43, 71],
       [],
       [89, 49, 1, 30, 28, 3, 63],
       [35, 21, 68, 94, 57, 94, 9, 3],
       [16],
       [29, 9, 97, 43],
       [17, 63, 24]]
print merge(lst)

基准测试:

import random

# adapt parameters to your own usage scenario
class_count = 50
class_size = 1000
list_count_per_class = 100
large_list_sizes = list(range(100, 1000))
small_list_sizes = list(range(0, 100))
large_list_probability = 0.5

if False:  # change to true to generate the test data file (takes a while)
    with open("/tmp/test.txt", "w") as f:
        lists = []
        classes = [
            range(class_size * i, class_size * (i + 1)) for i in range(class_count)
        ]
        for c in classes:
            # distribute each class across ~300 lists
            for i in xrange(list_count_per_class):
                lst = []
                if random.random() < large_list_probability:
                    size = random.choice(large_list_sizes)
                else:
                    size = random.choice(small_list_sizes)
                nums = set(c)
                for j in xrange(size):
                    x = random.choice(list(nums))
                    lst.append(x)
                    nums.remove(x)
                random.shuffle(lst)
                lists.append(lst)
        random.shuffle(lists)
        for lst in lists:
            f.write(" ".join(str(x) for x in lst) + "\n")

setup = """
# Niklas'
def merge_niklas(lsts):
    sets = [set(lst) for lst in lsts if lst]
    merged = 1
    while merged:
        merged = 0
        results = []
        while sets:
            common, rest = sets[0], sets[1:]
            sets = []
            for x in rest:
                if x.isdisjoint(common):
                    sets.append(x)
                else:
                    merged = 1
                    common |= x
            results.append(common)
        sets = results
    return sets

# Rik's
def merge_rik(data):
    sets = (set(e) for e in data if e)
    results = [next(sets)]
    for e_set in sets:
        to_update = []
        for i, res in enumerate(results):
            if not e_set.isdisjoint(res):
                to_update.insert(0, i)

        if not to_update:
            results.append(e_set)
        else:
            last = results[to_update.pop(-1)]
            for i in to_update:
                last |= results[i]
                del results[i]
            last |= e_set
    return results

# katrielalex's
def pairs(lst):
    i = iter(lst)
    first = prev = item = i.next()
    for item in i:
        yield prev, item
        prev = item
    yield item, first

import networkx

def merge_katrielalex(lsts):
    g = networkx.Graph()
    for lst in lsts:
        for edge in pairs(lst):
            g.add_edge(*edge)
    return networkx.connected_components(g)

# agf's (optimized)
from collections import deque

def merge_agf_optimized(lists):
    sets = deque(set(lst) for lst in lists if lst)
    results = []
    disjoint = 0
    current = sets.pop()
    while True:
        merged = False
        newsets = deque()
        for _ in xrange(disjoint, len(sets)):
            this = sets.pop()
            if not current.isdisjoint(this):
                current.update(this)
                merged = True
                disjoint = 0
            else:
                newsets.append(this)
                disjoint += 1
        if sets:
            newsets.extendleft(sets)
        if not merged:
            results.append(current)
            try:
                current = newsets.pop()
            except IndexError:
                break
            disjoint = 0
        sets = newsets
    return results

# agf's (simple)
def merge_agf_simple(lists):
    newsets, sets = [set(lst) for lst in lists if lst], []
    while len(sets) != len(newsets):
        sets, newsets = newsets, []
        for aset in sets:
            for eachset in newsets:
                if not aset.isdisjoint(eachset):
                    eachset.update(aset)
                    break
            else:
                newsets.append(aset)
    return newsets

# alexis'
def merge_alexis(data):
    bins = range(len(data))  # Initialize each bin[n] == n
    nums = dict()

    data = [set(m) for m in data]  # Convert to sets
    for r, row in enumerate(data):
        for num in row:
            if num not in nums:
                # New number: tag it with a pointer to this row's bin
                nums[num] = r
                continue
            else:
                dest = locatebin(bins, nums[num])
                if dest == r:
                    continue  # already in the same bin

                if dest > r:
                    dest, r = r, dest  # always merge into the smallest bin

                data[dest].update(data[r])
                data[r] = None
                # Update our indices to reflect the move
                bins[r] = dest
                r = dest

    # Filter out the empty bins
    have = [m for m in data if m]
    return have

def locatebin(bins, n):
    while bins[n] != n:
        n = bins[n]
    return n

lsts = []
size = 0
num = 0
max = 0
for line in open("/tmp/test.txt", "r"):
    lst = [int(x) for x in line.split()]
    size += len(lst)
    if len(lst) > max:
        max = len(lst)
    num += 1
    lsts.append(lst)
"""

setup += """
print "%i lists, {class_count} equally distributed classes, average size %i, max size %i" % (num, size/num, max)
""".format(class_count=class_count)

import timeit
print "niklas"
print timeit.timeit("merge_niklas(lsts)", setup=setup, number=3)
print "rik"
print timeit.timeit("merge_rik(lsts)", setup=setup, number=3)
print "katrielalex"
print timeit.timeit("merge_katrielalex(lsts)", setup=setup, number=3)
print "agf (1)"
print timeit.timeit("merge_agf_optimized(lsts)", setup=setup, number=3)
print "agf (2)"
print timeit.timeit("merge_agf_simple(lsts)", setup=setup, number=3)
print "alexis"
print timeit.timeit("merge_alexis(lsts)", setup=setup, number=3)

这些时间显然取决于基准测试的具体参数,例如类数、列表数、列表大小等。根据您的需求调整这些参数以获得更有用的结果。

以下是在我的机器上针对不同参数的一些示例输出。它们表明所有算法都有其优点和缺点,取决于它们接收的输入类型:

=====================
# many disjoint classes, large lists
class_count = 50
class_size = 1000
list_count_per_class = 100
large_list_sizes = list(range(100, 1000))
small_list_sizes = list(range(0, 100))
large_list_probability = 0.5
=====================

niklas
5000 lists, 50 equally distributed classes, average size 298, max size 999
4.80084705353
rik
5000 lists, 50 equally distributed classes, average size 298, max size 999
9.49251699448
katrielalex
5000 lists, 50 equally distributed classes, average size 298, max size 999
21.5317108631
agf (1)
5000 lists, 50 equally distributed classes, average size 298, max size 999
8.61671280861
agf (2)
5000 lists, 50 equally distributed classes, average size 298, max size 999
5.18117713928
=> alexis
=> 5000 lists, 50 equally distributed classes, average size 298, max size 999
=> 3.73504281044

===================
# less number of classes, large lists
class_count = 15
class_size = 1000
list_count_per_class = 300
large_list_sizes = list(range(100, 1000))
small_list_sizes = list(range(0, 100))
large_list_probability = 0.5
===================

niklas
4500 lists, 15 equally distributed classes, average size 296, max size 999
1.79993700981
rik
4500 lists, 15 equally distributed classes, average size 296, max size 999
2.58237695694
katrielalex
4500 lists, 15 equally distributed classes, average size 296, max size 999
19.5465381145
agf (1)
4500 lists, 15 equally distributed classes, average size 296, max size 999
2.75445604324
=> agf (2)
=> 4500 lists, 15 equally distributed classes, average size 296, max size 999
=> 1.77850699425
alexis
4500 lists, 15 equally distributed classes, average size 296, max size 999
3.23530197144

===================
# less number of classes, smaller lists
class_count = 15
class_size = 1000
list_count_per_class = 300
large_list_sizes = list(range(100, 1000))
small_list_sizes = list(range(0, 100))
large_list_probability = 0.1
===================

niklas
4500 lists, 15 equally distributed classes, average size 95, max size 997
0.773697137833
rik
4500 lists, 15 equally distributed classes, average size 95, max size 997
1.0523750782
katrielalex
4500 lists, 15 equally distributed classes, average size 95, max size 997
6.04466891289
agf (1)
4500 lists, 15 equally distributed classes, average size 95, max size 997
1.20285701752
=> agf (2)
=> 4500 lists, 15 equally distributed classes, average size 95, max size 997
=> 0.714507102966
alexis
4500 lists, 15 equally distributed classes, average size 95, max size 997
1.1286110878

1
你可以使用 not x.isdisjoint(common) 代替 x & common 来避免构建完整的交集。 - Janne Karila
lst = [[65, 17, 5, 30, 79, 56, 48, 62], [6, 97, 32, 93, 55, 14, 70, 32], [75, 37, 83, 34, 9, 19, 14, 64], [43, 71], [], [89, 49, 1, 30, 28, 3, 63], [35, 21, 68, 94, 57, 94, 9, 3], [16], [29, 9, 97, 43], [17, 63, 24]]结果[set([1, 3, 5, **9**, 14, 16, 17, 19, 21, 24, 28, 29, 30, 32, 34, 35, 37, 43, 48, 49, 55, 56, 57, 62, 63, 64, 65, 68, 70, 75, 79, 83, 89, 93, 94, 97]), set(), set([43, 71]), set(), set([9])]是不正确的。 - Developer
如果您尝试给定的列表,您会发现数字 9 存在于两个输出集合中。因此,代码存在问题,正如问题中最初提到的那样,并非可靠! - Developer
@Nik:我改进了我的代码,在你的时间测试下表现得更好了,你能否更新一下结果? - Rik Poggi
2
这个问题非常有趣,不管怎样 :-) 感谢提供测试环境。 - alexis
显示剩余13条评论

16

我试图在这个问题和重复的问题中总结关于这个主题所说和做的一切。

我尝试测试并计时每个解决方案(所有代码在此处)。

测试

这是测试模块中的TestCase

class MergeTestCase(unittest.TestCase):

    def setUp(self):
        with open('./lists/test_list.txt') as f:
            self.lsts = json.loads(f.read())
        self.merged = self.merge_func(deepcopy(self.lsts))

    def test_disjoint(self):
        """Check disjoint-ness of merged results"""
        from itertools import combinations
        for a,b in combinations(self.merged, 2):
            self.assertTrue(a.isdisjoint(b))

    def test_coverage(self):    # Credit to katrielalex
        """Check coverage original data"""
        merged_flat = set()
        for s in self.merged:
            merged_flat |= s

        original_flat = set()
        for lst in self.lsts:
            original_flat |= set(lst)

        self.assertTrue(merged_flat == original_flat)

    def test_subset(self):      # Credit to WolframH
        """Check that every original data is a subset"""
        for lst in self.lsts:
            self.assertTrue(any(set(lst) <= e for e in self.merged))

这个测试假设返回的是一组集合,因此我无法测试一些适用于列表的方案。

我无法测试以下内容:

katrielalex
steabert

我测试了其中的几个,有两个失败了:

  -- Going to test: agf (optimized) --
Check disjoint-ness of merged results ... FAIL

  -- Going to test: robert king --
Check disjoint-ness of merged results ... FAIL

计时

性能与所用的数据测试密切相关。

迄今为止,有三个答案对他们自己和其他解决方案进行了时间测试。由于他们使用了不同的测试数据,因此得出了不同的结果。

  1. Niklas基准测试非常易于调整。使用他的基准测试可以通过更改一些参数进行不同的测试。

    我使用了他在自己答案中使用的相同的三组参数,并将它们放入了三个不同的文件中:

    filename = './lists/timing_1.txt'
    class_count = 50,
    class_size = 1000,
    list_count_per_class = 100,
    large_list_sizes = (100, 1000),
    small_list_sizes = (0, 100),
    large_list_probability = 0.5,
    
    filename = './lists/timing_2.txt'
    class_count = 15,
    class_size = 1000,
    list_count_per_class = 300,
    large_list_sizes = (100, 1000),
    small_list_sizes = (0, 100),
    large_list_probability = 0.5,
    
    filename = './lists/timing_3.txt'
    class_count = 15,
    class_size = 1000,
    list_count_per_class = 300,
    large_list_sizes = (100, 1000),
    small_list_sizes = (0, 100),
    large_list_probability = 0.1,
    

    这是我得到的结果:

    来自文件:timing_1.txt

    Timing with: >> Niklas << Benchmark
    Info: 5000 lists, average size 305, max size 999
    
    Timing Results:
    10.434  -- alexis
    11.476  -- agf
    11.555  -- Niklas B.
    13.622  -- Rik. Poggi
    14.016  -- agf (optimized)
    14.057  -- ChessMaster
    20.208  -- katrielalex
    21.697  -- steabert
    25.101  -- robert king
    76.870  -- Sven Marnach
    133.399  -- hochl
    

    从文件:timing_2.txt

    Timing with: >> Niklas << Benchmark
    Info: 4500 lists, average size 305, max size 999
    
    Timing Results:
    8.247  -- Niklas B.
    8.286  -- agf
    8.637  -- Rik. Poggi
    8.967  -- alexis
    9.090  -- ChessMaster
    9.091  -- agf (optimized)
    18.186  -- katrielalex
    19.543  -- steabert
    22.852  -- robert king
    70.486  -- Sven Marnach
    104.405  -- hochl
    

    来自文件:timing_3.txt

    Timing with: >> Niklas << Benchmark
    Info: 4500 lists, average size 98, max size 999
    
    Timing Results:
    2.746  -- agf
    2.850  -- Niklas B.
    2.887  -- Rik. Poggi
    2.972  -- alexis
    3.077  -- ChessMaster
    3.174  -- agf (optimized)
    5.811  -- katrielalex
    7.208  -- robert king
    9.193  -- steabert
    23.536  -- Sven Marnach
    37.436  -- hochl
    
  2. 使用Sven的测试数据,我得到了以下结果:

  3. Timing with: >> Sven << Benchmark
    Info: 200 lists, average size 10, max size 10
    
    Timing Results:
    2.053  -- alexis
    2.199  -- ChessMaster
    2.410  -- agf (optimized)
    3.394  -- agf
    3.398  -- Rik. Poggi
    3.640  -- robert king
    3.719  -- steabert
    3.776  -- Niklas B.
    3.888  -- hochl
    4.610  -- Sven Marnach
    5.018  -- katrielalex
    
  4. 最终,通过Agf的基准测试,我得到了:

  5. Timing with: >> Agf << Benchmark
    Info: 2000 lists, average size 246, max size 500
    
    Timing Results:
    3.446  -- Rik. Poggi
    3.500  -- ChessMaster
    3.520  -- agf (optimized)
    3.527  -- Niklas B.
    3.527  -- agf
    3.902  -- hochl
    5.080  -- alexis
    15.997  -- steabert
    16.422  -- katrielalex
    18.317  -- robert king
    1257.152  -- Sven Marnach
    
    正如我一开始所说的,所有的代码都可以在这个git仓库中找到。所有的合并函数都在名为core.py的文件中,文件中每个以_merge结尾的函数将在测试期间自动加载,所以添加/测试/改进你自己的解决方案不应该很难。
    如果有什么不对的地方,请让我知道,编码已经进行了很多,我需要几双新鲜的眼睛 :)

我的 Niklas B. 答案重写怎么样?我只是想知道那个时间是否相关。 - ChessMaster
@ChessMaster:有时表现略好,有时则稍差,这就是我没有将它列入结果的原因。如果您感兴趣,可以在链接中自行尝试,其中有一个名为core.py的文件,包含了所有合并函数的git存储库。每个以_merge结尾的函数都会自动加载。我刚刚把你的代码推送上去了,所以你将在跳过模式下找到它。 :) - Rik Poggi
感谢您的出色努力。 - Developer
3
有时候我真的很惊讶,这个网站上回答的质量和知识量之高。把这个汇编整理出来做得非常好! - Niklas B.
尼克拉斯-比的答案在哪里?此页面上的14个答案中没有一个是尼克拉斯-比的... - tommy.carstensen

7

使用矩阵操作

在回答问题之前,我想说:

这种方法是错误的。它容易产生数值不稳定性,并且比其他方法慢得多,请自行决定是否使用。

话虽如此,我还是忍不住要从动态角度解决这个问题(希望您能从中获得新的视角)。理论上来说,这应该总是有效的,但特征值计算经常会失败。思路是将列表视为从行到列的。如果两行共享一个公共值,则它们之间存在连接流。如果我们将这些流看作水,我们会发现当它们之间存在连接路径时,这些流会聚集成小池塘。为了简单起见,我将使用一个较小的数据集,但它也适用于您的数据集:

from numpy import where, newaxis
from scipy import linalg, array, zeros

X = [[0,1,3],[2],[3,1]]

我们需要将数据转换成流程图。如果第i行流入值j,我们将其放入矩阵中。这里有3行和4个唯一值:
A = zeros((4,len(X)), dtype=float)
for i,row in enumerate(X):
    for val in row: A[val,i] = 1

通常情况下,您需要更改4来捕获您拥有的唯一值的数量。如果该集合是从0开始的整数列表(如我们所拥有的),则可以将其作为最大数字。现在我们执行特征值分解。确切地说,是SVD,因为我们的矩阵不是方阵。
S  = linalg.svd(A)

我们希望仅保留该答案的3x3部分,因为它将代表池流的流动。实际上,我们只需要这个矩阵的绝对值;我们只关心在此“集群”空间中是否存在流动。

M  = abs(S[2])

我们可以将矩阵M视为马尔可夫矩阵,并通过行归一化使其明确。一旦完成此操作,我们会计算该矩阵的(左)特征值分解。
M /=  M.sum(axis=1)[:,newaxis]
U,V = linalg.eig(M,left=True, right=False)
V = abs(V)

现在,一个不联通(非遍历)的马尔可夫矩阵具有良好的特性,对于每个非连接的集群,都存在一个特征值为1。与这些1值相关联的特征向量是我们想要的:

idx = where(U > .999)[0]
C = V.T[idx] > 0

由于前面提到的数值不稳定性,我必须使用.999。此时,我们已经完成了!每个独立的集群现在可以拉出相应的行:

for cluster in C:
    print where(A[:,cluster].sum(axis=1))[0]

这将按预期产生以下结果:
[0 1 3]
[2]

X更改为你的lst,你将得到:[ 0 1 3 4 5 10 11 16] [2 8]

补充

这有什么用?我不知道你的基础数据来自哪里,但是当连接不是绝对的时会发生什么?假设行180%的时间有条目3 - 你怎样推广这个问题呢?上述流方法也可以正常工作,而且完全由.999值参数化,离单位越远,关联性就越松散。


可视化表示

既然一张图片胜过千言万语,这里是矩阵A和我的示例以及您的lst的V的绘图。请注意,在V中分为两个簇(在置换后是一个块对角矩阵,有两个块),因为每个示例只有两个唯一的列表!

My Example Your sample data


更快的实现

事后看来,我意识到您可以跳过SVD步骤,仅计算单个分解:

M = dot(A.T,A)
M /=  M.sum(axis=1)[:,newaxis]
U,V = linalg.eig(M,left=True, right=False)

这种方法的优点(除了速度之外)在于,M现在是对称的,因此计算可以更快、更精确(无需担心虚数)。

你是如何生成这些非常好看的图像的? - Zach Young
1
@ZacharyYoung,请查看MatPlotLib以获取绘图信息,参见“Gallery”。 - Developer
1
@开发者:哈哈,太棒了!我看到过很多有关Matplotlib的SO问题,但从未去看过它。看来我刚刚找到了我新的最爱的SO标签 :) - Zach Young
1
在阶段A = zeros((4,len(X)), dtype=float),我们需要知道整个巨大列表中有多少个唯一值存在!对于给定的示例,它是“4”,那么对于许多未知的大型列表,如何在其中找到唯一值呢? - Developer
@开发人员,让我按顺序回答那些问题。@Zachary Young关于matplotlib的说法是正确的,它是一个非常快速和易于生成图像的库。你对需要知道独特值的确切数量也是正确的,这就是4的来源。在许多问题中,您可能已经知道这一点,甚至有一个上界。只要它更大,你就没问题了。您也不受整数范围的限制,因为您始终可以将值映射到整数的有序集合上。当我测试您的“idx”时,我将“4”更改为“16”,我将在编辑中更明确地说明。 - Hooked
显示剩余7条评论

6

以下是我的回答。我还没有与今天的答案批次进行核对。

基于交集的算法是O(N^2)的,因为它们会将每个新集合与所有现有集合进行比较,所以我采用了一种索引每个数字并且运行接近O(N)(如果我们接受字典查找是O(1)的话)的方法。然后我进行了基准测试,并感到自己像个完全的白痴,因为它运行得更慢了,但仔细检查后发现,测试数据最终只有少数不同的结果集,因此二次算法没有太多的工作要做。使用超过10-15个不同的箱子进行测试,我的算法会快得多。尝试使用超过50个不同的箱子进行测试,速度会非常快。

(编辑:基准测试的运行方式也存在问题,但我的诊断是错误的。我修改了我的代码以适应重复测试的方式)。

def mergelists5(data):
    """Check each number in our arrays only once, merging when we find
    a number we have seen before.
    """

    bins = range(len(data))  # Initialize each bin[n] == n
    nums = dict()

    data = [set(m) for m in data ]  # Convert to sets    
    for r, row in enumerate(data):
        for num in row:
            if num not in nums:
                # New number: tag it with a pointer to this row's bin
                nums[num] = r
                continue
            else:
                dest = locatebin(bins, nums[num])
                if dest == r:
                    continue # already in the same bin

                if dest > r:
                    dest, r = r, dest   # always merge into the smallest bin

                data[dest].update(data[r]) 
                data[r] = None
                # Update our indices to reflect the move
                bins[r] = dest
                r = dest 

    # Filter out the empty bins
    have = [ m for m in data if m ]
    print len(have), "groups in result"
    return have


def locatebin(bins, n):
    """
    Find the bin where list n has ended up: Follow bin references until
    we find a bin that has not moved.
    """
    while bins[n] != n:
        n = bins[n]
    return n

此代码使用集合来解决timeit重复的问题,但如果您只是将列表附加到彼此,并且仅在构造have时丢弃重复项,则它同样有效(实际上稍微更快)。因此,它可能具有更像Fortran的优点(我从未想过以积极的方式这样说! :-) - alexis
有趣的分析,并且这也是最快的解决方案之一。我在各种情况下都测试过它,在我的汇总中使用了它 :) - Rik Poggi
谢谢。如果你没有数不清的冲突,它应该会更好。实际上,我删除了优化合并跟踪的结构,因为我测试的数据只有几千个冲突,看起来不值得增加复杂性。我想这完全取决于你的数据真正的样子。 - alexis
在 Python 3+ 中使用 range 函数的行为不同,可能无法洗牌。将 bins 定义为 list(range(len(data))) 可以解决此问题。有关详细信息,请参见此 stackoverflow 链接:https://dev59.com/SWIj5IYBdhLWcg3wOC2S - Liquidgenius

5

编辑:好的,其他问题已关闭,在这里发帖。

很好的问题!如果您将其视为图中的连接组件问题,则会简单得多。以下代码使用优秀的networkx图库和此问题中的pairs函数。

def pairs(lst):
    i = iter(lst)
    first = prev = item = i.next()
    for item in i:
        yield prev, item
        prev = item
    yield item, first

lists = [[1,2,3],[3,5,6],[8,9,10],[11,12,13]]

import networkx
g = networkx.Graph()
for sub_list in lists:
    for edge in pairs(sub_list):
            g.add_edge(*edge)

networkx.connected_components(g)
[[1, 2, 3, 5, 6], [8, 9, 10], [11, 12, 13]]

说明

我们创建一个新的(空)图形g。对于lists中的每个子列表,将其元素视为图形的节点,并在它们之间添加一条边。(由于我们只关心连通性,所以不需要添加所有的边——只需相邻的即可!)注意,add_edge接受两个对象,将它们视为节点(如果它们还不存在,则添加它们),并在它们之间添加一条边。

然后,我们只需要找到图的连通分量——这是一个已解决的问题!——并将它们输出作为我们的交集集合。


2
其实我一直以为这个问题早已有了解决方案,所以我不认为重新激活这个旧的线程有太多意义。然而,由于我也曾考虑过使用图库来解决这个问题,所以我将这个解决方案集成到我的基准测试中。不幸的是,它的表现似乎并不太好,看起来 Python 程序员在实现和优化集合方面做得很好 :) - Niklas B.
感谢您采用不同的方法。将networkx应用于此目的确实非常好。这是一个非常优秀的软件包。 - Developer
@NiklasB. 谢谢 =) 当我看到这个问题时,似乎没有人真正写出了一个可行的解决方案,但我想我可能读错了。 - Katriel
我喜欢networkx。顺便说一句...如果对于每个列表,您只是在列表中的第一个元素和所有其他元素之间添加了一条边,那么这可能会更简单。这可能只需要大约4行代码。 - Rusty Rob
@robertking 可以随意编辑 =)... 我只是写下了脑海中的想法! - Katriel
@katrielalex 嘿,我添加了自己的答案。希望它应该更快,因为集合中的每个节点仅连接到集合中的一个节点,所以最长路径为2,因此计算连通性不需要太多操作。 - Rusty Rob

4

这个新的函数只进行了最少必要数量的不相交测试,这是其他类似解决方案无法做到的。它还使用了一个deque尽可能避免一些线性时间操作,比如列表切片和从列表早期删除。

from collections import deque

def merge(lists):
    sets = deque(set(lst) for lst in lists if lst)
    results = []
    disjoint = 0
    current = sets.pop()
    while True:
        merged = False
        newsets = deque()
        for _ in xrange(disjoint, len(sets)):
            this = sets.pop()
            if not current.isdisjoint(this):
                current.update(this)
                merged = True
                disjoint = 0
            else:
                newsets.append(this)
                disjoint += 1
        if sets:
            newsets.extendleft(sets)
        if not merged:
            results.append(current)
            try:
                current = newsets.pop()
            except IndexError:
                break
            disjoint = 0
        sets = newsets
    return results

在给定的数据集中,集合之间的重叠越少,与其他函数相比,它的表现就会更好。
以下是一个示例案例。如果你有4个集合,需要进行比较:
1,2 1,3 1,4 2,3 2,4 3,4
如果1与3重叠,则需要重新测试2是否与1重叠,以便安全地跳过2对3的测试。
有两种方法可以解决这个问题。第一种方法是在每次重叠和合并后重新启动对集合1与其他集合的测试。第二种方法是继续进行测试,将1与4进行比较,然后返回并重新测试。后者会产生更少的不相交测试,因为更多的合并发生在单个传递中,所以在重新测试时,剩下要测试的集合更少。
问题是如何跟踪哪些集合必须重新测试。在上面的示例中,1需要重新测试与2,但不需要重新测试与4,因为1已经处于当前状态,而在4第一次测试时就已经测试过了。
“disjoint”计数器允许跟踪这一点。
我的答案并没有帮助找到改进的算法来将代码转换为FORTRAN;它只是我认为在Python中实现该算法最简单、最优雅的方式。
根据我的测试(或接受的答案中的测试),它比下一个最快的解决方案稍微快一点(高达10%)。
def merge0(lists):
    newsets, sets = [set(lst) for lst in lists if lst], []
    while len(sets) != len(newsets):
        sets, newsets = newsets, []
        for aset in sets:
            for eachset in newsets:
                if not aset.isdisjoint(eachset):
                    eachset.update(aset)
                    break
            else:
                newsets.append(aset)
    return newsets

不需要使用非Pythonic计数器(i,range)或复杂的变异(del,pop,insert),其他实现中使用了这些。它仅使用简单的迭代,在最简单的方式下合并重叠的集合,并在每次通过数据时构建一个新的列表。
我(更快更简单)版本的测试代码:
import random

tenk = range(10000)
lsts = [random.sample(tenk, random.randint(0, 500)) for _ in range(2000)]

setup = """
def merge0(lists):
  newsets, sets = [set(lst) for lst in lists if lst], []
  while len(sets) != len(newsets):
    sets, newsets = newsets, []
    for aset in sets:
      for eachset in newsets:
        if not aset.isdisjoint(eachset):
          eachset.update(aset)
          break
      else:
        newsets.append(aset)
  return newsets

def merge1(lsts):
  sets = [set(lst) for lst in lsts if lst]
  merged = 1
  while merged:
    merged = 0
    results = []
    while sets:
      common, rest = sets[0], sets[1:]
      sets = []
      for x in rest:
        if x.isdisjoint(common):
          sets.append(x)
        else:
          merged = 1
          common |= x
      results.append(common)
    sets = results
  return sets

lsts = """ + repr(lsts)

import timeit
print timeit.timeit("merge0(lsts)", setup=setup, number=10)
print timeit.timeit("merge1(lsts)", setup=setup, number=10)

非常感谢您指出我的错误,我认为我的代码现在已经没问题了。 - ChessMaster
popinsert有什么问题吗?这是将列表视为堆栈的明显方法,而del只是从列表中删除元素。此外,下一个最快的解决方案恰好是我的,所以你应该对其进行时间测试。通过你的测试速度,我发现我的代码仍然更快,尼克拉斯的测试有时会比你的快,有时则相反,但差异似乎非常小,没有本地声明无法遮盖的。先生,您无需吹嘘,如果您对我的代码有疑问,您可以在那里提问,而不是在这里嘲笑它。 - Rik Poggi
你真的看了我的代码吗?还是你只是随意地抛出大O符号?我不是瞎猜,我已经计时和分析了我的代码。而且根据你自己的测试结果,我的代码比你的更快。有一个insert(0,i)的操作,其开销与append(i)完全相同(我使用insert来避免之后反转列表)。del并不会花费太多时间,无论如何,这就是我的代码如何工作的:它会尽可能少地检查,在保持结果表更新的同时做到最小化消耗。 - Rik Poggi
基准测试代码有缺陷:timeit() 不会为每个迭代重新运行设置代码,这些合并程序使用列表具有破坏性。将调用更改为 print timeit.timeit("merge0(deepcopy(lsts))", setup=setup, number=10),您的时间将增加 10 倍。(在设置中添加 from copy import deepcopy)。 - alexis
@alexis 抱歉,它们不会。列表推导式创建一个新的集合列表,而不会改变原始的列表。 - agf
显示剩余3条评论

3
这里有一个使用不相交集合数据结构(具体来说是不相交森林)实现的示例,感谢comingstorm合并有共同元素的集合问题中提供的提示。我使用路径压缩进行了轻微 (~5%) 的速度提升;虽然它并非完全必要(而且它会防止 find 是尾递归,这可能会使事情变慢),但注意我使用了 dict 来表示 不相交森林;由于数据是 int,因此数组也可以工作,尽管速度可能不会更快。
def merge(data):
    parents = {}
    def find(i):
        j = parents.get(i, i)
        if j == i:
            return i
        k = find(j)
        if k != j:
            parents[i] = k
        return k
    for l in filter(None, data):
        parents.update(dict.fromkeys(map(find, l), find(l[0])))
    merged = {}
    for k, v in parents.items():
        merged.setdefault(find(v), []).append(k)
    return merged.values()

这种方法在Rik的基准测试中与其他最佳算法相当。

就我个人而言,我似乎发现相反的情况,这需要大约3-4倍的时间。猜测(如前所述),这在很大程度上取决于你正在测试的数据集。 - DSM

2
这将是我的更新方法:
def merge(data):
    sets = (set(e) for e in data if e)
    results = [next(sets)]
    for e_set in sets:
        to_update = []
        for i,res in enumerate(results):
            if not e_set.isdisjoint(res):
                to_update.insert(0,i)

        if not to_update:
            results.append(e_set)
        else:
            last = results[to_update.pop(-1)]
            for i in to_update:
                last |= results[i]
                del results[i]
            last |= e_set

    return results

注意:在合并过程中,空列表将被删除。

更新:可靠性。

为了100%的成功可靠性,您需要进行两个测试:

  • Check that all the resulting sets are mutually disjointed:

    merged = [{0, 1, 3, 4, 5, 10, 11, 16}, {8, 2}, {8}]
    
    from itertools import combinations
    for a,b in combinations(merged,2):
        if not a.isdisjoint(b):
            raise Exception(a,b)    # just an example
    
  • Check that the merged set cover the original data. (as suggested by katrielalex)

我认为这可能需要一些时间,但如果您想要100%确定,那么这可能会很值得。


lst = [[65, 17, 5, 30, 79, 56, 48, 62], [6, 97, 32, 93, 55, 14, 70, 32], [75, 37, 83, 34, 9, 19, 14, 64], [43, 71], [], [89, 49, 1, 30, 28, 3, 63], [35, 21, 68, 94, 57, 94, 9, 3], [16], [29, 9, 97, 43], [17, 63, 24]]代码输出结果为 [set([65, 96, 37, 70, 72, 75]), set([32, 33, 34, 36, 6, 9, 10, 14, 69, 71, 75, 82, 83, 85]), set([64, 19]), set(), set([16])] 是错误的。 - Developer
@开发者:我明白了,那是因为有一个列表,其中有两个不同的数字,每个数字都与两个不相交的集合中的一个数字相同。我会看一下的。 - Rik Poggi
@Rik:我无法重现你的时间。即使使用我的修复版本,差异也仅约为10%(我已将基准测试添加到我的答案中)。请添加测试代码,否则这没有太大用处。 - Niklas B.
@NiklasBaumstark:我修复了我的代码并发布了我的计时代码,如果有问题请让我知道。似乎有时我们的解决方案不同,因为你的代码没有添加空集,但我不确定,也许是我的错。 - Rik Poggi
@开发人员:更新答案,现在应该可以工作了。如果您进行其他测试,请告诉我。 - Rik Poggi
显示剩余11条评论

2

仅为娱乐...

def merge(mylists):
    results, sets = [], [set(lst) for lst in mylists if lst]
    upd, isd, pop = set.update, set.isdisjoint, sets.pop
    while sets:
        if not [upd(sets[0],pop(i)) for i in xrange(len(sets)-1,0,-1) if not isd(sets[0],sets[i])]:
            results.append(pop(0))
    return results

以及我对最佳答案的改写

def merge(lsts):
  sets = map(set,lsts)
  results = []
  while sets:
    first, rest = sets[0], sets[1:]
    merged = False
    sets = []
    for s in rest:
      if s and s.isdisjoint(first):
        sets.append(s)
      else:
        first |= s
        merged = True
    if merged: sets.append(first)
    else: results.append(first)
  return results

你的算法不正确。你只更新了第一个集合。尝试使用 merge([[1], [2, 3], [3, 4]]) - agf
@agf 算法我认为是可以的,但我尝试使用 update = sets[0].updateisdisjoint = sets[0].isdisjoint 来缩小代码,但在这种情况下效果不佳,谢谢。 - ChessMaster
没错,这修复了它。但是在我的测试中,你添加的那个表现非常糟糕。 - agf

0
以下是一个函数(Python 3.1),用于检查合并函数的结果是否正确。它会检查以下内容:
  • 结果集是否不相交?(并集元素数量等于各自元素数量之和)
  • 结果集中的元素是否与输入列表中的元素相同?
  • 每个输入列表是否都是某个结果集的子集?
  • 每个结果集是否都是最小的,即无法将其分成两个更小的集合?
  • 它不会检查是否存在空的结果集 - 我不知道您是否需要它们...

.

from itertools import chain

def check(lsts, result):
    lsts = [set(s) for s in lsts]
    all_items = set(chain(*lsts))
    all_result_items = set(chain(*result))
    num_result_items = sum(len(s) for s in result)
    if num_result_items != len(all_result_items):
        print("Error: result sets overlap!")
        print(num_result_items, len(all_result_items))
        print(sorted(map(len, result)), sorted(map(len, lsts)))
    if all_items != all_result_items:
        print("Error: result doesn't match input lists!")
    if not all(any(set(s).issubset(t) for t in result) for s in lst):
        print("Error: not all input lists are contained in a result set!")

    seen = set()
    todo = list(filter(bool, lsts))
    done = False
    while not done:
        deletes = []
        for i, s in enumerate(todo): # intersection with seen, or with unseen result set, is OK
            if not s.isdisjoint(seen) or any(t.isdisjoint(seen) for t in result if not s.isdisjoint(t)):
                seen.update(s)
                deletes.append(i)
        for i in reversed(deletes):
            del todo[i]
        done = not deletes
    if todo:
        print("Error: A result set should be split into two or more parts!")
        print(todo)

如果你能用单元测试语言编写这个,那就太棒了 =) - Katriel

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接