如何最好地断言numpy.array的相等性?

150

我想为我的应用程序编写一些单元测试,并需要比较两个数组。由于array.__eq__返回一个新数组(因此TestCase.assertEqual失败),那么断言相等的最佳方法是什么?

目前我正在使用:

self.assertTrue((arr1 == arr2).all())

但我并不是很喜欢它


4
注意,你的示例可能会意外地产生True,例如(np.array([1, 1]) == np.array([1])).all()将会返回True。 - M. Bernhardt
3
这段代码的意思是:确保数组array1和array2是相等的。 - Miszo97
9个回答

162

请查看numpy.testing中的assert函数,例如:

assert_array_equal

对于浮点数数组等式测试可能会失败,assert_almost_equal更加可靠。

更新

几个版本之前,numpy获得了assert_allclose,现在它是我最喜欢的函数,因为它允许我们指定绝对误差和相对误差,并且不需要十进制舍入作为紧密标准。


32
这与unittest如何交互?我认为在这方面简要说明一些内容会很有帮助。 - Heberto Mayorquin
我从未使用过unittest。然而,它与numpy、scipy和statsmodels一起非常好地运作。只需在测试函数或方法中使用asserts即可。 - Josef
这并不验证两个参数是否都是numpy数组。例如,它会在一个数组和一个列表上成功。为了测试,验证它们实际上是数组可能很有用,但我猜这需要手动检查类型? - max
不,它并不是这样的。我们只使用它来验证数字,一些函数返回标量、元组、numpy数组或pandas Series。当我们想要验证类或dtype时,我们会单独进行验证。Pandas还有一个assert函数,它也检查索引和列名是否符合预期,据我所知,它使用numpy assert来验证数值,可能使用assert_equal来验证其他数据类型。 - Josef
4
@RamonMartinez assert_allclose 与 unittest 兼容良好 :) - kotakotakota
18
如果您使用Python的unittest,您可以使用self.assertIsNone(np.testing.assert_array_equal(a, b)),因为它会在数组相等时返回None - mjkrause

34

我认为 (arr1 == arr2).all() 看起来很不错。但你也可以使用:

numpy.allclose(arr1, arr2)

但是它并不完全相同。

一个替代方案,几乎与您的示例相同:

numpy.alltrue(arr1 == arr2)

请注意,scipy.array实际上是一个引用numpy.array的对象。这使得查找文档变得更容易。


25

我发现使用self.assertEqual(arr1.tolist(), arr2.tolist())是使用unittest比较数组的最简单方法。

我同意它不是最美观的解决方案,可能也不是最快的,但它可能更符合你其余测试用例,你可以得到所有unittest错误描述,并且实现非常简单。


6
请注意,这种方法无法很好地处理np.nan,因为np.nan != np.nan,而self.assertEqual无法考虑到这一点。 - blacksite

8
self.assertTrue(np.array_equal(x, y, equal_nan=True))

equal_nan = True 如果你希望 np.nan == np.nan 返回 True

或者你可以使用numpy.allclose 来进行容差比较。


8
自 Python 3.2 起,您可以使用 assertSequenceEqual(array1.tolist(), array2.tolist())。这样做的好处是可以显示出两个数组中不同的项。

8
抱歉,当数组为float类型时,它的工作效果不太好。我们确实需要assertSequenceAlmostEqual - Grwlf

7
在我的测试中,我使用以下内容:
numpy.testing.assert_array_equal(arr1, arr2)

这是最好的,因为它提供了一个错误消息,指向错误出现的地方。 - tyrex

3

使用numpy

numpy.array_equal(a, b)

2

np.linalg.norm(arr1 - arr2) < 1e-6


4
请提供一些背景信息。 - Tobias Wilfert

0
使用Python 3.10.12内置的unittest模块可以很好地处理嵌套数组(深度相等)的测试。
self.assertEqual([
    ["1","0","1","1","0","1","1"]
], [
    ["1","0","1","1","0","1","x"]
])

并且它会打印一个友好的失败输出信息。
First differing element 0:
['1', '0', '1', '1', '0', '1', '1']
['1', '0', '1', '1', '0', '1', 'x']

- [['1', '0', '1', '1', '0', '1', '1']]
?                                  ^

+ [['1', '0', '1', '1', '0', '1', 'x']]
?   

根据你的问题,需要注意的是:如果你总是将指针与同一个数组进行比较(或者在原地修改数组然后将其与自身进行比较),结果每次都会返回true...这将是一个错误。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接