优化const非整数幂的pow()函数?

61

我的代码中有一些热点,其中的pow()函数占用了大约10-20%的执行时间。

我传递给pow(x,y)函数的输入非常具体,所以我想知道是否有一种方法可以使用更高效的方法来进行两个pow()函数的近似(每个指数一个):

  • 我有两个固定的指数:2.4和1/2.4。
  • 当指数为2.4时,x将在范围(0.090473935, 1.0]内。
  • 当指数为1/2.4时,x将在范围(0.0031308, 1.0]内。
  • 我正在使用SSE / AVX float向量。如果可以利用平台特定性能技巧,那就更好了!

最理想的最大误差率约为0.01%,虽然我也对完全精度(针对float)的算法感兴趣。

我已经在使用快速的pow()函数近似函数,但它没有考虑到这些限制。是否可能做得更好?


5
需要多准确? - In silico
1
pow是本地CPU指令吗?否则我只能想象pow(a,b)计算exp(b * log(a)),因此你的指数并不那么恒定。好了,看到答案会很棒! - Kerrek SB
3
pow通常不是CPU的本地指令,而是一个库函数,由explog构建,并具有大量特殊情况,例如x^2x^-1等。 - Dietrich Epp
1
鉴于您明显正在进行伽马校正,我假设您对第一个参数可以采用的值范围也有一些先验知识; 它是什么?此外,您的目标平台是什么? - Stephen Canon
1
@CoryNelson,你能把你对于(1/2.4和2.4)的最终解决方案发表为一个答案吗? - Lilith River
显示剩余2条评论
10个回答

34

因为这个回答与我的先前回答非常不同,所以我提供另一个答案,而且这个方法非常快。相对误差为3e-8。想要更高的精度?再添加几个Chebychev项。最好将顺序保持奇数,因为这会在2^n-epsilon和2^n+epsilon之间产生小的不连续性。

#include <stdlib.h>
#include <math.h>

// Returns x^(5/12) for x in [1,2), to within 3e-8 (relative error).
// Want more precision? Add more Chebychev polynomial coefs.
double pow512norm (
   double x)
{
   static const int N = 8;

   // Chebychev polynomial terms.
   // Non-zero terms calculated via
   //   integrate (2/pi)*ChebyshevT[n,u]/sqrt(1-u^2)*((u+3)/2)^(5/12)
   //   from -1 to 1
   // Zeroth term is similar except it uses 1/pi rather than 2/pi.
   static const double Cn[N] = { 
       1.1758200232996901923,
       0.16665763094889061230,
      -0.0083154894939042125035,
       0.00075187976780420279038,
      // Wolfram alpha doesn't want to compute the remaining terms
      // to more precision (it times out).
      -0.0000832402,
       0.0000102292,
      -1.3401e-6,
       1.83334e-7};

   double Tn[N];

   double u = 2.0*x - 3.0;

   Tn[0] = 1.0;
   Tn[1] = u;
   for (int ii = 2; ii < N; ++ii) {
      Tn[ii] = 2*u*Tn[ii-1] - Tn[ii-2];
   }   

   double y = 0.0;
   for (int ii = N-1; ii >= 0; --ii) {
      y += Cn[ii]*Tn[ii];
   }   

   return y;
}


// Returns x^(5/12) to within 3e-8 (relative error).
double pow512 (
   double x)
{
   static const double pow2_512[12] = {
      1.0,
      pow(2.0, 5.0/12.0),
      pow(4.0, 5.0/12.0),
      pow(8.0, 5.0/12.0),
      pow(16.0, 5.0/12.0),
      pow(32.0, 5.0/12.0),
      pow(64.0, 5.0/12.0),
      pow(128.0, 5.0/12.0),
      pow(256.0, 5.0/12.0),
      pow(512.0, 5.0/12.0),
      pow(1024.0, 5.0/12.0),
      pow(2048.0, 5.0/12.0)
   };

   double s;
   int iexp;

   s = frexp (x, &iexp);
   s *= 2.0;
   iexp -= 1;

   div_t qr = div (iexp, 12);
   if (qr.rem < 0) {
      qr.quot -= 1;
      qr.rem += 12;
   }

   return ldexp (pow512norm(s)*pow2_512[qr.rem], 5*qr.quot);
}

附录:这是怎么回事?
根据要求,以下解释了上面的代码如何工作。

概览
上面的代码定义了两个函数,double pow512norm (double x)double pow512 (double x)。后者是套件的入口点;这是用户代码应该调用来计算 x^(5/12) 的函数。函数 pow512norm(x) 使用 Chebyshev 多项式来近似 x^(5/12),但仅适用于范围 [1,2] 中的 x(使用 pow512norm(x) 来处理范围外的值将生成垃圾结果)。

函数 pow512(x) 将输入的 x 分成一对 (double s, int n),使得 x = s * 2^n,并且 1≤s<2。将 n 进一步分成 (int q, unsigned int r),使得 n = 12*q + r,且 r 小于 12,让我把找到 x^(5/12) 的问题分成以下几个部分:

  1. x^(5/12)=(s^(5/12))*((2^n)^(5/12)),通过 (uv)^a=(u^a)(v^a)(其中 u、v 为正且 a 为实数)得到。
  2. s^(5/12) 通过 pow512norm(s) 计算。
  3. (2^n)^(5/12)=(2^(12*q+r))^(5/12),通过替换得到。
  4. 2^(12*q+r)=(2^(12*q))*(2^r),通过 u^(a+b)=(u^a)*(u^b)(其中 u 为正数且 a、b 为实数)得到。
  5. (2^(12*q+r))^(5/12)=(2^(5*q))*((2^r)^(5/12)),通过一些操作得到。
  6. (2^r)^(5/12) 由查找表 pow2_512 计算。
  7. 计算 pow512norm(s)*pow2_512[qr.rem],我们就快完成了。这里 qr.rem 是上面第 3 步中计算的 r 值。所需的全部就是将其乘以 2^(5*q) 以得到所需结果。
  8. 这正是数学库函数 ldexp 所做的。
