如何将一个数组与一系列数组进行比较?

5
假设我有一个列表,其中包含许多numpy ndarrays(甚至是torch张量):
a, b, c = np.random.rand(3, 3), np.random.rand(3, 3), np.random.rand(3, 3)
collection = [a, b, c]

现在如果我想检查数组b是否在collection中(假设我不知道哪些数组存在于collection中),那么尝试:b in collection会抛出以下错误:

ValueError: 数组的真实值是模棱两可的,使用a.any()或a.all()

对包含数组的元组也是同样的情况。

解决这个问题的一种方法是使用列表推导式:

True in [(b == x).all() for x in collection]

但是这需要用到一个for循环,我想知道是否有更加“高效”的方法来完成这个任务?

3个回答

3

我会一直使用numpy数组:

import numpy as np
a, b, c = np.random.rand(3, 3), np.random.rand(3, 3), np.random.rand(3, 3)
array = np.dstack([a, b, c])
print(array.shape)
# (3, 3, 3)
np.all(array == b, axis=1).all(axis=1).any()
# True

非常好的解决方案 :) - anon01
@Paul H 最后一行给了我False。也许应该使用np.stack()而不是np.dstack()? - Omar AlSuwaidi

1
您可以在numpy数组中沿着axis=0叠加任意形状的张量,然后使用np.all一次性比较所有剩余轴(这只是PaulH答案的稍微清晰版本):
ugly_shaped_tensor_list = [np.random.rand(3,7,5,3) for j in range(10)]
known_tensor = ugly_shaped_tensor_list[1]

# stack all tensors in a single array along axis=0:
tensor_stack = np.stack(ugly_shaped_tensor_list)

# compare all axes except the "list" axis, 0:
matches = np.all(tensor_stack == known_tensor, axis=(1,2,3,4))
# array([False,  True, False, False, False, False, False, False, False, False])
matches.any()
# True

请问您能否解释一下关键字“axis”具体是指什么?如果将其设置为2,那意味着什么,它到底在比较什么? - Omar AlSuwaidi
1
把它们看作张量的笛卡尔坐标系标签。对于矩阵(嵌套数组),轴0 == 行值,轴1 == 列值。 - anon01

1

好的,这比预期的简单得多...

你可以直接将数组/张量 堆叠 成一个 更高的 维度(在本例中是深度/通道),然后结果就是一个数组,其中包含所有其他数组,但独立存储在“不同的维度”中。

a, b, c = np.random.rand(3, 3), np.random.rand(3, 3), np.random.rand(3, 3)
collection = np.stack((a, b, c))

现在,您可以像将其与列表进行比较一样,在collection中简单地检查b
b in collection
> True

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接