尽可能快地比较形如(a + sqrt(b))的两个值?

45
作为我正在编写的程序的一部分,我需要比较形式为a + sqrt(b)的两个值,其中ab是无符号整数。因为这是紧密循环的一部分,所以我希望这个比较尽可能快地运行。(如果有影响的话,我在x86-64机器上运行代码,并且无符号整数不大于10^6。另外,我知道a1<a2的事实。)
这是我试图优化的独立函数。我的数字是足够小的整数,以至于double(甚至float)可以精确地表示它们,但sqrt的舍入误差结果必须不改变结果。
// known pre-condition: a1 < a2  in case that helps
bool is_smaller(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
    return a1+sqrt(b1) < a2+sqrt(b2);  // computed mathematically exactly
}

测试用例: is_smaller(900000, 1000000, 900001, 998002)应该返回true,但是根据@wim的评论显示,使用sqrtf()将返回false。使用(int)sqrt()截断为整数也会如此。

a1 + sqrt(b1) =90100a2+sqrt(b2)=901000.00050050037512481206。最接近这个结果的浮点数是精确等于90100。


由于即使在完全内联作为sqrtsd指令时,在现代x86-64上sqrt()函数通常也非常昂贵,因此我尽可能地避免调用sqrt()

通过平方来消除平方根,还有可能避免任何四舍五入误差,从而使所有计算都变得准确无误。

如果函数像这样...

bool is_smaller(unsigned a1, unsigned b1, unsigned x) {
    return a1+sqrt(b1) < x;
}

...那么我可以这样做:return x-a1>=0 && static_cast<uint64_t>(x-a1)*(x-a1)>b1;

但是现在由于有两个sqrt(...)项,我无法进行相同的代数运算。

我可以使用这个公式将值平方两次:

      a1 + sqrt(b1) = a2 + sqrt(b2)
<==>  a1 - a2 = sqrt(b2) - sqrt(b1)
<==>  (a1 - a2) * (a1 - a2) = b1 + b2 - 2 * sqrt(b1) * sqrt(b2)
<==>  (a1 - a2) * (a1 - a2) = b1 + b2 - 2 * sqrt(b1 * b2)
<==>  (a1 - a2) * (a1 - a2) - (b1 + b2) = - 2 * sqrt(b1 * b2)
<==>  ((b1 + b2) - (a1 - a2) * (a1 - a2)) / 2 = sqrt(b1 * b2)
<==>  ((b1 + b2) - (a1 - a2) * (a1 - a2)) * ((b1 + b2) - (a1 - a2) * (a1 - a2)) / 4 = b1 * b2

由于无符号除以4只是一个位移操作,所以它很便宜,但是由于我要将数字平方两次,所以我需要使用128位整数,并且我需要引入一些>=0检查(因为我比较的是不等式而不是等式)。

感觉可能有更好的代数方法来更快地解决这个问题。是否有更快的方法?


8
观察到一点:如果a1+sqrt(b1)<a2成立,那么可以跳过计算sqrt(b2) - 500 - Internal Server Error
4
如果b <= 10^6,则最大值为max(sqrt(b)) = 1000。因此,只有当abs(a1-a2) <= 1000时需要进一步调查。否则,两者之间总是不相等的。 - StPiere
2
@StPiere:如果输入数据分布相对均匀,使用LUT进行sqrt计算在现代x86上将会非常糟糕。4MiB的缓存占用比L2缓存大小(通常为256kiB)要大得多,因此您最多只能获得L3命中,例如Skylake上的45个周期延迟。但即使在非常旧的Core 2上,单精度sqrt的最坏情况下延迟也为29个周期。(还需要几个周期来转换为FP)。在Skylake上,FP sqrt延迟~= L2缓存命中延迟,并且以吞吐量=延迟/4进行流水线处理。更不用说对其他代码的缓存污染的影响了。 - Peter Cordes
3
既然 a1 < a2,那么你可以直接排除所有满足 b1 < b2 条件的情况。 - kvantour
显示剩余18条评论
5个回答

19

这里有一个没有使用sqrt的版本,但我不确定它是否比只有一个sqrt的版本更快(这可能取决于值的分布)。

这是数学公式(如何移除两个平方根):

ad = a2-a1
bd = b2-b1

a1+sqrt(b1) < a2+sqrt(b2)              // subtract a1
   sqrt(b1) < ad+sqrt(b2)              // square it
        b1  < ad^2+2*ad*sqrt(b2)+b2    // arrange
   ad^2+bd  > -2*ad*sqrt(b2)

