使用NumPy查找大于前n个元素的元素

3

我希望能够在numpy数组中识别比索引5处开始的前5个元素大的元素。我已经使用'for'循环编写了解决此问题的解决方案。我的问题是如何在不迭代的情况下解决这类问题?是否有针对此问题的特定numpy函数?

import numpy as np
values = np.array([160, 140, 152, 142, 143, 186, 152, 145, 165, 152, 143, 148, 196, 152, 145, 157, 152])
indices = []
for i in range(5, len(values)):
    if np.all(values[(i-5):i]<values[i]):
        indices.append(i)
3个回答

2

一个技巧是在数组长度上滑动窗口计算最大值,不包括当前元素,并将其与当前元素进行比较。如果当前元素更大,则我们有一个赢家,否则我们没有。

为了获得滑动的最大值,我们可以利用 Scipy's 1D max 过滤器的服务,从而实现如下 -

from scipy.ndimage.filters import maximum_filter1d as maxf

def greater_than_all_prev(values, W=5):
    hW = (W-1)//2
    maxv = maxf(values,W, origin=hW)
    mask = values[1:] > maxv[:-1]
    mask[:W-1] = 0
    return np.flatnonzero(mask)+1

样例运行 -

In [336]: values
Out[336]: 
array([160, 140, 152, 142, 143, 186, 152, 145, 165, 152, 143, 148, 196,
       152, 145, 157, 152])

In [337]: greater_than_all_prev(values, W=5)
Out[337]: array([ 5, 12])

谢谢!你的解决方案是最快的!我应该学习scipy模块和numpy进行数据分析吗? - Elgin Cahangirov
@ElginCahangirov 这应该是一个很好的想法,我认为! - Divakar

1

Erik Rigtorp发布了一个关于使用NumPy进行高效滚动统计的技巧:

A loop in Python are however very slow compared to a loop in C code. Fortunately there is a trick to make NumPy perform this looping internally in C code. This is achieved by adding an extra dimension with the same size as the window and an appropriate stride:

def rolling_window(a, window):
    shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)
    strides = a.strides + (a.strides[-1],)
    return np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)
使用此函数,您可以做到以下事情:
winlen = 5

values = np.array([160, 140, 152, 142, 143, 186, 152, 145, 165, 152, 143, 148, 196, 152, 145, 157, 152])

rolling_values = rolling_window(values, winlen + 1)
rolling_indices = np.arange(winlen, values.shape[0])

mask = np.all(rolling_values[:, [-1]] >  rolling_values[:, :-1], axis=1)
indices = rolling_indices[mask]
print(indices)

说明:

rolling_window将值转换为以下形式的数组:

print(rolling_values)
array([[160, 140, 152, 142, 143, 186],
       [140, 152, 142, 143, 186, 152],
       [152, 142, 143, 186, 152, 145],
       [142, 143, 186, 152, 145, 165],
       [143, 186, 152, 145, 165, 152],
       [186, 152, 145, 165, 152, 143],
       [152, 145, 165, 152, 143, 148],
       [145, 165, 152, 143, 148, 196],
       [165, 152, 143, 148, 196, 152],
       [152, 143, 148, 196, 152, 145],
       [143, 148, 196, 152, 145, 157],
       [148, 196, 152, 145, 157, 152]])

每行包含一个元素(从第六个元素开始)和前面的五个元素。由于步幅技巧,这种表示不需要比原始数组更多的内存。
现在,我们可以比较每行中最后一个元素是否大于前面的元素,并查找相应的索引。

谢谢你的解决方案!我其实知道并使用了 striding,但在这个问题上没想到要用它。 - Elgin Cahangirov

0
一个简单的方法是这样的:
import numpy as np
values = np.array([160, 140, 152, 142, 143, 186, 152, 145, 165, 152, 143, 148, 196, 152, 145, 157, 152])

prod = np.ones_like(values)
for n in range(1,6):
    prod *= values > np.roll(values, n)
print(prod)

如果在prod中的某个索引处找到了1,则在values中的此索引处满足大于前五个元素的条件。您可以使用np.where(prod == 1)找到这些索引。请注意,np.roll会绕过数组的边界。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接