我有一个numpy数组:
arr = [0.23, 2.32, 4.04, 5.02, 6.84, 10.12, 10.34, 11.93,12.44]
我希望能够获取最接近我输入的整数的索引。例如,如果我输入10,则应该返回索引5(10.12),如果我输入12,则应该返回索引7(11.93)。
如果你的列表没有排序,你需要使用abs
+argmin
来得到一个线性时间复杂度的解决方案:
>>> np.abs(np.array(arr) - 12).argmin()
7
然而,如果你的列表已经排序(升序或降序),你可以使用二分查找来实现次线性时间的解决方案(非常快):
# https://ideone.com/aKEpI2 — improved by @user2357112
def binary_search(arr, val):
# val must be in the closed interval between arr[i-1] and arr[i],
# unless one of i-1 or i is beyond the bounds of the array.
i = np.searchsorted(arr, val)
if i == 0:
# Smaller than the smallest element
return i
elif i == len(arr):
# Bigger than the biggest element
return i - 1
elif val - arr[i - 1] <= arr[i] - val:
# At least as close to arr[i - 1] as arr[i]
return i - 1
# Closer to arr[i] than arr[i - 1]
return i
cases = [10, 12, 100, 10.12] # 5, 7, 8, 5
print(*[binary_search(arr, c) for c in cases], sep=',')