numpy.testing.assert_array_equal无法比较两个相同的不规则数组。

5
我有两个NumPy数组,并且想要测试它们是否相等。
下面的代码可以正确地实现:
# this works
x = np.array([np.array(['a', 'b']), np.array(['c', 'd'])], dtype='object')
y = np.array([np.array(['a', 'b']), np.array(['c', 'd'])], dtype='object')
assert np.testing.assert_array_equal(x,y)

如果其中一个内部数组是不规则的,则比较会失败:
# this works
x = np.array([np.array(['a', 'b']), np.array(['c'])], dtype='object')
y = np.array([np.array(['a', 'b']), np.array(['c'])], dtype='object')
np.testing.assert_array_equal(x,y)

Traceback (most recent call last):
  File "/home/.../test.py", line 12, in <module>
    np.testing.assert_array_equal(x,y)
  File "/home/.../lib/python3.9/site-packages/numpy/testing/_private/utils.py", line 932, in assert_array_equal
    assert_array_compare(operator.__eq__, x, y, err_msg=err_msg,
  File "/home/.../lib/python3.9/site-packages/numpy/testing/_private/utils.py", line 842, in assert_array_compare
    raise AssertionError(msg)
AssertionError: 
Arrays are not equal

Mismatched elements: 1 / 1 (100%)
 x: array([array(['a', 'b'], dtype='<U1'), array(['c'], dtype='<U1')],
      dtype=object)
 y: array([array(['a', 'b'], dtype='<U1'), array(['c'], dtype='<U1')],
      dtype=object)

更新:

为了让故事更加晦涩,以下内容也可以正常工作:

x = np.array([np.array(['a', 'b']), np.array(['c'])], dtype='object')
y = x
np.testing.assert_array_equal(x,y)

这是正确的行为吗?


1
在两种情况下显示 x==y(如果出现错误,则显示回溯信息)。 - hpaulj
1
如果你设置 x2 = x 然后运行 np.testing.assert_array_equal(x,x2),它会通过测试。但是,如果你重新初始化 x 为相同的不规则数组,则 np.testing.assert_array_equal(x,x2) 将失败。 - convolutionBoy
你是对的:我会更新问题。这真的很奇怪。 - Angelo
1个回答

0
在第一种情况下,数组是(2,2)(尽管是对象数据类型):
In [20]: x = np.array([np.array(['a', 'b']), np.array(['c', 'd'])], dtype='object')
    ...: y = np.array([np.array(['a', 'b']), np.array(['c', 'd'])], dtype='object')
In [21]: x
Out[21]: 
array([['a', 'b'],
       ['c', 'd']], dtype=object)
In [22]: x.shape
Out[22]: (2, 2)
In [23]: x==y
Out[23]: 
array([[ True,  True],
       [ True,  True]])

断言只需验证此比较的所有元素均为True

第二种情况:

In [24]: x = np.array([np.array(['a', 'b']), np.array(['c'])], dtype='object')
    ...: y = np.array([np.array(['a', 'b']), np.array(['c'])], dtype='object')
In [25]: x
Out[25]: 
array([array(['a', 'b'], dtype='<U1'), array(['c'], dtype='<U1')],
      dtype=object)
In [26]: x.shape
Out[26]: (2,)
In [27]: x==y
<ipython-input-27-051436df861e>:1: DeprecationWarning: elementwise comparison failed; 
 this will raise an error in the future.
  x==y
Out[27]: False

结果是一个标量,而不是一个(2,)数组。x==x产生True,带有相同的警告。

数组元素可以逐对比较:

In [30]: [i==j for i,j in zip(x,y)]
Out[30]: [array([ True,  True]), array([ True])]

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接