Numpy第一个出现大于已有值的数值

224

我在numpy中有一个一维数组,想要找到某个值超过numpy数组中的值的索引位置。

例如:

aa = range(-10,10)

查找在 aa 中,第一次出现大于 5 的数的位置。


3
应清楚地了解是否可能不存在解决方案(例如,在这种情况下,argmax答案将无法起作用(即(0,0,0,0)的最大值= 0),正如ambrus所指出的那样)。 - seanv507
1
我同意这一点,并且我已经包括了一个答案(即使有一个被接受的答案,我认为仍然存在歧义)。我认为代码的正确性比仅仅性能更重要。 - Eduardo Gomes
8个回答

291

这会快一点(而且看起来更好)

np.argmax(aa>5)

argmax 会在第一个True处停止("In case of multiple occurrences of the maximum values, the indices corresponding to the first occurrence are returned."),并且不会保存另一个列表。

In [2]: N = 10000

In [3]: aa = np.arange(-N,N)

In [4]: timeit np.argmax(aa>N/2)
100000 loops, best of 3: 52.3 us per loop

In [5]: timeit np.where(aa>N/2)[0][0]
10000 loops, best of 3: 141 us per loop

In [6]: timeit np.nonzero(aa>N/2)[0][0]
10000 loops, best of 3: 142 us per loop

169
请注意:如果输入数组中没有True值,np.argmax会返回0(这不是你想要的结果)。 - ambrus
20
结果是正确的,但我觉得解释有点可疑。argmax 似乎不会在第一个 True 处停止。(可以通过创建只有一个 True 的布尔数组来测试这一点。)速度可能是因为 argmax 不需要创建输出列表所解释的。 - DrV
2
我认为你是对的,@DrV。我的解释是关于为什么它能够给出正确的结果,尽管最初的意图实际上并不是寻求最大值,而不是为什么它更快,因为我不能声称理解argmax的内部细节。 - askewchan
8
@DrV,我刚刚在使用NumPy 1.11.2时对一百万个元素的布尔数组运行了argmax函数,这些数组在不同位置只有一个True,但是True所在的位置会影响结果。因此,似乎1.11.2版本的argmax在处理布尔数组时存在"短路"现象。 - Ulrich Stern
6
我重复了 @UlrichStern 的实验,使用一个有 2^30 个元素的数组(首先用1填充数组,然后再用0填充,最后添加一个真值以消除空页面技巧、页面错误噪声等)。当仅有一个真值在数组开头时,np.argmax 的速度比在数组末尾时快了 1e5 倍。这是在使用 numpy 1.16.5 版本的情况下得出的。 - Mr Fooz
显示剩余5条评论

125

如果你的数组已经排序,那么有一种更快的方法:searchsorted

import time
N = 10000
aa = np.arange(-N,N)
%timeit np.searchsorted(aa, N/2)+1
%timeit np.argmax(aa>N/2)
%timeit np.where(aa>N/2)[0][0]
%timeit np.nonzero(aa>N/2)[0][0]

# Output
100000 loops, best of 3: 5.97 µs per loop
10000 loops, best of 3: 46.3 µs per loop
10000 loops, best of 3: 154 µs per loop
10000 loops, best of 3: 154 µs per loop

30
假设数组已经排序(实际上问题并没有明确说明),这确实是最佳答案。您可以使用np.searchsorted(..., side='right')避免笨拙的+1 - askewchan
3
如果排序数组中存在重复值,我认为side参数才会起到作用。它不会改变返回的索引的含义,该索引始终是插入查询值并将所有后续条目向右移动以保持有序数组的索引位置。请记住,不要改变原文意思,并尽可能使翻译通俗易懂。 - Gus
2
@Gus,“side”在排序和插入数组中都有相同的值时会产生影响,无论其中任何一个重复值。排序数组中的重复值只是夸大了这种影响(两侧之间的差异是要插入的值在排序数组中出现的次数)。 “side”确实改变了返回索引的含义,但它不会改变将值插入到这些索引处的排序数组的结果。这是微妙但重要的区别;事实上,如果“N/2”不在“aa”中,则此答案会给出错误的索引。 - askewchan
1
正如上面的评论所暗示的那样,如果N/2不在aa中,则此答案会偏差1。正确的形式应该是np.searchsorted(aa, N/2, side='right')(没有+1)。否则,两种形式都会给出相同的索引。考虑N为奇数的测试用例(如果使用python 2,请将N/2.0强制转换为浮点数)。 - askewchan

40

我对此也很感兴趣,我已经将所有建议的答案与 perfplot 进行了比较。(免责声明:我是 perfplot 的作者。)

如果您知道您正在查找的数组 已经排序,那么

numpy.searchsorted(a, alpha)

这对你来说非常有用。它是O(log(n))操作,也就是说,速度几乎不取决于数组的大小。你无法比这更快。

如果您不知道有关数组的任何信息,则不会出错

numpy.argmax(a > alpha)

