Python字典中枚举作为键

4
假设我有一个枚举变量:
class Color(Enum):
    RED = "RED"
    GREEN = "GREEN"
    BLUE = "BLUE"

我想创建一个ColorDict类,它可以像本地Python字典一样工作,但只接受Color枚举或其相应的字符串值作为键。
d = ColorDict() # I want to implement a ColorDict class such that ...

d[Color.RED] = 123
d["RED"] = 456  # I want this to override the previous value
d[Color.RED]    # ==> 456
d["foo"] = 789  # I want this to produce an KeyError exception

这个“pythonic way”实现ColorDict类应该怎么做?我是使用继承(覆盖Python的本地dict)还是组合(将一个dict作为成员)?

继承或组合真的取决于你。如果使用继承,你将不得不覆盖所有接受输入的方法,如__setitem__.update,这可能很容易。 - juanpa.arrivillaga
这是个人选择,但在这些情况下我通常更喜欢组合。通过明确要公开哪些操作,限制您需要执行的工作量,特别是如果您不关心实现整个字典接口,它可以使界面更易于理解。 - flakes
@Mark 代码片段包含我希望实现的行为,而不是我当前观察到的行为。我更新了注释以使其更清晰。对于造成的困惑,我感到抱歉。 - Ying Xiong
1
另一种选择是继承 collections.abc.MutableMapping,这将涉及组合,但您只需要实现最少量的方法即可。 - juanpa.arrivillaga
谢谢@英雄,我在你发布之前意识到自己看错了。 - Mark
KeyError 是一种查找错误(即找不到某个东西),但赋值操作并不是一种查找。我会考虑使用 ValueError 来表示错误的键值,但说实话,我不确定哪一个更合适。 - VPfB
2个回答

4

一个简单的解决方案是稍微修改你的Color对象,然后子类化dict以添加对键的测试。我会这样做:

class Color(Enum):
    RED = "RED"
    GREEN = "GREEN"
    BLUE = "BLUE"

    @classmethod
    def is_color(cls, color):
        if isinstance(color, cls):
            color=color.value
        if not color in cls.__members__:
            return False
        else:
            return True


class ColorDict(dict):
    
    def __setitem__(self, k, v):
        if Color.is_color(k):
            super().__setitem__(Color(k), v)
        else:
            raise KeyError(f"Color {k} is not valid")

    def __getitem__(self, k):
        if isinstance(k, str):
            k = Color(k.upper())
        return super().__getitem__(k)

d = ColorDict()

d[Color.RED] = 123
d["RED"] = 456
d[Color.RED]
d["foo"] = 789

Color类中,我添加了一个测试函数来返回TrueFalse,以判断颜色是否在允许的列表中。使用upper()函数将字符串转换为大写,以便与预定义的值进行比较。
然后,我子类化了dict对象,重写了__setitem__特殊方法,包括对传递的值进行测试,并覆盖了__getitem__以将任何作为str传递的键转换为正确的Enum。根据您想要使用ColorDict类的具体情况,您可能需要覆盖更多的函数。这里有一个很好的解释:如何正确地子类化字典并覆盖__getitem__和__setitem__

1
小注释 - 使用 test_ 前缀来命名过滤函数(test_color)可能不是一个好主意,因为一些测试框架可能会错误地将其识别为测试用例。is_color 将是一个惯用的过滤器名称。 - Nathaniel Ford
那是一个非常好的观点!我会进行修改。 - defladamouse
1
和@VPfB的评论一样。我们可能需要做类似于super().__setitem__(Color(k), v)这样的事情。 - Ying Xiong
谢谢你们两个,我已经添加了建议的Color(k),这似乎解决了问题。但现在需要覆盖__getitem__以匹配。 - defladamouse

2

一种方法是使用抽象基类collections.abc.MutableMapping,这样,您只需重写抽象方法,然后可以确保访问始终通过您的逻辑进行 - 您也可以使用dict,但例如,覆盖dict.__setitem__将不会影响dict.updatedict.setdefault等...因此,您也必须手动覆盖这些方法。通常,更容易只使用抽象基类:

from collections.abc import MutableMapping
from enum import Enum

class Color(Enum):
    RED = "RED"
    GREEN = "GREEN"
    BLUE = "BLUE"

class ColorDict(MutableMapping):

    def __init__(self): # could handle more ways of initializing  but for simplicity...
        self._data = {}

    def __getitem__(self, item):
        return self._data[color]

    def __setitem__(self, item, value):
        color = self._handle_item(item)
        self._data[color] = value

    def __delitem__(self, item):
        del self._data[color]

    def __iter__(self):
        return iter(self._data)

    def __len__(self):
        return len(self._data)

    def _handle_item(self, item):
        try:
            color = Color(item)
        except ValueError:
            raise KeyError(item) from None
        return color

请注意,您也可以添加以下内容:
    def __repr__(self):
        return repr(self._data)

为了更加便于调试。
在repl中的一个例子:
In [3]: d = ColorDict() # I want to implement a ColorDict class such that ...
   ...:
   ...: d[Color.RED] = 123
   ...: d["RED"] = 456  # I want this to override the previous value
   ...: d[Color.RED]    # ==> 456
Out[3]: 456

In [4]: d["foo"] = 789  # I want this to produce an KeyError exception
   ...:
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-4-9cf80d6dd8b4> in <module>
----> 1 d["foo"] = 789  # I want this to produce an KeyError exception

<ipython-input-2-a0780e16594b> in __setitem__(self, item, value)
     17
     18     def __setitem__(self, item, value):
---> 19         color = self._handle_item(item)
     20         self._data[color] = value
     21

<ipython-input-2-a0780e16594b> in _handle_item(self, item)
     34             color = Color(item)
     35         except ValueError:
---> 36             raise KeyError(item) from None
     37         return color
     38     def __repr__(self): return repr(self._data)

KeyError: 'foo'

In [5]: d
Out[5]: {<Color.RED: 'RED'>: 456}

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接