AVX加速下最快的指数函数实现

14
我正在寻找一个在AVX元素(单精度浮点数)上快速处理指数函数的高效近似方法,即__m256 _mm256_exp_ps(__m256 x),不使用SVML。
相对精度应该达到大约1e-6或约20个尾数位(2的20次方中的1个部分)。
如果它是用英特尔内在函数以C风格编写的,我会很高兴。
代码应该可移植(Windows、macOS、Linux、MSVC、ICC、GCC等)。
这类似于Fastest Implementation of Exponential Function Using SSE,但那个问题寻找的是低精度非常快的实现(当前答案提供约1e-3精度)。
此外,这个问题要求AVX / AVX2(和FMA)。但请注意,两个问题的答案都可以轻松地在SSE4__m128或AVX2__m256之间移植,因此未来的读者应根据所需的精度/性能权衡进行选择。

vml 应该没问题:https://bitbucket.org/eschnett/vecmathlib/wiki/Home - Regis Portalez
1
请查看来自avx_mathfun的AVX2优化指数函数。 - wim
1
@Royi 为什么你不能将你的 SEE 和 AVX 函数移动到不同的源文件中,并使用 -msse2 编译其中一个,另一个使用 -mavx 编译呢? - Z boson
1
@Zboson 注意,Agner Fog声称vectormath_exp.h中的exp性能较差,请参见文档的第43页avx_mathfun的优点是它使用类似于Chebyshev逼近多项式而不是VCL使用的Taylor展开。因此,avx_mathfun应该在性能和精度之间有更好的平衡。 - wim
3
据@wim称,GCC只对"double"类型的"exp"函数进行向量化,而不是"float"类型的。奇怪。https://godbolt.org/g/mN14F7 - Z boson
显示剩余25条评论
5个回答

12

exp函数来自avx_mathfun,它结合范围缩减和类似Chebyshev逼近的多项式,利用AVX指令并行计算8个exp值。使用正确的编译器设置,确保在可能的情况下将addpsmulps融合为FMA指令。

很容易将原始的exp代码从avx_mathfun调整为可移植(适用于不同编译器)的C / AVX2内部代码。原始代码使用gcc风格的对齐属性和巧妙的宏。修改后的代码使用标准的_mm256_set1_ps()。以下是小型测试代码和表格下的修改后的代码。修改后的代码需要AVX2。

以下代码用于简单测试:

int main(){
    int i;
    float xv[8];
    float yv[8];
    __m256 x = _mm256_setr_ps(1.0f, 2.0f, 3.0f ,4.0f ,5.0f, 6.0f, 7.0f, 8.0f);
    __m256 y = exp256_ps(x);
    _mm256_store_ps(xv,x);
    _mm256_store_ps(yv,y);

    for (i=0;i<8;i++){
        printf("i = %i, x = %e, y = %e \n",i,xv[i],yv[i]);
    }
    return 0;
}

输出看起来没问题:

i = 0, x = 1.000000e+00, y = 2.718282e+00 
i = 1, x = 2.000000e+00, y = 7.389056e+00 
i = 2, x = 3.000000e+00, y = 2.008554e+01 
i = 3, x = 4.000000e+00, y = 5.459815e+01 
i = 4, x = 5.000000e+00, y = 1.484132e+02 
i = 5, x = 6.000000e+00, y = 4.034288e+02 
i = 6, x = 7.000000e+00, y = 1.096633e+03 
i = 7, x = 8.000000e+00, y = 2.980958e+03 
修改后的代码(AVX2)为:
#include <stdio.h>
#include <immintrin.h>
/*     gcc -O3 -m64 -Wall -mavx2 -march=broadwell  expc.c    */

