在我的单元测试中,我想检查两个数组是否相同。以下是简化示例:
a = np.array([1, 2, np.NaN])
b = np.array([1, 2, np.NaN])
if np.all(a==b):
print 'arrays are equal'
这不起作用是因为nan != nan
。
应该如何继续?
在我的单元测试中,我想检查两个数组是否相同。以下是简化示例:
a = np.array([1, 2, np.NaN])
b = np.array([1, 2, np.NaN])
if np.all(a==b):
print 'arrays are equal'
这不起作用是因为nan != nan
。
应该如何继续?
在numpy 1.19之前的版本中,对于不涉及单元测试的情况,这可能是最佳方法:
>>> ((a == b) | (numpy.isnan(a) & numpy.isnan(b))).all()
True
然而,现代版本提供了一个新的关键字参数equal_nan
,并使用array_equal
函数完美地实现了这一目的。
这是由flyingdutchman首先指出的;有关详细信息,请参见下面他的回答。
或者你可以使用numpy.testing.assert_equal
或numpy.testing.assert_array_equal
与try/except
一起使用:
In : import numpy as np
In : def nan_equal(a,b):
...: try:
...: np.testing.assert_equal(a,b)
...: except AssertionError:
...: return False
...: return True
In : a=np.array([1, 2, np.NaN])
In : b=np.array([1, 2, np.NaN])
In : nan_equal(a,b)
Out: True
In : a=np.array([1, 2, np.NaN])
In : b=np.array([3, 2, np.NaN])
In : nan_equal(a,b)
Out: False
编辑
因为您正在使用它进行单元测试,裸的 assert
(而不是将其包装以获取 True/False
)可能更自然。
np.testing.assert_equal(a,b)
,如果它引发异常,则测试失败(无错误),我甚至可以得到一个漂亮的打印输出,其中包含差异和不匹配。谢谢。 - saroelenumpy.testing.assert_*
没有遵循 Python assert
的相同语义。在普通的Python中,只有在 __debug__
为 True 时(即在没有使用 -O 标志运行脚本时)才会引发 AssertionError
异常,请参见 文档。因此,我强烈不建议将 AssertionErrors
包装为流程控制。当然,由于我们在测试套件中,最好的解决方案是保持 numpy.testing.assert 不变。 - Stefano Mnumpy.testing.assert_equal()
的文档并没有明确说明它认为 NaN 等于 NaN(而 numpy.testing.assert_array_equal()
则有):这个信息在其他地方有记录吗? - Eric O. Lebigotnan = nan
的情况?即使数组相同,包括dtype,我仍然得到一个“AssertionError: Arrays are not equal”的错误提示。 - thinwybknumpy.allclose()
方法,它允许在存在nan值时指定行为。那么你的示例将如下所示:a = np.array([1, 2, np.nan])
b = np.array([1, 2, np.nan])
if np.allclose(a, b, equal_nan=True):
print('arrays are equal')
那么将打印“arrays are equal
”。
您可以在此处找到相关文档。
bool
而不是引发AssertionError
。我需要这个来实现一个带有数组属性的类的__eq__(...)
。 - Bas Swinckelsrtol=0, atol=0
以避免它将接近的值视为相等的问题(由 @senderle 提到)。因此:np.allclose(a, b, equal_nan=True, rtol=0, atol=0)
。 - Claudeequal_nan
参数。
示例代码如下:a = np.array([1, 2, np.NaN])
b = np.array([1, 2, np.NaN])
assert np.array_equal(a, b, equal_nan=True)
但需要注意的是,如果一个元素的数据类型是object
,那么这种方法将无法生效。不确定是否存在bug。
numpy.ma.all
或numpy.ma.allclose
:a=np.array([1, 2, np.NaN])
b=np.array([1, 2, np.NaN])
np.ma.all(np.ma.masked_invalid(a) == np.ma.masked_invalid(b)) #True
np.ma.masked_where(np.isnan(a), a)
,否则你将无法比较无限值。 - John Zwincka=np.array([1, 2, np.NaN])
和 b=np.array([1, np.NaN, 2])
进行了测试,它们显然不相等,但是 np.ma.all(np.ma.masked_invalid(a) == np.ma.masked_invalid(b))
仍然返回 True,所以如果你使用这种方法,请注意这一点。 - tavo补充一下@Luis Albert Centeno的回答,你可能更愿意使用以下方式:
np.allclose(a, b, rtol=0, atol=0, equal_nan=True)
rtol
和atol
控制相等测试的容差。简而言之,allclose()
返回:
all(abs(a - b) <= atol + rtol * abs(b))
True
。
is
进行恒等性测试:a is b
((a == b) | (numpy.isnan(a) & numpy.isnan(b))).all()
在评估字符串列表时,它给我一些错误。
这更具有类型通用性:
def EQUAL(a,b):
return ((a == b) | ((a != a) & (b != b)))
对我来说,这个方法很好用:
a = numpy.array(float('nan'), 1, 2)
b = numpy.array(2, float('nan'), 2)
numpy.equal(a, b, where =
numpy.logical_not(numpy.logical_or(
numpy.isnan(a),
numpy.isnan(b)
))
).all()
PS. 当存在NaN时忽略比较
如果您在像单元测试这样的事情上这样做,那么您不太关心性能和所有类型的“正确”行为,您可以使用此来拥有适用于所有类型的数组而不仅仅是数字的东西:
a = np.array(['a', 'b', None])
b = np.array(['a', 'b', None])
assert list(a) == list(b)
将 ndarray
转换为 list
有时可以在某些测试中获得所需的行为。 (但不要在生产代码或较大的数组中使用此功能!)
a
和b
设置为np.array([1, np.nan])
。 - wjandrea
%timeit
进行测试,结果为23.7微秒对比1.01毫秒。 - AllanLRH