在数组B中找出与数组A最匹配的元素索引

5
我有两个数组 AB。现在暂时把它们都当作一维数组。
对于 A 中的每个元素,我都需要找到在 B 中最匹配该元素的元素索引。
我可以使用列表表达式来解决这个问题。
import numpy as np

A = np.array([ 1, 3, 1, 5 ])
B = np.array([ 1.1, 2.1, 3.1, 4.1, 5.1, 6.1 ])

indices = np.array([ np.argmin(np.abs(B-a)) for a in A ])

print(indices)    # prints [0 2 0 4]
print(B[indices]) # prints [1.1 3.1 1.1 5.1]

但是对于大型数组,这种方法非常缓慢。
我在想是否有一种更快的方法,可以利用优化后的NumPy函数。

3个回答

3
你可以计算重塑后的A和B之间的绝对差,并在axis=1上使用argmin
np.argmin(np.abs(B-A[:,None]), axis=1)

输出结果:数组([0, 2, 0, 4])


@Bastian 注意中间数组的大小将是两个输入值的乘积,这取决于输入和资源,可能会占用大量内存 ;) - mozway

2
广播可能会反噬你(tmp数组的创建也将包含在时间内),下面的方法不使用太多tmp内存,因此具有内存效率。参考这里,当广播由于过多的内存使用而变慢时。
这里仅供参考。除此之外,您可以在cython numpy中编写自定义函数。Cython与numba使用不同的优化方式。因此需要实验哪种更好地优化。但对于numba,您可以留在Python中编写类似C的代码。
import numpy as np
import numba as nb

A = np.array([ 1, 3, 1, 5 ], dtype=np.float64)
B = np.array([ 1.1, 2.1, 3.1, 4.1, 5.1, 6.1 ], dtype=np.float64)

# Convert to fast optimized machine code
@nb.njit(
    # Signature of Input
    (nb.float64[:], nb.float64[:]),
    # Optional
    parallel=True
)
def less_mem_ver(A, B):

    arg_mins = np.empty(A.shape, dtype=np.int64)

    # nb.prange is for parallel=True
    # what can be parallelized
    # Usage of for loop because, to prevent creation of tmp arrays due to broadcasting
    # It takes time to allocate tmp array
    # No loss in writing for loop as numba will vectorize this just like numpy
    for i in nb.prange(A.shape[0]):
        min_num = 1e+307
        min_index = -1
        for j in range(B.shape[0]):
            t = np.abs(A[i] - B[j])
            if t < min_num:
                min_index = j
                min_num = t
        arg_mins[i] = min_index
    return arg_mins
less_mem_ver(A, B)


0

除了使用广播的已经给出的答案之外,还有一种使用内部广播的方法。

我想将我的另一个答案单独保留,因为它使用的不是numpy。

np.argmin(np.abs(np.subtract.outer(A, B)), axis=1)

您可以在许多ufunc中调用outer。

参考


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接