如何定义作为函数的枚举值?

35

我有一个情况,需要强制执行并给用户选择一种可作为参数传递给另一个函数的多个选项之一的功能:

我真的想要实现以下内容:

from enum import Enum

#Trivial Function 1
def functionA():
    pass

#Trivial Function 2
def functionB():
    pass

#This is not allowed (as far as i can tell the values should be integers)
#But pseudocode for what I am after
class AvailableFunctions(Enum):
    OptionA = functionA
    OptionB = functionB

那么下面的内容可以被执行:

def myUserFunction(theFunction = AvailableFunctions.OptionA):
   #Type Check
   assert isinstance(theFunction,AvailableFunctions) 

   #Execute the actual function held as value in the enum or equivalent
   return theFunction.value() 

另一个选项:将函数放入字典中,以函数名作为键。 - PM 2Ring
你发现了枚举的边缘情况...当Python构建一个类时,它会将函数属性添加为类的方法。这就是为什么它们不会成为“Enum”的值。正如我在我的答案中所展示的那样,很容易“欺骗”解释器,使其不将它们作为方法。 - Bakuriu
6个回答

54

你的假设是错误的。值可以是任意的,它们不仅限于整数。来自文档

上面的示例使用整数作为枚举值。使用整数简短方便(并由功能API默认提供),但并非强制执行。在绝大多数用例中,一个人并不关心枚举的实际值。但如果值很重要,枚举可以具有任意值。

然而函数的问题在于它们被认为是方法定义而不是属性!

In [1]: from enum import Enum

In [2]: def f(self, *args):
   ...:     pass
   ...: 

In [3]: class MyEnum(Enum):
   ...:     a = f
   ...:     def b(self, *args):
   ...:         print(self, args)
   ...:         

In [4]: list(MyEnum)  # it has no values
Out[4]: []

In [5]: MyEnum.a
Out[5]: <function __main__.f>

In [6]: MyEnum.b
Out[6]: <function __main__.MyEnum.b>

您可以解决这个问题,方法是使用一个包装类或者只使用functools.partial 或(仅适用于Python2)staticmethod

from functools import partial

class MyEnum(Enum):
    OptionA = partial(functionA)
    OptionB = staticmethod(functionB)

示例运行:

In [7]: from functools import partial

In [8]: class MyEnum2(Enum):
   ...:     a = partial(f)
   ...:     def b(self, *args):
   ...:         print(self, args)
   ...:         

In [9]: list(MyEnum2)
Out[9]: [<MyEnum2.a: functools.partial(<function f at 0x7f4130f9aae8>)>]

In [10]: MyEnum2.a
Out[10]: <MyEnum2.a: functools.partial(<function f at 0x7f4130f9aae8>)>

或者使用一个包装类:

In [13]: class Wrapper:
    ...:     def __init__(self, f):
    ...:         self.f = f
    ...:     def __call__(self, *args, **kwargs):
    ...:         return self.f(*args, **kwargs)
    ...:     

In [14]: class MyEnum3(Enum):
    ...:     a = Wrapper(f)
    ...:     

In [15]: list(MyEnum3)
Out[15]: [<MyEnum3.a: <__main__.Wrapper object at 0x7f413075b358>>]
此外,请注意,如果您希望,可以在枚举类中定义__call__方法,使值成为可调用对象:
In [1]: from enum import Enum

In [2]: def f(*args):
   ...:     print(args)
   ...:     

In [3]: class MyEnum(Enum):
   ...:     a = partial(f)
   ...:     def __call__(self, *args):
   ...:         self.value(*args)
   ...:         

In [5]: MyEnum.a(1,2,3)   # no need for MyEnum.a.value(1,2,3)
(1, 2, 3)

1
请注意,__call__方法应包含一个return语句,否则当您调用函数时,您将无法获得预期的响应。我曾经被卡在这里两个小时。回想起来似乎很明显,但以防其他人也遇到同样的问题... - Logister
仅当函数返回一个值时才使用@Logister。 - Bakuriu
2
不需要导入 partial,你可以直接使用 a = staticmethod(f) - mar77i
1
@mar77i 没错。我已经编辑了答案,以显示该选项。谢谢。 - Bakuriu
我没有让staticmethod起作用,但是functools.partial完美地解决了问题。感谢您提供这个聪明的解决方案! - Eerik Sven Puudist
显示剩余3条评论

8

从Python 3.11开始,有了更加简洁易懂的方法。在enum中添加了membernonmember函数等多项改进,所以现在您可以执行以下操作:

from enum import Enum, member

def fn(x):
    print(x)

class MyEnum(Enum):
    meth = fn
    mem = member(fn)
    @classmethod
    def this_is_a_method(cls):
        print('No, still not a member')
    def this_is_just_function():
        print('No, not a member')
    @member
    def this_is_a_member(x):
        print('Now a member!', x)

现在

>>> list(MyEnum)
[<MyEnum.mem: <function fn at ...>>, <MyEnum.this_is_a_member: <function MyEnum.this_is_a_member at ...>>]

>>> MyEnum.meth(1)
1

>>> MyEnum.mem(1)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: 'MyEnum' object is not callable

>>> MyEnum.mem.value(1)
1

>>> MyEnum.this_is_a_method()
No, still not a member

>>> MyEnum.this_is_just_function()
No, not a member

