Python,Pairwise 'distance',需要快速完成,有什么好办法吗?

7
作为我的博士论文的一个副项目,我参与了使用Python对某些系统进行建模的任务。就效率而言,我的程序在以下问题上遇到了瓶颈,我将在最小工作示例中公开这个问题。我处理大量由它们的三维起点和终点编码的片段,因此每个片段由6个标量表示。我需要计算成对的最小片段距离。两个片段之间最小距离的解析表达式可以在source中找到。关于MWE:
import numpy as np
N_segments = 1000
List_of_segments = np.random.rand(N_segments, 6)

Pairwise_minimal_distance_matrix = np.zeros( (N_segments,N_segments) )
for i in range(N_segments):
    for j in range(i+1,N_segments): 

        p0 = List_of_segments[i,0:3] #beginning point of segment i
        p1 = List_of_segments[i,3:6] #end point of segment i
        q0 = List_of_segments[j,0:3] #beginning point of segment j
        q1 = List_of_segments[j,3:6] #end point of segment j
        #for readability, some definitions
        a = np.dot( p1-p0, p1-p0)
        b = np.dot( p1-p0, q1-q0)
        c = np.dot( q1-q0, q1-q0)
        d = np.dot( p1-p0, p0-q0)
        e = np.dot( q1-q0, p0-q0)
        s = (b*e-c*d)/(a*c-b*b)
        t = (a*e-b*d)/(a*c-b*b)
        #the minimal distance between segment i and j
        Pairwise_minimal_distance_matrix[i,j] = sqrt(sum( (p0+(p1-p0)*s-(q0+(q1-q0)*t))**2)) #minimal distance

现在,我意识到这非常低效,这就是我在这里的原因。我广泛地研究了如何避免循环,但我遇到了一个问题。显然,这种计算最好使用python中的cdist。然而,它可以处理的自定义距离函数必须是二元函数。在我的情况下,这是一个问题,因为我的向量具有特定的长度为6,并且必须被拆分为它们的前三个和后三个组件。我不认为我可以将距离计算转换为二进制函数。

欢迎任何意见。