函数逼近
这里的目标是找到一个易于计算且足够好的逼近函数f(x)=x^(5/12),使其在某种意义下接近f(x)。 修辞问题:'接近'是什么意思?两种竞争性的解释是最小化均方误差与最小化最大绝对误差。
我将用股市类比来描述它们之间的区别。 假设你想为最终退休存钱。 如果你还在二十多岁,最好的方法是投资股票或股票基金。 这是因为足够长的时间内,股票市场平均跑赢任何其他投资计划。 但是,我们都看到过投资股票时非常糟糕的情况。 如果你已经五十岁或六十岁了(或者四十岁想要提前退休),你需要更加保守地投资。 那些下降会对您的退休金组合造成严重影响。
回到函数逼近:作为逼近值的消费者,您通常会关注最坏情况误差而不是平均性能。 使用一些构建良好平均性能(例如最小二乘法)的逼近可能会导致您的程序花费大量时间使用逼近,而该逼近的性能远低于平均水平。 您需要一个极值逼近,即最小化某些区域内的最大绝对误差。 一个良好的数学库将采用极值逼近而不是最小二乘法,因为这让数学库的作者可以提供一些保证其库性能的保证。
数学库通常使用多项式或有理多项式来逼近某个定义在a≤x≤b区间内的函数f(x)。 假设函数f(x)在此区间上是解析的,并且您想通过某个N次多项式p(x)来逼近该函数。 对于给定的N次,存在某个神奇且独特的多项式p(x),使得p(x)-f(x)在[a,b]区间内具有N+2个极值点,而这些N+2个极值点的绝对值相等。 找到这个神奇的多项式p(x)是函数逼近的圣杯。
我没有为您找到那个圣杯。 我改用了Chebyshev逼近。 Chebyshev第一类多项式是一组正交(但不是标准正交)的多项式,在进行函数逼近时具有一些非常好的特性。 Chebyshev逼近通常非常接近那神奇的多项式p(x)。 (实际上,寻找该神奇多项式的Remez交换算法通常从Chebyshev逼近开始。)

pow512norm(x)
这个函数使用Chebyshev逼近来找到一个多项式p*(x),它近似于x^(5/12)。在这里,我使用p*(x)来区别于上面描述的神奇多项式p(x)。Chebyshev逼近p*(x)很容易找到;但是找到p(x)则非常困难。Chebyshev逼近p*(x)可以表示为sum_i Cn[i]*Tn(i,x),其中Cn[i]是Chebyshev系数,Tn(i,x)是在x处求值的Chebyshev多项式。

我使用Wolfram Alpha为我找到了Chebyshev系数Cn。例如,这个链接计算了Cn[1]。输入框后的第一个框中显示了所需答案,本例中为0.166658。这不是我想要的那么多位数。点击“更多数字”,你就能得到更多位数。Wolfram Alpha是免费的;它有计算量的限制,当计算高阶项时会达到这个限制(如果你购买或者有使用Mathematica的权限,你将能够计算高精度的高阶系数)。

Chebyshev多项式Tn(x)在数组Tn中计算。除了给出非常接近神奇多项式p(x)的东西之外,使用Chebyshev逼近的另一个原因是这些Chebyshev多项式的值很容易计算:从Tn[0]=1Tn[1]=x开始,然后迭代地计算Tn[i]=2*x*Tn[i-1] - Tn[i-2]。(在我的代码中,我使用'ii'作为索引变量而不是'i'。我从不使用'i'作为变量名。英语单词中有多少个字母'i'?有多少个连续的两个'i'?)

pow512(x)
pow512是用户代码应该调用的函数。我已经在上面描述了这个函数的基本原理。还有一些细节:数学库函数frexp(x)返回输入x的尾数s和指数iexp。(小问题:我希望s在1到2之间与pow512norm一起使用,但frexp返回一个在0.5到1之间的值。)数学库函数div一次性返回整数除法的商和余数。最后,我使用数学库函数ldexp将这三部分组合成最终答案。


哎呀,我希望我能理解这背后的数学原理。这个公式对x^2.4也适用吗? - Cory Nelson
1
@Cory: 关于数学问题,我会在明天添加一些细节。关于 x^2.4:从概念上讲,是的。切比雪夫多项式跨越了R[-1,1]上的一维实函数空间(希尔伯特空间)。也就是说,用于表示x^2.4所需的项数将比表示x^(5/12)所需的项数要多得多。但是,并不需要直接计算x^2.4。只需要能够提取第五个根号,x^(1/5),这将需要比x^(5/12)更少的切比雪夫项。有了x^(1/5),x^2.4 = (x*x^(1/5))^2。 - David Hammen
@Pascal:完全严谨地说,Chebyshev系数并不是实际计算中的“最佳”选择;极值多项式才是。尽管如此,正如David在他的回答中所指出的那样,对于行为良好的函数,这两者通常非常接近。 - Stephen Canon
我认为pow512没有正确处理x=0的情况。不过,特殊处理它应该很容易。 - Michael Anderson
通过FFT可以快速计算Chebyshev系数,而不是通过数值积分,因此可以自动化处理除5/12以外的其他值。[@Pascal:好文章。很高兴看到人们仍然在ENS Lyon做出了伟大的工作,那里是我毕业的地方:)] - Alexandre C.
显示剩余4条评论

