计算距离平方的最快方法

14

我的代码在计算3D空间中两点之间的距离时非常依赖于计算。 为了避免昂贵的平方根,我使用了平方距离。 但它仍然占据了计算时间的主要部分,我想用更快的函数替换我的简单函数。 现在我的代码如下:

double distance_squared(double *a, double *b)
{
  double dx = a[0] - b[0];
  double dy = a[1] - b[1];
  double dz = a[2] - b[2];

  return dx*dx + dy*dy + dz*dz;
}

我也尝试使用宏来避免函数调用,但它并没有太大帮助。

#define DISTANCE_SQUARED(a, b) ((a)[0]-(b)[0])*((a)[0]-(b)[0]) + ((a)[1]-(b)[1])*((a)[1]-(b)[1]) + ((a)[2]-(b)[2])*((a)[2]-(b)[2])

我考虑使用SIMD指令,但找不到一个好的示例或完整的指令列表(理想情况下是一些矢量乘加)。

由于每次函数调用只知道一组点,因此GPU不是一个选择。

计算距离平方的最快方法是什么?


注:SIMD指单指令多数据流技术(Single Instruction Multiple Data),可以提高程序运行速度。

1
如果你的点是静止不动的,预计算(所有?)距离对可能会有益处。 - schot
3
你需要在其他地方进行优化,这里应该已经是最优的了(正如David所说)。也许你可以提供更广阔的问题视角,也许你不必重新计算所有内容或使用其他技巧。 - duedl0r
你想要计算什么?显然这是某个算法的底部,但你是在尝试找到一组n个点中两个点之间的最短距离还是其他什么。如果你描述一下你实际想要计算的内容,可能会有替代算法而不是蛮力算法。 - Martin York
这些点不是静态的,因此无法进行预计算。我已经在考虑其他算法。这不是暴力代码(kdtree实现)的一部分,但仍然需要经常计算,因此优化它将有所帮助。我只是想知道通常情况下最快的方法是什么,因为我找不到任何关于这个主题的信息。 - Pim Schellart
7个回答

11

一个好的编译器将会进行优化,其效果通常是比你手动优化还要好的。如果好的编译器认为使用SIMD指令会有益,则会使用这样的指令。请确保为您的编译器打开所有此类可能的优化选项。不幸的是,维度为3的向量往往与SIMD单元不兼容。

我怀疑您只能接受编译器生成的代码已经非常接近最优了,并且不会有明显的提升。


4
确切地说,使用gcc和选项“-O3 -march=native”,通常可以生成不错的代码。在优化前,请使用“-S”检查汇编代码以避免浪费大量时间。 - Jens Gustedt

8
第一件明显的事情就是使用restrict关键字。目前来看,ab是可别名的(因此,从编译器的角度来看,它们是别名)。没有任何编译器会自动将其矢量化,因为这样做是错误的。更糟糕的是,不仅编译器不能将这样的循环矢量化,在你存储这些值的情况下(幸运的是在你的例子中没有),还必须每次重新加载值。始终要清楚地了解别名,因为它会极大地影响编译器。
接下来,如果您可以接受这一点,请使用float而不是double并填充到4个浮点数,即使其中一个未使用,这是大多数CPU的更“自然”的数据布局(这在某种程度上是特定于平台的,但对于大多数平台来说,4个浮点数是一个好的选择 - 3个Double,即1.5 SIMD寄存器,在任何地方都不是最优的)。
(对于手写的SIMD实现(比您想象的要难),首先确保具有对齐的数据。接下来,查看目标机器上指令的延迟,并首先执行最长的指令。例如,在Pre-Prescott Intel上,首先将每个组件洗牌到一个寄存器中,然后与自身相乘是有意义的,即使使用了3个乘法,因为Shuffle具有很长的延迟。 在后面的模型中,洗牌需要一个周期,因此这将是一个完全的反优化。再次表明,将其留给编译器并不是一个坏主意。)

