使用条件语句加速Python嵌套循环

3

我正在将MATLAB的代码转换为Python,以加快简单操作的速度。我编写了一个包含嵌套循环和条件语句的函数;循环的目的是返回与数组y相比,在数组x中最近元素的索引列表。我正在比较1e5个项目的顺序,需要大约30秒才能运行。非常感谢任何帮助来加速这个过程!我使用numba-pro自动即时编译器取得了部分成功:

@autojit()
def find_nearest(x,y,idx):
    idx_old = 0
    rng1 = range(y.shape[0])
    rng2 = range(x.shape[0])
    for i in rng1:
        prev = abs(x[idx_old]-y[i])
        for j in rng2:
            if abs(x[j]-y[i]) < prev:
                prev = abs(x[j]-y[i])
                idx_old = j
        idx[i] = idx_old
    return idx

非常抱歉,我是一个初学者,对Python一窍不通!


1
你能否更新你的脚本,包括find_nearest的示例数据输入,以便更清晰明了? - JoshAdel
例如x = np.array([1.1,2.3,5.9,8.5]), y = np.array([0.2, 5.5, 12])和idx = np.zeros(np.shape(y))应该返回idx = [0,2,3](y中项最接近x中项的索引)。在MATLAB中,我使用knnsearch来执行此操作,在我的大型数据集上,它需要约2.5秒才能解决;而我的实现需要约30秒。输入数组不需要按任何特定的排序顺序排列。感谢您的关注! - Chris Church
我尝试使用sci-kitlearn的k最近邻算法实现,但它返回的是在数据集上训练过的函数;而在我的完整数据集上进行训练是不可行的。我的搜索域包含1023848个项目,我正在尝试找到其中12325个最接近的项目。 - Chris Church
2个回答

4

您的Numba代码没有问题,除了算法不够高效。更好的方法是对x数组进行排序,并进行二分查找,与这个答案这个答案非常相似:

def find_nearest(x, y):
    indices = np.argsort(x)

    loc = np.searchsorted(x[indices], y)
    right = indices.take(loc, mode='clip')
    left = indices.take(loc-1, mode='clip')

    return np.where(abs(y-x[left]) < abs(y-x[right]), left, right)

在我的电脑上,这个方法比使用KDTree来处理x和y各有100万和10万个元素的情况快80倍左右。大约三分之二的时间用于对数组进行argsort排序,因此我认为在这里使用Numba并不能获得太多优势。


非常感谢您的回复,我已经成功实现了您的代码。在我的数据集上运行速度快了大约270倍(在大约1百万个项目的集合中找到了约1e4个索引)。您能否推荐一些教程/文章/网页,让我学习如何在Python中实现更高效的代码?我已经查看了您建议的类似答案。索引操作总是这么快吗? - Chris Church
@Chris,我主要是通过在这个网站上关注[Numpy]标签并进行试验/尝试来学习更高效的Python编程。您能否解释一下您所说的“索引操作总是更快”的含义? - user2379410
通过索引操作,我指的是返回布尔/索引数组以进行后续操作的函数。最初,我尝试将我的“find_nearest”函数转换为矩阵和向量操作的单行代码;然而,当我尝试在大规模数据集上运行它时,遇到了“MemoryError”的问题。 - Chris Church
@Chris 哦,我明白了。这与索引操作与矩阵/向量操作无关。您的原始算法复杂度为O(m*n),其中mnxy的长度。您可能还创建了一个大小为m*n的数组(非常大)。使用我的代码,复杂度为O((m+n)log(m)),附加内存使用量为O(m+n)(我认为..不确定排序)。 - user2379410

1

我已经找到了一个临时解决方案来解决我的问题。通过实现scipy.spatial的kdtree,我能够将运行时间从32秒缩短到不到10秒。这仍然比MATLAB knnsearch算法慢四倍;了解如何加速带有条件语句的循环仍然很重要。但目前来说,这个修订后的实现更快:

from scipy import spatial
from numpy import matrix

tree = spatial.KDTree(matrix(x).T)
(_, idxx) = tree.query(matrix(y).T)

数组x和y以1d格式存在;树需要查询以列向量形式进行。如有改进原实现的建议,将不胜感激!

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接