一个八叉树(在维基百科上提到最近邻搜索的常见用途)能帮助这里吗? - Antti Haapala -- Слава Україні
1
你读过Lumelsky的关于线段之间快速计算距离的论文吗?你的实现与它相比如何?(更通用的方法可以在这里找到:http://graphics.stanford.edu/courses/cs164-09-spring/Handouts/paper_GJKoriginal.pdf) - Michael Foukarakis
谢谢提供链接,我会查看它们。 - Mathusalem
@Michael Foukarakis,我快速阅读了这篇论文,这是我自己通过分析得出的结论。它并没有具体说明如何加速计算。它只是概述了一种处理特殊情况的聪明方法。不过还是很值得一读的。 - Mathusalem
3个回答

5
你可以使用numpy的向量化功能来加速计算。我的版本一次计算距离矩阵的所有元素,然后将对角线和下三角设置为零。
def pairwise_distance2(s):
    # we need this because we're gonna divide by zero
    old_settings = np.seterr(all="ignore")

    N = N_segments # just shorter, could also use len(s)

    # we repeat p0 and p1 along all columns
    p0 = np.repeat(s[:,0:3].reshape((N, 1, 3)), N, axis=1)
    p1 = np.repeat(s[:,3:6].reshape((N, 1, 3)), N, axis=1)
    # and q0, q1 along all rows
    q0 = np.repeat(s[:,0:3].reshape((1, N, 3)), N, axis=0)
    q1 = np.repeat(s[:,3:6].reshape((1, N, 3)), N, axis=0)

    # element-wise dot product over the last dimension,
    # while keeping the number of dimensions at 3
    # (so we can use them together with the p* and q*)
    a = np.sum((p1 - p0) * (p1 - p0), axis=-1).reshape((N, N, 1))
    b = np.sum((p1 - p0) * (q1 - q0), axis=-1).reshape((N, N, 1))
    c = np.sum((q1 - q0) * (q1 - q0), axis=-1).reshape((N, N, 1))
    d = np.sum((p1 - p0) * (p0 - q0), axis=-1).reshape((N, N, 1))
    e = np.sum((q1 - q0) * (p0 - q0), axis=-1).reshape((N, N, 1))

    # same as above
    s = (b*e-c*d)/(a*c-b*b)
    t = (a*e-b*d)/(a*c-b*b)

    # almost same as above
    pairwise = np.sqrt(np.sum( (p0 + (p1 - p0) * s - ( q0 + (q1 - q0) * t))**2, axis=-1))

    # turn the error reporting back on
    np.seterr(**old_settings)

    # set everything at or below the diagonal to 0
    pairwise[np.tril_indices(N)] = 0.0

    return pairwise

现在让我们试试它。使用您的示例,N = 1000,我得到了以下时间:

%timeit pairwise_distance(List_of_segments)
1 loops, best of 3: 10.5 s per loop

%timeit pairwise_distance2(List_of_segments)
1 loops, best of 3: 398 ms per loop

当然,结果也是一样的:
(pairwise_distance2(List_of_segments) == pairwise_distance(List_of_segments)).all()

返回 True。我还相当确定算法中隐藏着一个矩阵乘法,因此应该有进一步加速(以及清理)的潜力。

顺便说一句:我已经尝试过先使用numba,但没有成功。不过原因我也不确定。


非常感谢。当我读到你的第一句话时,我想“当然”,但看到实现方式后,我认为你为我节省了几个小时的工作:为此我感谢你!据我所知,Numba 是基于 GPU 计算的。在您的经验中,对于类似的向量大小(N_segments ~= 1E3-E4),将数据推送到显卡上的开销是否被加速所抵消,或者在实施之前无法确定? - Mathusalem
我很快就遇到了内存问题(N_segments ~= 10k)。我会调查一下是什么导致内存占用过高,并回复您(我可以在Matlab上运行多达30k)。 - Mathusalem
@Mathusalem 一个由双精度浮点数组成的 N*N 矩阵,其中 N=10000,大约需要占用0.75GB的内存。这个函数会创建许多这样的矩阵。我的方法是(因为我还没有阅读足够的关于该问题的资料来设计更好的算法),将问题分成可管理大小的块,并逐个计算它们(如果你愿意,甚至可以并行计算)。 - Carsten
我也尝试从程序中提取循环部分并进行JIT编译,但是我没有获得任何速度提升。你有同样的问题吗?我怀疑... - Mathusalem
针对您的numba评论,我已经深入研究了一下。我认为加速不足的原因在于numba对numpy的支持还不够充分。特别是,numba在处理许多numpy函数(如numpy.dot、numpy.sqrt)时存在困难,因为它不知道这些函数的返回类型,从而无法进行优化。链接:https://github.com/numba/numba/issues/251 - Mathusalem

2
这更像是一个元回答,至少起步阶段是这样。你的问题可能已经在“我的程序遇到瓶颈”和“我意识到这极其低效”的地方了。
极其低效?按什么标准来衡量?你有比较吗?你的代码运行时间太慢了吗?对于你来说,什么是合理的运行时间?你能否投入更多的计算能力来解决问题?同样重要的是——你是否使用适当的基础设施来运行你的代码(numpy/scipy编译器与供应商编译器一起编译,可能带有OpenMP支持)?
然后,如果你已经回答了上述所有问题并需要进一步优化你的代码——你当前的代码瓶颈在哪里确切地?你进行过分析吗?循环体本身可能比循环的评估更加耗费资源吗?如果是这样的话,“循环”就不是你的瓶颈,你也不需要担心首先嵌套的循环。首先优化循环体,可能通过想出数据的非正统矩阵表示,以便你可以通过矩阵乘法等方式一次性执行所有这些单一计算。如果你的问题不能通过高效的线性代数运算来解决,你可以开始编写C扩展或使用Cython或使用PyPy(最近刚刚获得了一些基本的numpy支持!)。有无数的优化可能性——真正的问题是:你距离实际解决方案有多近,你需要多少优化,你愿意投入多少努力。

免责声明:我在我的博士学位中也使用了scipy/numpy进行非规范的成对距离处理。对于一种特定的距离度量,我最终在简单的Python中编写了“成对”部分(即我也使用了双重嵌套循环),但是花费了一些精力使主体尽可能高效(通过 i)我的问题的加密矩阵乘法表示和 ii)使用bottleneck的组合)。


0

您可以像下面这样使用它:

def distance3d (p, q):
    if (p == q).all ():
        return 0

    p0 = p[0:3]
    p1 = p[3:6]
    q0 = q[0:3]
    q1 = q[3:6]

    ...  # Distance computation using the formula above.

print (distance.cdist (List_of_segments, List_of_segments, distance3d))

看起来并没有更快,因为它在内部执行相同的循环。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接