如何为枚举类型添加别名

3

我有一个抽象基类GameNodeState,其中包含一个Type枚举:

import abc
import enum


class GameNodeState(metaclass=abc.ABCMeta):
    class Type(enum.Enum):
        INIT = enum.auto()
        INTERMEDIATE = enum.auto()
        END = enum.auto()

枚举中的名称是通用的,因为它们必须对任何 GameNodeState 子类都有意义。但当我作为 GameStateRoundState 这样的子类创建 GameNodeState 时,我希望能够给 GameNodeState.Type 的成员添加具体的别名,如果这个枚举通过子类访问。例如,如果 GameState 子类将 INTERMEDIATE 别名为 ROUND,而 RoundStateINTERMEDIATE 别名为 TURN,则应有以下行为:
>>> GameNodeState.Type.INTERMEDIATE
<Type.INTERMEDIATE: 2>

>>> RoundState.Type.TURN
<Type.INTERMEDIATE: 2>

>>> RoundState.Type.INTERMEDIATE
<Type.INTERMEDIATE: 2>

>>> GameNodeState.Type.TURN
AttributeError: TURN

我的第一个想法是这样的:

class GameState(GameNodeState):
    class Type(GameNodeState.Type):
        ROUND = GameNodeState.Type.INTERMEDIATE.value


class RoundState(GameNodeState):
    class Type(GameNodeState.Type):
        TURN = GameNodeState.Type.INTERMEDIATE.value

但是枚举类型不能被子类化。


注意:在GameNodeState层级结构中,还有显然更多的属性和方法,我这里只保留了最基本的部分来专注于这个特定的事情。


不要使用子类! - Nizam Mohamed
1个回答

3

细化

(以下是原始解决方案。)

我从上面的代码中提取了一个中间概念,即枚举联合的概念。这可以用于获得上述行为,并且在其他情况下也很有用。该代码可以在此处找到,我还发布了一个Code Review问题

为了参考,我也在此添加了代码:

import enum
import itertools as itt
from functools import reduce
import operator
from typing import Literal, Union

import more_itertools as mitt


AUTO = object()


class UnionEnumMeta(enum.EnumMeta):
    """
    The metaclass for enums which are the union of several sub-enums.

    Union enums have the _subenums_ attribute which is a tuple of the enums forming the
    union.
    """

    @classmethod
    def make_union(
        mcs, *subenums: enum.EnumMeta, name: Union[str, Literal[AUTO], None] = AUTO
    ) -> enum.EnumMeta:
        """
        Create an enum whose set of members is the union of members of several enums.

        Order matters: where two members in the union have the same value, they will
        be considered as aliases of each other, and the one appearing in the first
        enum in the sequence will be used as the canonical members (the aliases will
        be associated to this enum member).

        :param subenums: Sequence of sub-enums to make a union of.
        :param name: Name to use for the enum class. AUTO will result in a combination
                     of the names of all subenums, None will result in "UnionEnum".
        :return: An enum class which is the union of the given subenums.
        """
        subenums = mcs._normalize_subenums(subenums)

        class UnionEnum(enum.Enum, metaclass=mcs):
            pass

        union_enum = UnionEnum
        union_enum._subenums_ = subenums

        if duplicate_names := reduce(
            set.intersection, (set(subenum.__members__) for subenum in subenums)
        ):
            raise ValueError(
                f"Found duplicate member names in enum union: {duplicate_names}"
            )

        # If aliases are defined, the canonical member will be the one that appears
        # first in the sequence of subenums.
        # dict union keeps last key so we have to do it in reverse:
        union_enum._value2member_map_ = value2member_map = reduce(
            operator.or_, (subenum._value2member_map_ for subenum in reversed(subenums))
        )
        # union of the _member_map_'s but using the canonical member always:
        union_enum._member_map_ = member_map = {
            name: value2member_map[member.value]
            for name, member in itt.chain.from_iterable(
                subenum._member_map_.items() for subenum in subenums
            )
        }
        # only include canonical aliases in _member_names_
        union_enum._member_names_ = list(
            mitt.unique_everseen(
                itt.chain.from_iterable(subenum._member_names_ for subenum in subenums),
                key=member_map.__getitem__,
            )
        )

        if name is AUTO:
            name = (
                "".join(subenum.__name__.removesuffix("Enum") for subenum in subenums)
                + "UnionEnum"
            )
            UnionEnum.__name__ = name
        elif name is not None:
            UnionEnum.__name__ = name

        return union_enum

    def __repr__(cls):
        return f"<union of {', '.join(map(str, cls._subenums_))}>"

    def __instancecheck__(cls, instance):
        return any(isinstance(instance, subenum) for subenum in cls._subenums_)

    @classmethod
    def _normalize_subenums(mcs, subenums):
        """Remove duplicate subenums and flatten nested unions"""
        # we will need to collapse at most one level of nesting, with the inductive
        # hypothesis that any previous unions are already flat
        subenums = mitt.collapse(
            (e._subenums_ if isinstance(e, mcs) else e for e in subenums),
            base_type=enum.EnumMeta,
        )
        subenums = mitt.unique_everseen(subenums)
        return tuple(subenums)


