如何检查一个包含多个numpy数组的列表中是否包含给定的测试数组?

4

我有一个numpy数组列表,比如说:

a = [np.random.rand(3, 3), np.random.rand(3, 3), np.random.rand(3, 3)]

and I have a test array, say

b = np.random.rand(3, 3)

我想要检查 a 中是否包含 b。然而,我希望这个过程可以更简单易懂。
b in a 

抛出以下错误:

ValueError: 具有多个元素的数组的真值是不明确的。请使用 a.any() 或 a.all()

我想要的正确方式是什么?


你尝试过列表推导式吗? - Divakar
你说的列表推导式是什么意思?在我看来,列表推导式的意思类似于... - Guldam Kwak
[ a for a in some_iterable] - Guldam Kwak
在这个任务中,列表推导式的作用是什么? - Guldam Kwak
为什么不将a定义为一个3x3x3的数组? - Nils Werner
6个回答

5

您可以将 a 转换为形状为 (3, 3, 3) 的一维数组:

a = np.asarray(a)

然后将其与b进行比较(由于我们正在比较浮点数,因此应使用isclose()

np.all(np.isclose(a, b), axis=(1, 2))

例如:

a = [np.random.rand(3,3),np.random.rand(3,3),np.random.rand(3,3)]
a = np.asarray(a)
b = a[1, ...]       # set b to some value we know will yield True

np.all(np.isclose(a, b), axis=(1, 2))
# array([False,  True, False])

1

正如@jotasi所指出的那样,由于数组内的逐元素比较,真值是不确定的。此前有一个关于这个问题的答案在这里。总体而言,您可以通过以下各种方式完成任务:

  1. 列表转数组:

您可以将列表转换为(3,3,3)形状的数组,然后使用“in”运算符,如下所示:

    >>> a = [np.random.rand(3, 3), np.random.rand(3, 3), np.random.rand(3, 3)]
    >>> a= np.asarray(a)
    >>> b= a[1].copy()
    >>> b in a
    True
  1. np.all:

    >>> any(np.all((b==a),axis=(1,2)))
    True
    
  2. list-comperhension: This done by iterating over each array:

    >>> any([(b == a_s).all() for a_s in a])
    True
    
以下是上述三种方法的速度比较:

速度比较

import numpy as np
import perfplot

perfplot.show(
    setup=lambda n: np.asarray([np.random.rand(3*3).reshape(3,3) for i in range(n)]),
    kernels=[
        lambda a: a[-1] in a,
        lambda a: any(np.all((a[-1]==a),axis=(1,2))),
        lambda a: any([(a[-1] == a_s).all() for a_s in a])
        ],
    labels=[
        'in', 'np.all', 'list_comperhension'
        ],
    n_range=[2**k for k in range(1,20)],
    xlabel='Array size',
    logx=True,
    logy=True,
    )

0
好的,所以in不起作用,因为它实际上正在执行
def in_(obj, iterable):
    for elem in iterable:
        if obj == elem:
            return True
    return False

现在的问题是,对于两个ndarrays aba == b 是一个数组(试一下),而不是布尔值,所以 if a == b 会失败。解决方案是定义一个新函数。
def array_in(arr, list_of_arr):
     for elem in list_of_arr:
        if (arr == elem).all():
            return True
     return False

a = [np.arange(5)] * 3
b = np.ones(5)

array_in(b, a) # --> False

0
这个错误是因为如果abnumpy数组,那么a == b不会返回TrueFalse,而是在逐个比较ab的元素后返回boolean值的array
你可以尝试像这样做:
np.any([np.all(a_s == b) for a_s in a])
  • [np.all(a_s == b) for a_s in a] 在这里,您正在创建一个布尔值列表,遍历 a 的元素,并检查所有元素是否与 ba 的特定元素相同。

  • 使用 np.any 可以检查数组中是否有任何元素为 True


0

正如这个答案所指出的,文档说明:

对于诸如列表、元组、集合、冻结集、字典或collections.deque等容器类型,表达式x in y等价于any(x is e or x == e for e in y)。

a[0]==b是一个数组,其中包含a[0]b的逐元素比较。这个数组的整体真值显然是模棱两可的。如果所有元素匹配,它们是否相同?或者大多数匹配,或者至少有一个匹配?因此,numpy强制你明确你的意思。你想知道的是测试所有元素是否相同。你可以使用numpyall方法来实现:

any((b is e) or (b == e).all() for e in a)

或者放在一个函数中:

def numpy_in(arrayToTest, listOfArrays):
    return any((arrayToTest is e) or (arrayToTest == e).all()
               for e in listOfArrays)

0
使用numpy中的array_equal函数。
    import numpy as np
    a = [np.random.rand(3,3),np.random.rand(3,3),np.random.rand(3,3)]
    b = np.random.rand(3,3)

    for i in a:
        if np.array_equal(b,i):
            print("yes")

实际上,(a==b).all() 不比 np.array_equal(a, b) 慢。主要的区别是 np.array_equal 首先会测试数组的形状。 - jotasi

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接