为什么别名将阻止矢量化? - avakar
4
我认为restrict关键字在这里并没有什么帮助,因为据我所知,函数体内没有任何受益于额外别名信息的内容;如果你对代码进行了积极的注释,应该将参数标记为const double * restrict,因为它们从未被修改... - Christoph
向量化意味着一次读写多个独立的值,并在一个操作中处理它们。编译器只能在独立数据上执行此操作,如果没有其他指针/引用或者它可以毫无疑问地证明在同一范围内没有其他指针/引用访问任何数据(有时可能会发生),或者如果程序员明确“承诺”(restrict关键字)他已经考虑到了这一点并且“不会发生”。在任何其他情况下,这是不安全的,编译器通常会拒绝执行可能导致不正确结果的操作。 - Damon
正如克里斯多夫所指出的那样,在此函数中数据从未被修改过,尽管聪明的编译器可能确实会想出并且能够证明向量化是安全的(如果没有其他阻碍),但我不会打赌我的右手。无论如何,声明一个未别名化的指针restrict是正确的(声明常量指针const也是如此)。在声明上精确无误不仅仅是装饰性的,它也是一种“代码正确性”。 - Damon
你需要一个聪明的编译器才能从“restrict”中受益。该关键字通过消除某些优化障碍来发挥作用,而这些障碍在此处根本不存在。 - MSalters

4
这是使用SSE3的SIMD代码来完成此操作的:

movaps xmm0,a
movaps xmm1,b
subps xmm0,xmm1
mulps xmm0,xmm0
haddps xmm0,xmm0
haddps xmm0,xmm0

但是,为了让它正常工作,您需要四个值向量(x、y、z、0)。如果您只有三个值,则需要进行一些调整以获得所需的格式,这将抵消上述任何好处。

总的来说,由于CPU的超标量流水线架构,获得性能最佳的方法是对大量数据执行相同的操作,这样您就可以交错各种步骤并进行一些循环展开以避免管道停顿。基于“不能在直接修改后立即使用值”的原则,上述代码最后三条指令肯定会停顿——第二条指令必须等待前一条指令的结果完成,这在流水线系统中不利。

在同时对两个或更多不同点集的点进行计算时,可以消除上述瓶颈——在等待一个计算的结果时,您可以开始下一个点的计算:

movaps xmm0,a1
                  movaps xmm2,a2
movaps xmm1,b1
                  movaps xmm3,b2
subps xmm0,xmm1
                  subps xmm2,xmm3
mulps xmm0,xmm0
                  mulps xmm2,xmm2
haddps xmm0,xmm0
                  haddps xmm2,xmm2
haddps xmm0,xmm0
                  haddps xmm2,xmm2

你能否添加使用SSE的C代码来处理第一行,假设每个三维点的第四个值为零?谢谢。 - Royi

3

如果您想要优化某个东西,首先请对代码进行分析并检查汇编输出。

在使用gcc -O3(4.6.1)编译后,我们将得到带有SIMD的优美反汇编输出:

movsd   (%rdi), %xmm0
movsd   8(%rdi), %xmm2
subsd   (%rsi), %xmm0
movsd   16(%rdi), %xmm1
subsd   8(%rsi), %xmm2
subsd   16(%rsi), %xmm1
mulsd   %xmm0, %xmm0
mulsd   %xmm2, %xmm2
mulsd   %xmm1, %xmm1
addsd   %xmm2, %xmm0
addsd   %xmm1, %xmm0

2
尽管它使用SSE2指令,但这不是SIMD代码。所有指令都作用于单个值。mulsd表示“标量双精度浮点乘法”,多数据版本为mulpd“打包双精度浮点乘法”。 - Giacomo Verticale
1
我已经检查过了,这也是我得到的答案。因此问题变成:"我该如何编写代码,使得gcc可以将其编译为SIMD代码,还是需要手动编写,如果需要,该如何编写?" - Pim Schellart