这里,右侧总是负数。如果左侧为正数,则我们必须返回true。

如果左侧为负数,则可以平方不等式:

ad^4+bd^2+2*bd*ad^2 < 4*ad^2*b2

需要注意的关键点是,如果 a2>=a1+1000,那么is_smaller始终返回true(因为sqrt(b1)的最大值为1000)。如果a2<=a1+1000,那么ad是一个较小的数,因此ad^4总是适合64位(没有必要使用128位算术)。以下是代码:

bool is_smaller(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
    int ad = a2 - a1;
    if (ad>1000) {
        return true;
    }

    int bd = b2 - b1;
    if (ad*ad+bd>0) {
        return true;
    }

    int ad2 = ad*ad;

    return (long long int)ad2*ad2 + (long long int)bd*bd + 2ll*bd*ad2 < 4ll*ad2*b2;
}

编辑:正如Peter Cordes所指出的那样,第一个if是不必要的,因为第二个if已经处理了它,所以代码变得更小更快:

EDIT: 此处有误,第一个if并非必要,因为第二个if语句已经包含了它。这样修改后,代码更加简洁高效。

bool is_smaller(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
    int ad = a2 - a1;
    int bd = b2 - b1;
    if ((long long int)ad*ad+bd>0) {
        return true;
    }

    int ad2 = ad*ad;
    return (long long int)ad2*ad2 + (long long int)bd*bd + 2ll*bd*ad2 < 4ll*ad2*b2;
}

2
最好省略ad>1000分支;我认为ad*ad+bd>0分支覆盖了所有情况。对于分支预测来说,大多数情况下成立的一个分支比两个分支各自有时成立要好。除非ad>1000检查捕获了大部分输入,否则值得做1次额外的subimul。(还有可能是64位的movzx - Peter Cordes
@PeterCordes:很好的观察(像往常一样 :)),谢谢! - geza
哦,你是为了确保ad * ad不会溢出int32_t而这么做的吗?在内联之后,x86-64编译器可以优化掉对64位的零扩展(因为将32位寄存器写入其中隐含地执行了该操作),因此我们可以将无符号输入提升为uint64_t,然后进行减法运算以得到int64_t。(32位的减法需要一个movsxd符号扩展可能为负数的结果,因此要避免使用。) - Peter Cordes
@PeterCordes: 差不多没错 :) 我加了这个,所以 ad^4 肯定不会溢出 64 位。但是,正如你所说,将 ad*ad 乘法作为 64 位轻松处理该情况。 - geza
1
@wim:我已经确认我的解决方案没问题。但是,我发现 Brendan 的第二个版本(也许第一个版本也有)存在一些问题 :) - geza
显示剩余6条评论

4

我有些疲倦,可能会犯错;但我相信如果我真的犯了错,会有人指出来的..

bool is_smaller(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
    a_diff = a1-a2;   // May be negative

    if(a_diff < 0) {
        if(b1 < b2) {
            return true;
        }
        temp = a_diff+sqrt(b1);
        if(temp < 0) {
            return true;
        }
        return temp*temp < b2;
    } else {
        if(b1 >= b2) {
            return false;
        }
    }
//  return a_diff+sqrt(b1) < sqrt(b2);

    temp = a_diff+sqrt(b1);
    return temp*temp < b2;
}

如果您知道 a1 < a2 ,则可以表示为:

bool is_smaller(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
    a_diff = a2-a1;    // Will be positive

    if(b1 > b2) {
        return false;
    }
    if(b1 >= a_diff*a_diff) {
        return false;
    }
    temp = a_diff+sqrt(b2);
    return b1 < temp*temp;
}

2
我们知道 a1 < a2,所以不需要测试 a_diff < 0。但是也许值得测试它是否大于 1000(反向)。 - Acorn
1
你需要一个有符号差异来定义 int a_diff,如果你想要对其进行平方以检查条件而不进行实际的平方根运算,你需要使用 int64_t - Peter Cordes
嘿 - 我犯了一个错误(如果a_diff + sqrt(b1);为负,则无法对其进行平方而不破坏符号) - 已修复。还添加了“如果您知道a1 < a2”。 - Brendan
@kvantour:你说得对(而且 a_diff = a2 - a1 让它变得简洁了很多)。我有点不明白 b1/a_diff < a_diff 所以我用了另外一种方法,可能搞错了。大约10小时后再回来看看。 - Brendan
@Brendan 请看一下这个评论。另外,您可以先对b进行测试,然后计算a_diff - kvantour
显示剩余3条评论

