在自定义类对象列表中使用 __contains__ 方法

6

我有一个简单的类定义如下:

class User(object):
    def __init__(self, id=None, name=None):
        self.id = id
        self.name = name

    def __contains__(self, item):
        return item == self.id

使用这个类,我可以对单个实例进行简单的检查,像这样:
>>> user1 = User(1, 'name_1')
>>> 1 in user1
True
>>> 2 in user1
False

这是按预期工作的。

但如何检查一个值是否在一个User对象列表中呢?它似乎总是返回False。

例如:

from random import randint
from pprint import pprint
users = [User(x, 'name_{}'.format(x)) for x in xrange(5)]
pprint(users, indent=4)

for x in xrange(5):
    i = randint(2,6) 
    if i in users:
        print("User already exists: {}".format(i))
    else:
        print("{} is not in users list. Creating new user with id: {}".format(i, i))
        users.append(User(i, 'new_user{}'.format(i)))
pprint(users, indent=4)

这将创建类似于以下输出的结果:
[   0 => name_0, 
    1 => name_1, 
    2 => name_2, 
    3 => name_3, 
    4 => name_4]
6 is not in users list. Creating new user with id: 6
6 is not in users list. Creating new user with id: 6
6 is not in users list. Creating new user with id: 6
3 is not in users list. Creating new user with id: 3
3 is not in users list. Creating new user with id: 3
[   0 => name_0,
    1 => name_1,
    2 => name_2,
    3 => name_3,
    4 => name_4,
    6 => new_user6,
    6 => new_user6,
    6 => new_user6,
    3 => new_user3,
    3 => new_user3]

问题在于id为6的用户应该只被创建1次,因为它之前没有被创建过。第二次和第三次尝试创建6时应该失败。id为3的用户根本不应该被重新创建,因为它是users变量的初始化的一部分。
我该如何修改我的__contains__方法以便能够正确地使用in来与我的类的多个实例进行比较?

2
if any(i in u for u in users) - khelwood
3
你确定要实现__contains__来检查ID吗?我认为像123 in user这样的表达并不是很直观。 - Andrea Corbellini
你想让 'name_1' in user1 也为 True 吗? - Peter Wood
@PeterWood,不,它只需要检查id - Andy
3个回答

6
如果users是用户列表并且您检查if i in users,那么您不是在检查User.__contains__。您正在检查list.__contains__。无论您在User.__contains__中做什么都不会影响检查i是否在列表中的结果。
如果您想检查i是否与users中的任何User匹配,则可以执行以下操作:
if any(i in u for u in users)

或者更明确一点:
if any(u.id==i for u in users)

尽量避免使用User.__contains__


6
这似乎是一个你真正想要定义__eq__以接受对其他User对象和int的比较的情况。这将使包含User集合的检查自动工作,并且在一般用法中比在非容器类型上实现__contains__更有意义。
import sys
from operator import index

class User(object):  # Explicit inheritance from object can be removed for Py3-only code
    def __init__(self, id=None, name=None):
        self.id = id
        self.name = name

    def __eq__(self, item):
        if isinstance(item, User):
            return self.id == item.id and self.name == item.name
        try:
            # Accept any int-like thing
            return self.id == index(item)
        except TypeError:
            return NotImplemented

    # Canonical mirror of __eq__ only needed on Py2; Py3 defines it implicitly
    if sys.version_info < (3,):
        def __ne__(self, other):
            equal = self.__eq__(other)
            return equal if equal is NotImplemented else not equal

    def __hash__(self):
        return self.id

现在你可以将你的类型与普通集合(包括set和dict键)一起使用,并且可以轻松查找。
from operator import attrgetter

# Use set for faster lookup; can sort for display when needed
users = {User(x, 'name_{}'.format(x)) for x in xrange(5)}
pprint(sorted(users, key=attrgetter('id')), indent=4)

for x in xrange(5):
    i = randint(2,6) 
    if i in users:
        print("User already exists: {}".format(i))
    else:
        print("{} is not in users list. Creating new user with id: {}".format(i, i))
        users.add(User(i, 'new_user{}'.format(i)))
pprint(sorted(users, key=attrgetter('id')), indent=4)

这个user = {...}行不会失败吗,因为没有定义__getitem__ - NewGuy
@NewGuy: 为什么需要 __getitem__?这是序列和映射的特殊方法; User 都不是。要使对象成为 set 的成员(现在是集合推导而不是列表推导),它必须定义 __eq____hash__;除非您想创建自己的类似于 listdict 的类,否则不需要 __getitem__ - ShadowRanger

5
这是对__contains__的误用。你希望在类(如UserList)上实现__contains__
更好的方式是直接在生成器表达式或列表推导中访问id属性(而不使用in运算符)。例如:
class User(object):
    def __init__(self, id=None, name=None):
        self.id = id
        self.name = name

user = User(1, 'name_1')
assert 1 == user.id

user_list = [user, User(2, 'name_2')]
assert any(2 == u.id for u in user_list)

那么对于你的随机示例,你可以使用一个set或dictionary来存储已存在的用户的ID。

users = [User(x, 'name_{}'.format(x)) for x in xrange(5)]
ids = set(u.id for u in users)

for x in xrange(5):
    i = randint(2,6) 
    if i in ids:
        print("User id already exists: {}".format(i))
    else:
        print("{} is not in users list. Creating new user with id: {}".format(i, i))
        ids.add(i)
        users.append(User(i, 'new_user{}'.format(i)))

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接