找到一个numpy数组中最接近某个值的所有索引。

3
在numpy数组中需要找到所有最接近给定常量的值的索引。 背景是数字信号处理。数组保存滤波器的幅度函数(np.abs(np.fft.rfft(h))),并搜索幅度为0.5或另一种情况下为0的某些频率(=索引)。 大多数时候,所需的值不完全包含在序列中。应该找到最接近该值的索引。
迄今为止,我想出了以下的方法,其中我查看序列与常量之间差异的符号变化。但是,这仅适用于在问题点处单调递增或递减的序列。有时它也会错过1。
def findvalue(seq, value):
    diffseq = seq - value
    signseq = np.sign(diffseq)
    signseq[signseq == 0] = 1
    return np.where(np.diff(signseq))[0]

我想知道是否有更好的解决方案。这仅适用于1D实数浮点数组,而且在我的情况下计算效率的要求不是很高。

作为一个数值例子,以下代码应返回[8, 41]。这里为了简单起见,我用半波替换了滤波器幅度响应。

f=np.sin(np.linspace(0, np.pi))
findvalue(f, 0.5)

我找到的类似问题只返回第一个或第二个索引,以下是:
查找最接近值的第二个索引
在numpy数组中查找最接近的值
4个回答

1
以下函数将返回一个分数索引,显示大致何时交叉该值:
def FindValueIndex(seq, val):
    r = np.where(np.diff(np.sign(seq - val)) != 0)
    idx = r + (val - seq[r]) / (seq[r + np.ones_like(r)] - seq[r])
    idx = np.append(idx, np.where(seq == val))
    idx = np.sort(idx)
    return idx

逻辑:找到序列-值符号变化的位置。取过渡时一下标和上标的值进行插值计算。将该值等于实际值的索引加入其中。
如果需要整数索引,只需使用np.round。您还可以选择np.floor或np.ceil将索引四舍五入为所需值。
def FindValueIndex(seq, val):
    r = np.where(np.diff(np.sign(seq - val)) != 0)
    idx = r + (val - seq[r]) / (seq[r + np.ones_like(r)] - seq[r])
    idx = np.append(idx, np.where(seq == val))
    idx = np.sort(idx)
    return np.round(idx)

谢谢,插值是我也考虑过的一种改进结果的方法。一个问题:为什么你写 r + np.ones_like(r) 而不是简单地写 r + 1 - Martin Scharrer
这是因为np.where返回一个元组。或者我可能可以这样做r = np.where(np.diff(np.sign(seq - val)) != 0)[0],然后r + 1就可以工作了。 - Aguy
谢谢,根据您的输入,我最终得到了以下方法。我将sort更改为unique以删除重复项,并使四舍五入变成可选项。def argvalue(seq, val, intidx=True): r = np.where(np.diff(np.sign(seq - val)) != 0) idx = r + (val - seq[r]) / (seq[r + np.ones_like(r)] - seq[r]) idx = np.append(idx, np.where(seq == val)) if intidx: idx = np.round(idx).astype(int) idx = np.unique(idx) return idx - Martin Scharrer

1
def findvalue(seq, value):
    diffseq = seq - value
    signseq = np.sign(diffseq)
    zero_crossings = signseq[0:-2] != signseq[1:-1]
    indices = np.where(zero_crossings)[0]
    for i, v in enumerate(indices):
        if abs(seq[v + 1] - value) < abs(seq[v] - value):
            indices[i] = v + 1
    return indices

更多解释

def print_vec(v):
    for i, f in enumerate(v):
        print("[{}]{:.2f} ".format(i,f), end='')
    print('')

def findvalue_loud(seq, value):
    diffseq = seq - value
    signseq = np.sign(diffseq)
    print_vec(signseq)
    zero_crossings = signseq[0:-2] != signseq[1:-1]
    print(zero_crossings)

    indices = np.where(zero_crossings)[0]
    # indices contains the index in the original vector
    # just before the seq crosses the value [8 40]
    # this may be good enough for you
    print(indices)

    for i, v in enumerate(indices):
        if abs(seq[v + 1] - value) < abs(seq[v] - value):
            indices[i] = v + 1
    # now indices contains the closest [8 41]
    print(indices)
    return indices

1
我认为你有两个选择。一个是对形状进行一些假设,并寻找seqval之间差异的零交叉点(就像@ColonelFazackerley他们的答案中所做的那样)。另一个选择是说明你想要考虑何种相对容差范围内的值是足够接近的。
在后一种情况下,你可以使用numpy.isclose
import numpy as np

def findvalue(seq, val, rtol=0.05):    # value that works for your example
    return np.where(np.isclose(seq, val, rtol=rtol))[0]

例子:

x = np.sin(np.linspace(0, np.pi))
print(findvalue(x, 0.5))
# array([ 8, 41])

这种方法的缺点在于它取决于rtol的值。如果将其设置得太大(例如0.1),则会得到多个接近交叉点的值;如果设置得太低,则不会得到任何值。

0

这可能远非最佳方法(我仍在学习numpy),但我希望它能帮助您找到一个。

min_distance = np.abs(your_array - your_constant).min()
# These two tuples contain number closest to your constant from each side.
np.where(bar == val - min_distance)  # Closest, < your_constant
np.where(bar == val + min_distance)  # Closest, > your_constant

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接