查找列表中是否包含特定的numpy数组

5
import numpy as np

a = np.eye(2)
b = np.array([1,1],[0,1])

my_list = [a, b]

a in my_list 返回true,但b in my_list 返回 "ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()"。我可以通过先将数组转换为字符串或列表来解决此问题,但是有没有更好(更符合Python风格)的方法?

2个回答

3
问题在于,在numpy中,==运算符返回一个数组:
>>> a == b
array([[ True, False],
       [ True,  True]], dtype=bool)

您可以使用 .array_equal()来将数组与纯布尔值进行比较。
>>> any(np.array_equal(a, x) for x in my_list)
True
>>> any(np.array_equal(b, x) for x in my_list)
True
>>> any(np.array_equal(np.array([a, a]), x) for x in my_list)
False
>>> any(np.array_equal(np.array([[0,0],[0,0]]), x) for x in my_list)
False

如果我检查第一个值(如上面的a in my_list)成功的原因是由于短路:如果数组中的每个元素在第一次检查时都为真,则不会检查列表的其余部分,这是它成功的原因吗? - Chris Midgley
@ChrisMidgley:是的,它是短路计算(顺便提一下,any()函数也是短路计算)。此外,如果所有元素都为True,则布尔值为True是明确无误的。但是,如果是True和False的混合,则NumPy无法决定隐式转换,从而引发错误。 - kennytm

-1

关于问题的更多信息。如果您使用以下方式形成my_list:

my_list = [b,a] 

失败的那个...是一个有趣的问题。


如果您想了解为什么,请查看/谷歌PyObject_RichCompareBool的文档。 - seberg

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接