__m256 exp256_ps(__m256 x) {
/* Modified code. The original code is here: https://github.com/reyoung/avx_mathfun

   AVX implementation of exp
   Based on "sse_mathfun.h", by Julien Pommier
   http://gruntthepeon.free.fr/ssemath/
   Copyright (C) 2012 Giovanni Garberoglio
   Interdisciplinary Laboratory for Computational Science (LISC)
   Fondazione Bruno Kessler and University of Trento
   via Sommarive, 18
   I-38123 Trento (Italy)
  This software is provided 'as-is', without any express or implied
  warranty.  In no event will the authors be held liable for any damages
  arising from the use of this software.
  Permission is granted to anyone to use this software for any purpose,
  including commercial applications, and to alter it and redistribute it
  freely, subject to the following restrictions:
  1. The origin of this software must not be misrepresented; you must not
     claim that you wrote the original software. If you use this software
     in a product, an acknowledgment in the product documentation would be
     appreciated but is not required.
  2. Altered source versions must be plainly marked as such, and must not be
     misrepresented as being the original software.
  3. This notice may not be removed or altered from any source distribution.
  (this is the zlib license)
*/
/* 
  To increase the compatibility across different compilers the original code is
  converted to plain AVX2 intrinsics code without ingenious macro's,
  gcc style alignment attributes etc. The modified code requires AVX2
*/
__m256   exp_hi        = _mm256_set1_ps(88.3762626647949f);
__m256   exp_lo        = _mm256_set1_ps(-88.3762626647949f);

__m256   cephes_LOG2EF = _mm256_set1_ps(1.44269504088896341);
__m256   cephes_exp_C1 = _mm256_set1_ps(0.693359375);
__m256   cephes_exp_C2 = _mm256_set1_ps(-2.12194440e-4);

__m256   cephes_exp_p0 = _mm256_set1_ps(1.9875691500E-4);
__m256   cephes_exp_p1 = _mm256_set1_ps(1.3981999507E-3);
__m256   cephes_exp_p2 = _mm256_set1_ps(8.3334519073E-3);
__m256   cephes_exp_p3 = _mm256_set1_ps(4.1665795894E-2);
__m256   cephes_exp_p4 = _mm256_set1_ps(1.6666665459E-1);
__m256   cephes_exp_p5 = _mm256_set1_ps(5.0000001201E-1);
__m256   tmp           = _mm256_setzero_ps(), fx;
__m256i  imm0;
__m256   one           = _mm256_set1_ps(1.0f);

        x     = _mm256_min_ps(x, exp_hi);
        x     = _mm256_max_ps(x, exp_lo);

  /* express exp(x) as exp(g + n*log(2)) */
        fx    = _mm256_mul_ps(x, cephes_LOG2EF);
        fx    = _mm256_add_ps(fx, _mm256_set1_ps(0.5f));
        tmp   = _mm256_floor_ps(fx);
__m256  mask  = _mm256_cmp_ps(tmp, fx, _CMP_GT_OS);    
        mask  = _mm256_and_ps(mask, one);
        fx    = _mm256_sub_ps(tmp, mask);
        tmp   = _mm256_mul_ps(fx, cephes_exp_C1);
__m256  z     = _mm256_mul_ps(fx, cephes_exp_C2);
        x     = _mm256_sub_ps(x, tmp);
        x     = _mm256_sub_ps(x, z);
        z     = _mm256_mul_ps(x,x);

__m256  y     = cephes_exp_p0;
        y     = _mm256_mul_ps(y, x);
        y     = _mm256_add_ps(y, cephes_exp_p1);
        y     = _mm256_mul_ps(y, x);
        y     = _mm256_add_ps(y, cephes_exp_p2);
        y     = _mm256_mul_ps(y, x);
        y     = _mm256_add_ps(y, cephes_exp_p3);
        y     = _mm256_mul_ps(y, x);
        y     = _mm256_add_ps(y, cephes_exp_p4);
        y     = _mm256_mul_ps(y, x);
        y     = _mm256_add_ps(y, cephes_exp_p5);
        y     = _mm256_mul_ps(y, z);
        y     = _mm256_add_ps(y, x);
        y     = _mm256_add_ps(y, one);

  /* build 2^n */
        imm0  = _mm256_cvttps_epi32(fx);
        imm0  = _mm256_add_epi32(imm0, _mm256_set1_epi32(0x7f));
        imm0  = _mm256_slli_epi32(imm0, 23);
__m256  pow2n = _mm256_castsi256_ps(imm0);
        y     = _mm256_mul_ps(y, pow2n);
        return y;
}

int main(){
    int i;
    float xv[8];
    float yv[8];
    __m256 x = _mm256_setr_ps(1.0f, 2.0f, 3.0f ,4.0f ,5.0f, 6.0f, 7.0f, 8.0f);
    __m256 y = exp256_ps(x);
    _mm256_store_ps(xv,x);
    _mm256_store_ps(yv,y);

    for (i=0;i<8;i++){
        printf("i = %i, x = %e, y = %e \n",i,xv[i],yv[i]);
    }
    return 0;
}