已排序:

这里输入图片描述

未排序:

这里输入图片描述

生成图表的代码:

import numpy
import perfplot


alpha = 0.5
numpy.random.seed(0)


def argmax(data):
    return numpy.argmax(data > alpha)


def where(data):
    return numpy.where(data > alpha)[0][0]


def nonzero(data):
    return numpy.nonzero(data > alpha)[0][0]


def searchsorted(data):
    return numpy.searchsorted(data, alpha)


perfplot.save(
    "out.png",
    # setup=numpy.random.rand,
    setup=lambda n: numpy.sort(numpy.random.rand(n)),
    kernels=[argmax, where, nonzero, searchsorted],
    n_range=[2 ** k for k in range(2, 23)],
    xlabel="len(array)",
)

4
np.searchsorted不是常数时间复杂度,实际上是O(log(n))。但你的测试用例实际上对searchsorted的最佳情况进行了基准测试(这是O(1))。 - MSeifert
@MSeifert 你需要什么样的输入数组/alpha才能看到O(log(n))? - Nico Schlömer
1
获取索引为sqrt(length)的项会导致非常糟糕的性能。我还在这里写了一个答案(https://dev59.com/pmQo5IYBdhLWcg3wMs-2#49927020),其中包括基准测试。 - MSeifert
我怀疑searchsorted(或任何算法)能否击败二分查找对于已排序的均匀分布数据的O(log(n))。编辑:searchsorted 一个二分查找。 - Mateen Ulhaq
如果你知道均匀分布,你可以用O(1)击败二分查找。如果我有0-1000之间的单调数字,并且你想找到值748,你可以去第784个位置。这是一个排序的均匀分布数据集,有一种算法可以击败它。 - Tatarize

19
In [34]: a=np.arange(-10,10)

In [35]: a
Out[35]:
array([-10,  -9,  -8,  -7,  -6,  -5,  -4,  -3,  -2,  -1,   0,   1,   2,
         3,   4,   5,   6,   7,   8,   9])

In [36]: np.where(a>5)
Out[36]: (array([16, 17, 18, 19]),)

In [37]: np.where(a>5)[0][0]
Out[37]: 16

17

元素之间有恒定步长的数组

对于一个range或任何其他线性递增的数组,您可以通过程序计算索引,而无需实际迭代整个数组:

def first_index_calculate_range_like(val, arr):
    if len(arr) == 0:
        raise ValueError('no value greater than {}'.format(val))
    elif len(arr) == 1:
        if arr[0] > val:
            return 0
        else:
            raise ValueError('no value greater than {}'.format(val))

    first_value = arr[0]
    step = arr[1] - first_value
    # For linearly decreasing arrays or constant arrays we only need to check
    # the first element, because if that does not satisfy the condition
    # no other element will.
    if step <= 0:
        if first_value > val:
            return 0
        else:
            raise ValueError('no value greater than {}'.format(val))

    calculated_position = (val - first_value) / step

    if calculated_position < 0:
        return 0
    elif calculated_position > len(arr) - 1:
        raise ValueError('no value greater than {}'.format(val))

    return int(calculated_position) + 1

这个可能还有改进的余地。我已经确保它对一些示例数组和值进行了正确的处理,但这并不意味着里面没有错误,特别是考虑到它使用了浮点数...

>>> import numpy as np
>>> first_index_calculate_range_like(5, np.arange(-10, 10))
16
>>> np.arange(-10, 10)[16]  # double check
6

>>> first_index_calculate_range_like(4.8, np.arange(-10, 10))
15

考虑到它可以在不进行任何迭代的情况下计算位置,因此时间复杂度为常数时间 (O(1)),可能会超过所有其他提到的方法。但是,它需要数组中的一个恒定步长,否则将产生错误结果。

使用numba的通用解决方案

一种更通用的方法是使用numba函数:

@nb.njit
def first_index_numba(val, arr):
    for idx in range(len(arr)):
        if arr[idx] > val:
            return idx
    return -1

那种方法适用于任何数组,但它必须遍历整个数组,因此在平均情况下它的时间复杂度是 O(n)

>>> first_index_numba(4.8, np.arange(-10, 10))
15
>>> first_index_numba(5, np.arange(-10, 10))
16

基准测试

尽管Nico Schlömer已经提供了一些基准测试,但我认为包括我的新解决方案并测试不同的“值”可能会很有用。

测试设置:

import numpy as np
import math
import numba as nb

def first_index_using_argmax(val, arr):
    return np.argmax(arr > val)

def first_index_using_where(val, arr):
    return np.where(arr > val)[0][0]

def first_index_using_nonzero(val, arr):
    return np.nonzero(arr > val)[0][0]

def first_index_using_searchsorted(val, arr):
    return np.searchsorted(arr, val) + 1

def first_index_using_min(val, arr):
    return np.min(np.where(arr > val))

def first_index_calculate_range_like(val, arr):
    if len(arr) == 0:
        raise ValueError('empty array')
    elif len(arr) == 1:
        if arr[0] > val:
            return 0
        else:
            raise ValueError('no value greater than {}'.format(val))

    first_value = arr[0]
    step = arr[1] - first_value
    if step <= 0:
        if first_value > val:
            return 0
        else:
            raise ValueError('no value greater than {}'.format(val))

    calculated_position = (val - first_value) / step

    if calculated_position < 0:
        return 0
    elif calculated_position > len(arr) - 1:
        raise ValueError('no value greater than {}'.format(val))

    return int(calculated_position) + 1

@nb.njit
def first_index_numba(val, arr):
    for idx in range(len(arr)):
        if arr[idx] > val:
            return idx
    return -1

funcs = [
    first_index_using_argmax, 
    first_index_using_min, 
    first_index_using_nonzero,
    first_index_calculate_range_like, 
    first_index_numba, 
    first_index_using_searchsorted, 
    first_index_using_where
]

from simple_benchmark import benchmark, MultiArgument

并且使用以下方式生成图表:

%matplotlib notebook
b.plot()

项目位于开头

b = benchmark(
    funcs,
    {2**i: MultiArgument([0, np.arange(2**i)]) for i in range(2, 20)},
    argument_name="array size")

enter image description here

在性能方面,numba函数表现最佳,其次是calculate函数和searchsorted函数。其他解决方案的表现要差得多。

item is at the end

b = benchmark(
    funcs,
    {2**i: MultiArgument([2**i-2, np.arange(2**i)]) for i in range(2, 20)},
    argument_name="array size")

enter image description here

对于小数组,numba函数的性能非常出色,但对于更大的数组,它的性能被calculate函数和searchsorted函数超越。

项在sqrt(len)处

b = benchmark(
    funcs,
    {2**i: MultiArgument([np.sqrt(2**i), np.arange(2**i)]) for i in range(2, 20)},
    argument_name="array size")

enter image description here

这更有趣。再次,numba和calculate函数表现出色,但是在这种情况下,实际上会触发searchsorted的最坏情况,这种情况下不起作用。

当没有值满足条件时,功能的比较

另一个有趣的点是,如果没有任何一个值的索引应该被返回,这些函数将如何行动:

arr = np.ones(100)
value = 2

for func in funcs:
    print(func.__name__)
    try:
        print('-->', func(value, arr))
    except Exception as e:
        print('-->', e)

有了这个结果:

first_index_using_argmax
--> 0
first_index_using_min
--> zero-size array to reduction operation minimum which has no identity
first_index_using_nonzero
--> index 0 is out of bounds for axis 0 with size 0
first_index_calculate_range_like
--> no value greater than 2
first_index_numba
--> -1
first_index_using_searchsorted
--> 101
first_index_using_where
--> index 0 is out of bounds for axis 0 with size 0

Searchsorted、argmax和numba返回的值是错误的。但是searchsortednumba返回的索引不是数组的有效索引。

函数whereminnonzerocalculate会抛出异常。但是只有calculate的异常实际上提供了有用的信息。

这意味着如果你不确定值是否在数组中,你需要将这些调用包装在适当的包装函数中,以捕获异常或无效的返回值,并进行适当的处理。


注意:calculate和searchsorted选项仅在特定条件下起作用。 "calculate"函数需要具有恒定步长,而searchsorted需要数组已排序。因此,在正确的情况下这些可能是有用的,但不是此问题的通用解决方案。如果您正在处理已排序的Python列表,则可以考虑使用bisect模块而不是使用NumPy的searchsorted。

6

我想提议

np.min(np.append(np.where(aa>5)[0],np.inf))

这将返回满足条件的最小索引,如果条件从未满足,则返回无穷大(where 返回一个空数组)。


4
你应该使用np.where而不是np.argmax。后者即使没有找到值,也会返回位置0,这不是你期望的索引。
>>> aa = np.array(range(-10,10))
>>> print(aa)
array([-10,  -9,  -8,  -7,  -6,  -5,  -4,  -3,  -2,  -1,   0,   1,   2,
         3,   4,   5,   6,   7,   8,   9])

如果条件成立,则返回索引的数组。
>>> idx = np.where(aa > 5)[0]
>>> print(idx)
array([16, 17, 18, 19], dtype=int64)

否则,如果不满足条件,则返回一个空数组。
>>> not_found = len(np.where(aa > 20)[0])
>>> print(not_found)
array([], dtype=int64)

这种情况下反对使用 argmax 的观点是:如果解决方案不含糊,那么越简单越好。因此,要检查某些内容是否符合条件,只需执行 if len(np.where(aa > value_to_search)[0]) > 0

1

我会选择

i = np.min(np.where(V >= x))

其中V是向量(1维数组),x是值,i是结果索引。


这个解决方案比 np.where(capacity < demand)[0][0] 慢。除了可读性更好的 np.min 之外,没有使用它的理由。 - Muhammad Yasirroni

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接