在有限域上实现FFT

9
我希望使用NTT实现多项式乘法。我遵循了Number-theoretic transform (integer DFT),看起来它可以工作。
现在我想要在有限域Z_p[x]上实现多项式乘法,其中p是任意质数。
与前面无界情况相比,系数现在受到p的限制,这会改变什么吗?
特别地,原始NTT要求找到素数N作为工作模数,它大于(输入向量中最大元素的幅度)^2 * (输入向量的长度) + 1,以便结果永远不会溢出。如果结果将被那个p素数限制,那么模数可以多小?注意,p-1不必形如(某些正整数)*(输入向量的长度)
编辑:我复制粘贴了上面链接中的源代码以说明问题:
# 
# Number-theoretic transform library (Python 2, 3)
# 
# Copyright (c) 2017 Project Nayuki
# All rights reserved. Contact Nayuki for licensing.
# https://www.nayuki.io/page/number-theoretic-transform-integer-dft
#

import itertools, numbers

def find_params_and_transform(invec, minmod):
    check_int(minmod)
    mod = find_modulus(len(invec), minmod)
    root = find_primitive_root(len(invec), mod - 1, mod)
    return (transform(invec, root, mod), root, mod)

def check_int(n):
    if not isinstance(n, numbers.Integral):
        raise TypeError()

def find_modulus(veclen, minimum):
    check_int(veclen)
    check_int(minimum)
    if veclen < 1 or minimum < 1:
        raise ValueError()
    start = (minimum - 1 + veclen - 1) // veclen
    for i in itertools.count(max(start, 1)):
        n = i * veclen + 1
        assert n >= minimum
        if is_prime(n):
            return n

def is_prime(n):
    check_int(n)
    if n <= 1:
        raise ValueError()
    return all((n % i != 0) for i in range(2, sqrt(n) + 1))

def sqrt(n):
    check_int(n)
    if n < 0:
        raise ValueError()
    i = 1
    while i * i <= n:
        i *= 2
    result = 0
    while i > 0:
        if (result + i)**2 <= n:
            result += i
        i //= 2
    return result