正如@Peter Cordes指出的, 可以将_mm256_floor_ps(fx + 0.5f)替换为_mm256_round_ps(fx)。 此外,mask = _mm256_cmp_ps(tmp, fx, _CMP_GT_OS);和接下来的两行代码似乎是多余的。 通过将cephes_exp_C1cephes_exp_C2合并为inv_LOG2EF,可以进一步优化。 这导致以下代码,但尚未经过彻底测试!

#include <stdio.h>
#include <immintrin.h>
#include <math.h>
/*    gcc -O3 -m64 -Wall -mavx2 -march=broadwell  expc.c -lm     */

__m256 exp256_ps(__m256 x) {
/* Modified code from this source: https://github.com/reyoung/avx_mathfun

   AVX implementation of exp
   Based on "sse_mathfun.h", by Julien Pommier
   http://gruntthepeon.free.fr/ssemath/
   Copyright (C) 2012 Giovanni Garberoglio
   Interdisciplinary Laboratory for Computational Science (LISC)
   Fondazione Bruno Kessler and University of Trento
   via Sommarive, 18
   I-38123 Trento (Italy)
  This software is provided 'as-is', without any express or implied
  warranty.  In no event will the authors be held liable for any damages
  arising from the use of this software.
  Permission is granted to anyone to use this software for any purpose,
  including commercial applications, and to alter it and redistribute it
  freely, subject to the following restrictions:
  1. The origin of this software must not be misrepresented; you must not
     claim that you wrote the original software. If you use this software
     in a product, an acknowledgment in the product documentation would be
     appreciated but is not required.
  2. Altered source versions must be plainly marked as such, and must not be
     misrepresented as being the original software.
  3. This notice may not be removed or altered from any source distribution.
  (this is the zlib license)

*/
/* 
  To increase the compatibility across different compilers the original code is
  converted to plain AVX2 intrinsics code without ingenious macro's,
  gcc style alignment attributes etc.
  Moreover, the part "express exp(x) as exp(g + n*log(2))" has been significantly simplified.
  This modified code is not thoroughly tested!
*/


__m256   exp_hi        = _mm256_set1_ps(88.3762626647949f);
__m256   exp_lo        = _mm256_set1_ps(-88.3762626647949f);

__m256   cephes_LOG2EF = _mm256_set1_ps(1.44269504088896341f);
__m256   inv_LOG2EF    = _mm256_set1_ps(0.693147180559945f);

__m256   cephes_exp_p0 = _mm256_set1_ps(1.9875691500E-4);
__m256   cephes_exp_p1 = _mm256_set1_ps(1.3981999507E-3);
__m256   cephes_exp_p2 = _mm256_set1_ps(8.3334519073E-3);
__m256   cephes_exp_p3 = _mm256_set1_ps(4.1665795894E-2);
__m256   cephes_exp_p4 = _mm256_set1_ps(1.6666665459E-1);
__m256   cephes_exp_p5 = _mm256_set1_ps(5.0000001201E-1);
__m256   fx;
__m256i  imm0;
__m256   one           = _mm256_set1_ps(1.0f);

        x     = _mm256_min_ps(x, exp_hi);
        x     = _mm256_max_ps(x, exp_lo);

  /* express exp(x) as exp(g + n*log(2)) */
        fx     = _mm256_mul_ps(x, cephes_LOG2EF);
        fx     = _mm256_round_ps(fx, _MM_FROUND_TO_NEAREST_INT |_MM_FROUND_NO_EXC);
__m256  z      = _mm256_mul_ps(fx, inv_LOG2EF);
        x      = _mm256_sub_ps(x, z);
        z      = _mm256_mul_ps(x,x);

__m256  y      = cephes_exp_p0;
        y      = _mm256_mul_ps(y, x);
        y      = _mm256_add_ps(y, cephes_exp_p1);
        y      = _mm256_mul_ps(y, x);
        y      = _mm256_add_ps(y, cephes_exp_p2);
        y      = _mm256_mul_ps(y, x);
        y      = _mm256_add_ps(y, cephes_exp_p3);
        y      = _mm256_mul_ps(y, x);
        y      = _mm256_add_ps(y, cephes_exp_p4);
        y      = _mm256_mul_ps(y, x);
        y      = _mm256_add_ps(y, cephes_exp_p5);
        y      = _mm256_mul_ps(y, z);
        y      = _mm256_add_ps(y, x);
        y      = _mm256_add_ps(y, one);

  /* build 2^n */
        imm0   = _mm256_cvttps_epi32(fx);
        imm0   = _mm256_add_epi32(imm0, _mm256_set1_epi32(0x7f));
        imm0   = _mm256_slli_epi32(imm0, 23);
__m256  pow2n  = _mm256_castsi256_ps(imm0);
        y      = _mm256_mul_ps(y, pow2n);
        return y;
}