24
在 IEEE 754 黑客风格中,这里有另一种更快、不那么"神奇"的解决方案。它在大约十几个时钟周期内实现了 0.08% 的误差率(对于 p=2.4,在 Intel Merom CPU 上)。
浮点数最初是作为对数的近似值而发明的,因此您可以使用整数值作为 log2 的近似值。通过将整数转换指令应用于浮点值,可以在某种程度上实现这一点,以获得另一个浮点值。
要完成 pow 计算,您可以乘以一个常数因子,并使用转换为整数指令将对数转换回去。在 SSE 上,相关指令是 cvtdq2ps 和 cvtps2dq。
然而,情况并不是那么简单。IEEE 754 中的指数字段是带符号的,偏置值为 127 表示指数为零。在乘以对数之前必须删除此偏差,并在指数化之前重新添加。此外,在零上进行偏差调整的减法将无法工作。幸运的是,可以通过事先乘以常数因子来实现两种调整。
x^p
= exp2( p * log2( x ) )
= exp2( p * ( log2( x ) + 127 - 127 ) - 127 + 127 )
= cvtps2dq( p * ( log2( x ) + 127 - 127 - 127 / p ) )
= cvtps2dq( p * ( log2( x ) + 127 - log2( exp2( 127 - 127 / p ) ) )
= cvtps2dq( p * ( log2( x * exp2( 127 / p - 127 ) ) + 127 ) )
= cvtps2dq( p * ( cvtdq2ps( x * exp2( 127 / p - 127 ) ) ) )
exp2( 127 / p - 127 )是常量因子。这个函数相当专业:它不能处理小的分数指数,因为常量因子随着指数的倒数呈指数增长,会溢出。它也不能处理负指数。大的指数会导致高误差,因为乘法将尾数位和指数位混合在一起。
但是,它只有4条快速指令。预先乘以一个常数系数,转换为“整数”(对数),幂乘,转换为“整数”(从对数)。在这个SSE实现中,转换非常快。我们还可以把额外的常数系数挤入第一次乘法中。
template< unsigned expnum, unsigned expden, unsigned coeffnum, unsigned coeffden >
__m128 fastpow( __m128 arg ) {
        __m128 ret = arg;
//      std::printf( "arg = %,vg\n", ret );
        // Apply a constant pre-correction factor.
        ret = _mm_mul_ps( ret, _mm_set1_ps( exp2( 127. * expden / expnum - 127. )
                * pow( 1. * coeffnum / coeffden, 1. * expden / expnum ) ) );
//      std::printf( "scaled = %,vg\n", ret );
        // Reinterpret arg as integer to obtain logarithm.
        asm ( "cvtdq2ps %1, %0" : "=x" (ret) : "x" (ret) );
//      std::printf( "log = %,vg\n", ret );
        // Multiply logarithm by power.
        ret = _mm_mul_ps( ret, _mm_set1_ps( 1. * expnum / expden ) );
//      std::printf( "powered = %,vg\n", ret );
        // Convert back to "integer" to exponentiate.
        asm ( "cvtps2dq %1, %0" : "=x" (ret) : "x" (ret) );
//      std::printf( "result = %,vg\n", ret );
        return ret;
}

使用指数为2.4进行几次试验,结果表明这一方法始终高估约5%。(该例程始终保证高估。)您可以简单地乘以0.95,但是再增加几个指令将使我们获得大约4位小数的精度,这应该足够用于图形处理。
关键是要将高估与低估匹配,并取平均值。
  • 计算x^0.8:四个指令,误差约为+3%。
  • 计算x^-0.4:一个rsqrtps。(这已经足够准确,但牺牲了与零一起工作的能力。)
  • 计算x^0.4:一个mulps
  • 计算x^-0.2:一个rsqrtps
  • 计算x^2:一个mulps
  • 计算x^3:一个mulps
  • x^2.4 = x^2 * x^0.4:一个mulps。这是高估。
  • x^2.4 = x^3 * x^-0.4 * x^-0.2:两个mulps。这是低估。
  • 对上述进行平均:一个addps,一个mulps

指令总数为14个,包括两个延迟为5的转换和两个吞吐量为4的倒数平方根估计。
为了正确取平均值,我们希望按其预期误差对估计值进行加权。与0.4相比,低估将误差提高到0.6的幂次,因此我们预计它会出现1.5倍的误差。加权不会增加任何指令;可以在前因子中完成。将系数称为a:a^0.5 = 1.5 a^-0.75,因此a = 1.38316186。
最终误差约为0.015%,比初始的fastpow结果好两个数量级。运行时间约为一打循环周期,具有volatile源和目标变量...虽然它重叠迭代,但实际使用也会看到指令级并行性。考虑到SIMD,这是每3个周期的吞吐量为一个标量结果!
int main() {
        __m128 const x0 = _mm_set_ps( 0.01, 1, 5, 1234.567 );
        std::printf( "Input: %,vg\n", x0 );

        // Approx 5% accuracy from one call. Always an overestimate.
        __m128 x1 = fastpow< 24, 10, 1, 1 >( x0 );
        std::printf( "Direct x^2.4: %,vg\n", x1 );

        // Lower exponents provide lower initial error, but too low causes overflow.
        __m128 xf = fastpow< 8, 10, int( 1.38316186 * 1e9 ), int( 1e9 ) >( x0 );
        std::printf( "1.38 x^0.8: %,vg\n", xf );

        // Imprecise 4-cycle sqrt is still far better than fastpow, good enough.
        __m128 xfm4 = _mm_rsqrt_ps( xf );
        __m128 xf4 = _mm_mul_ps( xf, xfm4 );

        // Precisely calculate x^2 and x^3
        __m128 x2 = _mm_mul_ps( x0, x0 );
        __m128 x3 = _mm_mul_ps( x2, x0 );

        // Overestimate of x^2 * x^0.4
        x2 = _mm_mul_ps( x2, xf4 );

        // Get x^-0.2 from x^0.4. Combine with x^-0.4 into x^-0.6 and x^2.4.
        __m128 xfm2 = _mm_rsqrt_ps( xf4 );
        x3 = _mm_mul_ps( x3, xfm4 );
        x3 = _mm_mul_ps( x3, xfm2 );

        std::printf( "x^2 * x^0.4: %,vg\n", x2 );
        std::printf( "x^3 / x^0.6: %,vg\n", x3 );
        x2 = _mm_mul_ps( _mm_add_ps( x2, x3 ), _mm_set1_ps( 1/ 1.960131704207789 ) );
        // Final accuracy about 0.015%, 200x better than x^0.8 calculation.
        std::printf( "average = %,vg\n", x2 );
}

抱歉我不能早些发布这篇文章。将其扩展到x^1/2.4留作练习;v)。


