这里的问题(您可能已经知道,但还是重复一下)是
list.index
的工作方式如下:
for idx, item in enumerate(your_list):
if item == wanted_item:
return idx
该行代码
if item == wanted_item
是问题所在,因为它会隐式地将
item == wanted_item
转换为布尔值。但是,
numpy.ndarray
(除非它是标量)会引发以下错误:
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
解决方案 1:适配器(薄包装)类
通常情况下,我会在需要使用 python 函数(如 list.index
)时,在numpy.ndarray
周围使用一个薄包装(适配器):
class ArrayWrapper(object):
__slots__ = ["_array"]
def __init__(self, array):
self._array = array
def __eq__(self, other_array):
return np.array_equal(self._array, other_array)
def __array__(self):
return self._array
def __repr__(self):
return repr(self._array)
这些薄包装比手动使用一些
enumerate
循环或推导式更昂贵,但您不必重新实现Python函数。假设列表仅包含numpy数组(否则您需要进行一些
if ... else ...
检查):
list_of_wrapped_arrays = [ArrayWrapper(arr) for arr in list_of_arrays]
在这一步之后,您可以在此列表上使用所有Python函数:
>>> list_of_arrays = [np.ones((3, 3)), np.ones((3)), np.ones((3, 3)) * 2, np.ones((3))]
>>> list_of_wrapped_arrays.index(np.ones((3,3)))
0
>>> list_of_wrapped_arrays.index(np.ones((3)))
1
这些包装器不再是numpy数组,但它们是薄包装器,因此额外的列表非常小。因此,根据您的需求,您可以保留封装的列表和原始列表,并选择在哪个列表上执行操作,例如,您现在也可以使用
list.count
计算相同的数组:
>>> list_of_wrapped_arrays.count(np.ones((3)))
2
或者
list.remove
:
>>> list_of_wrapped_arrays.remove(np.ones((3)))
>>> list_of_wrapped_arrays
[array([[ 1., 1., 1.],
[ 1., 1., 1.],
[ 1., 1., 1.]]),
array([[ 2., 2., 2.],
[ 2., 2., 2.],
[ 2., 2., 2.]]),
array([ 1., 1., 1.])]
解决方案2:子类和{{link1:
ndarray.view
}}
该方法使用
numpy.array
的显式子类。它的优点是您可以获得所有内置的数组功能,并且只修改所请求的操作(即
__eq__
)。
class ArrayWrapper(np.ndarray):
def __eq__(self, other_array):
return np.array_equal(self, other_array)
>>> your_list = [np.ones(3), np.ones(3)*2, np.ones(3)*3, np.ones(3)*4]
>>> view_list = [arr.view(ArrayWrapper) for arr in your_list]
>>> view_list.index(np.array([2,2,2]))
1
再次以这种方式获取大多数列表方法:`list.remove`,`list.count`,除了`list.index`。
然而,如果某些操作隐式使用`__eq__`,则此方法可能会产生微妙的行为。您可以始终通过使用`np.asarray`或`.view(np.ndarray)`将其重新解释为普通的numpy数组:
>>> view_list[1]
ArrayWrapper([ 2., 2., 2.])
>>> view_list[1].view(np.ndarray)
array([ 2., 2., 2.])
>>> np.asarray(view_list[1])
array([ 2., 2., 2.])
替代方案:重写__bool__
(或对于Python 2,__nonzero__
)
您可以重写__bool__
或__nonzero__
,而不是在__eq__
方法中修复问题:
class ArrayWrapper(np.ndarray):
def __bool__(self):
return bool(np.all(self))
__nonzero__ = __bool__
再次这将使得
list.index
按预期工作:
>>> your_list = [np.ones(3), np.ones(3)*2, np.ones(3)*3, np.ones(3)*4]
>>> view_list = [arr.view(ArrayWrapper) for arr in your_list]
>>> view_list.index(np.array([2,2,2]))
1
但这肯定会改变更多的行为!例如:
>>> if ArrayWrapper([1,2,3]):
... print('that was previously impossible!')
that was previously impossible!
is
替代==
进行比较吗? - Mad Physicistlst = [array]; lst.find(array) # 0
. 这是因为is
检查非常快(指针比较),并且由于在列表中搜索已经有引用的内容是相当常见的,所以Python在回退到==
比较之前进行is
比较。 - mgilson