int main(){
    int i;
    float xv[8];
    float yv[8];
    __m256 x = _mm256_setr_ps(11.0f, -12.0f, 13.0f ,-14.0f ,15.0f, -16.0f, 17.0f, -18.0f);
    __m256 y = exp256_ps(x);
    _mm256_store_ps(xv,x);
    _mm256_store_ps(yv,y);

 /* compare exp256_ps with the double precision exp from math.h, 
    print the relative error             */
    printf("i      x                     y = exp256_ps(x)      double precision exp        relative error\n\n");
    for (i=0;i<8;i++){ 
        printf("i = %i  x =%16.9e   y =%16.9e   exp_dbl =%16.9e   rel_err =%16.9e\n",
           i,xv[i],yv[i],exp((double)(xv[i])),
           ((double)(yv[i])-exp((double)(xv[i])))/exp((double)(xv[i])) );
    }
    return 0;
}

通过将exp256_ps与math.h中的双精度exp进行比较,下表给出了某些点上的准确性印象。相对误差在最后一列中。

i      x                     y = exp256_ps(x)      double precision exp        relative error

i = 0  x = 1.000000000e+00   y = 2.718281746e+00   exp_dbl = 2.718281828e+00   rel_err =-3.036785947e-08
i = 1  x =-2.000000000e+00   y = 1.353352815e-01   exp_dbl = 1.353352832e-01   rel_err =-1.289636419e-08
i = 2  x = 3.000000000e+00   y = 2.008553696e+01   exp_dbl = 2.008553692e+01   rel_err = 1.672817689e-09
i = 3  x =-4.000000000e+00   y = 1.831563935e-02   exp_dbl = 1.831563889e-02   rel_err = 2.501162103e-08
i = 4  x = 5.000000000e+00   y = 1.484131622e+02   exp_dbl = 1.484131591e+02   rel_err = 2.108215155e-08
i = 5  x =-6.000000000e+00   y = 2.478752285e-03   exp_dbl = 2.478752177e-03   rel_err = 4.380257261e-08
i = 6  x = 7.000000000e+00   y = 1.096633179e+03   exp_dbl = 1.096633158e+03   rel_err = 1.849522682e-08
i = 7  x =-8.000000000e+00   y = 3.354626242e-04   exp_dbl = 3.354626279e-04   rel_err =-1.101575118e-08

1
无论如何,OP希望找到一些快速的解决方案,相对误差高达1e-6是可以接受的,这与需要所有23位有效数字正确相距甚远,因此应该放弃额外精度技巧(如果有)。除非存在在指数步骤之间的截止区域或其他危险区域。 - Peter Cordes
2
@PeterCordes 我明白了。确实,0.693359375 的浮点表示以 15 个零位结尾。因此,将 0.693359375 乘以一个小整数应该是准确的,而这个大/小分割可能有助于提高精度。 - wim
2
那可能就是这样了。当然,如果FMA可用,您在临时变量中根本不需要舍入。 :) - Peter Cordes
@Royi 谢谢。请注意,优化代码的先前版本仍包含行 fx = _mm256_add_ps(fx, _mm256_set1_ps(0.5f));,这是/曾经是错误的。当前答案应该是正确的。 - wim
我相当确定 _mm256_slli_epi32 需要 AVX2。在回答中特别指出这一点可能是值得的。 - njuffa
显示剩余12条评论

10
由于快速计算exp()需要操作IEEE-754浮点数操作数的指数字段,因此AVX不适用于此计算,因为它缺乏整数运算。因此,我将重点放在AVX2上。融合乘加支持技术上是与AVX2分开的功能,因此我提供两个代码路径,一个使用FMA,一个不使用,由宏USE_FMA控制。
下面的代码将exp()计算到接近所需的10-6的精度。在这里使用FMA并没有提供任何显着的改进,但它应该在支持它的平台上提供性能优势。
在先前的答案中使用的算法用于较低精度的SSE实现,不能完全扩展到相当准确的实现,因为它包含一些具有较差数值特性的计算,在那种情况下并不重要。与其计算ex= 2i* 2f,其中f在[0,1]或f在[-½,½],计算ex= 2i* ef,其中f在更窄的区间[-½log 2,½log 2]中是有优势的,其中log表示自然对数。
为此,我们首先计算i = rint(x * log2(e)),然后计算f = x - log(2) * i。重要的是,后面的计算需要使用比本机精度更高的精度,以提供传递给核心近似的准确缩小参数。为此,我们使用Cody-Waite方案,该方案最初发表于W.J. Cody和W.Waite的“基本函数软件手册”,Prentice Hall 1980年版。常数log(2)分成具有足够尾随零位的“高”部分和具有远小于数学常数的“低”部分,该常数保持两者之间的差异。
选择具有足够尾随零位的“高”部分,使得i与“高”部分的乘积在本机精度中可以完全表示。在这里,我选择了具有八个尾随零位的“高”部分,因为i肯定适合八位。

