在Numpy数组中查找所有接近数字对的最快方法

4

假设我有一个包含10个随机浮点数的Numpy数组:

import numpy as np
np.random.seed(99)
N = 10
arr = np.random.uniform(0., 10., size=(N,))
print(arr)

out[1]: [6.72278559 4.88078399 8.25495174 0.31446388 8.08049963 
         5.6561742 2.97622499 0.46695721 9.90627399 0.06825733]

我想找出所有唯一的数对,这些数对相差不超过容差tol = 1。(即绝对差小于等于1)。具体来说,我想获取所有唯一的索引数对。每个相邻数对的索引应该按顺序排序,并且所有相邻数对都应该按照第一个索引排序。我已经编写了下面可行的代码:

def all_close_pairs(arr, tol=1.):
    res = set()
    for i, x1 in enumerate(arr):
        for j, x2 in enumerate(arr):
            if i == j:
                continue
            if np.isclose(x1, x2, rtol=0., atol=tol):
                res.add(tuple(sorted([i, j])))
    res = np.array(list(res))
    return res[res[:,0].argsort()]

print(all_close_pairs(arr, tol=1.))

out[2]: [[1 5]
         [2 4]
         [3 7]
         [3 9]
         [7 9]]

然而,实际上我有一个包含N = 1000个数字的数组,在我的代码中由于嵌套的for循环而变得极其缓慢。我认为使用Numpy向量化可以更加高效地完成这个任务。有没有人知道在Numpy中最快的做法是什么?

5个回答

6
一种高效的解决方法是首先使用index = np.argsort()对输入值进行排序。然后,您可以使用arr[index]生成已排序的数组,并在快速连续的数组上以准线性时间迭代接近值,如果成对数量较小。如果成对数量很大,则复杂度是二次的,因为会生成二次数量的成对。得到的复杂度是:O(n log n + m)其中n是输入数组的大小,m是产生的成对数。
要找到彼此接近的值,一个高效的方法是使用Numba迭代值。虽然在Numpy中可能是可能的,但由于需要比较的值的数量不同,因此可能不高效。以下是一种实现:
import numba as nb

@nb.njit('int32[:,::1](float64[::1], float64)')
def findCloseValues(arr, tol):
    res = []
    for i in range(arr.size):
        val = arr[i]
        # Iterate over the close numbers (only once)
        for j in range(i+1, arr.size):
            # Sadly neither np.isclose or np.abs are implemented in Numba so far
            if max(val, arr[j]) - min(val, arr[j]) >= tol:
                break
            res.append((i, j))
    if len(res) == 0: # No pairs: we need to help Numpy to know the shape
        return np.empty((0, 2), dtype=np.int32)
    return np.array(res, dtype=np.int32)

最后,索引需要更新以引用未排序数组中的索引而不是排序后的数组。您可以使用index[result]来完成。

这是最终代码:

index = arr.argsort()
result = findCloseValues(arr[index], 1.0)
print(index[result])

这是结果(顺序可能与问题中的不同,但如果需要,您可以对其进行排序):

array([[9, 3],
       [9, 7],
       [3, 7],
       [1, 5],
       [4, 2]])

改进算法的复杂度

如果您需要更快的算法,则可以使用另一种输出格式:您可以为每个输入值提供靠近目标输入值的最小/最大值范围。为了找到这个范围,您可以在排序后的数组中使用二分搜索(参见:np.searchsorted)。得到的算法的运行时间为O(n log n)。但是,由于范围不连续,您无法获取未排序数组中的索引。

基准测试

以下是在我的计算机上对1000个随机项目进行容差为1.0的性能测试结果:

Reference implementation:   ~17000 ms             (x 1)
Angelicos' implementation:    1773 ms           (x ~10)
Rivers' implementation:        122 ms           (x 139)
Rchome's implementation:        20 ms           (x 850)
Chris' implementation:           4.57 ms       (x 3720)
This implementation:             0.67 ms      (x 25373)

1
我不知道你是否能够击败一个 O(N^2) 的算法。最终,有多达 O(N^2) 对,即使你对其进行排序,也无法避免构建输出对。不过像 numpy 和 numba 这样的向量化工具可以提供帮助。 - rchome
确实。好观点。我忘记了在最坏情况下对数对数量的计算是 N*(N+1)/2 ^^"。我编辑了问题以修正这一点,并提供了一个解决方案来降低复杂度(假设输出格式可以适应)。我可能会检查算法的性能 ;)。 - Jérôme Richard

