y = x^2
的速度。明确一点:DWORD x[n+1] = { LSW, ......, MSW };
- 其中n+1是使用的DWORD数目
- 所以数字x的值为
x = x[0]+x[1]<<32 + ... x[N]<<32*(n)
问题是:如何在不丢失精度的情况下尽快计算y = x^2
?
- 使用C++和可用的整数算术(32位带进位)。
我目前的方法是应用乘法y = x*x
并避免多次乘法。
例如:
x = x[0] + x[1]<<32 + ... x[n]<<32*(n)
为了简单起见,让我重新写一下:
x = x0+ x1 + x2 + ... + xn
索引代表数组内的地址,因此:
y = x*x
y = (x0 + x1 + x2 + ...xn)*(x0 + x1 + x2 + ...xn)
y = x0*(x0 + x1 + x2 + ...xn) + x1*(x0 + x1 + x2 + ...xn) + x2*(x0 + x1 + x2 + ...xn) + ...xn*(x0 + x1 + x2 + ...xn)
y0 = x0*x0
y1 = x1*x0 + x0*x1
y2 = x2*x0 + x1*x1 + x0*x2
y3 = x3*x0 + x2*x1 + x1*x2
...
y(2n-3) = xn(n-2)*x(n ) + x(n-1)*x(n-1) + x(n )*x(n-2)
y(2n-2) = xn(n-1)*x(n ) + x(n )*x(n-1)
y(2n-1) = xn(n )*x(n )
经过仔细观察,很明显几乎所有的
xi*xj
都出现了两次(不包括第一个和最后一个),这意味着N*N
次乘法可以用(N+1)*(N/2)
次乘法来代替。附注:32bit*32bit = 64bit
,所以每次mul+add
操作的结果都以64+1 bit
处理。有没有更快的计算方法?我在搜索中找到的都是平方根算法,而不是平方...
快速平方
!!! 注意,我的代码中所有的数字都是以MSW(Most Significant Word)为首位,而不是上面测试中的LSW(Least Significant Word),这样做是为了简化方程,否则会造成索引混乱。
当前功能性fsqr实现
void arbnum::sqr(const arbnum &x)
{
// O((N+1)*N/2)
arbnum c;
DWORD h, l;
int N, nx, nc, i, i0, i1, k;
c._alloc(x.siz + x.siz + 1);
nx = x.siz - 1;
nc = c.siz - 1;
N = nx + nx;
for (i=0; i<=nc; i++)
c.dat[i]=0;
for (i=1; i<N; i++)
for (i0=0; (i0<=nx) && (i0<=i); i0++)
{
i1 = i - i0;
if (i0 >= i1)
break;
if (i1 > nx)
continue;
h = x.dat[nx-i0];
if (!h)
continue;
l = x.dat[nx-i1];
if (!l)
continue;
alu.mul(h, l, h, l);
k = nc - i;
if (k >= 0)
alu.add(c.dat[k], c.dat[k], l);
k--;
if (k>=0)
alu.adc(c.dat[k], c.dat[k],h);
k--;
for (; (alu.cy) && (k>=0); k--)
alu.inc(c.dat[k]);
}
c.shl(1);
for (i = 0; i <= N; i += 2)
{
i0 = i>>1;
h = x.dat[nx-i0];
if (!h)
continue;
alu.mul(h, l, h, h);
k = nc - i;
if (k >= 0)
alu.add(c.dat[k], c.dat[k],l);
k--;
if (k>=0)
alu.adc(c.dat[k], c.dat[k], h);
k--;
for (; (alu.cy) && (k >= 0); k--)
alu.inc(c.dat[k]);
}
c.bits = c.siz<<5;
c.exp = x.exp + x.exp + ((c.siz - x.siz - x.siz)<<5) + 1;
c.sig = sig;
*this = c;
}
使用Karatsuba乘法
(感谢Calpis)
我实现了Karatsuba乘法,但结果比简单的O(N^2)
乘法慢得多,可能是因为那可怕的递归,我找不到任何避免的方法。它的权衡必须在非常大的数字(超过几百位数)上...但即使在那种情况下,仍然有很多内存传输。有没有办法避免递归调用(非递归变体,几乎所有递归算法都可以这样做)。不过,我会尝试调整一些东西,看看会发生什么(避免规范化等等,也可能是代码中的一些愚蠢错误)。无论如何,在解决Karatsuba的情况下,对于x*x
,性能提升并不大。
优化的Karatsuba乘法
对于y = x^2循环1000次,0.9 < x < 1 ~ 32*98位
的性能测试:
x = 0.98765588997654321000000009876... | 98*32 bits
sqr [ 213.989 ms ] ... O((N+1)*N/2) fast sqr
mul1[ 363.472 ms ] ... O(N^2) classic multiplication
mul2[ 349.384 ms ] ... O(3*(N^log2(3))) optimized Karatsuba multiplication
mul3[ 9345.127 ms] ... O(3*(N^log2(3))) unoptimized Karatsuba multiplication
x = 0.98765588997654321000... | 195*32 bits
sqr [ 883.01 ms ]
mul1[ 1427.02 ms ]
mul2[ 1089.84 ms ]
x = 0.98765588997654321000... | 389*32 bits
sqr [ 3189.19 ms ]
mul1[ 5553.23 ms ]
mul2[ 3159.07 ms ]
经过对Karatsuba的优化,代码比以前快得多。不过,对于较小的数字,它的速度略低于我O(N^2)乘法的一半。对于较大的数字,它的速度比Booth乘法的复杂度给出的比例更快。乘法的阈值约为3298位,平方的阈值约为32389位,因此,如果输入位的总和超过这个阈值,就会使用Karatsuba乘法来加速乘法,对于平方也是类似的情况。
顺便说一下,优化包括:
- 通过避免使用过大的递归参数来减少堆垃圾 - 避免使用任何大数算术(+,-),而是使用带进位的32位ALU - 忽略0*y、x*0或0*0的情况 - 将输入的x、y数字大小重新格式化为2的幂,以避免重新分配 - 实现模乘法以最小化递归的z1 = (x0 + x1)*(y0 + y1)
修改后的Schönhage-Strassen乘法用于平方实现。
我已经测试了使用FFT和NTT变换来加速平方根计算。结果如下:
1. FFT
失去了精度,因此需要高精度的复数。这实际上会显著减慢速度,因此没有加速效果。结果不准确(可能会被错误舍入),所以FFT目前无法使用。
2. NTT
NTT是有限域DFT,因此不会丢失精度。它需要在无符号整数上进行模运算:modpow、modmul、modadd和modsub。
我使用DWORD(32位无符号整数)。由于溢出问题,NTT的输入/输出向量大小受到限制!对于32位模运算,N受限于(2^32)/(max(input[])^2),因此bigint必须分割成较小的块(我使用BYTES,因此bigint的最大处理大小为)
(2^32)/((2^8)^2) = 2^16 bytes = 2^14 DWORDs = 16384 DWORDs)
使用的是1xNTT + 1xINTT
而不是2xNTT + 1xINTT
来进行乘法,但是NTT
的使用速度太慢,而且阈值数值太大,不适合在我的实现中实际使用(对于mul
和sqr
)。
可能会超过溢出限制,因此应该使用64位模运算,这可能会进一步减慢速度。所以对于我的目的来说,NTT
也无法使用。
一些测量结果:
a = 0.98765588997654321000 | 389*32 bits
looped 1x times
sqr1[ 3.177 ms ] fast sqr
sqr2[ 720.419 ms ] NTT sqr
mul1[ 5.588 ms ] simpe mul
mul2[ 3.172 ms ] karatsuba mul
mul3[ 1053.382 ms ] NTT mul
我的实现:
void arbnum::sqr_NTT(const arbnum &x)
{
// O(N*log(N)*(log(log(N)))) - 1x NTT
// Schönhage-Strassen sqr
// To prevent NTT overflow: n <= 48K * 8 bit -> result siz <= 12K * 32 bit -> x.siz + y.siz <= 12K!!!
int i, j, k, n;
int s = x.sig*x.sig, exp0 = x.exp + x.exp - ((x.siz+x.siz)<<5) + 2;
i = x.siz;
for (n = 1; n < i; n<<=1)
;
if (n + n > 0x3000) {
_error(_arbnum_error_TooBigNumber);
zero();
return;
}
n <<= 3;
DWORD *xx, *yy, q, qq;
xx = new DWORD[n+n];
#ifdef _mmap_h
if (xx)
mmap_new(xx, (n+n) << 2);
#endif
if (xx==NULL) {
_error(_arbnum_error_NotEnoughMemory);
zero();
return;
}
yy = xx + n;
// Zero padding (and split DWORDs to BYTEs)
for (i--, k=0; i >= 0; i--)
{
q = x.dat[i];
xx[k] = q&0xFF; k++; q>>=8;
xx[k] = q&0xFF; k++; q>>=8;
xx[k] = q&0xFF; k++; q>>=8;
xx[k] = q&0xFF; k++;
}
for (;k<n;k++)
xx[k] = 0;
//NTT
fourier_NTT ntt;
ntt.NTT(yy,xx,n); // init NTT for n
// Convolution
for (i=0; i<n; i++)
yy[i] = modmul(yy[i], yy[i], ntt.p);
//INTT
ntt.INTT(xx, yy);
//suma
q=0;
for (i = 0, j = 0; i<n; i++) {
qq = xx[i];
q += qq&0xFF;
yy[n-i-1] = q&0xFF;
q>>=8;
qq>>=8;
q+=qq;
}
// Merge WORDs to DWORDs and copy them to result
_alloc(n>>2);
for (i = 0, j = 0; i<siz; i++)
{
q =(yy[j]<<24)&0xFF000000; j++;
q |=(yy[j]<<16)&0x00FF0000; j++;
q |=(yy[j]<< 8)&0x0000FF00; j++;
q |=(yy[j] )&0x000000FF; j++;
dat[i] = q;
}
#ifdef _mmap_h
if (xx)
mmap_del(xx);
#endif
delete xx;
bits = siz<<5;
sig = s;
exp = exp0 + (siz<<5) - 1;
// _normalize();
}
结论
对于较小的数字,使用我的快速sqr
方法是最佳选择,在阈值之后,使用Karatsuba乘法更好。但我仍然认为可能有一些我们忽视的琐事。有其他人有什么想法吗?
NTT优化
经过大规模的优化(主要是NTT):Stack Overflow问题模运算和NTT(有限域DFT)优化。
一些值已经发生了变化:
a = 0.98765588997654321000 | 1553*32bits
looped 10x times
mul2[ 28.585 ms ] Karatsuba mul
mul3[ 26.311 ms ] NTT mul
所以现在,在大约1500*32位阈值之后,NTT乘法终于比Karatsuba更快了。
一些测量和错误被发现。
a = 0.99991970486 | 1553*32 bits
looped: 10x
sqr1[ 58.656 ms ] fast sqr
sqr2[ 13.447 ms ] NTT sqr
mul1[ 102.563 ms ] simpe mul
mul2[ 28.916 ms ] Karatsuba mul Error
mul3[ 19.470 ms ] NTT mul
我发现我的Karatsuba在大数的每个DWORD段的LSB上溢出/下溢。当我研究完后,我会更新代码...
此外,经过进一步的NTT优化,阈值发生了变化,所以对于NTT sqr,操作数的位数是310*32位=9920位,对于NTT mul,结果的位数是1396*32位=44672位(操作数的位数之和)。
Karatsuba代码已经修复,感谢@greybeard。
//---------------------------------------------------------------------------
void arbnum::_mul_karatsuba(DWORD *z, DWORD *x, DWORD *y, int n)
{
// Recursion for Karatsuba
// z[2n] = x[n]*y[n];
// n=2^m
int i;
for (i=0; i<n; i++)
if (x[i]) {
i=-1;
break;
} // x==0 ?
if (i < 0)
for (i = 0; i<n; i++)
if (y[i]) {
i = -1;
break;
} // y==0 ?
if (i >= 0) {
for (i = 0; i < n + n; i++)
z[i]=0;
return;
} // 0.? = 0
if (n == 1) {
alu.mul(z[0], z[1], x[0], y[0]);
return;
}
if (n< 1)
return;
int n2 = n>>1;
_mul_karatsuba(z+n, x+n2, y+n2, n2); // z0 = x0.y0
_mul_karatsuba(z , x , y , n2); // z2 = x1.y1
DWORD *q = new DWORD[n<<1], *q0, *q1, *qq;
BYTE cx,cy;
if (q == NULL) {
_error(_arbnum_error_NotEnoughMemory);
return;
}
#define _add { alu.add(qq[i], q0[i], q1[i]); for (i--; i>=0; i--) alu.adc(qq[i], q0[i], q1[i]); } // qq = q0 + q1 ...[i..0]
#define _sub { alu.sub(qq[i], q0[i], q1[i]); for (i--; i>=0; i--) alu.sbc(qq[i], q0[i], q1[i]); } // qq = q0 - q1 ...[i..0]
qq = q;
q0 = x + n2;
q1 = x;
i = n2 - 1;
_add;
cx = alu.cy; // =x0+x1
qq = q + n2;
q0 = y + n2;
q1 = y;
i = n2 - 1;
_add;
cy = alu.cy; // =y0+y1
_mul_karatsuba(q + n, q + n2, q, n2); // =(x0+x1)(y0+y1) mod ((2^N)-1)
if (cx) {
qq = q + n;
q0 = qq;
q1 = q + n2;
i = n2 - 1;
_add;
cx = alu.cy;
}// += cx*(y0 + y1) << n2
if (cy) {
qq = q + n;
q0 = qq;
q1 = q;
i = n2 -1;
_add;
cy = alu.cy;
}// +=cy*(x0+x1)<<n2
qq = q + n; q0 = qq; q1 = z + n; i = n - 1; _sub; // -=z0
qq = q + n; q0 = qq; q1 = z; i = n - 1; _sub; // -=z2
qq = z + n2; q0 = qq; q1 = q + n; i = n - 1; _add; // z1=(x0+x1)(y0+y1)-z0-z2
DWORD ccc=0;
if (alu.cy)
ccc++; // Handle carry from last operation
if (cx || cy)
ccc++; // Handle carry from before last operation
if (ccc)
{
i = n2 - 1;
alu.add(z[i], z[i], ccc);
for (i--; i>=0; i--)
if (alu.cy)
alu.inc(z[i]);
else
break;
}
delete[] q;
#undef _add
#undef _sub
}
//---------------------------------------------------------------------------
void arbnum::mul_karatsuba(const arbnum &x, const arbnum &y)
{
// O(3*(N)^log2(3)) ~ O(3*(N^1.585))
// Karatsuba multiplication
//
int s = x.sig*y.sig;
arbnum a, b;
a = x;
b = y;
a.sig = +1;
b.sig = +1;
int i, n;
for (n = 1; (n < a.siz) || (n < b.siz); n <<= 1)
;
a._realloc(n);
b._realloc(n);
_alloc(n + n);
for (i=0; i < siz; i++)
dat[i]=0;
_mul_karatsuba(dat, a.dat, b.dat, n);
bits = siz << 5;
sig = s;
exp = a.exp + b.exp + ((siz-a.siz-b.siz)<<5) + 1;
// _normalize();
}
//---------------------------------------------------------------------------
我的
arbnum
数字表示方式:// dat is MSDW first ... LSDW last
DWORD *dat; int siz,exp,sig,bits;
dat[siz]
是尾数。LSDW 表示最低有效 DWORD。exp
是dat[0]
的最高有效位的指数。尾数中存在第一个非零位!!!
// |-----|---------------------------|---------------|------| // | sig | MSB mantisa LSB | exponent | bits | // |-----|---------------------------|---------------|------| // | +1 | 0.(0 ... 0) | 2^0 | 0 | +零 // | -1 | 0.(0 ... 0) | 2^0 | 0 | -零 // |-----|---------------------------|---------------|------| // | +1 | 1.(dat[0] ... dat[siz-1]) | 2^exp | n | +数字 // | -1 | 1.(dat[0] ... dat[siz-1]) | 2^exp | n | -数字 // |-----|---------------------------|---------------|------| // | +1 | 1.0 | 2^+0x7FFFFFFE | 1 | +无穷大 // | -1 | 1.0 | 2^+0x7FFFFFFE | 1 | -无穷大 // |-----|---------------------------|---------------|------|