2
还有一种计算整数平方根的方法是使用牛顿法,如这里所述。另一种方法是通过二分查找来寻找floor(sqrt(n)),因为在10^6以下只有1000个完全平方数。虽然这种方法可能性能不佳,但它是一种有趣的方法。我没有测量过其中任何一个的性能,但以下是示例:
#include <iostream>
#include <array>
#include <algorithm>        // std::lower_bound
#include <cassert>          


bool is_smaller_sqrt(unsigned a1, unsigned b1, unsigned a2, unsigned b2)
{
    return a1 + sqrt(b1) < a2 + sqrt(b2);
}

static std::array<int, 1001> squares;

template <typename C>
void squares_init(C& c)
{
    for (int i = 0; i < c.size(); ++i)
        c[i] = i*i;
}

inline bool greater(const int& l, const int& r)
{
    return r < l;
}

inline bool is_smaller_bsearch(unsigned a1, unsigned b1, unsigned a2, unsigned b2)
{
    // return a1 + sqrt(b1) < a2 + sqrt(b2)

    // find floor(sqrt(b1)) - binary search withing 1000 elems
    auto it_b1 = std::lower_bound(crbegin(squares), crend(squares), b1, greater).base();

    // find floor(sqrt(b2)) - binary search withing 1000 elems
    auto it_b2 = std::lower_bound(crbegin(squares), crend(squares), b2, greater).base();

    return (a2 - a1) > (it_b1 - it_b2);
}

unsigned int sqrt32(unsigned long n)
{
    unsigned int c = 0x8000;
    unsigned int g = 0x8000;

    for (;;) {
        if (g*g > n) {
            g ^= c;
        }

        c >>= 1;

        if (c == 0) {
            return g;
        }

        g |= c;
    }
}

bool is_smaller_sqrt32(unsigned a1, unsigned b1, unsigned a2, unsigned b2)
{
    return a1 + sqrt32(b1) < a2 + sqrt32(b2);
}

int main()
{
    squares_init(squares);

    // now can use is_smaller
    assert(is_smaller_sqrt(1, 4, 3, 1) == is_smaller_sqrt32(1, 4, 3, 1));
    assert(is_smaller_sqrt(1, 2, 3, 3) == is_smaller_sqrt32(1, 2, 3, 3));
    assert(is_smaller_sqrt(1000, 4, 1001, 1) == is_smaller_sqrt32(1000, 4, 1001, 1));
    assert(is_smaller_sqrt(1, 300, 3, 200) == is_smaller_sqrt32(1, 300, 3, 200));
}

你的整数sqrt32在那个循环中运行了固定的16次迭代,我想。你可以根据位扫描找到n中最高位并除以2来从较小的位置开始。或者只是从一个较低的固定起点开始,因为已知n的最大值为100万,而不是约40亿。所以我们可以节省大约12/2=6次迭代。但这可能仍然比将其转换为单精度float进行sqrtss和返回要慢。也许如果你在整数循环中并行执行两个平方根,那么c更新和循环开销就会被摊销,并且会有2个依赖链。 - Peter Cordes
二分查找反转表是一个有趣的想法,但在现代x86-64上可能仍然很糟糕,因为硬件sqrt并不是非常慢,但相对于具有更短/更简单流水线设计的分支预测错误非常昂贵。也许这个答案中的一些内容对于在微控制器上遇到同样问题的人会有用。 - Peter Cordes

2
我不确定代数操作与整数算术结合是否一定能够导致最快的解决方案。在这种情况下,您将需要许多标量乘法(这并不是很快),和/或者分支预测可能会失败,从而降低性能。显然,您必须进行基准测试,以查看哪种解决方案在您的特定情况下最快。
使sqrt更快的一种方法是向gcc或clang添加-fno-math-errno选项。在这种情况下,编译器无需检查负输入。对于icc来说,这是默认设置。
通过使用矢量化的sqrt指令sqrtpd,而不是标量的sqrt指令sqrtsd,可以实现更高的性能改进。Peter Cordes 已经表明,clang能够自动矢量化此代码,从而生成此sqrtpd
然而,自动向量化的成功程度很大程度上取决于正确的编译器设置和所使用的编译器(如clang、gcc、icc等)。使用-march=nehalem或更早版本的clang不会进行矢量化。
以下内嵌代码提供了更可靠的矢量化结果。为了可移植性,我们只假设支持SSE2,这是x86-64的基线。
/* gcc -m64 -O3 -fno-math-errno smaller.c                      */
/* Adding e.g. -march=nehalem or -march=skylake might further  */
/* improve the generated code                                  */
/* Note that SSE2 in guaranteed to exist with x86-64           */
#include<immintrin.h>
#include<math.h>
#include<stdio.h>
#include<stdint.h>

