从数组中选择最小的n个元素的最快方法是什么?

3

我正在愉快地编写一个使用numba编写的快速选择算法,并希望分享结果。

考虑数组x

np.random.seed([3,1415])
x = np.random.permutation(np.arange(10))
x

array([9, 4, 5, 1, 7, 6, 8, 3, 2, 0])

如何快速获取最小的n个元素。

我尝试了
np.partition

np.partition(x, 5)[:5]

array([0, 1, 2, 3, 4])

pd.Series.nsmallest

pd.Series(x).nsmallest(5).values

array([0, 1, 2, 3, 4])
2个回答

4

总的来说,我不建议尝试击败NumPy。对于长数组而言,很少有人能够竞争,更难得的是找到一个更快的实现方式。即使速度更快,也可能最多只比NumPy快2倍。所以这种情况很少值得一试。

然而,最近我自己尝试了类似的事情,因此我可以分享一些有趣的结果。

我并不是自己想出来的。我基于 numba (re-) 实现的np.median 的方法。他们可能知道他们在做什么。

我最终得到的结果是:

import numba as nb
import numpy as np

@nb.njit
def _partition(A, low, high):
    """copied from numba source code"""
    mid = (low + high) >> 1
    if A[mid] < A[low]:
        A[low], A[mid] = A[mid], A[low]
    if A[high] < A[mid]:
        A[high], A[mid] = A[mid], A[high]
        if A[mid] < A[low]:
            A[low], A[mid] = A[mid], A[low]
    pivot = A[mid]

    A[high], A[mid] = A[mid], A[high]

    i = low
    for j in range(low, high):
        if A[j] <= pivot:
            A[i], A[j] = A[j], A[i]
            i += 1

    A[i], A[high] = A[high], A[i]
    return i

@nb.njit
def _select_lowest(arry, k, low, high):
    """copied from numba source code, slightly changed"""
    i = _partition(arry, low, high)
    while i != k:
        if i < k:
            low = i + 1
            i = _partition(arry, low, high)
        else:
            high = i - 1
            i = _partition(arry, low, high)
    return arry[:k]

@nb.njit
def _nlowest_inner(temp_arry, n, idx):
    """copied from numba source code, slightly changed"""
    low = 0
    high = n - 1
    return _select_lowest(temp_arry, idx, low, high)

@nb.njit
def nlowest(a, idx):
    """copied from numba source code, slightly changed"""
    temp_arry = a.flatten()  # does a copy! :)
    n = temp_arry.shape[0]
    return _nlowest_inner(temp_arry, n, idx)

在计时之前,我加入了一些热身调用。热身是为了不将编译时间算入计时中:

rselect(np.random.rand(10), 5)
nlowest(np.random.rand(10), 5)

由于我的计算机速度较慢,我稍微改变了元素数量和重复次数。但结果似乎表明,我(或者说numba开发人员)已经击败了NumPy:

results = pd.DataFrame(
    index=pd.Index([100, 500, 1000, 5000, 10000, 50000, 100000, 500000], name='Size'),
    columns=pd.Index(['nsmall_np', 'nsmall_pd', 'nsmall_pir', 'nlowest'], name='Method')
)

rselect(np.random.rand(10), 5)
nlowest(np.random.rand(10), 5)

for i in results.index:
    x = np.random.rand(i)
    n = i // 2
    for j in results.columns:
        stmt = '{}(x, n)'.format(j)
        setp = 'from __main__ import {}, x, n'.format(j)
        results.set_value(i, j, timeit(stmt, setp, number=100))

print(results)

Method   nsmall_np nsmall_pd  nsmall_pir      nlowest
Size                                                 
100     0.00343059  0.561372  0.00190855  0.000935566
500     0.00428461   1.79398  0.00326862   0.00187225
1000    0.00560669   3.36844  0.00432595   0.00364284
5000     0.0132515  0.305471   0.0142569    0.0108995
10000    0.0255161  0.340215    0.024847    0.0248285
50000     0.105937  0.543337    0.150277     0.118294
100000      0.2452  0.835571    0.333697     0.248473
500000     1.75214   3.50201     2.20235      1.44085

enter image description here