def enum_union(*enums, **kwargs):
    return UnionEnumMeta.make_union(*enums, **kwargs)

有了这个,我们可以定义extend_enum装饰器来计算基本枚举和枚举“扩展”的并集,从而实现期望的行为:

def extend_enum(base_enum):
    def decorator(extension_enum):
        return enum_union(base_enum, extension_enum)

    return decorator

使用方法:

class GameNodeState(metaclass=abc.ABCMeta):
    class Type(enum.Enum):
        INIT = enum.auto()
        INTERMEDIATE = enum.auto()
        END = enum.auto()


class RoundState(GameNodeState):
    @extend_enum(GameNodeState.Type)
    class Type(enum.Enum):
        TURN = GameNodeState.Type.INTERMEDIATE.value


class GameState(GameNodeState):
    @extend_enum(GameNodeState.Type)
    class Type(enum.Enum):
        ROUND = GameNodeState.Type.INTERMEDIATE.value

现在,以上所有示例都会产生相同的输出(加上添加的实例检查,即isinstance(RoundState.Type.TURN, RoundState.Type)返回True)。
我认为这是一种更干净的解决方案,因为它不涉及描述符; 它不需要知道所有者类的任何信息(对于顶级类,也可以正常工作)。
通过子类和GameNodeState实例的属性查找应自动链接到正确的“扩展”(即联合),只要将扩展枚举与GameNodeState超类的名称相同地添加,以隐藏原始定义。

原始内容

不确定这是否是一个糟糕的主意,但这是一种使用封装在枚举周围的描述符的解决方案,该描述符基于从中访问它的类获取别名集的解决方案。

class ExtensibleClassEnum:
    class ExtensionWrapperMeta(enum.EnumMeta):
        @classmethod
        def __prepare__(mcs, name, bases):
            # noinspection PyTypeChecker
            classdict: enum._EnumDict = super().__prepare__(name, bases)
            classdict["_ignore_"] = ["base_descriptor", "extension_enum"]
            return classdict

        # noinspection PyProtectedMember
        def __new__(mcs, cls, bases, classdict):
            base_descriptor = classdict.pop("base_descriptor")
            extension_enum = classdict.pop("extension_enum")
            wrapper_enum = super().__new__(mcs, cls, bases, classdict)
            wrapper_enum.base_descriptor = base_descriptor
            wrapper_enum.extension_enum = extension_enum

            base, extension = base_descriptor.base_enum, extension_enum
            if set(base._member_map_.keys()) & set(extension._member_map_.keys()):
                raise ValueError("Found duplicate names in extension")
            # dict union keeps last key so we have to do it in reverse:
            wrapper_enum._value2member_map_ = (
                extension._value2member_map_ | base._value2member_map_
            )
            # union of both _member_map_'s but using the canonical member always:
            wrapper_enum._member_map_ = {
                name: wrapper_enum._value2member_map_[member.value]
                for name, member in itertools.chain(
                    base._member_map_.items(), extension._member_map_.items()
                )
            }
            # aliases shouldn't appear in _member_names_
            wrapper_enum._member_names_ = list(
                m.name for m in wrapper_enum._value2member_map_.values()
            )
            return wrapper_enum

        def __repr__(self):
            # have to use vars() to avoid triggering the descriptor
            base_descriptor = vars(self)["base_descriptor"]
            return (
                f"<extension wrapper enum for {base_descriptor.base_enum}"
                f" in {base_descriptor._extension2owner[self]}>"
            )

    def __init__(self, base_enum):
        if not issubclass(base_enum, enum.Enum):
            raise TypeError(base_enum)
        self.base_enum = base_enum
        # The user won't be able to retrieve the descriptor object itself, just
        # the enum, so we have to forward calls to register_extension:
        self.base_enum.register_extension = staticmethod(self.register_extension)

        # mapping of owner class -> extension for subclasses that define an extension
        self._extensions: Dict[Type, ExtensibleClassEnum.ExtensionWrapperMeta] = {}
        # reverse mapping
        self._extension2owner: Dict[ExtensibleClassEnum.ExtensionWrapperMeta, Type] = {}

        # add the base enum as the base extension via __set_name__:
        self._pending_extension = base_enum

    @property
    def base_owner(self):
        # will be initialised after __set_name__ is called with base owner
        return self._extension2owner[self.base_enum]

    def __set_name__(self, owner, name):
        # step 2 of register_extension: determine the class that defined it
        self._extensions[owner] = self._pending_extension
        self._extension2owner[self._pending_extension] = owner
        del self._pending_extension

    def __get__(self, instance, owner):
        # Only compute extensions once:
        if owner in self._extensions:
            return self._extensions[owner]

        # traverse in MRO until we find the closest supertype defining an extension
        for supertype in owner.__mro__:
            if supertype in self._extensions:
                extension = self._extensions[supertype]
                break
        else:
            raise TypeError(f"{owner} is not a subclass of {self.base_owner}")

        # Cache the result
        self._extensions[owner] = extension
        return extension

    def make_extension(self, extension: enum.EnumMeta):
        class ExtensionWrapperEnum(
            enum.Enum, metaclass=ExtensibleClassEnum.ExtensionWrapperMeta
        ):
            base_descriptor = self
            extension_enum = extension

        return ExtensionWrapperEnum

    def register_extension(self, extension_enum):
        """Decorator for enum extensions"""
        # need a way to determine owner class
        # add a temporary attribute that we will use when __set_name__ is called:
        if hasattr(self, "_pending_extension"):
            # __set_name__ not called after the previous call to register_extension
            raise RuntimeError(
                "An extension was created outside of a class definition",
                self._pending_extension,
            )
        self._pending_extension = self.make_extension(extension_enum)
        return self

