更快的方法根据阈值分割numpy数组

4
假设我有一个随机的numpy数组:
X = np.arange(1000)

以及一个阈值:

thresh = 50

我希望把X分成两个部分X_lX_r,使得X_l中的每个元素都小于或等于thresh,而X_r中的每个元素都大于thresh。然后将这两个分区传递给递归函数。
使用numpy创建一个布尔数组,并用它来对X进行分区:
Z = X <= thresh
X_l, X_r = X[Z == 0], X[Z == 1]
recursive_call(X_l, X_r)

这个步骤需要重复执行多次,有没有办法让它更快?是否可以避免在每次调用时创建分区的副本?


你的代码将小值(<= thresh)放入 X_l 中,而将大值放入 X_r 中。这是有意为之的吗? - askewchan
2个回答

7

X[~Z]X[Z==0]更快:

In [13]: import numpy as np

In [14]: X = np.random.random_integers(0, 1000, size=1000)

In [15]: thresh = 50

In [18]: Z = X <= thresh

In [19]: %timeit X_l, X_r = X[Z == 0], X[Z == 1]
10000 loops, best of 3: 23.9 us per loop

In [20]: %timeit X_l, X_r = X[~Z], X[Z]
100000 loops, best of 3: 16.4 us per loop

你是否已经进行过分析以确定这确实是你代码中的瓶颈?如果你的代码只有1%的时间用于执行此分割操作,那么无论你如何优化此操作,对整体性能的影响也不会超过1%。

与其优化此一操作,你可能更受益于重新思考你的算法或数据结构。如果这真的是瓶颈,那么你可能通过在C语言Cython中重新编写此 {{code}} 片段来实现更好的效果...

当你使用大小为1000的numpy数组时,使用Python列表/集合/字典可能更快,但必须要看情况。NumPy数组的速度优势有时只有在数组非常大时才会显现。你可能希望用纯Python重写你的代码,并用timeit对两个版本进行基准测试。

嗯,让我重新表述一下。NumPy 数组的大小并不是影响其速度快慢的关键因素。只是如果你创建了很多小的 NumPy 数组,那么拥有小的 NumPy 数组有时会成为这种情况的标志,而创建 NumPy 数组的速度比创建 Python 列表等数据结构要慢得多:
In [21]: %timeit np.array([])
100000 loops, best of 3: 4.31 us per loop

In [22]: %timeit []
10000000 loops, best of 3: 29.5 ns per loop

In [23]: 4310/295.
Out[23]: 14.610169491525424

此外,当你使用纯Python编码时,你可能更倾向于使用字典和集合,而这些在NumPy中没有直接的等价物。这可能会导致你使用一种更快的替代算法。

谢谢。这至少加快了速度。 - blueSurfer
小心!这会将更大的值赋给X_l,较小的值赋给X_r。我认为@bluenot20的意图是相反的,尽管问题本身也是相反的。 - askewchan
那个Cython链接已经过时了。那个技术在Cython 2.0中可能会被删除。你现在应该使用memoryviews了。请参考:memoryviews - Veedrac

1
你的数组是否总是已经排序好了?在你的例子中,你使用了arange,它是已经排序好的,因此不需要进行布尔索引,只需在适当位置将数组切成两半即可。这样可以避免使用“高级索引”,从而无需复制数组。
X = np.arange(0, 2*thresh)
i = X.searchsorted(thresh, side='right') # side='right' for `<=`
X_l, X_r = X[:i], X[i:]

这对于已排序的数组可以节省很多时间,但显然在其他情况下无法使用:
thresh = 500
X = np.arange(2*thresh)

%%timeit
i = X.searchsorted(thresh, side='right')
X_l, X_r = X[:i], X[i:]
100000 loops, best of 3: 5.16 µs per loop

%%timeit
Z = X <= thresh                         
X_l, X_r = X[Z], X[~Z]
100000 loops, best of 3: 12.1 µs per loop

你需要计时的是两种解决方案如何依赖于数组的大小。在切片的情况下,X[:i] 应该是 X 的一个“视图”,因此操作应该花费更多或更少的恒定时间。你只需要创建一个新的数组对象,然后指向原始向量的子集,实际上并没有复制任何数据。 - Bas Swinckels

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接