统计数据更新

我实现了一个小型测试工具,并添加了两个对应上述情况的x^(5/12)案例。

#include <cstdio>
#include <xmmintrin.h>
#include <cmath>
#include <cfloat>
#include <algorithm>
using namespace std;

template< unsigned expnum, unsigned expden, unsigned coeffnum, unsigned coeffden >
__m128 fastpow( __m128 arg ) {
    __m128 ret = arg;
//  std::printf( "arg = %,vg\n", ret );
    // Apply a constant pre-correction factor.
    ret = _mm_mul_ps( ret, _mm_set1_ps( exp2( 127. * expden / expnum - 127. )
        * pow( 1. * coeffnum / coeffden, 1. * expden / expnum ) ) );
//  std::printf( "scaled = %,vg\n", ret );
    // Reinterpret arg as integer to obtain logarithm.
    asm ( "cvtdq2ps %1, %0" : "=x" (ret) : "x" (ret) );
//  std::printf( "log = %,vg\n", ret );
    // Multiply logarithm by power.
    ret = _mm_mul_ps( ret, _mm_set1_ps( 1. * expnum / expden ) );
//  std::printf( "powered = %,vg\n", ret );
    // Convert back to "integer" to exponentiate.
    asm ( "cvtps2dq %1, %0" : "=x" (ret) : "x" (ret) );
//  std::printf( "result = %,vg\n", ret );
    return ret;
}

__m128 pow125_4( __m128 arg ) {
    // Lower exponents provide lower initial error, but too low causes overflow.
    __m128 xf = fastpow< 4, 5, int( 1.38316186 * 1e9 ), int( 1e9 ) >( arg );

    // Imprecise 4-cycle sqrt is still far better than fastpow, good enough.
    __m128 xfm4 = _mm_rsqrt_ps( xf );
    __m128 xf4 = _mm_mul_ps( xf, xfm4 );

    // Precisely calculate x^2 and x^3
    __m128 x2 = _mm_mul_ps( arg, arg );
    __m128 x3 = _mm_mul_ps( x2, arg );

    // Overestimate of x^2 * x^0.4
    x2 = _mm_mul_ps( x2, xf4 );

    // Get x^-0.2 from x^0.4, and square it for x^-0.4. Combine into x^-0.6.
    __m128 xfm2 = _mm_rsqrt_ps( xf4 );
    x3 = _mm_mul_ps( x3, xfm4 );
    x3 = _mm_mul_ps( x3, xfm2 );

    return _mm_mul_ps( _mm_add_ps( x2, x3 ), _mm_set1_ps( 1/ 1.960131704207789 * 0.9999 ) );
}

__m128 pow512_2( __m128 arg ) {
    // 5/12 is too small, so compute the sqrt of 10/12 instead.
    __m128 x = fastpow< 5, 6, int( 0.992245 * 1e9 ), int( 1e9 ) >( arg );
    return _mm_mul_ps( _mm_rsqrt_ps( x ), x );
}

__m128 pow512_4( __m128 arg ) {
    // 5/12 is too small, so compute the 4th root of 20/12 instead.
    // 20/12 = 5/3 = 1 + 2/3 = 2 - 1/3. 2/3 is a suitable argument for fastpow.
    // weighting coefficient: a^-1/2 = 2 a; a = 2^-2/3
    __m128 xf = fastpow< 2, 3, int( 0.629960524947437 * 1e9 ), int( 1e9 ) >( arg );
    __m128 xover = _mm_mul_ps( arg, xf );

    __m128 xfm1 = _mm_rsqrt_ps( xf );
    __m128 x2 = _mm_mul_ps( arg, arg );
    __m128 xunder = _mm_mul_ps( x2, xfm1 );

    // sqrt2 * over + 2 * sqrt2 * under
    __m128 xavg = _mm_mul_ps( _mm_set1_ps( 1/( 3 * 0.629960524947437 ) * 0.999852 ),
                                _mm_add_ps( xover, xunder ) );

    xavg = _mm_mul_ps( xavg, _mm_rsqrt_ps( xavg ) );
    xavg = _mm_mul_ps( xavg, _mm_rsqrt_ps( xavg ) );
    return xavg;
}

__m128 mm_succ_ps( __m128 arg ) {
    return (__m128) _mm_add_epi32( (__m128i) arg, _mm_set1_epi32( 4 ) );
}

