Python中的不相交集合实现

12

我对Python比较陌生。我正在学习并实现“不相交集合”:

class DisjointSet:
    def __init__(self, vertices, parent):
        self.vertices = vertices
        self.parent = parent

    def find(self, item):
        if self.parent[item] == item:
            return item
        else:
            return self.find(self.parent[item])

    def union(self, set1, set2):
        self.parent[set1] = set2

现在在驱动程序代码中:

def main():
    vertices = ['a', 'b', 'c', 'd', 'e', 'h', 'i']
    parent = {}

    for v in vertices:
        parent[v] = v

    ds = DisjointSet(vertices, parent)
    print("Print all vertices in genesis: ")
    ds.union('b', 'd')

    ds.union('h', 'b')
    print(ds.find('h')) # prints d (OK)
    ds.union('h', 'i')
    print(ds.find('i')) # prints i (expecting d)

main()

所以,一开始我将所有节点初始化为独立的不相交集。然后合并了bdhb,得到了集合:hbd,接着合并了hi,应该会(像我假设的那样)得到集合:ihbd。我理解由于在这行代码union(set1, set2)中设置了父节点:

self.parent[set1] = set2

我将h的父节点设置为i,从而将其从bd的集合中移除。如何实现一个ihbd的集合,在union()的参数顺序不同情况下也能产生相同的结果呢?


你不应该在构造函数中使用 parent 参数,因为调用者没有任何选择可以指定。相反,你应该在 init 中填充它,而不是在 main() 函数中。 - Matt Timmermans
1
这是一个 Py 实现:https://www.nayuki.io/res/disjoint-set-data-structure/disjointset.py。另外一个:https://github.com/mrapacz/disjoint-set。 - Abhijit Sarkar
2个回答

12
你的程序不能正常工作,因为你误解了不相交集合实现的算法。合并是通过修改节点的父节点而不是提供作为输入的节点来实现的。正如你已经注意到的那样,盲目地修改任何输入节点的父节点将破坏之前的合并。下面是正确的实现:
def union(self, set1, set2):
    root1 = self.find(set1)
    root2 = self.find(set2)
    self.parent[root1] = root2

我还建议阅读 不相交集合数据结构 以获取更多信息和可能的优化。


1
为了使您的实现更快,您可能希望在查找时更新父级。
    def find(self, item):
        if self.parent[item] == item:
            return item
        else:
            res = self.find(self.parent[item])
            self.parent[item] = res
            return res

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接