比较多个numpy数组

14

如何比较超过2个的numpy数组?

import numpy 
a = numpy.zeros((512,512,3),dtype=numpy.uint8)
b = numpy.zeros((512,512,3),dtype=numpy.uint8)
c = numpy.zeros((512,512,3),dtype=numpy.uint8)
if (a==b==c).all():
     pass

这会引发一个ValueError错误,而且我不想一次比较两个数组。


3
“我不想逐一比较数组” - 那么这是你最好的选择。 - user2357112
如果有许多数组,那么@user2357112这并不容易。 - Jayanth Reddy
6个回答

12

对于三个数组,你可以检查第一个和第二个数组之间对应元素的相等性,然后检查第二个和第三个数组之间对应元素的相等性,这将给我们两个布尔标量,最后查看这两个标量是否都为True,以得到最终标量输出,代码如下:

```python scalar_output = np.all(arr1 == arr2) and np.all(arr2 == arr3) ```
np.logical_and( (a==b).all(), (b==c).all() )

如需比较多个数组,可以将它们堆叠起来,在堆叠的轴上进行差分运算,并检查所有这些差分是否都等于零。如果是,则所有输入数组之间相等,否则不相等。实现代码如下:

L = [a,b,c]    # List of input arrays
out = (np.diff(np.vstack(L).reshape(len(L),-1),axis=0)==0).all()

1
重塑(reshape)函数中的 len(L),-1 参数是必需的吗?即使没有这个参数,它也能正常工作。 - Jayanth Reddy
@JayanthReddy 不是的,它不包含。该轴将包含与堆叠轴合并的输入数组的第一个轴。因此,我们需要使用该重塑将该轴分成两个部分。为了举例说明,请尝试使用a = np.random.randint(0,9,(4,5,3)); b = a.copy(); c = a.copy() - Divakar

8

如果有三个数组,你应该只比较其中两个:

if np.array_equal(a, b) and np.array_equal(b, c):
    do_whatever()

如果有多个数组,我们假设它们都合并成了一个大数组arrays。那么你可以这样做:

if np.all(arrays[:-1] == arrays[1:]):
    do_whatever()

@JayanthReddy:可能是因为你把“arrays”变成了列表或其他什么东西。 - user2357112

6

扩展之前的答案,我会使用itertools中的combinations方法来构造所有可能的对,然后对每个对进行比较。例如,如果我有三个数组并且想要确认它们都相等,我会使用:

from itertools import combinations

for pair in combinations([a, b, c], 2):
    assert np.array_equal(pair[0], pair[1])

1
构造所有的对并不是必要的(因为a=b和b=c意味着a=c)。一个简单的for循环一次比较两个数组会更加计算效率。 - Jayanth Reddy
3
没错!这明显是低效的。只是一种直观透明地将每个项目与其他项目比较的方法。 - Elsewhere
1
这个答案适用于 float 数组,你应该使用 allclose 而不是 array_equalallclose 不是传递的。@Elsewhere 也许你可以根据这个扩展你的答案,因为问题并没有明确要求比较相等,也没有限制为 int 数组。 - a_guest

1

一行代码解决方案:

arrays = [a, b, c]    
all([np.array_equal(a, b) for a, b in zip(arrays, arrays[1:])])

我们测试相邻数组之间的相等性。

虽然这段代码可能回答了问题,但建议添加一些说明它的作用。 - dan1st

1

支持不同形状和NaN的解决方案

与数组列表的第一个元素进行比较:

import numpy as np

a = np.arange(3)
b = np.arange(3)
c = np.arange(3)
d = np.arange(4)

lst_eq = [a, b, c]
lst_neq = [a, b, d]

def all_equal(lst):
    for arr in lst[1:]:
        if not np.array_equal(lst[0], arr, equal_nan=True):
            return False
    return True

print('all_equal(lst_eq)=', all_equal(lst_eq))
print('all_equal(lst_neq)=', all_equal(lst_neq))

output

all_equal(lst_eq)= True
all_equal(lst_neq)= False

对于形状相等且不支持nan的情况

将所有内容合并到一个数组中,沿着新轴计算绝对差异,并检查沿着新维度的最大元素是否等于0或低于某个阈值。这应该非常快。

import numpy as np

a = np.arange(3)
b = np.arange(3)
c = np.arange(3)
d = np.array([0, 1, 3])

lst_eq = [a, b, c]
lst_neq = [a, b, d]

def all_equal(lst, threshold = 0):
    arr = np.stack(lst, axis=0)

    return np.max(np.abs(np.diff(arr, axis=0))) <= threshold

print('all_equal(lst_eq)=', all_equal(lst_eq))
print('all_equal(lst_neq)=', all_equal(lst_neq))

output

all_equal(lst_eq)= True
all_equal(lst_neq)= False

0

这可能会起作用。

import numpy

x = np.random.rand(10)
arrays = [x for _ in range(10)]

print(np.allclose(arrays[:-1], arrays[1:]))  # True

arrays.append(np.random.rand(10))

print(np.allclose(arrays[:-1], arrays[1:]))  # False

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接