为一个类覆盖__contains__方法

16

我需要在Python中模拟枚举,我是通过编写类来实现的,例如:

class Spam(Enum):
    k = 3
    EGGS = 0
    HAM = 1
    BAKEDBEANS = 2

现在我想测试某个常量是否是特定枚举类的有效选择,使用以下语法:

if (x in Foo):
    print("seems legit")

因此,我尝试创建一个“枚举”基类,在其中覆盖了__contains__方法,如下所示:

class Enum:
    """
    Simulates an enum.
    """

    k = 0 # overwrite in subclass with number of constants

    @classmethod
    def __contains__(cls, x):
        """
        Test for valid enum constant x:
            x in Enum
        """
        return (x in range(cls.k))
然而,当在类上使用 in 关键字(就像上面的示例一样)时,我会收到错误消息:
TypeError: argument of type 'type' is not iterable

为什么会这样?我能以某种方式获得我想要的语法糖吗?


我认为添加一段澄清注释,标明代码可以正常运行:if x in Foo(): print('seems legit')会使得问题和答案更容易理解。因为 type(Foo()) 是 Foo 类型,而 type(Foo) 是元类。 - phil_20686
2个回答

24

为什么会这样呢?

当你使用特殊语法,如a in Foo时,会在Foo的类型上查找__contains__方法。然而,你的__contains__实现存在于Foo本身,而不是它的类型type上。因此,出现错误,因为type没有实现(或迭代)这个函数。

如果您先实例化一个对象,然后在创建之后再添加一个__contains__函数到实例变量中,也会出现同样的情况:该函数不会被调用。

>>> class Empty: pass
... 
>>> x = Empty()
>>> x.__contains__ = lambda: True
>>> 1 in x
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: argument of type 'Empty' is not iterable

我是否可以以某种方式获得我想要的语法糖?

是的。如上所述,该方法在Foo的类型上进行查找。类的类型称为元类,因此您需要一个实现__contains__的新元类。

尝试这个:

class MetaEnum(type):
    def __contains__(cls, x):
            return x in range(cls.k)

正如你所看到的,元类上的方法将元类实例——类——作为它们的第一个参数。这应该是有意义的。它非常类似于类方法,只是方法存在于元类而不是类中。

从具有自定义元类的类继承也会继承元类,因此您可以像这样创建基础类:

class BaseEnum(metaclass=MetaEnum):
    pass

class MyEnum(BaseEnum):
    k = 3

print(1 in MyEnum) # True

2
啊,我有一个模糊的概念,元类存在,但我从来没有自己的用例。这很好用,谢谢。 - clstaudt
8年后,这个精彩的答案让我找到了正确的解决方案,解决了我在将一些包切换到Python 3.8时遇到的问题。如果有人因类似的问题而来到这里,我在这里发布了我的解决方案,希望对他们有用。 - Pierre D

1

我的使用情况是测试我Enum成员的名称。

通过对此解决方案进行轻微修改:

from enum import Enum, EnumMeta, auto


class MetaEnum(EnumMeta):
    def __contains__(cls, item):
        return item in cls.__members__.keys()


class BaseEnum(Enum, metaclass=MetaEnum):
    pass


class LogSections(BaseEnum):
    configuration = auto()
    debug = auto()
    errors = auto()
    component_states = auto()
    alarm = auto()


if __name__ == "__main__":
    print('configuration' in LogSections)
    print('b' in LogSections)

True
False

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接