int is_smaller_v5(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
    uint64_t a64    =  (((uint64_t)a2)<<32) | ((uint64_t)a1); /* Avoid too much port 5 pressure by combining 2 32 bit integers in one 64 bit integer */
    uint64_t b64    =  (((uint64_t)b2)<<32) | ((uint64_t)b1); 
    __m128i ax      = _mm_cvtsi64_si128(a64);         /* Move integer from gpr to xmm register                  */
    __m128i bx      = _mm_cvtsi64_si128(b64);         
    __m128d a       = _mm_cvtepi32_pd(ax);            /* Convert 2 integers to double                           */
    __m128d b       = _mm_cvtepi32_pd(bx);            /* We don't need _mm_cvtepu32_pd since a,b < 1e6          */
    __m128d sqrt_b  = _mm_sqrt_pd(b);                 /* Vectorized sqrt: compute 2 sqrt-s with 1 instruction   */
    __m128d sum     = _mm_add_pd(a, sqrt_b);
    __m128d sum_lo  = sum;                            /* a1 + sqrt(b1) in the lower 64 bits                     */
    __m128d sum_hi  =  _mm_unpackhi_pd(sum, sum);     /* a2 + sqrt(b2) in the lower 64 bits                     */
    return _mm_comilt_sd(sum_lo, sum_hi);
}


int is_smaller(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
    return a1+sqrt(b1) < a2+sqrt(b2);
}


int main(){
    unsigned a1; unsigned b1; unsigned a2; unsigned b2;
    a1 = 11; b1 = 10; a2 = 10; b2 = 10;
    printf("smaller?  %i  %i \n",is_smaller(a1,b1,a2,b2), is_smaller_v5(a1,b1,a2,b2));
    a1 = 10; b1 = 11; a2 = 10; b2 = 10;
    printf("smaller?  %i  %i \n",is_smaller(a1,b1,a2,b2), is_smaller_v5(a1,b1,a2,b2));
    a1 = 10; b1 = 10; a2 = 11; b2 = 10;
    printf("smaller?  %i  %i \n",is_smaller(a1,b1,a2,b2), is_smaller_v5(a1,b1,a2,b2));
    a1 = 10; b1 = 10; a2 = 10; b2 = 11;
    printf("smaller?  %i  %i \n",is_smaller(a1,b1,a2,b2), is_smaller_v5(a1,b1,a2,b2));

    return 0;
}


在Intel Skylake上进行简单的吞吐量测试,使用编译器选项gcc -m64 -O3 -fno-math-errno -march=nehalem,我发现is_smaller_v5()的吞吐量比原来的is_smaller()好了2.6倍:包括循环开销在内,分别为6.8个CPU周期和18个CPU周期。然而,在一个(可能过于)简单的延迟测试中,其中输入a1、a2、b1、b2取决于先前的is_smaller(_v5)的结果,我没有看到任何改进。(39.7个周期与39个周期相同)。请参见this Godbolt link以获取生成的汇编代码。

clang已经像这样自动向量化了:P https://godbolt.org/z/GvNe2B看一下double和signed int版本。但仅适用于double,而不是float。对于吞吐量,您应该绝对使用此策略的float,因为打包转换只有1个uop,并且sqrtps具有更好的吞吐量。 OP的数字都是100万或更少,因此可以由float精确表示,其平方根也可以。顺便说一句,看起来您忘记设置-mtune=haswell,因此您的gcc选择了存储/重新加载策略_mm_set_epi32而不是ALUmovd - Peter Cordes
@PeterCordes:单精度不够准确,请参见我的评论此处。我们只知道目标是x86-64。即使使用-march=nehalem,某种情况下clang也无法矢量化。实际上,使用4个movd指令,clang生成更好的汇编代码。 - wim
1
@PeterCordes:请注意,在吞吐量测试中,自动向量化函数可能很容易在端口5上成为瓶颈。如果我数对了的话,Clang 生成 9个p5微操作(Skylake)。 - wim
我没有仔细看。不惊讶它不是最优的。:P 有趣的是,我没有意识到(u)comisd有一个内在函数。当然有道理,但我以前从未注意过。如果您在没有AVX的情况下编译,则应该能够通过将movhlps存储到“死”变量中而不是使用unpckhpd来保存movaps。但是,这需要进行大量转换,因为内在函数使得帮助编译器优化其混洗方式变得不方便。 - Peter Cordes
@PeterCordes 实际上,我以前从来没有使用过(u)comisd内置函数,但在这里它似乎很有用。 - wim