你需要更改多少代码才能让它与 njit 协同工作? - piRSquared
1
_partition函数只是被简单地复制了,_select函数只在最后一行进行了更改(将arry[k]替换为arry[:k])。另外两个函数进行了更多的更改:我更改了函数名称,用新的idx参数替换了mid部分,并删除了处理偶数长度数组中位数的部分。nlowest函数最初是median_impl函数。此外,我将@register_jitable更改为@njit,并且不需要(也不想要)@overload。说实话,这条注释写起来可能比更改numba源代码还要花费更多时间。:D - MSeifert
是的,看了你分享的代码,他们似乎已经是 numba 的高级用户了。谢谢分享 :-) - piRSquared
那就是实际的“numba”源代码。希望他们是熟练的用户 :) - MSeifert
1
/摊手...我以为是numpy。我没有注意链接。 - piRSquared

2
更新
评论区的@user2357112指出我的函数进行了原地操作。结果发现,这就是我获得性能提升的地方。最终,我们使用numbaquickselect进行了简单的实现,性能也非常相似。虽然这个性能还算不错,但并不是我所期望的。

如我在问题中所说,我在尝试使用numba,想要分享一下我发现的东西。

请注意,我已导入njit而不是jit。这是一个装饰器,它自动防止自己回退到本机Python对象。这意味着当它执行速度优化时,它只会使用它可以真正提高速度的东西。这反过来意味着我的函数在我确定允许什么和不允许什么的时候会经常失败。

到目前为止,我认为使用numbajitnjit编写程序非常棘手和困难,但当你看到性能有所提升时,还是值得的。

这是我粗略实现的quickselect函数。

import numpy as np
from numba import njit
import pandas as pd
import numexpr as ne

@njit
def rselect(a, k):
    n = len(a)
    if n <= 1:
        return a
    elif k > n:
        return a
    else:
        p = np.random.randint(n)
        pivot = a[p]
        a[0], a[p] = a[p], a[0]
        i = j = 1
        while j < n:
            if a[j] < pivot:
                a[j], a[i] = a[i], a[j]
                i += 1
            j += 1
        a[i-1], a[0] = a[0], a[i-1]
        if i - 1 <= k <= i:
            return a[:k]
        elif k > i:
            return np.concatenate((a[:i], rselect(a[i:], k - i)))
        else:
            return rselect(a[:i-1], k)

你会注意到,它返回与问题中的方法相同的元素。
rselect(x, 5)

array([2, 1, 0, 3, 4])

关于速度呢?

def nsmall_np(x, n):
    return np.partition(x, n)[:n]

def nsmall_pd(x, n):
    pd.Series(x).nsmallest().values

def nsmall_pir(x, n):
    return rselect(x.copy(), n)


from timeit import timeit


results = pd.DataFrame(
    index=pd.Index([100, 1000, 3000, 6000, 10000, 100000, 1000000], name='Size'),
    columns=pd.Index(['nsmall_np', 'nsmall_pd', 'nsmall_pir'], name='Method')
)

for i in results.index:
    x = np.random.rand(i)
    n = i // 2
    for j in results.columns:
        stmt = '{}(x, n)'.format(j)
        setp = 'from __main__ import {}, x, n'.format(j)
        results.set_value(
            i, j, timeit(stmt, setp, number=1000)
        )

results

Method   nsmall_np  nsmall_pd  nsmall_pir
Size                                     
100       0.003873   0.336693    0.002941
1000      0.007683   1.170193    0.011460
3000      0.016083   0.309765    0.029628
6000      0.050026   0.346420    0.059591
10000     0.106036   0.435710    0.092076
100000    1.064301   2.073206    0.936986
1000000  11.864195  27.447762   12.755983

results.plot(title='Selection Speed', colormap='jet', figsize=(10, 6))

[1]: https://i.stack.imgur.com/hKo2o.png


2
你似乎在改变输入,而 numpy.partition 则会创建一个副本。你有测试过 ndarray.partition 方法的性能吗?它可以原地操作。 - user2357112
@user2357112和PooF,所有性能优势都消失了。谢谢...摆弄教会了我一些东西。 - piRSquared
你可能还想测试一下 numpy.partition 在复制数据时的表现。我认为目前大部分的分区运行都是针对已经分区的数据进行的,这可能会影响性能特征。 - user2357112
@user2357112 谢谢你的建议... 没有明显的区别。 - piRSquared
你知道为什么 pandas 处理小数组的速度比较慢吗?它们有没有针对小数组进行优化(比如排序)但是不起作用呢? - MSeifert
1
我的pandas版本是一个稻草人。它需要先创建一个系列对象,这会增加额外的开销。 - piRSquared

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接