我最近想要使用NTT来进行快速乘法,而不是使用DFFT。但是读了很多混淆的东西,到处都是不同的字母,没有简单的解决方案,而且我的有限域知识有点生疏。但是今天我终于搞定了(经过两天的尝试和与DFT系数的类比),因此在这里分享一下我对NTT的见解:
Computation
X(i) = sum(j=0..n-1) of ( Wn^(i*j)*x(i) );
where X[]
is NTT transformed x[]
of size n
where Wn
is the NTT basis. All computations are on integer modulo arithmetics mod p
no complex numbers anywhere.
Important values
Wn = r ^ L mod p
is basis for NTT
Wn = r ^ (p-1-L) mod p
is basis for INTT
Rn = n ^ (p-2) mod p
is scaling multiplicative constant for INTT ~(1/n)
p
is prime that p mod n == 1
and p>max'
max
is max value of x[i] for NTT or X[i] for INTT
r = <1,p)
L = <1,p)
and also divides p-1
r,L
must be combined so r^(L*i) mod p == 1
if i=0
or i=n
r,L
must be combined so r^(L*i) mod p != 1
if 0 < i < n
max'
is the sub-result max value and depends on n
and type of computation. For single (I)NTT it is max' = n*max
but for convolution of two n
sized vectors it is max' = n*max*max
etc. See Implementing FFT over finite fields for more info about it.
functional combination of r,L,p
is different for different n
this is important, you have to recompute or select parameters from table before each NTT layer (n
is always half of the previous recursion).
以下是我找到的用于查找参数
r,L,p
的C++代码(需要进行模算术运算,但未包含在代码中,您可以使用(a+b)%c,(a-b)%c,(a*b)%c等代替,但在这种情况下,请注意特别处理
modpow
和
modmul
的溢出问题)。代码尚未优化,但有许多可加快速度的方法。此外,素数表相当有限,因此请使用
SoE或任何其他算法以获得上限为max'
的素数,并确保安全运行。
DWORD _arithmetics_primes[]=
{
2,3,5,7,11,13,17,19,23,29,31,37,41,43,47,53,59,61,67,71,73,79,83,89,97,101,103,107,109,113,127,131,137,139,149,151,157,163,167,173,
179,181,191,193,197,199,211,223,227,229,233,239,241,251,257,263,269,271,277,281,283,293,307,311,313,317,331,337,347,349,353,359,367,373,379,383,389,397,401,409,
419,421,431,433,439,443,449,457,461,463,467,479,487,491,499,503,509,521,523,541,547,557,563,569,571,577,587,593,599,601,607,613,617,619,631,641,643,647,653,659,
661,673,677,683,691,701,709,719,727,733,739,743,751,757,761,769,773,787,797,809,811,821,823,827,829,839,853,857,859,863,877,881,883,887,907,911,919,929,937,941,
947,953,967,971,977,983,991,997,1009,1013,1019,1021,1031,1033,1039,1049,1051,1061,1063,1069,1087,1091,1093,1097,1103,1109,1117,1123,1129,1151,
0};
int i,j,k,n=16;
long w,W,iW,p,r,L,l,e;
long max=81*n;
for (e=1,j=0;e;j++)
{
p=_arithmetics_primes[j];
if (!p) break;
if ((p>max)&&(p%n==1))
for (r=2;r<p;r++)
{
for (l=1;l<p;l++)
{
L=(p-1);
if (L%l!=0) continue;
L/=l;
W=modpow(r,L,p);
e=0;
for (w=1,i=0;i<=n;i++,w=modmul(w,W,p))
{
if ((i==0) &&(w!=1)) { e=1; break; }
if ((i==n) &&(w!=1)) { e=1; break; }
if ((i>0)&&(i<n)&&(w==1)) { e=1; break; }
}
if (!e) break;
}
if (!e) break;
}
}
if (e) { error; }
W=modpow(r, L,p);
iW=modpow(r,p-1-L,p);
这是我的慢NTT和INTT实现(我还没有研究过快速NTT,INTT),它们都已通过Schönhage-Strassen乘法进行了测试。
void NTT(long *dst,long *src,long n,long m,long w)
{
long i,j,wj,wi,a,n2=n>>1;
for (wj=1,j=0;j<n;j++)
{
a=0;
for (wi=1,i=0;i<n;i++)
{
a=modadd(a,modmul(wi,src[i],m),m);
wi=modmul(wi,wj,m);
}
dst[j]=a;
wj=modmul(wj,w,m);
}
}
void INTT(long *dst,long *src,long n,long m,long w)
{
long i,j,wi=1,wj=1,rN,a,n2=n>>1;
rN=modpow(n,m-2,m);
for (wj=1,j=0;j<n;j++)
{
a=0;
for (wi=1,i=0;i<n;i++)
{
a=modadd(a,modmul(wi,src[i],m),m);
wi=modmul(wi,wj,m);
}
dst[j]=modmul(a,rN,m);
wj=modmul(wj,w,m);
}
}
dst
是目标数组
src
是源数组
n
是数组大小
m
是模数 (p
)
w
是基数 (Wn
)
希望这对某些人有所帮助。如果我忘记了什么,请写下来...
[编辑1:快速NTT / INTT]
最终,我成功实现了快速NTT / INTT。 比普通的FFT要棘手一些:
void _NFTT(long *dst,long *src,long n,long m,long w)
{
if (n<=1) { if (n==1) dst[0]=src[0]; return; }
long i,j,a0,a1,n2=n>>1,w2=modmul(w,w,m);
for (i=0,j=0;i<n2;i++,j+=2) dst[i]=src[j];
for ( j=1;i<n ;i++,j+=2) dst[i]=src[j];
_NFTT(src ,dst ,n2,m,w2);
_NFTT(src+n2,dst+n2,n2,m,w2);
for (w2=1,i=0,j=n2;i<n2;i++,j++,w2=modmul(w2,w,m))
{
a0=src[i];
a1=modmul(src[j],w2,m);
dst[i]=modadd(a0,a1,m);
dst[j]=modsub(a0,a1,m);
}
}
void _INFTT(long *dst,long *src,long n,long m,long w)
{
long i,rN;
rN=modpow(n,m-2,m);
_NFTT(dst,src,n,m,w);
for (i=0;i<n;i++) dst[i]=modmul(dst[i],rN,m);
}
[编辑3]
我优化了我的代码(比上面的代码快3倍),但我仍然不满意,所以我开始了一个新的问题。在那里,我进一步优化了代码(比上面的代码快40倍),因此它的速度几乎与相同位数的浮点FFT相同。这是链接: