NumPy - 1D数组的最快惰性字典序比较

4

我有两个 NumPy 1D 数组 ab

我如何对它们进行 字典序 比较?也就是说,应该像 Python 比较元组一样比较这些 1D 数组。

主要的问题在于这应该是惰性完成的,也就是说,函数应该在已知结果的最左侧出现时立即返回结果。

此外,我正在寻找 numpy 数组最快的解决方案。可能需要使用其他 numpy 函数来实现矢量化计算。

否则,非惰性简单实现可能如下:

i = np.flatnonzero((a < b) != (a > b))
print('a ' + ('==' if i.size == 0 else '<' if a[i[0]] < b[i[0]] else '>') + ' b')

或者使用简单但速度较慢的懒惰变体,因为它使用纯Python类型:

ta, tb = tuple(a), tuple(b)
print('a ' + ('<' if ta < tb else '==' if ta == tb else '>') + ' b')

另一种解决方案是使用 np.lexsort,但问题是它是否针对仅具有两列(两个1D数组)进行了优化,或者根本不懒惰?此外,问题是lexsort的结果可能不足以具有三种答案可能性</==/>,可能只足以告诉是否<=。另外,lexsort需要一些非懒惰的预处理,如np.stack和反转行顺序。
print('a ' + ('<=' if np.lexsort(np.stack((a, b), 1)[::-1])[0] == 0 else '>') + ' b')

但是它能在numpy中实现懒惰且快速吗?我需要懒惰的行为,因为1D数组可能非常大,但在大多数情况下比较结果非常接近开头。

2个回答

3
在Python中,您可以对压缩的列表进行迭代:
def lazy_compare(a, b):
    for x, y in zip(a, b):
        if x < y:
            return 'a < b'
        if x > y:
            return 'a > b'
    return 'a == b'

e.g.

print(lazy_compare(['a', 'b', 'c', 'd', 'e'], ['a', 'b', 'b', 'd', 'e']))
print(lazy_compare(['a', 'b', 'c', 'd', 'e'], ['a', 'b', 'c', 'd', 'f']))
print(lazy_compare(['a', 'b', 'c', 'd', 'e'], ['a', 'b', 'c', 'd', 'e']))

输出:

a > b
a < b
a == b

自从zip返回一个迭代器仅在使用时生成值,这是一种懒惰的方式,并且只要它找到结果就会返回结果,所以只有当两个列表相等时,才需要遍历整个列表。

评论不是用来进行长时间讨论的;此对话已被转移到聊天室(https://chat.stackoverflow.com/rooms/222543/discussion-on-answer-by-nick-numpy-fastest-lazy-lexicographical-comparing-of-1)。 - Samuel Liew

0

有人可能会猜想使用循环和索引数组可能比zip更快,但事实并非如此。

以这些定义为比较基础。

def lex_leq_zip(a, b):
    for x, y in zip(a, b):
        if x > y:
            return False
    return True

def lex_leq_index(x,y):
    for i in np.arange(x.size):
        if x[i] > y[i]:
            return False
    return True

然后我们扫描不同大小的数组以收集有关更改的数据:

for L in range(1,100000, 1000):
    for rep in range(10):
        x = np.random.random(size=L)
        y = np.random.random(size=L)
        z = timeit('lex_leq_zip(x,y)',
              globals={'lex_leq_zip':lex_leq_zip,
                       'x':x,
                       'y':y},
              number=1)
        i = timeit('lex_leq_index(x,y)',
              globals={'lex_leq_index':lex_leq_index,
                       'x':x,
                       'y':y},
              number=1)
        plt.scatter([L], [z], color='k')
        plt.scatter([L], [i], color='b')
plt.show()

放大结果图,我得到了这个: enter image description here

从上面的代码中回想一下,垂直轴是以秒为单位的时间,水平轴是数组的长度,蓝色因子是基于索引的实现,黑色因子是基于 zip 的实现。虽然我们考虑的是非常小的几分之一秒(在某些情况下可能很宝贵),但很明显 zip 的方法更快。

注:我还尝试在基于索引的实现上使用 Numba 的 @jit(nopython=True) 装饰器,但它显示出了类似的模式。

注:我还尝试在两种实现上都使用 NumPy 的 np.vectorize,但实际上都会导致与尝试索引数字有关的错误。


我使用了$\leq$而不是$<$,但结果应该是相似的。 - Galen

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接