比较包含NaN的numpy数组

100

在我的单元测试中,我想检查两个数组是否相同。以下是简化示例:

a = np.array([1, 2, np.NaN])
b = np.array([1, 2, np.NaN])

if np.all(a==b):
    print 'arrays are equal'

这不起作用是因为nan != nan。 应该如何继续?

10个回答

63

在numpy 1.19之前的版本中,对于不涉及单元测试的情况,这可能是最佳方法:

>>> ((a == b) | (numpy.isnan(a) & numpy.isnan(b))).all()
True

然而,现代版本提供了一个新的关键字参数equal_nan,并使用array_equal函数完美地实现了这一目的。

这是由flyingdutchman首先指出的;有关详细信息,请参见下面他的回答


+1 这个解决方案似乎比我用掩码数组发布的解决方案要快一些,尽管如果您正在创建用于代码其他部分的掩码,则从创建掩码中产生的开销将在ma策略的整体效率中变得不那么重要。 - JoshAdel
谢谢。你的解决方案确实有效,但我更喜欢Avaris建议的numpy内置测试。 - saroele
1
我真的很喜欢这个简单易懂的解决方案。而且,它似乎比@Avaris的解决方案更快。将其转换为lambda函数,并使用Ipython的%timeit进行测试,结果为23.7微秒对比1.01毫秒。 - AllanLRH
@NovicePhysicist,时间选择得很有趣!我想知道这是否与异常处理有关。你测试过正面和负面结果吗?速度可能会因为是否抛出异常而有显著差异。 - senderle
不,我只是进行了一个简单的测试,涉及到我的问题。比较了2D数组和1D向量 - 所以我猜它是按行比较的。但我想在Ipython笔记本中可以很容易地进行大量测试。此外,我为您的解决方案使用了lambda函数,但我认为如果我使用常规函数(通常情况下似乎是这样),它应该会更快一些。 - AllanLRH

50

或者你可以使用numpy.testing.assert_equalnumpy.testing.assert_array_equaltry/except一起使用:

In : import numpy as np

In : def nan_equal(a,b):
...:     try:
...:         np.testing.assert_equal(a,b)
...:     except AssertionError:
...:         return False
...:     return True

In : a=np.array([1, 2, np.NaN])

In : b=np.array([1, 2, np.NaN])

In : nan_equal(a,b)
Out: True

In : a=np.array([1, 2, np.NaN])

In : b=np.array([3, 2, np.NaN])

In : nan_equal(a,b)
Out: False

编辑

因为您正在使用它进行单元测试,裸的 assert(而不是将其包装以获取 True/False)可能更自然。


非常好,这是最优雅和内置的解决方案。我只需在我的单元测试中添加np.testing.assert_equal(a,b),如果它引发异常,则测试失败(无错误),我甚至可以得到一个漂亮的打印输出,其中包含差异和不匹配。谢谢。 - saroele
4
请注意,这种解决方案之所以有效,是因为 numpy.testing.assert_* 没有遵循 Python assert 的相同语义。在普通的Python中,只有在 __debug__ 为 True 时(即在没有使用 -O 标志运行脚本时)才会引发 AssertionError 异常,请参见 文档。因此,我强烈不建议将 AssertionErrors 包装为流程控制。当然,由于我们在测试套件中,最好的解决方案是保持 numpy.testing.assert 不变。 - Stefano M
numpy.testing.assert_equal() 的文档并没有明确说明它认为 NaN 等于 NaN(而 numpy.testing.assert_array_equal() 则有):这个信息在其他地方有记录吗? - Eric O. Lebigot
@EricOLebigot numpy.testing.assert_equal()函数是否考虑nan = nan的情况?即使数组相同,包括dtype,我仍然得到一个“AssertionError: Arrays are not equal”的错误提示。 - thinwybk
当前的官方文档和上面的示例都表明它确实认为 NaN == NaN。我认为最好的方法是您提出一个新的 StackOverflow 问题并提供详细信息。 - Eric O. Lebigot

49
最简单的方法是使用numpy.allclose()方法,它允许在存在nan值时指定行为。那么你的示例将如下所示:
a = np.array([1, 2, np.nan])
b = np.array([1, 2, np.nan])

if np.allclose(a, b, equal_nan=True):
    print('arrays are equal')

那么将打印“arrays are equal”。

您可以在此处找到相关文档。