void test_pow( double p, __m128 (*f)( __m128 ) ) {
    __m128 arg;

    for ( arg = _mm_set1_ps( FLT_MIN / FLT_EPSILON );
            ! isfinite( _mm_cvtss_f32( f( arg ) ) );
            arg = mm_succ_ps( arg ) ) ;

    for ( ; _mm_cvtss_f32( f( arg ) ) == 0;
            arg = mm_succ_ps( arg ) ) ;

    std::printf( "Domain from %g\n", _mm_cvtss_f32( arg ) );

    int n;
    int const bucket_size = 1 << 25;
    do {
        float max_error = 0;
        double total_error = 0, cum_error = 0;
        for ( n = 0; n != bucket_size; ++ n ) {
            float result = _mm_cvtss_f32( f( arg ) );

            if ( ! isfinite( result ) ) break;

            float actual = ::powf( _mm_cvtss_f32( arg ), p );

            float error = ( result - actual ) / actual;
            cum_error += error;
            error = std::abs( error );
            max_error = std::max( max_error, error );
            total_error += error;

            arg = mm_succ_ps( arg );
        }

        std::printf( "error max = %8g\t" "avg = %8g\t" "|avg| = %8g\t" "to %8g\n",
                    max_error, cum_error / n, total_error / n, _mm_cvtss_f32( arg ) );
    } while ( n == bucket_size );
}

int main() {
    std::printf( "4 insn x^12/5:\n" );
    test_pow( 12./5, & fastpow< 12, 5, 1059, 1000 > );
    std::printf( "14 insn x^12/5:\n" );
    test_pow( 12./5, & pow125_4 );
    std::printf( "6 insn x^5/12:\n" );
    test_pow( 5./12, & pow512_2 );
    std::printf( "14 insn x^5/12:\n" );
    test_pow( 5./12, & pow512_4 );
}

输出:

4 insn x^12/5:
Domain from 1.36909e-23
error max =      inf    avg =      inf  |avg| =      inf    to 8.97249e-19
error max =  2267.14    avg =  139.175  |avg| =  139.193    to 5.88021e-14
error max = 0.123606    avg = -0.000102963  |avg| = 0.0371122   to 3.85365e-09
error max = 0.123607    avg = -0.000108978  |avg| = 0.0368548   to 0.000252553
error max =  0.12361    avg = 7.28909e-05   |avg| = 0.037507    to  16.5513
error max = 0.123612    avg = -0.000258619  |avg| = 0.0365618   to 1.08471e+06
error max = 0.123611    avg = 8.70966e-05   |avg| = 0.0374369   to 7.10874e+10
error max =  0.12361    avg = -0.000103047  |avg| = 0.0371122   to 4.65878e+15
error max = 0.123609    avg =      nan  |avg| =      nan    to 1.16469e+16
14 insn x^12/5:
Domain from 1.42795e-19
error max =      inf    avg =      nan  |avg| =      nan    to 9.35823e-15
error max = 0.000936462 avg = 2.0202e-05    |avg| = 0.000133764 to 6.13301e-10
error max = 0.000792752 avg = 1.45717e-05   |avg| = 0.000129936 to 4.01933e-05
error max = 0.000791785 avg = 7.0132e-06    |avg| = 0.000129923 to  2.63411
error max = 0.000787589 avg = 1.20745e-05   |avg| = 0.000129347 to   172629
error max = 0.000786553 avg = 1.62351e-05   |avg| = 0.000132397 to 1.13134e+10
error max = 0.000785586 avg = 8.25205e-06   |avg| = 0.00013037  to 6.98147e+12
6 insn x^5/12:
Domain from 9.86076e-32
error max = 0.0284339   avg = 0.000441158   |avg| = 0.00967327  to 6.46235e-27
error max = 0.0284342   avg = -5.79938e-06  |avg| = 0.00897913  to 4.23516e-22
error max = 0.0284341   avg = -0.000140706  |avg| = 0.00897084  to 2.77556e-17
error max = 0.028434    avg = 0.000440504   |avg| = 0.00967325  to 1.81899e-12
error max = 0.0284339   avg = -6.11153e-06  |avg| = 0.00897915  to 1.19209e-07
error max = 0.0284298   avg = -0.000140597  |avg| = 0.00897084  to 0.0078125
error max = 0.0284371   avg = 0.000439748   |avg| = 0.00967319  to      512
error max = 0.028437    avg = -7.74294e-06  |avg| = 0.00897924  to 3.35544e+07
error max = 0.0284369   avg = -0.000142036  |avg| = 0.00897089  to 2.19902e+12
error max = 0.0284368   avg = 0.000439183   |avg| = 0.0096732   to 1.44115e+17
error max = 0.0284367   avg = -7.41244e-06  |avg| = 0.00897923  to 9.44473e+21
error max = 0.0284366   avg = -0.000141706  |avg| = 0.00897088  to 6.1897e+26
error max = 0.485129    avg = -0.0401671    |avg| = 0.048422    to 4.05648e+31
error max = 0.994932    avg = -0.891494 |avg| = 0.891494    to 2.65846e+36
error max = 0.999329    avg =      nan  |avg| =      nan    to       -0
14 insn x^5/12:
Domain from 2.64698e-23
error max =  0.13556    avg = 0.00125936    |avg| = 0.00354677  to 1.73472e-18
error max = 0.000564988 avg = 2.51458e-06   |avg| = 0.000113709 to 1.13687e-13
error max = 0.000565065 avg = -1.49258e-06  |avg| = 0.000112553 to 7.45058e-09
error max = 0.000565143 avg = 1.5293e-06    |avg| = 0.000112864 to 0.000488281
error max = 0.000565298 avg = 2.76457e-06   |avg| = 0.000113713 to       32
error max = 0.000565453 avg = -1.61276e-06  |avg| = 0.000112561 to 2.09715e+06
error max = 0.000565531 avg = 1.42628e-06   |avg| = 0.000112866 to 1.37439e+11
error max = 0.000565686 avg = 2.71505e-06   |avg| = 0.000113715 to 9.0072e+15
error max = 0.000565763 avg = -1.56586e-06  |avg| = 0.000112415 to 1.84467e+19

我怀疑更准确的5/12的精度受到了rsqrt操作的限制。


