在列表中查找numpy数组的索引

18
import numpy as np
foo = [1, "hello", np.array([[1,2,3]]) ]

我期望

foo.index( np.array([[1,2,3]]) ) 

返回

2

但实际上我得到的是

ValueError: 数组中有多个元素的真值不明确。使用 a.any() 或 a.all()

有没有比我的当前解决方案更好的东西?它似乎效率低下。

def find_index_of_array(list, array):
    for i in range(len(list)):
        if np.all(list[i]==array):
            return i

find_index_of_array(foo, np.array([[1,2,3]]) )
# 2

1
非同质列表只是一个例子,还是你真的有一个包含许多不同类型的列表? - mgilson
1
@mgilson 只是我编造的例子。我正在处理一个等维度的NumPy数组列表。 - Lee88
你能否重构一下代码,使用 is 替代 == 进行比较吗? - Mad Physicist
@MadPhysicist -- 原来如果您使用相同的数组,numpy python _会_做正确的事情。 lst = [array]; lst.find(array) # 0. 这是因为is检查非常快(指针比较),并且由于在列表中搜索已经有引用的内容是相当常见的,所以Python在回退到==比较之前进行is比较。 - mgilson
相关链接:https://dev59.com/qYbca4cB1Zd3GeqPWHUV#27697254 和 https://dev59.com/Wm035IYBdhLWcg3wVebn - Alex Riley
显示剩余7条评论
6个回答

14
这里产生错误的原因显然是因为NumPy的ndarray重写了==,返回的是一个数组而不是布尔值。
据我所知,这里没有简单的解决方案。只要np.all(val == array)这一部分能够正常工作,以下方法就可以奏效。
next((i for i, val in enumerate(lst) if np.all(val == array)), -1)
无论那一位是否起作用都取决于数组中的其他元素以及它们是否可以与numpy数组进行比较。

请注意,这与list.index不同,后者在没有这样的项时会引发ValueError。但这是一个好的简单解决方案! - MSeifert
@MSeifert -- 是的。我想要一个更像str.find的API。如果你想要一个异常,那么你可以只去掉-1部分(只将生成器传递给next)。在这种情况下,如果没有找到,你会得到一个StopIteration - mgilson

3
这个怎么样?
arr = np.array([[1,2,3]])
foo = np.array([1, 'hello', arr], dtype=np.object)

# if foo array is of heterogeneous elements (str, int, array)
[idx for idx, el in enumerate(foo) if type(el) == type(arr)]

# if foo array has only numpy arrays in it
[idx for idx, el in enumerate(foo) if np.array_equal(el, arr)]

输出:

[2]

注意:即使 foo 是一个列表,这个方法仍然适用。我只是在这里把它作为一个 numpy 数组。

OP在评论中表示,真正的列表只会包含数组,因此这最多只是一个预处理步骤。 - Mad Physicist
1
从技术上讲,如果列表中只有一个数组,并且您事先知道该数组是您要查找的数组,则第一种方法才能可靠地工作。 - Mad Physicist
是的。OP的问题不是在问一个numpy数组的索引吗? - kmario23
不,OP的问题是如何让==比较器返回一个布尔值,以便他可以在numpy数组中找到正确的索引,就像你的第二个解决方案一样。 - Mad Physicist

2

为了提高性能,您可能只想处理输入列表中的NumPy数组。因此,在进入循环之前,我们可以进行类型检查并索引数组元素。

因此,实现如下:

def find_index_of_array_v2(list1, array1):
    idx = np.nonzero([type(i).__module__ == np.__name__ for i in list1])[0]
    for i in idx:
        if np.all(list1[i]==array1):
            return i

不幸的是,根据评论,OP的列表一开始就只包含numpy数组,因此这将比优化更加繁琐。 - Mad Physicist
@MadPhysicist 真的吗?我以为 OP 有一个样例 foo = [1, "hello", np.array([[1,2,3]]) ],这是一个混合的列表。我错过了那个提到“..列表只包含numpy数组”的地方吗? - Divakar
是的。这是问题的第三个评论。OP想要最通用的答案,这很有道理,这就是他使用人为构造的数组进行提问的原因,但不幸的是,这确实对你的答案产生了一些阻碍。 - Mad Physicist
@MadPhysicist 我想我会保留它,供将来可能有混合形式输入列表的读者使用。对他们可能有用。 - Divakar

2
这里的问题(您可能已经知道,但还是重复一下)是list.index的工作方式如下:
for idx, item in enumerate(your_list):
    if item == wanted_item:
        return idx

该行代码 if item == wanted_item 是问题所在,因为它会隐式地将 item == wanted_item 转换为布尔值。但是,numpy.ndarray(除非它是标量)会引发以下错误:ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

解决方案 1:适配器(薄包装)类

通常情况下,我会在需要使用 python 函数(如 list.index)时,在numpy.ndarray周围使用一个薄包装(适配器):