本质上,我们计算 f = x - i * log(2)high - i * log(2)low。这个简化后的参数被传递到核心近似函数中,该函数是一个极小值最大化逼近算法多项式,结果乘以 2i,与前面的答案相同。

#include <immintrin.h>

#define USE_FMA 0

/* compute exp(x) for x in [-87.33654f, 88.72283] 
   maximum relative error: 3.1575e-6 (USE_FMA = 0); 3.1533e-6 (USE_FMA = 1)
*/
__m256 faster_more_accurate_exp_avx2 (__m256 x)
{
    __m256 t, f, p, r;
    __m256i i, j;

    const __m256 l2e = _mm256_set1_ps (1.442695041f); /* log2(e) */
    const __m256 l2h = _mm256_set1_ps (-6.93145752e-1f); /* -log(2)_hi */
    const __m256 l2l = _mm256_set1_ps (-1.42860677e-6f); /* -log(2)_lo */
    /* coefficients for core approximation to exp() in [-log(2)/2, log(2)/2] */
    const __m256 c0 =  _mm256_set1_ps (0.041944388f);
    const __m256 c1 =  _mm256_set1_ps (0.168006673f);
    const __m256 c2 =  _mm256_set1_ps (0.499999940f);
    const __m256 c3 =  _mm256_set1_ps (0.999956906f);
    const __m256 c4 =  _mm256_set1_ps (0.999999642f);

    /* exp(x) = 2^i * e^f; i = rint (log2(e) * x), f = x - log(2) * i */
    t = _mm256_mul_ps (x, l2e);      /* t = log2(e) * x */
    r = _mm256_round_ps (t, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); /* r = rint (t) */

#if USE_FMA
    f = _mm256_fmadd_ps (r, l2h, x); /* x - log(2)_hi * r */
    f = _mm256_fmadd_ps (r, l2l, f); /* f = x - log(2)_hi * r - log(2)_lo * r */
#else // USE_FMA
    p = _mm256_mul_ps (r, l2h);      /* log(2)_hi * r */
    f = _mm256_add_ps (x, p);        /* x - log(2)_hi * r */
    p = _mm256_mul_ps (r, l2l);      /* log(2)_lo * r */
    f = _mm256_add_ps (f, p);        /* f = x - log(2)_hi * r - log(2)_lo * r */
#endif // USE_FMA

    i = _mm256_cvtps_epi32(t);       /* i = (int)rint(t) */

    /* p ~= exp (f), -log(2)/2 <= f <= log(2)/2 */
    p = c0;                          /* c0 */
#if USE_FMA
    p = _mm256_fmadd_ps (p, f, c1);  /* c0*f+c1 */
    p = _mm256_fmadd_ps (p, f, c2);  /* (c0*f+c1)*f+c2 */
    p = _mm256_fmadd_ps (p, f, c3);  /* ((c0*f+c1)*f+c2)*f+c3 */
    p = _mm256_fmadd_ps (p, f, c4);  /* (((c0*f+c1)*f+c2)*f+c3)*f+c4 ~= exp(f) */
#else // USE_FMA
    p = _mm256_mul_ps (p, f);        /* c0*f */
    p = _mm256_add_ps (p, c1);       /* c0*f+c1 */
    p = _mm256_mul_ps (p, f);        /* (c0*f+c1)*f */
    p = _mm256_add_ps (p, c2);       /* (c0*f+c1)*f+c2 */
    p = _mm256_mul_ps (p, f);        /* ((c0*f+c1)*f+c2)*f */
    p = _mm256_add_ps (p, c3);       /* ((c0*f+c1)*f+c2)*f+c3 */
    p = _mm256_mul_ps (p, f);        /* (((c0*f+c1)*f+c2)*f+c3)*f */
    p = _mm256_add_ps (p, c4);       /* (((c0*f+c1)*f+c2)*f+c3)*f+c4 ~= exp(f) */
#endif // USE_FMA

    /* exp(x) = 2^i * p */
    j = _mm256_slli_epi32 (i, 23); /* i << 23 */
    r = _mm256_castsi256_ps (_mm256_add_epi32 (j, _mm256_castps_si256 (p))); /* r = p * 2^i */

    return r;
}

