计算欧几里得距离平方的可能优化方法

4

我需要在Python项目中每天进行几亿次欧几里得距离计算。

以下是我最初的代码:

def euclidean_dist_square(x, y):
    diff = np.array(x) - np.array(y)
    return np.dot(diff, diff)

这非常快,而且我已经不再进行平方根计算,因为我只需要对项目进行排名(最近邻搜索)。虽然如此,它仍然是脚本的瓶颈。因此,我编写了一个C扩展程序,用于计算距离。计算始终使用128维向量。

#include "euclidean.h"
#include <math.h>

double euclidean(double x[128], double y[128])
{
    double Sum;
    for(int i=0;i<128;i++)
    {
        Sum = Sum + pow((x[i]-y[i]),2.0);
    }
    return Sum;
}

这个扩展的完整代码在这里:https://gist.github.com/herrbuerger/bd63b73f3c5cf1cd51de

现在与numpy版本相比,这确实可以提供很好的加速。

但是是否有任何方法可以进一步加速(这是我第一次使用C扩展,所以我假设有)?由于该函数每天使用的次数很多,每微秒实际上都会带来好处。

你们中的一些人可能会建议将这个函数从Python完全移植到另一种语言,不幸的是,这是一个更大的项目,不是一个选项 :(

谢谢。

编辑

我已经在CodeReview上发布了这个问题:https://codereview.stackexchange.com/questions/52218/possible-optimizations-for-calculating-squared-euclidean-distance

如果有人已经开始回答,我将在一个小时内删除此问题。


如果您在问题中包含代码,只要代码能够正常运行,这将被视为Code Review的主题。 - syb0rg
要点实际上包含了所需的一切。现在最好的方法是什么?关闭这个并在代码审查中发布一个新条目? - herrherr
4
我建议您删除这个问题,然后在Code Review上重新提问,并将所有相关的代码都包含在问题中,而不是提供一个链接。如果您没有这样做,我保证问题将在那里被关闭。 - syb0rg
可能是更有效的numpy计算距离的方法?的重复问题(对于关闭和重新打开,很抱歉,因为我有金徽章,所以甚至不能再投票了)。 - Fred Foo
也非常类似于这个问题(答案肯定是一样的)。 - Fred Foo
1
@MikeDunlavey,甚至在 -O1 的情况下,GCC也会消除那个 pow 调用。 - Fred Foo
1个回答

12

我知道的在NumPy中计算欧几里得距离最快的方法是scikit-learn中的方法,它可以总结为:

def squared_distances(X, Y):
    """Return a distance matrix for each pair of rows i, j in X, Y."""
    # https://dev59.com/qGIk5IYBdhLWcg3wp_iM#19094808
    X_row_norms = np.einsum('ij,ij->i', X, X)
    Y_row_norms = np.einsum('ij,ij->i', Y, Y)
    distances = np.dot(X, Y)
    distances *= -2
    distances += X_row_norms
    distances += Y_row_norms

    np.maximum(distances, 0, distances)  # get rid of negatives; optional
    return distances

这段代码的瓶颈是矩阵乘法(np.dot),因此请确保您的NumPy链接到一个良好的BLAS实现;通过在多核机器上使用多线程的BLAS和足够大的输入矩阵,它应该比您在C中编写的任何东西更快。请注意,它依赖于二项式公式。

||x - y||² = ||x||² + ||y||² - 2 x⋅y    

并且可以在k-NN使用情况下缓存X_row_normsY_row_norms中的任何一个。

(我是这段代码的共同作者,并花了相当多的时间对其和SciPy实现进行优化;scikit-learn更快,但牺牲了一些准确性,但对于k-NN来说,这不应该太重要。SciPy实现,在scipy.spatial.distance可用,实际上是您刚刚编写的代码的优化版本,而且更准确。)


+1 给“马嘴”。你在代码审查中得不到这个。 - david.pfx
你能提供一个例子吗?特别是X和Y应该长什么样子? - herrherr
@herrherr:X 是一个 n×128 的矩阵,Y 是一个 m×128 的矩阵。结果是一个 n×m 的距离矩阵。 - Fred Foo
我现在正在使用你从这里给出的建议,我必须说它比我之前使用的numpy版本快得多,而且比C版本也要快得多。谢谢! - herrherr
1
这可能是不稳定的:https://github.com/scikit-learn/scikit-learn/issues/9354 - Andreas Mueller

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接