>>> MyEnum.this_is_a_member()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: 'MyEnum' object is not callable

>>> MyEnum.this_is_a_member.value(1)
Now a member! 1

6

另一个不太臃肿的解决方案是将函数放在元组中。正如Bakuriu所提到的,您可能希望使枚举可调用。

from enum import Enum

def functionA():
    pass

def functionB():
    pass

class AvailableFunctions(Enum):
    OptionA = (functionA,)
    OptionB = (functionB,)

    def __call__(self, *args, **kwargs):
        self.value[0](*args, **kwargs)

现在您可以像这样使用它:
AvailableFunctions.OptionA() # calls functionA

2
除了 Bakuriu 的回答之外,如果您使用类似上面的包装器方法,您将失去有关原始函数的信息,例如 __name____repr__ 等等。在包装后,这会导致问题,例如如果您想要使用sphinx来生成源代码文档。因此,请将以下内容添加到您的包装器类中。
class wrapper:
    def __init__(self, function):
        self.function = function
        functools.update_wrapper(self, function)
    def __call__(self,*args, **kwargs):
        return self.function(*args, **kwargs)
    def __repr__(self):
        return self.function.__repr__()

2
@bakuriu的方法基础上,我想强调我们也可以使用多个函数作为值的字典,并且实现更广泛的多态性,类似于Java中的枚举。以下是一个虚构的示例,以说明我的意思:
from enum import Enum, unique

@unique
class MyEnum(Enum):
    test = {'execute': lambda o: o.test()}
    prod = {'execute': lambda o: o.prod()}

    def __getattr__(self, name):
        if name in self.__dict__:
            return self.__dict__[name]
        elif not name.startswith("_"):
            value = self.__dict__['_value_']
            return value[name]
        raise AttributeError(name)

class Executor:
    def __init__(self, mode: MyEnum):
        self.mode = mode

    def test(self):
        print('test run')

    def prod(self):
        print('prod run')

    def execute(self):
        self.mode.execute(self)

Executor(MyEnum.test).execute()
Executor(MyEnum.prod).execute()

显然,当只有一个函数时,字典方法并不提供额外的好处,因此在存在多个函数时使用此方法。确保所有值的键都是统一的,否则使用就不会具有多态性。 __getattr__ 方法是可选的,它只是为了语法糖(即,如果没有它,mode.execute() 将变成 mode.value['execute']()
由于字典不能被设置为只读,因此使用 namedtuple 将更好,并且对上述内容只需要进行微小的更改。
from enum import Enum, unique
from collections import namedtuple

EnumType = namedtuple("EnumType", "execute")

@unique
class MyEnum(Enum):
    test = EnumType(lambda o: o.test())
    prod = EnumType(lambda o: o.prod())

    def __getattr__(self, name):
        if name in self.__dict__:
            return self.__dict__[name]
        elif not name.startswith("_"):
            value = self.__dict__['_value_']
            return getattr(value, name)
        raise AttributeError(name)

我刚刚通过你的例子理解了什么是语法糖,谢谢。 - Ondřej Javorský

0
只是为了补充一下:有一种方法可以在下游代码中不使用“partial”或“member”的情况下使其工作,但你必须深入了解元类和“Enum”的实现。在其他答案的基础上,这样做是可行的:
from enum import Enum, EnumType, _EnumDict, member
import inspect


class _ExtendedEnumType(EnumType):
    # Autowraps class-level functions/lambdas in enum with member, so they behave as one would expect
    # I.e. be a member with name and value instead of becoming a method
    # This is a hack, going deep into the internals of the enum class
    # and performing an open-heart surgery on it...
    def __new__(metacls, cls: str, bases, classdict: _EnumDict, *, boundary=None, _simple=False, **kwds):
        non_members = set(classdict).difference(classdict._member_names)
        for k in non_members:
            if not k.startswith("_"):
                if classdict[k].__class__ in [classmethod, staticmethod]:
                    continue
                # instance methods don't make sense for enums, and may break callable enums
                if "self" in inspect.signature(classdict[k]).parameters:
                    raise TypeError(
                        f"Instance methods are not allowed in enums but found method"
                        f" {classdict[k]} in {cls}"
                    )
                # After all the input validation, we can finally wrap the function
                # For python<3.11, one should use `functools.partial` instead of `member`
                callable_val = member(classdict[k])
                # We have to use del since _EnumDict doesn't allow overwriting
                del classdict[k]
                classdict[k] = callable_val
                classdict._member_names[k] = None
        return super().__new__(metacls, cls, bases, classdict, boundary=boundary, _simple=_simple, **kwds)

class ExtendedEnum(Enum, metaclass=_ExtendedEnumType):
    pass

现在你可以做的是:
class A(ExtendedEnum):
    a = 3
    b = lambda: 4
    
    @classmethod
    def afunc(cls):
        return 5
    
    @staticmethod
    def bfunc():
        pass

一切都会按预期进行。
附注:为了一些更多的枚举魔法,我也喜欢添加。
    def __getitem__(cls, item):
        if hasattr(item, "name"):
            item = item.name
        # can't use [] because of particularities of super()
        return super().__getitem__(item)

为了使`A[A.a]`能够正常工作,建议将其扩展为`_ExtendedEnumType`。
同时,我还建议按照上述提议,将其设计为可调用的枚举类型。

`


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接