获取Numpy数组的索引

3

我有一个numpy数组:

arr = [0.23, 2.32, 4.04, 5.02, 6.84, 10.12, 10.34, 11.93,12.44]

我希望能够获取最接近我输入的整数的索引。例如,如果我输入10,则应该返回索引5(10.12),如果我输入12,则应该返回索引7(11.93)。


1
这个数组是否保证已经排序? - user2357112
1个回答

2

如果你的列表没有排序,你需要使用abs+argmin来得到一个线性时间复杂度的解决方案:

>>> np.abs(np.array(arr) - 12).argmin()
7

然而,如果你的列表已经排序(升序或降序),你可以使用二分查找来实现次线性时间的解决方案(非常快):

# https://ideone.com/aKEpI2 — improved by @user2357112
def binary_search(arr, val):
    # val must be in the closed interval between arr[i-1] and arr[i],
    # unless one of i-1 or i is beyond the bounds of the array.
    i = np.searchsorted(arr, val)

    if i == 0:
        # Smaller than the smallest element
        return i
    elif i == len(arr):
        # Bigger than the biggest element
        return i - 1
    elif val - arr[i - 1] <= arr[i] - val:
        # At least as close to arr[i - 1] as arr[i]
        return i - 1

    # Closer to arr[i] than arr[i - 1]
    return i

cases = [10, 12, 100, 10.12]   # 5, 7, 8, 5
print(*[binary_search(arr, c) for c in cases], sep=',')

1
你今天有点马虎,是吗?searchsorted的解决方案对于10给出了错误的答案(4而不是5)。 - Paul Panzer
@PaulPanzer 不好意思,最近睡眠不足。让我来修复它,我会确保它能正常工作。 - cs95
"我已经相当缺觉了" --- 我猜到了 ;-) - Paul Panzer
@cᴏʟᴅsᴘᴇᴇᴅ:不是很对。您的相邻选择考虑了右侧的邻居和右侧元素,而不是左侧邻居和右侧邻居。此外,len(arr)-1的特殊情况是错误的。 - user2357112
我相信这个编辑过的版本应该可以处理好事情,除了最近邻决定中的舍入误差,但我不确定自己是否有足够的睡眠。 - user2357112
显示剩余5条评论

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接