class ArrayWrapper(object):

    __slots__ = ["_array"]  # minimizes the memory footprint of the class.

    def __init__(self, array):
        self._array = array

    def __eq__(self, other_array):
        # array_equal also makes sure the shape is identical!
        # If you don't mind broadcasting you can also use
        # np.all(self._array == other_array)
        return np.array_equal(self._array, other_array)

    def __array__(self):
        # This makes sure that `np.asarray` works and quite fast.
        return self._array

    def __repr__(self):
        return repr(self._array)

这些薄包装比手动使用一些enumerate循环或推导式更昂贵,但您不必重新实现Python函数。假设列表仅包含numpy数组(否则您需要进行一些if ... else ... 检查):
list_of_wrapped_arrays = [ArrayWrapper(arr) for arr in list_of_arrays]

在这一步之后,您可以在此列表上使用所有Python函数:
>>> list_of_arrays = [np.ones((3, 3)), np.ones((3)), np.ones((3, 3)) * 2, np.ones((3))]
>>> list_of_wrapped_arrays.index(np.ones((3,3)))
0
>>> list_of_wrapped_arrays.index(np.ones((3)))
1

这些包装器不再是numpy数组,但它们是薄包装器,因此额外的列表非常小。因此,根据您的需求,您可以保留封装的列表和原始列表,并选择在哪个列表上执行操作,例如,您现在也可以使用list.count计算相同的数组:
>>> list_of_wrapped_arrays.count(np.ones((3)))
2

或者 list.remove:

>>> list_of_wrapped_arrays.remove(np.ones((3)))
>>> list_of_wrapped_arrays
[array([[ 1.,  1.,  1.],
        [ 1.,  1.,  1.],
        [ 1.,  1.,  1.]]), 
 array([[ 2.,  2.,  2.],
        [ 2.,  2.,  2.],
        [ 2.,  2.,  2.]]), 
 array([ 1.,  1.,  1.])]

解决方案2:子类和{{link1:ndarray.view}}
该方法使用numpy.array的显式子类。它的优点是您可以获得所有内置的数组功能,并且只修改所请求的操作(即__eq__)。
class ArrayWrapper(np.ndarray):
    def __eq__(self, other_array):
        return np.array_equal(self, other_array)

>>> your_list = [np.ones(3), np.ones(3)*2, np.ones(3)*3, np.ones(3)*4]

>>> view_list = [arr.view(ArrayWrapper) for arr in your_list]

>>> view_list.index(np.array([2,2,2]))
1

再次以这种方式获取大多数列表方法:`list.remove`,`list.count`,除了`list.index`。
然而,如果某些操作隐式使用`__eq__`,则此方法可能会产生微妙的行为。您可以始终通过使用`np.asarray`或`.view(np.ndarray)`将其重新解释为普通的numpy数组:
>>> view_list[1]
ArrayWrapper([ 2.,  2.,  2.])

>>> view_list[1].view(np.ndarray)
array([ 2.,  2.,  2.])

>>> np.asarray(view_list[1])
array([ 2.,  2.,  2.])

替代方案:重写__bool__(或对于Python 2,__nonzero__

您可以重写__bool____nonzero__,而不是在__eq__方法中修复问题:

class ArrayWrapper(np.ndarray):
    # This could also be done in the adapter solution.
    def __bool__(self):
        return bool(np.all(self))

    __nonzero__ = __bool__

再次这将使得list.index按预期工作:
>>> your_list = [np.ones(3), np.ones(3)*2, np.ones(3)*3, np.ones(3)*4]
>>> view_list = [arr.view(ArrayWrapper) for arr in your_list]
>>> view_list.index(np.array([2,2,2]))
1

但这肯定会改变更多的行为!例如:
>>> if ArrayWrapper([1,2,3]):
...     print('that was previously impossible!')
that was previously impossible!

我一直在寻找一种覆盖数组类本身的方法,但这对于现有的对象没有帮助。非常好的解决方案。 - Mad Physicist
@MadPhysicist 是的,你也可以使用ndarray.view和一个子类覆盖__eq__来工作。步骤仍然相同,您需要通过一次遍历来创建这些视图的列表,然后再应用list.index操作。 - MSeifert
是的,这样做的好处是您可以将子类列表用作唯一列表,而无需对周围代码进行进一步修改。 - Mad Physicist
@MSeifert 作为一个Python的初学者,这对我来说在教学方面非常有用。谢谢。 - Lee88

0
这应该可以完成工作:
[i for i,j in enumerate(foo) if j.__class__.__name__=='ndarray']
[2]

0

你可以使用视图来覆盖 equals 方法

import numpy as np

class Vector(np.ndarray):
    def __eq__(self, other: np.ndarray) -> bool:
        return np.array_equal(super(),other)

data=list(np.random.random((100,3)))
element=data[3]

print(data.index(element.view(Vector))) #prints 3
print(element.view(Vector) in data) #prints True

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接