如果需要更高的准确度,可以将多项式逼近度数提高一级,使用以下系数:

/* maximum relative error: 1.7428e-7 (USE_FMA = 0); 1.6586e-7 (USE_FMA = 1) */
const __m256 c0 =  _mm256_set1_ps (0.008301110f);
const __m256 c1 =  _mm256_set1_ps (0.041906696f);
const __m256 c2 =  _mm256_set1_ps (0.166674897f);
const __m256 c3 =  _mm256_set1_ps (0.499990642f);
const __m256 c4 =  _mm256_set1_ps (0.999999762f);
const __m256 c5 =  _mm256_set1_ps (1.000000000f);

这是一个低次多项式的相当精确的近似值!通过对我的答案中的代码进行一些实验,我发现当不需要高精度时,特别是在使用FMA计算减少的参数时,Cody-Waite技巧并不总是必要的。仅使用AVX2(无FMA),在[-84,84]范围内的所有浮点数x上,第二个exp256_ps(x)函数(在我的答案底部)具有最大相对误差为4.1e-6。启用编译器的FMA后,x在[-84,84]范围内的最大相对误差为3.0e-7。 - wim
@wim 我使用了极小化近似,这可能与其他代码中使用的近似不同。我同意使用FMA有时可以使Cody-Waite技巧不必要。我仍在尝试使用FMA增强版本的上述代码。初步迹象表明,在这里它似乎并没有带来更多的精度。 - njuffa
1
i = (int)r 这个注释与代码不符:cvtps_epi32 是一个四舍五入的转换,而不是 C 的向零截断(那应该是 cvttps_epi32,多了一个 t 表示 truncate)。C 没有任何好的方法来表示四舍五入并转换为 intlrint 返回一个 long),但最准确的表示操作的方式是 (int) rint(r)。哦,我看到 rroundps 的结果。通过使用 _mm256_cvtps_epi32(t) 而不是 r 来缩短 dep 链(但不是关键路径)。 - Peter Cordes
@PeterCordes 我同意你的观察结果。我不确定微小的ILP增加是否会导致实际性能上的差异,因为这是如注释中所述的那样偏离关键路径;我也不确定来自该更改的端口使用潜在差异对性能的影响(我对多个x86 CPU世代的微妙细节远没有你了解得多 :-) 如果性能数据显示使用 _mm256_cvtps_epi32(t) 更优,我很乐意应用这个更改。 - njuffa
它给CPU提供了在同一端口上处理FP加法和舍入指令的空闲周期时更早运行它的选项。这可能会导致更多或更少的资源冲突(其中转换从关键路径中窃取一个周期)。嗯,在Haswell / Skylake上,vroundps是2个uops(对于p1 / p01),因此将其转换为整数并返回与Haswell +上的实际vroundps相同的延迟和吞吐量。以这种方式执行可以免费获得四舍五入的整数结果!(我认为在Bulldozer / Ryzen和Sandybridge上也是如此) - Peter Cordes

6

我很喜欢尝试各种方法,最终发现了这个相对准确率约为~1-07e的算法,并且转换为向量指令非常简单。 该算法仅需4个常数、5次乘法和1次除法,比内置的exp()函数快两倍。

float fast_exp(float x)
{
    const float c1 = 0.007972914726F;
    const float c2 = 0.1385283768F;
    const float c3 = 2.885390043F;
    const float c4 = 1.442695022F;      
    x *= c4; //convert to 2^(x)
    int intPart = (int)x;
    x -= intPart;
    float xx = x * x;
    float a = x + c1 * xx * x;
    float b = c3 + c2 * xx;
    float res = (b + a) / (b - a);
    reinterpret_cast<int &>(res) += intPart << 23; // res *= 2^(intPart)
    return res;
}

转换为AVX1(更新)


