覆写__eq__和__hash__来比较两个实例的字典属性

3

我很难理解如何根据每个实例拥有的底层字典属性正确比较对象。

由于我正在覆盖__eq__,所以我需要同时覆盖__hash__吗?我对何时/在何处这样做没有明确的了解,真的需要一些帮助。

我创建了一个简单的示例来说明我遇到的最大递归异常。 RegionalCustomerCollection按地理区域组织帐户ID。RegionalCustomerCollection对象如果它们的区域及其相应的accountid是相等的,则被认为是相等的。基本上,所有的items()内容应该是相等的。

from collections import defaultdict

class RegionalCustomerCollection(object):

    def __init__(self):
        self.region_accountids = defaultdict(set) 

    def get_region_accountid(self, region_name=None):
        return self.region_accountids.get(region_name, None)

    def set_region_accountid(self, region_name, accountid):
        self.region_accountids[region_name].add(accountid)

    def __eq__(self, other):
        if (other == self):
            return True

        if isinstance(other, RegionalCustomerCollection):
            return self.region_accountids == other.region_accountids

        return False 

    def __repr__(self):
        return ', '.join(["{0}: {1}".format(region, acctids) 
                          for region, acctids 
                          in self.region_accountids.items()])

让我们创建两个对象实例,并填充一些示例数据:

>>> a = RegionalCustomerCollection()
>>> b = RegionalCustomerCollection()
>>> a.set_region_accountid('northeast',1)
>>> a.set_region_accountid('northeast',2)
>>> a.set_region_accountid('northeast',3)
>>> a.set_region_accountid('southwest',4)
>>> a.set_region_accountid('southwest',5)
>>> b.set_region_accountid('northeast',1)
>>> b.set_region_accountid('northeast',2)
>>> b.set_region_accountid('northeast',3)
>>> b.set_region_accountid('southwest',4)
>>> b.set_region_accountid('southwest',5)

现在让我们尝试比较这两个实例并生成递归异常:

>>> a == b
... 
RuntimeError: maximum recursion depth exceeded while calling a Python object

如果您想将一个对象作为字典的键,那么就需要实现__eq____hash__方法。除非您需要进行比较(例如hash(self) == hash(other)),否则您并不一定需要实现__hash__方法。 - jonrsharpe
啊哈,我明白了。谢谢你的解释。 - Dowwie
2个回答

3

由于对象是可变的,因此您的对象不应返回哈希值。如果您将此对象放入字典或集合中,然后在之后更改它,则可能永远无法再次找到它。

为了使对象不可哈希,您需要执行以下操作:

class MyClass(object):
    __hash__ = None

这将确保对象不可哈希化。
 [in] >>> m = MyClass()
 [in] >>> hash(m)
[out] >>> TypeError: unhashable type 'MyClass'

您的问题得到了解答吗?我怀疑没有,因为您明确在寻找一个哈希函数。

就您收到的运行时错误而言,是由以下代码行引起的:

    if self == other:
        return True

这会让您陷入无限递归循环中。请尝试以下方法:

    if self is other:
        return True

谢谢,Steve。你的解释很清晰,消除递归的解决方案很有效。 - Dowwie

1
你不需要覆盖__hash__来比较两个对象(如果你想要自定义哈希,例如在插入集合或字典时提高性能,则需要覆盖)。
此外,你在这里有无限递归:
    def __eq__(self, other):
        if (other == self):
            return True

        if isinstance(other, RegionalCustomerCollection):
            return self.region_accountids == other.region_accountids
return False
如果两个对象都是RegionalCustomerCollection类型,则会出现无限递归,因为==调用__eq__

感谢Rodrigo对于“hash”覆盖使用案例(性能)的澄清。 - Dowwie

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接