1
这种类型的问题通常在分子动力学模拟中经常出现。通常通过截断和邻居列表来减少计算量,从而减少计算次数。然而,实际计算平方距离的过程是完全按照您提出的方式进行的(使用编译器优化和固定类型float[3])。
因此,如果您想减少平方计算的数量,您应该告诉我们更多关于这个问题的信息。

0
也许直接将这6个double作为参数传递可能会更快(因为可以避免数组解引用):
inline double distsquare_coord(double xa, double ya, double za, 
                               double xb, double yb, double zb) 
{ 
  double dx = xa-yb; double dy=ya-yb; double dz=za-zb;
  return dx*dx + dy*dy + dz*dz; 
}

或者,如果您在附近有许多点,则可以通过线性逼近其他附近点的距离来计算到同一固定点的距离。


4
你仍然需要进行数组解引用,只不过是将成本转移到函数外部。 - Martin York
同意@Loki Astari的观点,并且有可能,6个参数无法适应CPU寄存器... - tensai_cirno
在使用Linux的AMD64上,我认为寄存器可以容纳6个参数(至少可以容纳6个整数或指针参数,但我忘记了浮点数)。 - Basile Starynkevitch
而且很可能,调用代码可能已经将这6个值保存在寄存器中了。 - Basile Starynkevitch

0

如果您可以重新排列数据以同时处理两对输入向量,则可以使用此代码(仅限SSE2)

// @brief Computes two squared distances between two pairs of 3D vectors
// @param a
//  Pointer to the first pair of 3D vectors.
//  The two vectors must be stored with stride 24, i.e. (a + 3) should point to the first component of the second vector in the pair.
//  Must be aligned by 16 (2 doubles).
// @param b
//  Pointer to the second pairs of 3D vectors.
//  The two vectors must be stored with stride 24, i.e. (a + 3) should point to the first component of the second vector in the pair.
//  Must be aligned by 16 (2 doubles).
// @param c
//  Pointer to the output 2 element array.
//  Must be aligned by 16 (2 doubles).
//  The two distances between a and b vectors will be written to c[0] and c[1] respectively.
void (const double * __restrict__ a, const double * __restrict__ b, double * __restrict c) {
    // diff0 = ( a0.y - b0.y, a0.x - b0.x ) = ( d0.y, d0.x )
    __m128d diff0 = _mm_sub_pd(_mm_load_pd(a), _mm_load_pd(b));
    // diff1 = ( a1.x - b1.x, a0.z - b0.z ) = ( d1.x, d0.z )
    __m128d diff1 = _mm_sub_pd(_mm_load_pd(a + 2), _mm_load_pd(b + 2));
    // diff2 = ( a1.z - b1.z, a1.y - b1.y ) = ( d1.z, d1.y )
    __m128d diff2 = _mm_sub_pd(_mm_load_pd(a + 4), _mm_load_pd(b + 4));

    // prod0 = ( d0.y * d0.y, d0.x * d0.x )
    __m128d prod0 = _mm_mul_pd(diff0, diff0);
    // prod1 = ( d1.x * d1.x, d0.z * d0.z )
    __m128d prod1 = _mm_mul_pd(diff1, diff1);
    // prod2 = ( d1.z * d1.z, d1.y * d1.y )
    __m128d prod2 = _mm_mul_pd(diff1, diff1);

    // _mm_unpacklo_pd(prod0, prod2) = ( d1.y * d1.y, d0.x * d0.x )
    // psum = ( d1.x * d1.x + d1.y * d1.y, d0.x * d0.x + d0.z * d0.z )
    __m128d psum = _mm_add_pd(_mm_unpacklo_pd(prod0, prod2), prod1);

    // _mm_unpackhi_pd(prod0, prod2) = ( d1.z * d1.z, d0.y * d0.y )
    // dotprod = ( d1.x * d1.x + d1.y * d1.y + d1.z * d1.z, d0.x * d0.x + d0.y * d0.y + d0.z * d0.z )
    __m128d dotprod = _mm_add_pd(_mm_unpackhi_pd(prod0, prod2), psum);

    __mm_store_pd(c, dotprod);
}

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接