__m256 _mm256_exp_ps(__m256 _x)
{
    __m256 c1 = _mm256_set1_ps(0.007972914726F);
    __m256 c2 = _mm256_set1_ps(0.1385283768F);
    __m256 c3 = _mm256_set1_ps(2.885390043F);
    __m256 c4 = _mm256_set1_ps(1.442695022F);
    __m256 x = _mm256_mul_ps(_x, c4); //convert to 2^(x)
    __m256 intPartf = _mm256_round_ps(x, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC);
    x = _mm256_sub_ps(x, intPartf);
    __m256 xx = _mm256_mul_ps(x, x);
    __m256 a = _mm256_add_ps(x, _mm256_mul_ps(c1, _mm256_mul_ps(xx, x))); //can be improved with FMA
    __m256 b = _mm256_add_ps(c3, _mm256_mul_ps(c2, xx));
    __m256 res = _mm256_div_ps(_mm256_add_ps(b, a), _mm256_sub_ps(b, a));
    __m256i intPart = _mm256_cvtps_epi32(intPartf); //res = 2^intPart. Can be improved with AVX2!
    __m128i ii0 = _mm_slli_epi32(_mm256_castsi256_si128(intPart), 23);
    __m128i ii1 = _mm_slli_epi32(_mm256_extractf128_si256(intPart, 1), 23);     
    __m128i res_0 = _mm_add_epi32(ii0, _mm256_castsi256_si128(_mm256_castps_si256(res)));
    __m128i res_1 = _mm_add_epi32(ii1, _mm256_extractf128_si256(_mm256_castps_si256(res), 1));
    return _mm256_insertf128_ps(_mm256_castsi256_ps(_mm256_castsi128_si256(res_0)), _mm_castsi128_ps(res_1), 1);
}

一种AVX2版本可以使用_mm256_slli_epi32(intPart,23)等,而无需将其拆分为128位整数。并且可以手动使用_mm256_fmadd_ps 来处理多项式的一部分;并非所有编译器都默认缩减,特别是不跨语句缩减。 几乎所有具有AVX2的CPU都具有FMA功能,唯一的例外是一种不常用的Via型号。

