在 IEEE 754 黑客风格中,这里有另一种更快、不那么"神奇"的解决方案。它在大约十几个时钟周期内实现了 0.08% 的误差率(对于 p=2.4,在 Intel Merom CPU 上)。
浮点数最初是作为对数的近似值而发明的,因此您可以使用整数值作为 log2 的近似值。通过将整数转换指令应用于浮点值,可以在某种程度上实现这一点,以获得另一个浮点值。
要完成 pow 计算,您可以乘以一个常数因子,并使用转换为整数指令将对数转换回去。在 SSE 上,相关指令是 cvtdq2ps 和 cvtps2dq。
然而,情况并不是那么简单。IEEE 754 中的指数字段是带符号的,偏置值为 127 表示指数为零。在乘以对数之前必须删除此偏差,并在指数化之前重新添加。此外,在零上进行偏差调整的减法将无法工作。幸运的是,可以通过事先乘以常数因子来实现两种调整。
x^p
= exp2( p * log2( x ) )
= exp2( p * ( log2( x ) + 127 - 127 ) - 127 + 127 )
= cvtps2dq( p * ( log2( x ) + 127 - 127 - 127 / p ) )
= cvtps2dq( p * ( log2( x ) + 127 - log2( exp2( 127 - 127 / p ) ) )
= cvtps2dq( p * ( log2( x * exp2( 127 / p - 127 ) ) + 127 ) )
= cvtps2dq( p * ( cvtdq2ps( x * exp2( 127 / p - 127 ) ) ) )
exp2( 127 / p - 127 )
是常量因子。这个函数相当专业:它不能处理小的分数指数,因为常量因子随着指数的倒数呈指数增长,会溢出。它也不能处理负指数。大的指数会导致高误差,因为乘法将尾数位和指数位混合在一起。
但是,它只有4条快速指令。预先乘以一个常数系数,转换为“整数”(对数),幂乘,转换为“整数”(从对数)。在这个SSE实现中,转换非常快。我们还可以把额外的常数系数挤入第一次乘法中。
template< unsigned expnum, unsigned expden, unsigned coeffnum, unsigned coeffden >
__m128 fastpow( __m128 arg ) {
__m128 ret = arg;
ret = _mm_mul_ps( ret, _mm_set1_ps( exp2( 127. * expden / expnum - 127. )
* pow( 1. * coeffnum / coeffden, 1. * expden / expnum ) ) );
asm ( "cvtdq2ps %1, %0" : "=x" (ret) : "x" (ret) );
ret = _mm_mul_ps( ret, _mm_set1_ps( 1. * expnum / expden ) );
asm ( "cvtps2dq %1, %0" : "=x" (ret) : "x" (ret) );
return ret;
}
使用指数为2.4进行几次试验,结果表明这一方法始终高估约5%。(该例程始终保证高估。)您可以简单地乘以0.95,但是再增加几个指令将使我们获得大约4位小数的精度,这应该足够用于图形处理。
关键是要将高估与低估匹配,并取平均值。
- 计算x^0.8:四个指令,误差约为+3%。
- 计算x^-0.4:一个
rsqrtps
。(这已经足够准确,但牺牲了与零一起工作的能力。)
- 计算x^0.4:一个
mulps
。
- 计算x^-0.2:一个
rsqrtps
。
- 计算x^2:一个
mulps
。
- 计算x^3:一个
mulps
。
- x^2.4 = x^2 * x^0.4:一个
mulps
。这是高估。
- x^2.4 = x^3 * x^-0.4 * x^-0.2:两个
mulps
。这是低估。
- 对上述进行平均:一个
addps
,一个mulps
。
指令总数为14个,包括两个延迟为5的转换和两个吞吐量为4的倒数平方根估计。
为了正确取平均值,我们希望按其预期误差对估计值进行加权。与0.4相比,低估将误差提高到0.6的幂次,因此我们预计它会出现1.5倍的误差。加权不会增加任何指令;可以在前因子中完成。将系数称为a:a^0.5 = 1.5 a^-0.75,因此a = 1.38316186。
最终误差约为0.015%,比初始的
fastpow
结果好两个数量级。运行时间约为一打循环周期,具有
volatile
源和目标变量...虽然它重叠迭代,但实际使用也会看到指令级并行性。考虑到SIMD,这是每3个周期的吞吐量为一个标量结果!
int main() {
__m128 const x0 = _mm_set_ps( 0.01, 1, 5, 1234.567 );
std::printf( "Input: %,vg\n", x0 );
__m128 x1 = fastpow< 24, 10, 1, 1 >( x0 );
std::printf( "Direct x^2.4: %,vg\n", x1 );
__m128 xf = fastpow< 8, 10, int( 1.38316186 * 1e9 ), int( 1e9 ) >( x0 );
std::printf( "1.38 x^0.8: %,vg\n", xf );
__m128 xfm4 = _mm_rsqrt_ps( xf );
__m128 xf4 = _mm_mul_ps( xf, xfm4 );
__m128 x2 = _mm_mul_ps( x0, x0 );
__m128 x3 = _mm_mul_ps( x2, x0 );
x2 = _mm_mul_ps( x2, xf4 );
__m128 xfm2 = _mm_rsqrt_ps( xf4 );
x3 = _mm_mul_ps( x3, xfm4 );
x3 = _mm_mul_ps( x3, xfm2 );
std::printf( "x^2 * x^0.4: %,vg\n", x2 );
std::printf( "x^3 / x^0.6: %,vg\n", x3 );
x2 = _mm_mul_ps( _mm_add_ps( x2, x3 ), _mm_set1_ps( 1/ 1.960131704207789 ) );
std::printf( "average = %,vg\n", x2 );
}
抱歉我不能早些发布这篇文章。将其扩展到x^1/2.4留作练习;v)。
统计数据更新
我实现了一个小型测试工具,并添加了两个对应上述情况的x^(5/12)案例。
#include <cstdio>
#include <xmmintrin.h>
#include <cmath>
#include <cfloat>
#include <algorithm>
using namespace std;
template< unsigned expnum, unsigned expden, unsigned coeffnum, unsigned coeffden >
__m128 fastpow( __m128 arg ) {
__m128 ret = arg;
ret = _mm_mul_ps( ret, _mm_set1_ps( exp2( 127. * expden / expnum - 127. )
* pow( 1. * coeffnum / coeffden, 1. * expden / expnum ) ) );
asm ( "cvtdq2ps %1, %0" : "=x" (ret) : "x" (ret) );
ret = _mm_mul_ps( ret, _mm_set1_ps( 1. * expnum / expden ) );
asm ( "cvtps2dq %1, %0" : "=x" (ret) : "x" (ret) );
return ret;
}
__m128 pow125_4( __m128 arg ) {
__m128 xf = fastpow< 4, 5, int( 1.38316186 * 1e9 ), int( 1e9 ) >( arg );
__m128 xfm4 = _mm_rsqrt_ps( xf );
__m128 xf4 = _mm_mul_ps( xf, xfm4 );
__m128 x2 = _mm_mul_ps( arg, arg );
__m128 x3 = _mm_mul_ps( x2, arg );
x2 = _mm_mul_ps( x2, xf4 );
__m128 xfm2 = _mm_rsqrt_ps( xf4 );
x3 = _mm_mul_ps( x3, xfm4 );
x3 = _mm_mul_ps( x3, xfm2 );
return _mm_mul_ps( _mm_add_ps( x2, x3 ), _mm_set1_ps( 1/ 1.960131704207789 * 0.9999 ) );
}
__m128 pow512_2( __m128 arg ) {
__m128 x = fastpow< 5, 6, int( 0.992245 * 1e9 ), int( 1e9 ) >( arg );
return _mm_mul_ps( _mm_rsqrt_ps( x ), x );
}
__m128 pow512_4( __m128 arg ) {
__m128 xf = fastpow< 2, 3, int( 0.629960524947437 * 1e9 ), int( 1e9 ) >( arg );
__m128 xover = _mm_mul_ps( arg, xf );
__m128 xfm1 = _mm_rsqrt_ps( xf );
__m128 x2 = _mm_mul_ps( arg, arg );
__m128 xunder = _mm_mul_ps( x2, xfm1 );
__m128 xavg = _mm_mul_ps( _mm_set1_ps( 1/( 3 * 0.629960524947437 ) * 0.999852 ),
_mm_add_ps( xover, xunder ) );
xavg = _mm_mul_ps( xavg, _mm_rsqrt_ps( xavg ) );
xavg = _mm_mul_ps( xavg, _mm_rsqrt_ps( xavg ) );
return xavg;
}
__m128 mm_succ_ps( __m128 arg ) {
return (__m128) _mm_add_epi32( (__m128i) arg, _mm_set1_epi32( 4 ) );
}
void test_pow( double p, __m128 (*f)( __m128 ) ) {
__m128 arg;
for ( arg = _mm_set1_ps( FLT_MIN / FLT_EPSILON );
! isfinite( _mm_cvtss_f32( f( arg ) ) );
arg = mm_succ_ps( arg ) ) ;
for ( ; _mm_cvtss_f32( f( arg ) ) == 0;
arg = mm_succ_ps( arg ) ) ;
std::printf( "Domain from %g\n", _mm_cvtss_f32( arg ) );
int n;
int const bucket_size = 1 << 25;
do {
float max_error = 0;
double total_error = 0, cum_error = 0;
for ( n = 0; n != bucket_size; ++ n ) {
float result = _mm_cvtss_f32( f( arg ) );
if ( ! isfinite( result ) ) break;
float actual = ::powf( _mm_cvtss_f32( arg ), p );
float error = ( result - actual ) / actual;
cum_error += error;
error = std::abs( error );
max_error = std::max( max_error, error );
total_error += error;
arg = mm_succ_ps( arg );
}
std::printf( "error max = %8g\t" "avg = %8g\t" "|avg| = %8g\t" "to %8g\n",
max_error, cum_error / n, total_error / n, _mm_cvtss_f32( arg ) );
} while ( n == bucket_size );
}
int main() {
std::printf( "4 insn x^12/5:\n" );
test_pow( 12./5, & fastpow< 12, 5, 1059, 1000 > );
std::printf( "14 insn x^12/5:\n" );
test_pow( 12./5, & pow125_4 );
std::printf( "6 insn x^5/12:\n" );
test_pow( 5./12, & pow512_2 );
std::printf( "14 insn x^5/12:\n" );
test_pow( 5./12, & pow512_4 );
}
输出:
4 insn x^12/5:
Domain from 1.36909e-23
error max = inf avg = inf |avg| = inf to 8.97249e-19
error max = 2267.14 avg = 139.175 |avg| = 139.193 to 5.88021e-14
error max = 0.123606 avg = -0.000102963 |avg| = 0.0371122 to 3.85365e-09
error max = 0.123607 avg = -0.000108978 |avg| = 0.0368548 to 0.000252553
error max = 0.12361 avg = 7.28909e-05 |avg| = 0.037507 to 16.5513
error max = 0.123612 avg = -0.000258619 |avg| = 0.0365618 to 1.08471e+06
error max = 0.123611 avg = 8.70966e-05 |avg| = 0.0374369 to 7.10874e+10
error max = 0.12361 avg = -0.000103047 |avg| = 0.0371122 to 4.65878e+15
error max = 0.123609 avg = nan |avg| = nan to 1.16469e+16
14 insn x^12/5:
Domain from 1.42795e-19
error max = inf avg = nan |avg| = nan to 9.35823e-15
error max = 0.000936462 avg = 2.0202e-05 |avg| = 0.000133764 to 6.13301e-10
error max = 0.000792752 avg = 1.45717e-05 |avg| = 0.000129936 to 4.01933e-05
error max = 0.000791785 avg = 7.0132e-06 |avg| = 0.000129923 to 2.63411
error max = 0.000787589 avg = 1.20745e-05 |avg| = 0.000129347 to 172629
error max = 0.000786553 avg = 1.62351e-05 |avg| = 0.000132397 to 1.13134e+10
error max = 0.000785586 avg = 8.25205e-06 |avg| = 0.00013037 to 6.98147e+12
6 insn x^5/12:
Domain from 9.86076e-32
error max = 0.0284339 avg = 0.000441158 |avg| = 0.00967327 to 6.46235e-27
error max = 0.0284342 avg = -5.79938e-06 |avg| = 0.00897913 to 4.23516e-22
error max = 0.0284341 avg = -0.000140706 |avg| = 0.00897084 to 2.77556e-17
error max = 0.028434 avg = 0.000440504 |avg| = 0.00967325 to 1.81899e-12
error max = 0.0284339 avg = -6.11153e-06 |avg| = 0.00897915 to 1.19209e-07
error max = 0.0284298 avg = -0.000140597 |avg| = 0.00897084 to 0.0078125
error max = 0.0284371 avg = 0.000439748 |avg| = 0.00967319 to 512
error max = 0.028437 avg = -7.74294e-06 |avg| = 0.00897924 to 3.35544e+07
error max = 0.0284369 avg = -0.000142036 |avg| = 0.00897089 to 2.19902e+12
error max = 0.0284368 avg = 0.000439183 |avg| = 0.0096732 to 1.44115e+17
error max = 0.0284367 avg = -7.41244e-06 |avg| = 0.00897923 to 9.44473e+21
error max = 0.0284366 avg = -0.000141706 |avg| = 0.00897088 to 6.1897e+26
error max = 0.485129 avg = -0.0401671 |avg| = 0.048422 to 4.05648e+31
error max = 0.994932 avg = -0.891494 |avg| = 0.891494 to 2.65846e+36
error max = 0.999329 avg = nan |avg| = nan to -0
14 insn x^5/12:
Domain from 2.64698e-23
error max = 0.13556 avg = 0.00125936 |avg| = 0.00354677 to 1.73472e-18
error max = 0.000564988 avg = 2.51458e-06 |avg| = 0.000113709 to 1.13687e-13
error max = 0.000565065 avg = -1.49258e-06 |avg| = 0.000112553 to 7.45058e-09
error max = 0.000565143 avg = 1.5293e-06 |avg| = 0.000112864 to 0.000488281
error max = 0.000565298 avg = 2.76457e-06 |avg| = 0.000113713 to 32
error max = 0.000565453 avg = -1.61276e-06 |avg| = 0.000112561 to 2.09715e+06
error max = 0.000565531 avg = 1.42628e-06 |avg| = 0.000112866 to 1.37439e+11
error max = 0.000565686 avg = 2.71505e-06 |avg| = 0.000113715 to 9.0072e+15
error max = 0.000565763 avg = -1.56586e-06 |avg| = 0.000112415 to 1.84467e+19
我怀疑更准确的5/12的精度受到了rsqrt
操作的限制。
pow
是本地CPU指令吗?否则我只能想象pow(a,b)
计算exp(b * log(a))
,因此你的指数并不那么恒定。好了,看到答案会很棒! - Kerrek SBpow
通常不是CPU的本地指令,而是一个库函数,由exp
和log
构建,并具有大量特殊情况,例如x^2
和x^-1
等。 - Dietrich Epp