1
也许不比其他答案更好,但使用了不同的思路(和大量的预先分析)。
// Compute approximate integer square root of input in the range [0,10^6].
// Uses a piecewise linear approximation to sqrt() with bounded error in each piece:
//   0 <= x <= 784 : x/28
//   784 < x <= 7056 : 21 + x/112
//   7056 < x <= 28224 : 56 + x/252
//   28224 < x <= 78400 : 105 + x/448
//   78400 < x <= 176400 : 168 + x/700
//   176400 < x <= 345744 : 245 + x/1008
//   345744 < x <= 614656 : 336 + x/1372
//   614656 < x <= 1000000 : (784000+x)/1784
// It is the case that sqrt(x) - 7.9992711366390365897... <= pseudosqrt(x) <= sqrt(x).
unsigned pseudosqrt(unsigned x) {
    return 
        x <= 78400 ? 
            x <= 7056 ?
                x <= 764 ? x/28 : 21 + x/112
              : x <= 28224 ? 56 + x/252 : 105 + x/448
          : x <= 345744 ?
                x <= 176400 ? 168 + x/700 : 245 + x/1008
              : x <= 614656 ? 336 + x/1372 : (x+784000)/1784 ;
}

// known pre-conditions: a1 < a2, 
//                  0 <= b1 <= 1000000
//                  0 <= b2 <= 1000000
bool is_smaller(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
// Try three refinements:
// 1: a1 + sqrt(b1) <= a1 + 1000, 
//    so is a1 + 1000 < a2 ?  
//    Convert to a2 - a1 > 1000 .
// 2: a1 + sqrt(b1) <= a1 + pseudosqrt(b1) + 8 and
//    a2 + pseudosqrt(b2) <= a2 + sqrt(b2), 
//    so is  a1 + pseudosqrt(b1) + 8 < a2 + pseudosqrt(b2) ?
//    Convert to a2 - a1 > pseudosqrt(b1) - pseudosqrt(b2) + 8 .
// 3: Actually do the work.
//    Convert to a2 - a1 > sqrt(b1) - sqrt(b2)
// Use short circuit evaluation to stop when resolved.
    unsigned ad = a2 - a1;
    return (ad > 1000)
           || (ad > pseudosqrt(b1) - pseudosqrt(b2) + 8)
           || ((int) ad > (int)(sqrt(b1) - sqrt(b2)));
}

(我手边没有编译器,所以可能会有一两个错别字。)


@PeterCordes:除非现在long long操作与unsigned操作一样快,否则你的“much faster”声明将会令人惊讶。但事实似乎并非如此。[https://stackoverflow.com/questions/48779619/why-do-arithmetic-operations-on-long-long-int-take-more-time-than-for-an-int] - Eric Towers
1
除非分支预测完美无缺,否则比较和分支是很昂贵的。在现代x86-64上(如AMD Zen或自Nehalem以来的英特尔),比一个long long乘法(3个时钟周期延迟,1个周期吞吐量)要昂贵得多。对于现代x86-64上更宽的类型,只有除法的成本更高,其他操作不依赖于数据或类型宽度。一些旧的x86-64 CPU(如Bulldozer系列或Silvermont)具有较慢的64位乘法。https://agner.org/optimize/。(当然,我们正在谈论标量;使用SIMD进行自动向量化使得窄类型有价值,因为您可以每个向量做更多事情) - Peter Cordes
@PeterCordes:那么Geza的早期退出是4个算术运算和一个比较,而我的是1个算术运算和一个比较。据我所知,1仍然比4少得多。 - Eric Towers
@PeterCordes:我并不完全相信Geza的数学是完全正确的。可能很容易修复,但我没有深入研究细节。 - wim
1
@wim:是的,无论OP做什么,都需要进行相当详尽的单元测试!(1M ^ 4太昂贵了,因此需要修剪搜索空间以查看一些大值和使不等式两侧几乎相等的一些值。) - Peter Cordes
显示剩余9条评论

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接