使用方法如下:

class GameNodeState(metaclass=abc.ABCMeta):
    @ExtensibleClassEnum
    class Type(enum.Enum):
        INIT = enum.auto()
        INTERMEDIATE = enum.auto()
        END = enum.auto()


class RoundState(GameNodeState):
    @GameNodeState.Type.register_extension
    class Type(enum.Enum):
        TURN = GameNodeState.Type.INTERMEDIATE.value


class GameState(GameNodeState):
    @GameNodeState.Type.register_extension
    class Type(enum.Enum):
        ROUND = GameNodeState.Type.INTERMEDIATE.value

然后:

>>> (RoundState.Type.TURN 
...  == RoundState.Type.INTERMEDIATE 
...  == GameNodeState.Type.INTERMEDIATE 
...  == GameState.Type.INTERMEDIATE 
...  == GameState.Type.ROUND)
...
True

>>> RoundState.Type.__members__
mappingproxy({'INIT': <Type.INIT: 1>,
              'INTERMEDIATE': <Type.INTERMEDIATE: 2>,
              'END': <Type.END: 3>,
              'TURN': <Type.INTERMEDIATE: 2>})

>>> list(RoundState.Type)
[<Type.INTERMEDIATE: 2>, <Type.INIT: 1>, <Type.END: 3>]

>>> GameNodeState.Type.TURN
Traceback (most recent call last):
  ...
  File "C:\Program Files\Python39\lib\enum.py", line 352, in __getattr__
    raise AttributeError(name) from None
AttributeError: TURN

厉害了!用“is”代替“==”也可以吗? - Ethan Furman
是的,它可以使用 is,因为它直接从原始枚举中获取给定名称的规范成员。然而,正是因为这个原因,isinstance(RoundState.Type.TURN, RoundState.Type) 是不起作用的。你是 enum 的作者吗?你会反对这种(概念上的)东西,因为它很脆弱,还是其他原因,或者你会说如果足够精细,它是可以的? - Anakhand
我还在考虑中。 :-) - Ethan Furman
嗨 @EthanFurman,我从上面的代码中提取了一个细化并在Code Review SE([链接](https://codereview.stackexchange.com/q/253628/178268))上发布了它,我觉得您可能想要参与进来。 - Anakhand
感谢您的留言!已经阅读评论。 - Ethan Furman
1
考虑到你的改进比原来的好得多,你应该把改进放在首位,并将最初的尝试作为历史记录。 - Ethan Furman

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接