在unittest中比较(断言相等)包含numpy数组的两个复杂数据结构

29

我使用Python的unittest模块,想要检查两个复杂数据结构是否相等。这些对象可以是包含各种值的字典列表:数字、字符串、Python容器(列表/元组/字典)和numpy数组。后者是提出此问题的原因,因为我不能只做以下操作:


assert a == b
self.assertEqual(big_struct1, big_struct2)

因为它能够产生一个

ValueError: The truth value of an array with more than one element is ambiguous.
Use a.any() or a.all()

我想我需要编写自己的相等性测试来解决这个问题,适用于任意结构。 我目前的想法是一个递归函数,它:

  • 尝试将 arg1 的当前“节点”直接与 arg2 的相应节点进行比较;
  • 如果没有引发异常,则继续(“终端”节点/叶子也在此处理);
  • 如果捕获到 ValueError ,则会向更深处移动,直到找到 numpy.array
  • 比较数组(例如像这样)。

看起来有点麻烦的是跟踪两个结构的“对应”节点,但也许 zip 就是我需要的全部。

问题是:是否有更好(更简单)的方法来解决这个问题? 也许 numpy 提供了一些工具来处理这个问题?如果没有其他建议,我将实现这个想法(除非我有更好的想法),并作为答案发布。

P.S. 我模糊地感觉到我可能见过一个解决这个问题的问题,但现在我找不到它了。

P.P.S. 另一种方法是遍历结构并将所有 numpy.array 转换为列表,但这是否更容易实现?对我来说似乎是一样的。


编辑:子类化 numpy.ndarray 听起来非常有前途,但显然我没有两个硬编码到测试中的比较。 不过,其中一个确实是硬编码的,所以我可以:

  • 使用自定义的 numpy.array 子类填充它;
  • 按照jterrace的答案中的内容将 isinstance(other,SaneEqualityArray)更改为 isinstance(other,np.ndarray)
  • 始终将其用作比较中的LHS。

我的问题是:

  1. 它会工作吗(我的意思是,对我来说听起来很好,但也许一些棘手的边缘情况处理不正确)? 我的自定义对象是否始终作为递归相等性检查中的LHS,如我所预期的那样?
  2. 同样地,是否有更好的方法(假设我至少得到具有真正numpy数组的结构之一)?

编辑2:我试过了,(看起来)工作正常的实现显示在这个答案中


我想象编写一个适用于任意数据结构的相等性测试可能会非常困难。这些数据结构真的没有固定的结构吗? - loopbackbee
@goncalopp 有好几个,相当复杂,而且理论上可能会发生变化。我不想依赖它,特别是因为即使我知道 X 的位置,我也不知道如何比较两个结构中除了 X 以外的所有内容。 - Lev Levitsky
个人而言,我会选择递归函数方法。不过,我会首先明确检查对象的type类型——如果你的数据结构较大,进行盲目比较可能是可行的,但如果出现ValueError错误,则会浪费时间,因为需要重新检查值。 - loopbackbee
@goncalopp 感谢您的建议。性能并不是关键问题,这只是为了测试目的。我更关心的是尽量减少实施和维护解决方案所需的工作量。 - Lev Levitsky
7个回答

14

虽然我有评论,但是会太长了...

有趣的事实是,你不能使用==来测试数组是否相同。建议您使用np.testing.assert_array_equal

  1. 该函数检查dtype、shape等等,
  2. 它不会因为这个很妙的数学问题 (float('nan') == float('nan')) == False 而失败(通常情况下,Python序列的==也会更有趣,因为它使用PyObject_RichCompareBool进行快速检查,但对于 NaNs 来说是不正确的);
  3. 另外还有assert_allclose,因为如果要进行实际计算,浮点数的相等性可能非常棘手,而通常希望相差很小,因为这些值可能取决于硬件或者由您对它们进行的处理而可能是随机的。

如果您需要像这样深度嵌套的内容,我几乎建议您尝试使用pickle进行序列化,但这过于严格(并且第3点当然完全错误),例如您的数组的内存布局不重要,但在其序列化中却很重要。


使用pickle进行序列化存在其自身的问题......了解numpy.testing实用程序的好处,但我仍然不确定如何在此处应用它们。 - Lev Levitsky

10

assertEqual函数将调用对象的__eq__方法,对于复杂数据类型应该进行递归。例外是numpy,它没有一个明智的__eq__方法。使用这个问题的numpy子类,你可以恢复相等行为的正常:

import copy
import numpy
import unittest

class SaneEqualityArray(numpy.ndarray):
    def __eq__(self, other):
        return (isinstance(other, SaneEqualityArray) and
                self.shape == other.shape and
                numpy.ndarray.__eq__(self, other).all())