def find_primitive_root(degree, totient, mod):
    check_int(degree)
    check_int(totient)
    check_int(mod)
    if not (1 <= degree <= totient < mod):
        raise ValueError()
    if totient % degree != 0:
        raise ValueError()
    gen = find_generator(totient, mod)
    root = pow(gen, totient // degree, mod)
    assert 0 <= root < mod
    return root

def find_generator(totient, mod):
    check_int(totient)
    check_int(mod)
    if not (1 <= totient < mod):
        raise ValueError()
    for i in range(1, mod):
        if is_generator(i, totient, mod):
            return i
    raise ValueError("No generator exists")

def is_generator(val, totient, mod):
    check_int(val)
    check_int(totient)
    check_int(mod)
    if not (0 <= val < mod):
        raise ValueError()
    if not (1 <= totient < mod):
        raise ValueError()
    pf = unique_prime_factors(totient)
    return pow(val, totient, mod) == 1 and all((pow(val, totient // p, mod) != 1) for p in pf)

def unique_prime_factors(n):
    check_int(n)
    if n < 1:
        raise ValueError()
    result = []
    i = 2
    end = sqrt(n)
    while i <= end:
        if n % i == 0:
            n //= i
            result.append(i)
            while n % i == 0:
                n //= i
            end = sqrt(n)
        i += 1
    if n > 1:
        result.append(n)
    return result

def transform(invec, root, mod):
    check_int(root)
    check_int(mod)
    if len(invec) >= mod:
        raise ValueError()
    if not all((0 <= val < mod) for val in invec):
        raise ValueError()
    if not (1 <= root < mod):
        raise ValueError()

    outvec = []
    for i in range(len(invec)):
        temp = 0
        for (j, val) in enumerate(invec):
            temp += val * pow(root, i * j, mod)
            temp %= mod
        outvec.append(temp)
    return outvec

def inverse_transform(invec, root, mod):
    outvec = transform(invec, reciprocal(root, mod), mod)
    scaler = reciprocal(len(invec), mod)
    return [(val * scaler % mod) for val in outvec]

def reciprocal(n, mod):
    check_int(n)
    check_int(mod)
    if not (0 <= n < mod):
        raise ValueError()
    x, y = mod, n
    a, b = 0, 1
    while y != 0:
        a, b = b, a - x // y * b
        x, y = y, x % y
    if x == 1:
        return a % mod
    else:
        raise ValueError("Reciprocal does not exist")

def circular_convolve(vec0, vec1):
    if not (0 < len(vec0) == len(vec1)):
        raise ValueError()
    if any((val < 0) for val in itertools.chain(vec0, vec1)):
        raise ValueError()
    maxval = max(val for val in itertools.chain(vec0, vec1))
    minmod = maxval**2 * len(vec0) + 1
    temp0, root, mod = find_params_and_transform(vec0, minmod)
    temp1 = transform(vec1, root, mod)
    temp2 = [(x * y % mod) for (x, y) in zip(temp0, temp1)]
    return inverse_transform(temp2, root, mod)

vec0 = [24, 12, 28, 8, 0, 0, 0, 0]
vec1 = [4, 26, 29, 23, 0, 0, 0, 0]

print(circular_convolve(vec0, vec1))

def modulo(vec, prime):
    return [x % prime for x in vec]

print(modulo(circular_convolve(vec0, vec1), 31))

输出:

[96, 672, 1120, 1660, 1296, 876, 184, 0]
[3, 21, 4, 17, 25, 8, 29, 0]

然而,当我将 minmod = maxval**2 * len(vec0) + 1 改为 minmod = maxval + 1 时,它停止工作:
[14, 16, 13, 20, 25, 15, 20, 0]
[14, 16, 13, 20, 25, 15, 20, 0]

“为了使其正常工作,最小的minmod(上面链接中的N)应该是多少?”

@JohnColeman 有两个质数需要处理:一个是用于多项式的,另一个是用于NTT的。由于第一个不必符合NTT所需的形式,我的问题是,第二个可以有多小? - minmax
1
我认为在你的p中可以安全地扮演链接所称的M的角色。如果我理解正确,他们花时间指定M的唯一原因是他们真正希望答案被模M,因此他们希望M足够大,以便最终输出可以被解释为普通整数。你想要输出模p,因此没有必要避免找到“安全”的M。 - John Coleman
1
“N”的形式似乎很重要。它似乎需要像kn+1的形式,其中n是项目数。大于p的最小素数N应该没问题。 - John Coleman
@Spektre 当我在 p>maxval**2 * len(vec0) + 1 处中断时,它打印出 [14, 16, 13, 20, 25, 15, 20, 0],而不是 [3, 21, 4, 17, 25, 8, 29, 0]。因此,一方面,“p 是质数,使得 p mod n == 1 且 p>max”似乎不足以作为停止准则,另一方面,我看到的所有文本都将其提及为停止准则。 - minmax
@minmax 我认为你在倒退或者过度思考了。为什么不使用适合于你的数据字的最大质数呢?这样可以消除模数(用单个减法替换它),而且一些这样的质数在二进制中可以有很多分组零,以进行额外的位运算优化。此外,由于最大可能值用于这样的情况,如果溢出是可能的,那么无论如何都无法避免(除非使用更大的数据字),因此没有必要因此而把头搞砸。使用最小的 p 总会让你陷入麻烦(除非使用任意精度)。 - Spektre
显示剩余6条评论
1个回答

1
如果您输入的n个整数与某个素数q绑定(任何mod q而不仅仅是素数都将相同),则可以将其用作max value +1,但要注意不能将其用作NTT的质数p,因为NTT质数p具有特殊属性。所有这些属性都在此处列出: 因此,我们每个输入的最大值为q-1,但在您的任务计算(对2个NTT结果进行卷积)期间,第一层结果的幅度可能会上升到n.(q-1),但由于我们正在对它们进行卷积,因此最终iNTT的输入幅度将上升到:
m = n.((q-1)^2)

如果您对NTT执行不同的操作,则m方程可能会发生变化。
现在让我们回到p,简而言之,您可以使用任何符合以下条件的质数p:
p mod n == 1
p > m

并且存在 1 <= r,L < p,使得:

p mod (L-1) = 0
r^(L*i) mod p == 1 // i = { 0,n }
r^(L*i) mod p != 1 // i = { 1,2,3, ... n-1 }

如果所有这些条件都满足,那么p就是一个“单位根”,可以用于进行NTT。要找到这样的质数以及r,L,请查看上面的链接(其中有一个查找此类内容的C++代码)。
例如,在字符串乘法中,我们取2个字符串,在它们上面进行NTT,然后卷积结果并iNTT回结果(即两个输入大小的总和)。例如:
                                99999999999999999999999999999999
                               *99999999999999999999999999999999
----------------------------------------------------------------
9999999999999999999999999999999800000000000000000000000000000001

q=10且两个操作数均为9^32时,n=32,因此m=9*9*32=2592,找到的素数是p=2689。正如您所看到的结果匹配,因此不会发生溢出。然而,如果我使用任何更小的素数仍满足所有其他条件,则结果将不匹配。我特意使用这个来尽可能地拉伸NTT值(所有值都为q-1,大小相等于2的幂)

如果你的NTT很快,而n不是2的幂,则需要为每个NTT零填充到最接近高于或等于2的幂的大小。但这不应影响m值,因为零填充不应增加值的数量级。我的测试证明它对于卷积来说是正确的:

m = (n1+n2).((q-1)^2)/2

其中n1,n2是零填充之前的原始输入大小。

有关实现NTT的更多信息,您可以查看我的C ++(经过广泛优化):

  • {{link1:模算术和NTT(有限场DFT)优化}}

因此,回答您的问题:

  1. 是的,您可以利用输入为mod q的事实,但不能将q用作p!!!

  2. 您只能在单个NTT(或第一层NTT)中使用minmod = n *(maxval + 1),但是由于您在NTT使用过程中将它们与卷积链接,因此无法将其用于最终INTT阶段!!!

但是,正如我在评论中提到的那样,最简单的方法是使用适合您使用的数据类型并且可用于支持所有2的幂输入大小的最大可能p

基本上这使得你的问题不相关了。我只能想到一种情况,那就是在没有“最大”限制的任意精度数字上。变量p存在许多性能问题,因为查找p非常缓慢(甚至可能比NTT本身还要慢),而且变量p禁用了所需的模算术的许多性能优化,从而使NTT变得非常缓慢。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接