如何加速距离计算:建议

6
考虑以下类:
class SquareErrorDistance(object):
    def __init__(self, dataSample):
        variance = var(list(dataSample))
        if variance == 0:
            self._norm = 1.0
        else:
            self._norm = 1.0 / (2 * variance)

    def __call__(self, u, v): # u and v are floats
        return (u - v) ** 2 * self._norm

我用它来计算矢量中两个元素之间的距离。基本上,我为使用此距离度量的矢量的每个维度创建该类的一个实例(有使用其他距离度量的维度)。分析显示,该类的__call__函数占据了我的knn实现运行时间的90%(谁会想到)。我不认为有任何纯Python方法可以加速此过程,但如果我在C中实现它呢?
如果我运行一个简单的C程序,只是使用上述公式计算随机值之间的距离,它比Python快几个数量级。因此,我尝试使用ctypes并调用执行计算的C函数,但显然参数和返回值的转换太昂贵了,因为生成的代码要慢得多。
当然,我可以在C中实现整个knn并调用它,但问题是,正如我所描述的,我对一些向量的维度使用不同的距离函数,将这些函数转换为C将是太多的工作。

那么我的选择是什么?使用Python C-API编写C函数是否可以摆脱开销?还有其他方法可以加快这个计算吗?


我建议使用Cython(带有示例实现的答案可能会在几分钟内提供)。我假设你的算法已经尽可能地调整好了? - user395760
@delnan:我已经在可能和适当的情况下使用了缓存,因此我不认为有任何节省距离计算的方法。 - Björn Pollex
有点离题:你知道__call__()表达式返回的计算结果就像这样写的一样吗(u - v) ** (2 * self._norm)?请参考这里的运算符优先级表。 - martineau
@martineau: 不是的。直接从我的解释器 (2.6.5) 运行:3**2*4 得到的是 36,而 3**(2*4) 得到的是 6561。这与你提供的链接所描述的是一致的。 - Björn Pollex
@Space_C0wb0y:我的错误,我现在看到表格是按照从低到高的优先级排序,而不是我习惯看到的相反方式呈现信息。抱歉。 - martineau
显示剩余2条评论
2个回答

2
以下是Cython代码(我意识到__init__的第一行不同,我用随机东西替换了它,因为我不知道var,而且无论如何都没有关系-你说__call__是瓶颈):
cdef class SquareErrorDistance:
    cdef double _norm

    def __init__(self, dataSample):
        variance = round(sum(dataSample)/len(dataSample))
        if variance == 0:
            self._norm = 1.0
        else:
            self._norm = 1.0 / (2 * variance)

    def __call__(self, double u, double v): # u and v are floats
        return (u - v) ** 2 * self._norm

通过一个简单的 setup.py 编译(只是将文件名改变的 文档示例),它在一个简单的虚构 timeit 基准测试中比等效的纯 Python 快了近 20 倍。请注意,唯一的更改是 _norm 字段和 __call__ 参数的 cdef。我认为这相当令人印象深刻。

这太棒了。非常感谢你。我实际上还可以将这个(指的是Cython)应用到许多其他热点领域。你真是给我带来了好心情 :) - Björn Pollex
1
@Space_C0wb0y:非常乐意帮忙 :) 如果您经常使用numpy,请参考http://docs.cython.org/src/tutorial/numpy.html。 - user395760
你也可以将变量声明为double类型。这可能不会有太大的区别,但为什么不呢? - Justin Peel

0

这可能帮助不大,但您可以使用嵌套函数重写它:

def SquareErrorDistance(dataSample):
    variance = var(list(dataSample))
    if variance == 0:
        def f(u, v):
            x = u - v
            return x * x
    else:
        norm = 1.0 / (2 * variance)
        def f(u, v):
            x = u - v
            return x * x * norm
    return f

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接