1
“_mm256_insertf128_ps(_mm256_setzero_ps(), _mm_castsi128_ps(res_0), 0)”是多余的;在您已经要插入新的高半部分时,您不需要通过插入到零向量中进行零扩展。只需强制转换即可。另外,大多数现代CPU也具有AVX2,因此您根本不需要解包为128位,除非您需要在Bulldozer系列和SnB/IvB上运行。Haswell及更高版本以及Ryzen(甚至Excavator APU)都具有AVX2。或者低功耗Intel(Silvermont系列)甚至没有AVX,现代Pentium/Celeron芯片也没有。因此,确实存在一些仅支持AVX1的CPU,但其越来越少了。” - Peter Cordes
1
这是相当优雅的方法。干得好!我没有看到 ~1e-7 对于超过 ±6.5 的幂数成立,但在 ±88 之外下降是线性的,并且不会变得比 ~1.3e-6 更糟。出于好奇,你用什么进行系数优化?几千次 BFGS 运行为我找到了不同的权衡,但没有任何更好的结果。 - Todd West
1
用这个公式优化系数: C[i] -= rate * D[i]; C- 系数 D- 导数 而"rate"也是动态的: if (cur_err < prev_err) { rate *= inc_ratio; dec_ratio = 0.7; inc_ratio *= inc_ratio; inc_ratio = Math.Min(inc_ratio, 10); } 如果(cur_err > prev_err) { inc_ratio = 1.2; rate *= dec_ratio; dec_ratio *= dec_ratio; dec_ratio = Math.Max(dec_ratio, 0.001); } - jenkas
1
还有一个重要的提示,您应该针对公式2^x进行系数优化,其中x的范围为-1...+1,因为幂的INT部分通过将其直接添加到浮点数的指数部分来计算,所以对于所有幂它都非常精确。 - jenkas
谢谢,@jenkas。我认为这里的分类取决于周围的逻辑,但我倾向于将其归类为某种梯度下降,尽管可能与某些Nelder-Mead实现的某些方面有重叠,也可能与粒子群有关。 (查阅一些优化文本可能会很有趣,例如http://users.iems.northwestern.edu/~nocedal/book/index.html等。) - Todd West
显示剩余2条评论

0

你可以使用泰勒级数自行近似指数

exp(z) = 1 + z + pow(z,2)/2 + pow(z,3)/6 + pow(z,4)/24 + ...

为此,您只需要从AVX中使用加法和乘法运算。像1/2、1/6、1/24等系数如果硬编码然后乘以比除以更快。

根据您的精度取尽可能多的序列成员。请注意,您将获得相对误差:对于小的z,绝对误差可能是1e-6,但对于大的z,绝对误差将超过1e-6,仍然abs(E-E1)/abs(E) - 11e-6小(其中E是精确指数,E1是您使用近似值得到的指数)。

更新:正如@Peter Cordes在评论中提到的那样,可以通过分离整数和小数部分的幂运算,通过操作二进制float表示的指数字段(基于2^x而不是e^x)来处理整数部分。然后,您的泰勒级数只需在一个小范围内最小化误差。


1
实际上,除非z非常接近于零(!),否则这种方法效率不高。此外,对于负值的z,可能会出现严重的精度问题。高效准确的exp逼近应该使用至少一种范围缩减形式。对于函数逼近,通常使用(有理)Chebyshev逼近而不是Taylor级数,因为它们具有更好的数值特性。 - wim
2
@Royi 只需尝试使用avx_mathfun,看看它是否适用于您的应用程序。令人惊讶的是,它使用范围缩减和简单的泰勒逼近! - wim
@RegisPortalez,你不会得到足够大的N,因为通常只使用前几个成员进行近似。然而,浮点数方法允许很好地表示接近0的数字。 - Serge Rogatch
1
不要使用宏来设置常量,而是使用标准的 _mm256_set1_ps()_mm256_set1_epi32() 等函数。 - wim
2
为了使泰勒展开不失效,您应该仅将其用于输入的小数部分。将整数部分放入float的指数字段中以获得2^x。(在某个地方进行额外的乘法可以解决2^x与e^x之间的差异,我想)。对于仅针对小数部分的多项式,您只需要在更小的范围内最小化误差。这是您用于log(x)的相同技巧的反向操作:提取输入的指数以获取log2(integer_part(x))。 - Peter Cordes
显示剩余7条评论

0

对于标准化的输入([-1,1]),您可以使用多项式逼近:

// compute Simd exp() at a time (only optimized for Type=float)
template<typename Type, int Simd>
inline
void expFast(float * const __restrict__ data, float * const __restrict__ result) noexcept
{

    alignas(64)
    Type resultData[Simd];

    
    for(int i=0;i<Simd;i++)
    {
        resultData[i] =    Type(0.0001972591916103993980868836)*data[i] + Type(0.001433947376170863208244555);
    }

    
    for(int i=0;i<Simd;i++)
    {
        resultData[i] =    resultData[i]*data[i] + Type(0.008338950118885968265658448);
    }

    
    for(int i=0;i<Simd;i++)
    {
        resultData[i] =    resultData[i]*data[i] + Type(0.04164162895364054151059463);
    }

    
    for(int i=0;i<Simd;i++)
    {
        resultData[i] =    resultData[i]*data[i] + Type(0.1666645212581130408580066);
    }

    
    for(int i=0;i<Simd;i++)
    {
        resultData[i] =    resultData[i]*data[i] + Type(0.5000045184212300597437206);
    }

    
    for(int i=0;i<Simd;i++)
    {
        resultData[i] =    resultData[i]*data[i] + Type(0.9999999756072401879691824);
    }

    
    for(int i=0;i<Simd;i++)
    {
        result[i] =    resultData[i]*data[i] + Type(0.999999818912344906607359);
    }

}

它的平均误差为0.5 ULPS,最大误差为10 ULPS,在-1和1之间选取了6400万个点。在AVX1(bulldozer)上,与std::exp相比,速度提高了10倍。

我认为您可以将此函数与整数乘法相结合以支持所有幂。但是简单的乘法部分需要为O(logN)而不是O(N),以便足够快地处理大幂次。例如,如果计算了x10,那么只需用自身执行1次操作即可获得x20,而不需要通过乘以x进行10次额外操作。

在循环中使用时,编译器会生成以下代码:

.L2:
    vmovaps zmm1, ZMMWORD PTR [rax]
    add     rax, 64
    vmovaps zmm0, zmm1
    vfmadd132ps     zmm0, zmm8, zmm9
    vfmadd132ps     zmm0, zmm7, zmm1
    vfmadd132ps     zmm0, zmm6, zmm1
    vfmadd132ps     zmm0, zmm5, zmm1
    vfmadd132ps     zmm0, zmm4, zmm1
    vfmadd132ps     zmm0, zmm3, zmm1
    vfmadd132ps     zmm0, zmm2, zmm1
    vmovaps ZMMWORD PTR [rax-64], zmm0
    cmp     rax, rdx
    jne     .L2

我认为它足够快,可以节省一些循环来处理输入的整数幂,可能高达浮点数的极限(1038)。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接