NumPy数组比较的快捷评估

9
在numpy中,如果我想比较两个数组,比如说我想测试A中的所有元素是否都小于B中的值,我会使用if (A < B).all():。但实际上这需要分配和评估完整的数组C = A < B,然后在其上调用C.all()。这有点浪费时间。有没有任何方法可以“快捷”比较,即逐个元素地直接评估A < B(不需要分配和计算临时变量C),并在发现第一个无效元素比较时停止并返回False

1
这个“浪费”会明显影响性能吗? - mfitzp
在我的情况下影响不大。但是一般来说,分配额外的1000x1000x1000数组可能会损害您的性能。我只是出于好奇而问,我相信这种快捷评估会非常方便。有一个名为allclose()的函数,我相信它会在找到第一个不符合要求的元素时停止。所以我很惊讶我没有找到allequal(a,b)allless(a,b)allgreater(a,b),我相信它们都会在使用中非常受欢迎,因为if (A<B).all():和类似的模式相当普遍。 - HiFile.app - best file manager
在查看了numpy源代码后,我认为不支持在第一个“True”上进行A<B的快捷方式,而.all()使用布尔AND检查应该可以快捷方式。但这对内存分配没有帮助。 - mfitzp
http://docs.scipy.org/doc/numpy-1.10.1/user/c-info.how-to-extend.html 请继续;-) - Dima Tisnek
我刚刚查看了numpy源代码,发现我们有一个allequal(a,b)函数(部分更正我的先前评论),但它不会快速评估,这很遗憾。 - HiFile.app - best file manager
显示剩余6条评论
2个回答

1
你的数组有多大?我想它们非常大,例如A.shape = (1000000)或更大,性能才会成为问题。你是否考虑使用numpy 视图
与其比较(A < B).all()(A < B).any(),你可以尝试定义一个视图,例如(A[:10] < B[:10]).all()。这是一个可能有效的简单循环:
k = 0
while( (A[k*10: (k+1)*10] < B[k*10: (k+1)*10] ).all() ):
    k += 1

你可以使用100或10**3作为分段大小,而不是10。显然,如果您的分段大小为1,则意味着:

k = 0
while ( A[k] < B[k] ):
    k+= 1

有时,比较整个数组可能会占用大量内存。如果AB的长度为10000,并且我需要比较每一对元素,那么我将耗尽空间。

使用视图来分割大数组是一个非常聪明的想法,所以我会点赞。但这是一种可怕的黑客行为,我认为这应该自然地存在于numpy库中。 - HiFile.app - best file manager

1

普通的Python中,andor会使用快捷求值,但是numpy不会。

(A < B).all()

使用numpy构建块,广播,逐元素比较与<all约简。 <像其他二进制运算一样工作,加,乘,和,或,gt,le等。 all就像其他约简方法一样,anymaxsummean,可以对整个数组或按行或按列进行操作。

可能编写一个将all<组合成一个迭代的函数,但很难获得我刚才描述的通用性。

但是,如果您必须实现迭代解决方案,并进行快速的快捷操作,则建议使用nditer开发该想法,然后使用cython进行编译。

http://docs.scipy.org/doc/numpy/reference/arrays.nditer.html是一个关于如何使用nditer的好教程,它会带你学习如何在cython中使用它。 nditer负责广播和迭代,让您可以集中精力进行比较和任何快捷方式。

这里是一个草图,可以将其转换为cython迭代器:

import numpy as np

a = np.arange(4)[:,None]
b = np.arange(2,5)[None,:]
c = np.array(True)
it = np.nditer([a, b, c], flags=['reduce_ok'],
    op_flags = [['readonly'], ['readonly'],['readwrite']])
for x, y, z in it:
    z[...] = x<y
    if not z:
        print('>',x,y)
        break
    else:
        print(x,y)
print(z)

带有一个样例运行:
1420:~/mypy$ python stack34852272.py 
(array(0), array(2))
(array(0), array(3))
(array(0), array(4))
(array(1), array(2))
(array(1), array(3))
(array(1), array(4))
('>', array(2), array(2))
False

以默认值False开始,并使用不同的break条件,您将获得一个快捷的any。将测试泛化以处理<<=等需要更多的工作。
在Python中使其正常运行,然后尝试在Cython中运行。如果遇到问题,请提出新问题。SO拥有很多Cython用户。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接