使用自定义比较谓词的heapq

143

我正在尝试使用自定义排序谓词构建堆。由于输入堆中的值是“用户定义”类型,我无法修改它们内置的比较谓词。

是否有办法做到这样:

h = heapq.heapify([...], key=my_lt_pred)
h = heapq.heappush(h, key=my_lt_pred)

或者更好的做法是,我可以将heapq函数包装在自己的容器中,这样我就不需要不断传递谓词了。


1
可能是 https://dev59.com/SnRB5IYBdhLWcg3wQFPu 的重复问题。 - Rajendran T
可能是重复的问题:如何使heapq根据特定属性评估堆? - Cristian Ciupitu
8个回答

165
根据heapq文档,自定义堆排序的方法是让堆中的每个元素成为一个元组,其中第一个元组元素接受普通的Python比较。由于heapq模块的函数不是面向对象的,因此有些麻烦,始终需要显式地将堆对象(堆化的列表)作为第一个参数传递。我们可以通过创建一个非常简单的包装器类来两全其美,使我们能够指定key函数,并将堆表示为对象。下面的类保留了一个内部列表,其中每个元素都是一个元组,其中第一个成员是一个关键字,插入时使用传递给Heap实例化的key参数计算出来:
# -*- coding: utf-8 -*-
import heapq

class MyHeap(object):
    def __init__(self, initial=None, key=lambda x:x):
        self.key = key
        self.index = 0
        if initial:
            self._data = [(key(item), i, item) for i, item in enumerate(initial)]
            self.index = len(self._data)
            heapq.heapify(self._data)
        else:
            self._data = []

    def push(self, item):
        heapq.heappush(self._data, (self.key(item), self.index, item))
        self.index += 1

    def pop(self):
        return heapq.heappop(self._data)[2]

(加上额外的self.index部分是为了避免在计算键值时出现冲突,因为存储的值可能不是直接可比的 - 否则,heapq会出现TypeError错误)


5
非常好!您甚至可以进一步使用三元组(self.key(item),id,item),其中id可以作为类属性处理的整数,并在每次推入后递增。这样,您可以避免由key(item1)= key(item2)引发的异常。因为键将是唯一的。 - zeycus
7
我曾试图将这个(或基于此的)内容推荐到Python标准库中,但建议被拒绝了。 - jsbueno
1
遗憾的是,它符合大多数Python特性的面向对象风格,并且关键参数提供了额外的灵活性。 - zeycus
@jsbueno 抱歉问一个愚蠢的问题,但我是Python新手,看了你的例子,我不知道如何定义自定义谓词。难道不应该有一些函数来定义严格弱于排序吗?我假设self.key(item)用于谓词,这是您可以定义自定义行为的地方,但在您的示例中,您只返回了原始值,对吗?但是,当您需要自定义排序时,例如如果您想基于一对值进行排序,则需要比较两个对并返回<的true/false。 - ssj_100
8
如果元素不可比较且键值存在并列情况,此方法将失败。我会在元组中加入id(item)作为中间元素来打破并列。 - Georgi Yanchev
显示剩余4条评论

150

定义一个类,在其中覆盖 __lt__() 函数。以下是一个 Python 3.7 的示例:

import heapq

class Node(object):
    def __init__(self, val: int):
        self.val = val

    def __repr__(self):
        return f'Node value: {self.val}'

    def __lt__(self, other):
        return self.val < other.val

heap = [Node(2), Node(0), Node(1), Node(4), Node(2)]
heapq.heapify(heap)
print(heap)  # output: [Node value: 0, Node value: 2, Node value: 1, Node value: 4, Node value: 2]

heapq.heappop(heap)
print(heap)  # output: [Node value: 1, Node value: 2, Node value: 2, Node value: 4]


21
这似乎是到目前为止最清晰的解决方案! - Foobar
1
完全同意前两个评论。这似乎是Python 3的更好、更清洁的解决方案。 - Chiraz BenAbdelkader
此外,这里有一个非常相似的问题的解决方案:https://dev59.com/LHE95IYBdhLWcg3wApHk#40455775?noredirect=1#comment110575517_40455775 - Chiraz BenAbdelkader
2
我使用__gt__进行了测试,同样有效。为什么我们使用哪个魔术方法并不重要呢?我在heapq的文档中找不到任何相关信息。也许这与Python通常如何进行比较有关? - Josh Clark
12
在使用 heapq 进行比较时,Python 会首先查找 __lt__() 方法。如果没有定义,则会查找 __gt__() 方法。如果两者都没有定义,则会抛出 TypeError: '<' not supported between instances of 'Node' and 'Node' 错误。可以通过同时定义 __lt__()__gt__() 方法,并在每个方法中添加打印语句,然后让 __lt__() 返回 NotImplemented 来确认这一点。 - Fanchen Bao
1
为了使这个解决方案变得完整,需要有一个决胜者。为了在“self.val == other.val”时打破平局,在“__lt__”函数中引入另一个字段(优先级或与您的业务领域相关的其他内容)是一种选择,这样我们就可以比较这个字段并确保在这个字段方面没有相等的值。 - Yiling

29

heapq文档建议堆元素可以是元组,其中第一个元素是优先级并定义排序顺序。

更相关的是,文档包括讨论和示例代码,介绍如何实现自己的heapq包装函数来解决排序稳定性和相同优先级元素(以及其他问题)的处理。