class TestAsserts(unittest.TestCase):

    def testAssert(self):
        tests = [
            [1, 2],
            {'foo': 2},
            [2, 'foo', {'d': 4}],
            SaneEqualityArray([1, 2]),
            {'foo': {'hey': SaneEqualityArray([2, 3])}},
            [{'foo': SaneEqualityArray([3, 4]), 'd': {'doo': 3}},
             SaneEqualityArray([5, 6]), 34]
        ]
        for t in tests:
            self.assertEqual(t, copy.deepcopy(t))

if __name__ == '__main__':
    unittest.main()

这个测试通过了。


我宁愿去处理__nonzero__而不是搞乱__eq__。虽然numpy这样做有很好的理由,但这只会引发错误。 - seberg
@LevLevitsky,导致错误的不是__eq__,导致错误并且定义不明确的是__nonzero__(即bool(np.ndarray)),更改__nonzero__可能不会改变工作正常的程序,除非它依赖于抛出的错误。在我看来,这似乎是一个很大的优势... - seberg
@seberg 谢谢,但我没有计划在我测试的函数中更改任何内容。我仍然希望它们将 numpy.array 放入返回值中,而不是其自定义子类。这是针对您关于“更改工作程序”的担忧的回复。 - Lev Levitsky
顺便提一下,您传递给 SaneEqualityArray 构造函数的参数实际上被解释为其形状,而内容并未初始化。我也犯过这个错误:ndarray 不能使用 array 的签名创建。 - Lev Levitsky
我将修改后的代码作为单独的回答添加了进去,它也展示了创建实例的正确(或至少可行)方法。 - Lev Levitsky
显示剩余4条评论

7

所以,jterrace所阐述的想法对我来说似乎是可行的,只需稍作修改:

class SaneEqualityArray(np.ndarray):
    def __eq__(self, other):
        return (isinstance(other, np.ndarray) and self.shape == other.shape and 
            np.allclose(self, other))

正如我所说,具有这些对象的容器应该在等式检查的左侧。我从现有的numpy.ndarray创建SaneEqualityArray对象,方法如下:

SaneEqualityArray(my_array.shape, my_array.dtype, my_array)

根据ndarray构造函数的签名:
ndarray(shape, dtype=float, buffer=None, offset=0,
        strides=None, order=None)

这个类定义于测试套件内,只用于测试目的。等式检查的右侧是由被测试函数返回的实际对象,并包含真正的 numpy.ndarray 对象。
附言:非常感谢迄今发布的两个答案的作者,它们都非常有用。如果有人发现这种方法存在任何问题,我将非常感激您的反馈。

2

我会定义自己的assertNumpyArraysEqual()方法,明确地进行您想要使用的比较。这样,您的生产代码不会改变,但您仍然可以在单元测试中进行合理的断言。请确保在包含__unittest = True的模块中定义它,以便它不会出现在堆栈跟踪中:

import numpy
__unittest = True

def assertNumpyArraysEqual(self, other):
    if self.shape != other.shape:
        raise AssertionError("Shapes don't match")
    if not numpy.allclose(self, other)
        raise AssertionError("Elements don't match!")

谢谢,这是一个好主意,如果两个数组都是被测试的函数生成的话,这将是最佳选择。 - Lev Levitsky

1
我遇到了相同的问题,并开发了一个函数来比较相等性,基于为对象创建一个固定的哈希值。这样做的额外优势是可以通过将其哈希与存储在代码中的固定哈希进行比较,测试对象是否符合预期。
代码(一个独立的Python文件,在这里)。有两个函数:fixed_hash_eq,解决了您的问题,compute_fixed_hash,从结构中生成哈希。测试在这里 下面是一个测试:
obj1 = [1, 'd', {'a': 4, 'b': np.arange(10)}, (7, [1, 2, 3, 4, 5])]
obj2 = [1, 'd', {'a': 4, 'b': np.arange(10)}, (7, [1, 2, 3, 4, 5])]
obj3 = [1, 'd', {'a': 4, 'b': np.arange(10)}, (7, [1, 2, 3, 4, 5])]
obj3[2]['b'][4] = 0
assert fixed_hash_eq(obj1, obj2)
assert not fixed_hash_eq(obj1, obj3)

1

0
在 @dbw 的基础上(感谢),以下方法插入到测试用例子类中对我很有帮助:
 def assertNumpyArraysEqual(self,this,that,msg=''):
    '''
    modified from https://dev59.com/cGYq5IYBdhLWcg3wui_F#15399475
    '''
    if this.shape != that.shape:
        raise AssertionError("Shapes don't match")
    if not np.allclose(this,that):
        raise AssertionError("Elements don't match!")

我在我的测试用例方法中将其称为self.assertNumpyArraysEqual(this,that),并且它的效果非常好。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接