比较NumPy数组以使NaN相等

22
有没有一种惯用的方法来比较两个NumPy数组,将NaN视为相互等同(但不等同于除NaN之外的任何其他值)?
例如,我希望以下两个数组相等:
np.array([1.0, np.NAN, 2.0])
np.array([1.0, np.NAN, 2.0])

同时比较以下两个数组,检查它们是否不相等:

np.array([1.0, np.NAN, 2.0])
np.array([1.0, 0.0, 2.0])
我正在寻找一种能够产生标量布尔结果的方法。 以下代码可以实现:
np.all((a == b) | (np.isnan(a) & np.isnan(b)))

但是这种方法不够优雅,并且会创建所有这些中间数组。

有没有更加简洁并更好地利用内存的方法呢?

另外,如果有帮助的话,这些数组已知具有相同的形状和数据类型。


1
@DanielRoseman:我理解。我有两种生成NumPy数组的方法,我需要知道它们是否生成了相同的数组。 - NPE
1
你已经排除了这个问题中的一个答案;你也要排除另外两个答案吗? - senderle
2
如果你正在使用numpy的当前git版本,则有一个numpy.isclose函数(https://github.com/numpy/numpy/blob/master/numpy/core/numeric.py#L2039),它带有一个`equal_nan`关键字参数(默认为`False`以实现兼容性)。但它不太适合内存。 - Joe Kington
2
如果没有相等但具有不同二进制表示的数字(例如0.0和-0.0),那么memoryview(a0)== memoryview(a1)就可以了。 - DSM
1
@DSM:谢谢你的回答,它确实符合我的使用情况。你介意把它写成一个答案吗? - NPE
显示剩余5条评论
4个回答

18

如果您真的关心内存使用(例如有非常大的数组),那么您应该使用numexpr,并且以下表达式将适用于您:

np.all(numexpr.evaluate('(a==b)|((a!=a)&(b!=b))'))

我已经在长度为3e8的非常大的数组上进行了测试,代码在我的机器上具有相同的性能

np.all(a==b)

并且使用相同数量的内存


9

顺便说一下,这不适用于字符串。比较数组和字符串会抛出以下错误:TypeError("ufunc 'isfinite' not supported for the input types, and the inputs could not be safely coerced to any supported types according to the casting rule ''safe''") - Ian

8

免责声明:我不建议经常使用此方法,我自己也不会使用它,但我可以想象在某些罕见情况下它可能有用。

如果数组具有相同的形状和数据类型,则可以考虑使用低级别的memoryview

>>> import numpy as np
>>> 
>>> a0 = np.array([1.0, np.NAN, 2.0])
>>> ac = a0 * (1+0j)
>>> b0 = np.array([1.0, np.NAN, 2.0])
>>> b1 = np.array([1.0, np.NAN, 2.0, np.NAN])
>>> c0 = np.array([1.0, 0.0, 2.0])
>>> 
>>> memoryview(a0)
<memory at 0x85ba1bc>
>>> memoryview(a0) == memoryview(a0)
True
>>> memoryview(a0) == memoryview(ac) # equal but different dtype
False
>>> memoryview(a0) == memoryview(b0) # hooray!
True
>>> memoryview(a0) == memoryview(b1)
False
>>> memoryview(a0) == memoryview(c0)
False

但要注意这样微妙的问题:

>>> zp = np.array([0.0])
>>> zm = -1*zp
>>> zp
array([ 0.])
>>> zm
array([-0.])
>>> zp == zm
array([ True], dtype=bool)
>>> memoryview(zp) == memoryview(zm)
False

这种情况发生是因为二进制表示不同,尽管它们比较相等(当然必须如此:这就是它知道打印负号的方式)。

>>> memoryview(zp)[0]
'\x00\x00\x00\x00\x00\x00\x00\x00'
>>> memoryview(zm)[0]
'\x00\x00\x00\x00\x00\x00\x00\x80'

好消息是,它会按照你所希望的方式进行短路:

In [47]: a0 = np.arange(10**7)*1.0
In [48]: a0[-1] = np.NAN    
In [49]: b0 = np.arange(10**7)*1.0    
In [50]: b0[-1] = np.NAN     
In [51]: timeit memoryview(a0) == memoryview(b0)
10 loops, best of 3: 31.7 ms per loop
In [52]: c0 = np.arange(10**7)*1.0    
In [53]: c0[0] = np.NAN   
In [54]: d0 = np.arange(10**7)*1.0    
In [55]: d0[0] = 0.0    
In [56]: timeit memoryview(c0) == memoryview(d0)
100000 loops, best of 3: 2.51 us per loop

作为比较:

In [57]: timeit np.all((a0 == b0) | (np.isnan(a0) & np.isnan(b0)))
1 loops, best of 3: 296 ms per loop
In [58]: timeit np.all((c0 == d0) | (np.isnan(c0) & np.isnan(d0)))
1 loops, best of 3: 284 ms per loop

这太棒了,感谢您抽出时间写下它。 - NPE
@aix:实际上我以前也需要过类似的东西(考虑nan相等),尽管性能和内存不是问题,所以我手动实现了它。也许值得提出一个功能请求。 - DSM

0

不确定这是否更好,但是有一个想法...

import numpy
class FloatOrNaN(numpy.float_):
    def __eq__(self, other):
        return (numpy.isnan(self) and numpy.isnan(other)) or super(FloatOrNaN,self).__eq__(other)

a = [1., np.nan, 2.]
one = numpy.array([FloatOrNaN(val) for val in a], dtype=object)
two = numpy.array([FloatOrNaN(val) for val in a], dtype=object)
print one == two   # yields  array([ True,  True,  True], dtype=bool)

这将丑陋的部分推到了dtype中,代价是内部工作变成了Python而不是C(Cython / etc可以解决此问题)。但是,它确实大大降低了内存成本。

尽管如此,仍然有点丑 :(


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接