简而言之,他们的解决方案是将heapq中的每个元素作为三元组,其中包含优先级、条目计数和要插入的元素。 条目计数确保具有相同优先级的元素按它们添加到heapq的顺序排序。


这是正确的解决方案,heappush和heappushpop都可以直接使用元组。 - daisy
1
这个解决方案很干净,但不能覆盖所有自定义算法,例如字符串的最大堆。 - PaleNeutron
1
我真的不明白人们如何认为提交一个条目计数是一个干净的解决方案。Python之神们听我的祈祷,制作一个普通优先队列类。谢谢! - aleksandarbos

23
setattr(ListNode, "__lt__", lambda self, other: self.val <= other.val)

请使用此函数对heapq中的对象进行值比较


2
有趣的方法来避免重新定义/重新封装对象! - powersource97
谢谢!这正是我正在寻找的。 - sfsf9797
虽然这对于 Leetcode 可能有效,但它在 heapq 中不起作用。 - mittal
谢谢,这个回答值得置顶。 - Fangda Han

3
两种方法的局限性在于它们不允许将平局视为平局。在第一种方法中,通过比较项目来打破平局,在第二种方法中则通过比较输入顺序来打破平局。让平局保持平局会更快,并且如果有很多平局,这可能会产生很大的差异。基于上述和文档,目前还不清楚是否可以在heapq中实现此功能。奇怪的是,heapq不接受键,而同一模块中派生的函数却接受。
附注: 如果您按照第一个评论中的链接(“可能是重复...”)进行操作,则有另一个定义le的建议,看起来像是解决方案。

3
写“两个答案”的限制在于不再清楚它们是哪两个。 - trincot

1
在Python3中,您可以使用functools模块中的cmp_to_key。请参考cpython源代码

假设您需要一个三元组的优先队列,并指定最后一个属性作为优先级。

from heapq import *
from functools import cmp_to_key
def mycmp(triplet_left, triplet_right):
    key_l, key_r = triplet_left[2], triplet_right[2]
    if key_l > key_r:
        return -1  # larger first
    elif key_l == key_r:
        return 0  # equal
    else:
        return 1


WrapperCls = cmp_to_key(mycmp)
pq = []
myobj = tuple(1, 2, "anystring")
# to push an object myobj into pq
heappush(pq, WrapperCls(myobj))
# to get the heap top use the `obj` attribute
inner = pq[0].obj

性能测试:

环境

Python 3.10.2

代码

from functools import cmp_to_key
from timeit import default_timer as time
from random import randint
from heapq import *

class WrapperCls1:
    __slots__ = 'obj'
    def __init__(self, obj):
        self.obj = obj
    def __lt__(self, other):
        kl, kr = self.obj[2], other.obj[2]
        return True if kl > kr else False

def cmp_class2(obj1, obj2):
    kl, kr = obj1[2], obj2[2]
    return -1 if kl > kr else 0 if kl == kr else 1

WrapperCls2 = cmp_to_key(cmp_class2)

triplets = [[randint(-1000000, 1000000) for _ in range(3)] for _ in range(100000)]
# tuple_triplets = [tuple(randint(-1000000, 1000000) for _ in range(3)) for _ in range(100000)]

def test_cls1():
    pq = []
    for triplet in triplets:
        heappush(pq, WrapperCls1(triplet))
        
def test_cls2():
    pq = []
    for triplet in triplets:
        heappush(pq, WrapperCls2(triplet))

def test_cls3():
    pq = []
    for triplet in triplets:
        heappush(pq, (-triplet[2], triplet))

start = time()
for _ in range(10):
    test_cls1()
    # test_cls2()
    # test_cls3()
print("total running time (seconds): ", -start+(start:=time()))

结果

每个函数使用list而不是tuple:

  • WrapperCls1: 16.2毫秒
  • 具有__slots__的WrapperCls1: 9.8毫秒
  • WrapperCls2: 8.6毫秒
  • 将优先级属性移至第一位置(不支持自定义谓词): 6.0毫秒。

因此,该方法比使用具有覆盖__lt__()函数和__slots__属性的自定义类略快。


你有检查过那些结果吗?我得到了类似[<functools.KeyWrapper> ...]的东西,而不是值。 - jizhihaoSAMA
@jizhihaoSAMA 你是否忘记使用了 .obj 属性? - Voyager

1

简单的小技巧:

假设你有一个(name,age)的列表如下:

a = [('Tim',4), ('Radha',9), ('Rob',7), ('Krsna',3)]

而且你想根据年龄对这个列表进行排序,可以通过将它们添加到最小堆中来实现,而不是编写所有自定义比较器的代码,你只需在将元组推入队列之前翻转其内容的顺序即可。这是因为heapq.heappush()默认按元组的第一个元素排序。

import heapq
heap = []
heapq.heapify(heap)
for element in a:
    heapq.heappush(heap, (element[1],element[0]))

这是一个简单的技巧,如果这样做能满足你的需求,而且你不想陷入编写自定义比较器的麻烦中。

类似地,默认情况下它按升序对值进行排序。如果你想按年龄降序排序,请翻转内容并将元组的第一个元素的值设为负数:

import heapq
heap = []
heapq.heapify(heap)
for element in a:
    heapq.heappush(heap, (-element[1],element[0]))

0

简单而实用

一个简单的解决方案是将条目存储为每个元组的列表,为每个元组定义所需顺序的优先级,如果您需要元组内每个项的不同顺序,请将其定义为负数以获取降序。

请参阅此主题中的官方heapq Python文档 Priority Queue Implementation Notes


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接