4
有点晚了,但这里有一个全 numpy 的解决方案:
import numpy as np

def close_enough( arr, tol = 1 ): 
    result = np.where( np.triu(np.isclose( arr[ :, None ], arr[ None, : ], rtol = 0.0, atol = tol ), 1)) 
    return np.swapaxes( result, 0, 1 ) 

扩展以解释正在发生的事情

def close_enough( arr, tol = 1 ):
    bool_arr = np.isclose( arr[ :, None ], arr[ None, : ], rtol = 0.0, atol = tol )
    # is_close generates a square array after comparing all elements with all elements.  

    bool_arr = np.triu( bool_arr, 1 ) 
    # Keep the upper right triangle, offset by 1 column. i.e. zero the main diagonal 
    # and all elements below and to the left.

    result = np.where( bool_arr )  # Return the row and column indices for Trues
    return np.swapaxes( result, 0, 1 ) # Return the pairs in rows rather than columns 

当 N = 1000 时,arr 是一个浮点数数组

%timeit close_enough( arr, tol = 1 )                                                                              
14.1 ms ± 28.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [19]: %timeit all_close_pairs( arr, tol = 1 )                                                                           
54.3 ms ± 268 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

(close_enough( arr, tol = 1) == all_close_pairs( arr, tol = 1 )).all()                                            
# True

3
很好,我喜欢这个答案比我的更好。 我不知道np.triu,这正是我试图用pair_coords[pair_coords[:,:,0]<pair_coords[:,:,1]]来完成的。此外,使用新轴和广播技巧来获取所有成对数据是个很巧妙的方法。 - rchome

3
这是一个使用纯numpy操作的解决方案。在我的电脑上看起来非常快,但我不知道我们要寻找什么样的速度。
def all_close_pairs(arr, tol=1.):
    N = arr.shape[0]
    # get indices in the array to consider using meshgrid
    pair_coords = np.array(np.meshgrid(np.arange(N), np.arange(N))).T
    # filter out pairs so we get indices in increasing order
    pair_coords = pair_coords[pair_coords[:, :, 0] < pair_coords[:, :, 1]]
    # compare indices in your array for closeness
    is_close = np.isclose(arr[pair_coords[:, 0]], arr[pair_coords[:, 1]], rtol=0, atol=tol)
    return pair_coords[is_close, :]

1
您可以先使用itertools.combinations创建组合:
def all_close_pairs(arr, tolerance):
    pairs = list(combinations(arr, 2))
    indexes = list(combinations(range(len(arr)), 2))
    all_close_pairs_indexes = [indexes[i] for i,pair in enumerate(pairs) if abs(pair[0] - pair[1]) <=  tolerance]
    return all_close_pairs_indexes

现在,对于N=1000,你只需要比较499500对而不是一百万。
工作原理如下:
- 我们首先通过itertools.combinations创建这些配对。 - 然后,我们创建它们索引的列表。 - 为了加快速度,我们使用列表推导式而不是for循环。 - 在这个推导式中,我们遍历所有的配对,使用enumerate来获取配对的索引,计算配对数字的绝对差,并检查是否小于等于容差。 - 如果绝对差小于等于容差,我们通过索引列表获取配对数字的索引,并将它们添加到我们的最终列表中。

1
问题在于你的代码具有O(n*n)(二次)复杂度。 为了降低复杂度,你可以尝试首先对项目进行排序:
def all_close_pairs(arr, tol=1.):
    res = set()
    arr = sorted(enumerate(arr), key=lambda x: x[1])
    for (idx1, (i, x1)) in enumerate(arr):
        for idx2 in range(idx1-1, -1, -1):
            j, x2 = arr[idx2]
            if not np.isclose(x1, x2, rtol=0., atol=tol):
                break
            indices = sorted([i, j])
            res.add(tuple(indices))
    return np.array(sorted(res))

然而,这只有在您的值范围远大于公差时才有效。

您可以通过使用双指针策略进一步改进,但总体复杂度仍将保持不变。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接