假设我有一个包含10个随机浮点数的Numpy数组:
import numpy as np
np.random.seed(99)
N = 10
arr = np.random.uniform(0., 10., size=(N,))
print(arr)
out[1]: [6.72278559 4.88078399 8.25495174 0.31446388 8.08049963
5.6561742 2.97622499 0.46695721 9.90627399 0.06825733]
我想找出所有唯一的数对,这些数对相差不超过容差tol = 1
。(即绝对差小于等于1)。具体来说,我想获取所有唯一的索引数对。每个相邻数对的索引应该按顺序排序,并且所有相邻数对都应该按照第一个索引排序。我已经编写了下面可行的代码:
def all_close_pairs(arr, tol=1.):
res = set()
for i, x1 in enumerate(arr):
for j, x2 in enumerate(arr):
if i == j:
continue
if np.isclose(x1, x2, rtol=0., atol=tol):
res.add(tuple(sorted([i, j])))
res = np.array(list(res))
return res[res[:,0].argsort()]
print(all_close_pairs(arr, tol=1.))
out[2]: [[1 5]
[2 4]
[3 7]
[3 9]
[7 9]]
然而,实际上我有一个包含N = 1000
个数字的数组,在我的代码中由于嵌套的for循环而变得极其缓慢。我认为使用Numpy向量化可以更加高效地完成这个任务。有没有人知道在Numpy中最快的做法是什么?
O(N^2)
的算法。最终,有多达O(N^2)
对,即使你对其进行排序,也无法避免构建输出对。不过像 numpy 和 numba 这样的向量化工具可以提供帮助。 - rchomeN*(N+1)/2
^^"。我编辑了问题以修正这一点,并提供了一个解决方案来降低复杂度(假设输出格式可以适应)。我可能会检查算法的性能 ;)。 - Jérôme Richard