@Cory:让我知道它的结果如何。你在这里看到的是我所做的全部测试。 - Potatoswatter
2
@David:各有所好。我只是想比其他“FP-hack”答案做得更好,因为那真的很糟糕。但是我不明白,随手使用手头的工具比深入基础的精通更糟糕...这只是聪明与美丽的区别。 - Potatoswatter
@David 好的。非常感谢您提供的所有答案,测试完成后我会更新我的问题并进行比较。 - Cory Nelson
@Cory:更新了另一个指数、测试工具和微调系数。 - Potatoswatter
3
cvtdq2ps/cvtps2dq也可以用内建函数来实现,例如 _mm_cvtepi32_ps(_mm_castps_si128(x))_mm_castsi128_ps(_mm_cvtps_epi32(x)) - Cory Nelson
显示剩余6条评论

20

Ian Stephenson写了这段代码,他声称它的性能优于pow()。他描述了这个想法如下:

pow基本上是使用log实现的:pow(a,b)=x(logx(a)*b)。所以我们需要一个快速的log和快速的指数 - x的值并不重要,所以我们用2。技巧在于浮点数已经是以log形式表示的格式:

a=M*2E

对两边取对数得到:

log2(a)=log2(M)+E
更简单地说:
log2(a)~=E

换言之,如果我们采用一个数的浮点表示,并提取指数作为其对数的良好起点。事实证明,通过调整位模式,尾数可以给出误差的良好近似值,这样做非常有效。

这足以满足简单的照明计算需求,但如果您需要更好的精度,则可以提取尾数,并使用它来计算二次校正因子,该方法相当精确。


不错。使用浮点数的内部格式是个好主意!标准库是否已经利用了这一点呢? - Kerrek SB
8
好主意。不过使用标准的 frexp 函数 提取指数和尾数或许能使代码更加美观。任何一个好的编译器都应该能够实现 frexp 作为位提取(加上偏移量)的方式,这样可以在不降低性能的情况下获得可移植性。 - Nemo
太酷了!我知道以前见过这个,但现在已经忘记了。这是一个非常好的技巧,可以快速近似计算log2。 - Mikola
这种优化没有考虑常数参数,因此并没有真正回答问题。这个特定的实现不是很好...除了没有使用frexp之外,通常的做法是计算前导零并移位数字,将指数移动到尾数中,然后使用零计数创建一个新的指数。在尾数内移位的位近似于使用2^x ≈ x, 1 < x < 2的结果。 - Potatoswatter
如果有一个好的链接,我可能能够优化通用pow。不过就像@potatoswatter所说,它并没有完全回答问题。+1 - Cory Nelson
啊哈!我已经走了一半的路——对尾数进行平方以改善浮点数对于对数和指数的分段逼近——但我放弃了,决定寻找更简单的方法。这里的其他答案让我意识到我的解决方案并不差,现在我不必再去计算那些神奇的数字了! - sh1

17

首先,现在大多数机器使用浮点数不会带来多少好处。实际上,双精度浮点数可能更快。你的幂次方运算,1.0/2.4,等于5/12或者1/3*(1+1/4)。即使这涉及到了一次cbrt和两次sqrt操作!它仍然比使用pow()快两倍。(优化:-O3,编译器:i686-apple-darwin10-g++-4.2.1)。

#include <math.h> // cmath does not provide cbrt; C99 does.
double xpow512 (double x) {
  double cbrtx = cbrt(x);
  return cbrtx*sqrt(sqrt(cbrtx));
}

5
在处理单个值时,float / double 的速度通常相同(除了 div / sqrt),但在使用SIMD时,float通常具有两倍的吞吐量。这是一个非常有趣的答案,我一定会尝试一下。 - Cory Nelson
对于另一种技术,请参见我的其他答案。它使用切比雪夫多项式,速度非常快,但比这个两行代码要冗长一些。 - David Hammen
2
在具有良好编写的数学库的平台上,使用float实际上比使用double快得多。 为了提供一些硬性数字,我进行了一个小实验:在OSX 10.6 / Core2 Duo上调用单精度powf几乎比您在此答案中提供的实现快两倍。 - Stephen Canon
1
@Stephen:那有点离题了。单个基本浮点运算与相应的双精度运算之间的速度差异并不显著。使powf比pow快得多的原因是,powf只追求约6-7位小数的精度,而pow则追求15-16位。如果您想进行公正的比较,您应该制作一个浮点版本的此函数,该版本调用cbrtf而不是cbrt,sqrtf而不是sqrt。 - David Hammen
4
当然可以,但这并不是一个无关的话题。提问者正在询问如何实现一个相当复杂的函数,而不仅仅是一个基本操作。即使对于基本操作,使用浮点数而不是双精度浮点数也可以将两倍的数据存入缓存中,并且在手动调整代码的情况下可能会获得两倍的SIMD并行性能。 - Stephen Canon
显示剩余3条评论

15

这可能无法回答你的问题。

2.4f1/2.4f让我很怀疑,因为这些正好是用于在sRGB和线性RGB颜色空间之间转换的幂。所以你实际上可能正在尝试优化那个,具体来说就是这个过程。我不确定,这就是为什么这可能无法回答你的问题。

如果是这种情况,请尝试使用查找表。类似于:

__attribute__((aligned(64))
static const unsigned short SRGB_TO_LINEAR[256] = { ... };
__attribute__((aligned(64))
static const unsigned short LINEAR_TO_SRGB[256] = { ... };

void apply_lut(const unsigned short lut[256], unsigned char *src, ...
如果您正在使用16位数据,请根据需要进行更改。无论如何,我建议将表格设置为16位,这样在使用8位数据时如果需要可以进行抖动。显然,如果您的数据一开始就是浮点数,这种方法就不太适用了——但是将sRGB数据存储为浮点数并不是很合理,因此最好先将其转换为16位/8位,然后再将其从线性转换为sRGB。
(sRGB不适合作为浮点数的原因是,HDR应该是线性的,而sRGB仅适用于存储在磁盘上或在屏幕上显示,但不适合操作。)

2
你抓住我啦 ;),我正在进行sRGB伽马压缩/解压。我的输入/输出需要是浮点数,否则我肯定会使用LUT。 - Cory Nelson
4
这些转换正好处于一个更大的浮点数流水线的中间。我觉得我可以将输入缩放为整数,以便作为索引进入具有所需粒度的浮点LUT中,但是大型LUT与SIMD不兼容。(另外,为了未来参考和因为很酷,我想先尝试一下不使用LUT的方法。;) - Cory Nelson
需要更多的缩放,因为色度是半大小,而且不幸的是,缩放必须在伽马压缩值上进行。这特别适用于Y'V12格式。 - Cory Nelson
1
@MSalters:对于某种“免费”的定义来说……在GPU上进行插值是否会增加内存访问? - Dietrich Epp
1
@rdb:64字节是常见的缓存行大小。 - Dietrich Epp
显示剩余10条评论

4

我将回答你实际上想要问的问题,即如何进行快速的sRGB <-> 线性RGB转换。为了精确高效地完成此任务,我们可以使用多项式逼近法。以下多项式逼近法是使用sollya生成的,并具有最坏情况下的相对误差为0.0144%。

inline double poly7(double x, double a, double b, double c, double d,
                              double e, double f, double g, double h) {
    double ab, cd, ef, gh, abcd, efgh, x2, x4;
    x2 = x*x; x4 = x2*x2;
    ab = a*x + b; cd = c*x + d;
    ef = e*x + f; gh = g*x + h;
    abcd = ab*x2 + cd; efgh = ef*x2 + gh;
    return abcd*x4 + efgh;
}

inline double srgb_to_linear(double x) {
    if (x <= 0.04045) return x / 12.92;

    // Polynomial approximation of ((x+0.055)/1.055)^2.4.
    return poly7(x, 0.15237971711927983387,
                   -0.57235993072870072762,
                    0.92097986411523535821,
                   -0.90208229831912012386,
                    0.88348956209696805075,
                    0.48110797889132134175,
                    0.03563925285274562038,
                    0.00084585397227064120);
}

inline double linear_to_srgb(double x) {
    if (x <= 0.0031308) return x * 12.92;

    // Piecewise polynomial approximation (divided by x^3)
    // of 1.055 * x^(1/2.4) - 0.055.
    if (x <= 0.0523) return poly7(x, -6681.49576364495442248881,
                                      1224.97114922729451791383,
                                      -100.23413743425112443219,
                                         6.60361150127077944916,
                                         0.06114808961060447245,
                                        -0.00022244138470139442,
                                         0.00000041231840827815,
                                        -0.00000000035133685895) / (x*x*x);

    return poly7(x, -0.18730034115395793881,
                     0.64677431008037400417,
                    -0.99032868647877825286,
                     1.20939072663263713636,
                     0.33433459165487383613,
                    -0.01345095746411287783,
                     0.00044351684288719036,
                    -0.00000664263587520855) / (x*x*x);
}

以下是用于生成多项式的Sollya输入:

suppressmessage(174);
f = ((x+0.055)/1.055)^2.4;
p0 = fpminimax(f, 7, [|D...|], [0.04045;1], relative);
p = fpminimax(f/(p0(1)+1e-18), 7, [|D...|], [0.04045;1], relative);
print("relative:", dirtyinfnorm((f-p)/f, [s;1]));
print("absolute:", dirtyinfnorm((f-p), [s;1]));
print(canonical(p));

s = 0.0523;
z = 3;
f = 1.055 * x^(1/2.4) - 0.055;

p = fpminimax(1.055 * (x^(z+1/2.4) - 0.055*x^z/1.055), 7, [|D...|], [0.0031308;s], relative)/x^z;
print("relative:", dirtyinfnorm((f-p)/f, [0.0031308;s]));
print("absolute:", dirtyinfnorm((f-p), [0.0031308;s]));
print(canonical(p));

p = fpminimax(1.055 * (x^(z+1/2.4) - 0.055*x^z/1.055), 7, [|D...|], [s;1], relative)/x^z;
print("relative:", dirtyinfnorm((f-p)/f, [s;1]));
print("absolute:", dirtyinfnorm((f-p), [s;1]));
print(canonical(p));

3

二项式级数 能处理常数指数,但只有在将输入归一化为[1,2)范围内时才能使用它。(请注意,它计算的是(1+x)^a)。你需要进行一些分析来确定你需要多少项才能达到你想要的精度。


现在这个很有趣。准确性在很大程度上取决于_x_的大小。虽然它对于较大的数字(>0.1)表现出色,但对于我有时需要处理的小数字(0.04和0.003),我需要进行许多迭代,而我在问题中链接的pow()近似方法则更快。 - Cory Nelson
你可能还想看看使用切比雪夫多项式进行函数逼近:http://en.wikipedia.org/wiki/Chebyshev_polynomials - zvrba
@Cory Nelson:二项式级数只是围绕x=1的泰勒级数的特殊情况。您还可以计算围绕x=0的泰勒级数。这在纸上看起来不太美观,有很多α项,但由于您“知道”α,所以这没什么大不了的。 - MSalters
我在我的第二个答案中确切地做到了这一点(请参见下文)。这第二个结果非常快速,始终非常准确。@Cory:切比雪夫多项式是一种非常强大的函数逼近工具,通常比泰勒级数更好。 - David Hammen

1

对于2.4的指数,你可以为所有2.4的值创建一个查找表并使用lirp或者更高阶的函数来填充中间值,如果表不够准确(基本上是一个巨大的对数表)。

或者,将值平方*值的2/5次方,这可以从函数的前半部分取得初始平方值,然后进行5次方根。对于5次方根,你可以使用牛顿法或其他快速逼近方法,但老实说,一旦到达这个点,你可能最好自己使用适当的缩写级数函数执行exp和log函数。


1
以下是一种可以与任何快速计算方法配合使用的想法。它是否有助于加快速度取决于数据的到达方式。您可以利用这样一个事实,即如果您知道 xpow(x, n),则可以使用幂的变化率来计算 pow(x + delta, n) 的合理近似值,其中 delta 很小,只需进行一次乘法和加法(或多或少)。如果您连续提供给幂函数的值足够接近,这将分摊准确计算的全部成本到多个函数调用中。请注意,您不需要额外的 pow 计算来获得导数。您可以扩展此方法以使用二阶导数,以便使用二次方程,这将增加您可以使用的 delta 并仍然获得相同的精度。

1
传统上,通过将x重写为x=2^(log2(x)),使powf(x,p) = x^p得到解决,从而将问题转化为两个近似值exp2()log2()。这种方法的优点是可以处理更大的幂p,但缺点是对于恒定的幂p和特定输入范围0 ≤ x ≤ 1,这不是最优解决方案。
当幂p > 1时,答案是一个显然的最小最大多项式,在边界0 ≤ x ≤ 1上,这就是p = 12/5 = 2.4的情况,如下所示:
float pow12_5(float x){
    float mp;
    // Minimax horner polynomials for x^(5/12), Note: choose the accurarcy required then implement with fma() [Fused Multiply Accumulates]
    // mp = 0x4.a84a38p-12 + x * (-0xd.e5648p-8 + x * (0xa.d82fep-4 + x * 0x6.062668p-4)); // 1.13705697e-3
    mp = 0x1.117542p-12 + x * (-0x5.91e6ap-8 + x * (0x8.0f50ep-4 + x * (0xa.aa231p-4 + x * (-0x2.62787p-4))));  // 2.6079002e-4
    // mp = 0x5.a522ap-16 + x * (-0x2.d997fcp-8 + x * (0x6.8f6d1p-4 + x * (0xf.21285p-4 + x * (-0x7.b5b248p-4 + x * 0x2.32b668p-4))));  // 8.61377e-5
    // mp = 0x2.4f5538p-16 + x * (-0x1.abcdecp-8 + x * (0x5.97464p-4 + x * (0x1.399edap0 + x * (-0x1.0d363ap0 + x * (0xa.a54a3p-4 + x * (-0x2.e8a77cp-4))))));  // 3.524655e-5
    return(mp);
}

然而,当 p < 1 时,在边界 0 ≤ x ≤ 1 上的最小极大逼近并不能恰当地收敛到所需的精度。一个选择 [不是很好] 是重新定义问题为 y=x^p=x^(p+m)/x^m,其中 m=1,2,3 是正整数,使新的幂逼近 p > 1,但这会引入除法,本质上更慢。

然而,还有另一种选择,即将输入的 x 分解为其浮点指数和尾数形式:

x = mx* 2^(ex) where 1 ≤ mx < 2
y = x^(5/12) = mx^(5/12) * 2^((5/12)*ex), let ey = floor(5*ex/12), k = (5*ex) % 12
  = mx^(5/12) * 2^(k/12) * 2^(ey)

现在对于 1 ≤ mx < 2 范围内的 mx^(5/12) 的极小化近似,不需要使用除法就能更快地收敛,但需要一个包含 2^(k/12) 的12点LUT。以下是代码:

float powk_12LUT[] = {0x1.0p0, 0x1.0f38fap0, 0x1.1f59acp0,  0x1.306fep0, 0x1.428a3p0, 0x1.55b81p0, 0x1.6a09e6p0, 0x1.7f910ep0, 0x1.965feap0, 0x1.ae89fap0, 0x1.c823ep0, 0x1.e3437ep0};
float pow5_12(float x){
    union{float f; uint32_t u;} v, e2;
    float poff, m, e, ei;
    int xe;

    v.f = x;
    xe = ((v.u >> 23) - 127);

    if(xe < -127) return(0.0f);

    // Calculate remainder k in 2^(k/12) to find LUT
    e = xe * (5.0f/12.0f);
    ei = floorf(e);
    poff = powk_12LUT[(int)(12.0f * (e - ei))];

    e2.u = ((int)ei + 127) << 23;   // Calculate the exponent
    v.u = (v.u & ~(0xFFuL << 23)) | (0x7FuL << 23); // Normalize exponent to zero

    // Approximate mx^(5/12) on [1,2), with appropriate degree minimax
    // m = 0x8.87592p-4 + v.f * (0x8.8f056p-4 + v.f * (-0x1.134044p-4));    // 7.6125e-4
    // m = 0x7.582138p-4 + v.f * (0xb.1666bp-4 + v.f * (-0x2.d21954p-4 + v.f * 0x6.3ea0cp-8));  // 8.4522726e-5
    m = 0x6.9465cp-4 + v.f * (0xd.43015p-4 + v.f * (-0x5.17b2a8p-4 + v.f * (0x1.6cb1f8p-4 + v.f * (-0x2.c5b76p-8))));   // 1.04091259e-5
    // m = 0x6.08242p-4 + v.f * (0xf.352bdp-4 + v.f * (-0x7.d0c1bp-4 + v.f * (0x3.4d153p-4 + v.f * (-0xc.f7a42p-8 + v.f * 0x1.5d840cp-8))));    // 1.367401e-6

    return(m * poff * e2.f);
}

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接