2
+1 是因为你的解决方案没有重复造轮子。但是,这只适用于类似数字的项目。否则,你会遇到令人讨厌的“TypeError: ufunc 'isfinite' not supported for the input types, and the inputs could not be safely coerced to any supported types according to the casting rule ''safe''”错误。 - MLguy
这是许多情境下的好答案!值得注意的是,即使数组不严格相等,它也会返回true。但大多数情况下都不会有影响。 - senderle
1
+1,因为它返回一个bool而不是引发AssertionError。我需要这个来实现一个带有数组属性的类的__eq__(...) - Bas Swinckels
2
仅作为后续答案的指针:https://dev59.com/x2gv5IYBdhLWcg3wSe62#58709110。添加 rtol=0, atol=0 以避免它将接近的值视为相等的问题(由 @senderle 提到)。因此:np.allclose(a, b, equal_nan=True, rtol=0, atol=0) - Claude

18
numpy的array_equal函数非常适合该问题,可以使用1.19版本中添加的equal_nan参数。 示例代码如下:
a = np.array([1, 2, np.NaN])
b = np.array([1, 2, np.NaN])
assert np.array_equal(a, b, equal_nan=True)

但需要注意的是,如果一个元素的数据类型是object,那么这种方法将无法生效。不确定是否存在bug


10
你可以使用numpy的掩码数组,掩盖NaN值,然后使用numpy.ma.allnumpy.ma.allclose
例如:
a=np.array([1, 2, np.NaN])
b=np.array([1, 2, np.NaN])
np.ma.all(np.ma.masked_invalid(a) == np.ma.masked_invalid(b)) #True

2
感谢让我了解到掩码数组的用法。不过,我更喜欢 Avaris 的解决方案。 - saroele
你应该使用 np.ma.masked_where(np.isnan(a), a),否则你将无法比较无限值。 - John Zwinck
5
我用 a=np.array([1, 2, np.NaN])b=np.array([1, np.NaN, 2]) 进行了测试,它们显然不相等,但是 np.ma.all(np.ma.masked_invalid(a) == np.ma.masked_invalid(b)) 仍然返回 True,所以如果你使用这种方法,请注意这一点。 - tavo
1
该方法仅测试两个没有NaN值的数组是否相同,但不测试NaN是否出现在相同的位置... 使用时可能存在危险。 - WillZ
使用它可能是危险的,这是一个有道理的观点。然而...这是在所有提到的建议中对我有效的唯一解决方案。如果您想比较可能以不同方式掩盖但包含基本相同信息的数据,则这是一个不错的方法。 - jpolly

8

补充一下@Luis Albert Centeno的回答,你可能更愿意使用以下方式:

np.allclose(a, b, rtol=0, atol=0, equal_nan=True)

rtolatol控制相等测试的容差。简而言之,allclose()返回:

all(abs(a - b) <= atol + rtol * abs(b))

默认情况下它们不会被设置为0,因此如果您的数字接近但不完全相等,则该函数可能返回True
注: "我想检查两个数组是否相同" >> 实际上,你要寻找的是相等性而不是恒等性。它们在Python中不同,我认为让每个人都理解这种差别以便分享相同的专业词汇更好。(https://www.blog.pythonlibrary.org/2017/02/28/python-101-equality-vs-identity/
你可以通过关键字is进行恒等性测试:
a is b

7
当我使用上述答案时:
 ((a == b) | (numpy.isnan(a) & numpy.isnan(b))).all()

在评估字符串列表时,它给我一些错误。

这更具有类型通用性:

def EQUAL(a,b):
    return ((a == b) | ((a != a) & (b != b)))

2
截至v1.19,numpy的array_equal函数支持equal_nan参数:
assert np.array_equal(a, b, equal_nan=True)

flyingdutchman已经发布了这个。我只是为了完整性添加了版本号。(顺便修复了你的答案中的版本号) - wjandrea

0

对我来说,这个方法很好用:

a = numpy.array(float('nan'), 1, 2)
b = numpy.array(2, float('nan'), 2)
numpy.equal(a, b, where = 
    numpy.logical_not(numpy.logical_or(
        numpy.isnan(a), 
        numpy.isnan(b)
    ))
).all()

PS. 当存在NaN时忽略比较


-1

如果您在像单元测试这样的事情上这样做,那么您不太关心性能和所有类型的“正确”行为,您可以使用此来拥有适用于所有类型的数组而不仅仅是数字的东西

a = np.array(['a', 'b', None])
b = np.array(['a', 'b', None])
assert list(a) == list(b)

ndarray 转换为 list 有时可以在某些测试中获得所需的行为。 (但不要在生产代码或较大的数组中使用此功能!)


这对数字实际上不起作用。例如,尝试将ab设置为np.array([1, np.nan]) - wjandrea

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接