我希望使用NTT实现多项式乘法。我遵循了Number-theoretic transform (integer DFT),看起来它可以工作。
现在我想要在有限域
与前面无界情况相比,系数现在受到
特别地,原始NTT要求找到素数
编辑:我复制粘贴了上面链接中的源代码以说明问题:
然而,当我将
“为了使其正常工作,最小的minmod(上面链接中的N)应该是多少?”
现在我想要在有限域
Z_p[x]
上实现多项式乘法,其中p
是任意质数。与前面无界情况相比,系数现在受到
p
的限制,这会改变什么吗?特别地,原始NTT要求找到素数
N
作为工作模数,它大于(输入向量中最大元素的幅度)^2 * (输入向量的长度) + 1
,以便结果永远不会溢出。如果结果将被那个p
素数限制,那么模数可以多小?注意,p-1
不必形如(某些正整数)*(输入向量的长度)
。编辑:我复制粘贴了上面链接中的源代码以说明问题:
#
# Number-theoretic transform library (Python 2, 3)
#
# Copyright (c) 2017 Project Nayuki
# All rights reserved. Contact Nayuki for licensing.
# https://www.nayuki.io/page/number-theoretic-transform-integer-dft
#
import itertools, numbers
def find_params_and_transform(invec, minmod):
check_int(minmod)
mod = find_modulus(len(invec), minmod)
root = find_primitive_root(len(invec), mod - 1, mod)
return (transform(invec, root, mod), root, mod)
def check_int(n):
if not isinstance(n, numbers.Integral):
raise TypeError()
def find_modulus(veclen, minimum):
check_int(veclen)
check_int(minimum)
if veclen < 1 or minimum < 1:
raise ValueError()
start = (minimum - 1 + veclen - 1) // veclen
for i in itertools.count(max(start, 1)):
n = i * veclen + 1
assert n >= minimum
if is_prime(n):
return n
def is_prime(n):
check_int(n)
if n <= 1:
raise ValueError()
return all((n % i != 0) for i in range(2, sqrt(n) + 1))
def sqrt(n):
check_int(n)
if n < 0:
raise ValueError()
i = 1
while i * i <= n:
i *= 2
result = 0
while i > 0:
if (result + i)**2 <= n:
result += i
i //= 2
return result
def find_primitive_root(degree, totient, mod):
check_int(degree)
check_int(totient)
check_int(mod)
if not (1 <= degree <= totient < mod):
raise ValueError()
if totient % degree != 0:
raise ValueError()
gen = find_generator(totient, mod)
root = pow(gen, totient // degree, mod)
assert 0 <= root < mod
return root
def find_generator(totient, mod):
check_int(totient)
check_int(mod)
if not (1 <= totient < mod):
raise ValueError()
for i in range(1, mod):
if is_generator(i, totient, mod):
return i
raise ValueError("No generator exists")
def is_generator(val, totient, mod):
check_int(val)
check_int(totient)
check_int(mod)
if not (0 <= val < mod):
raise ValueError()
if not (1 <= totient < mod):
raise ValueError()
pf = unique_prime_factors(totient)
return pow(val, totient, mod) == 1 and all((pow(val, totient // p, mod) != 1) for p in pf)
def unique_prime_factors(n):
check_int(n)
if n < 1:
raise ValueError()
result = []
i = 2
end = sqrt(n)
while i <= end:
if n % i == 0:
n //= i
result.append(i)
while n % i == 0:
n //= i
end = sqrt(n)
i += 1
if n > 1:
result.append(n)
return result
def transform(invec, root, mod):
check_int(root)
check_int(mod)
if len(invec) >= mod:
raise ValueError()
if not all((0 <= val < mod) for val in invec):
raise ValueError()
if not (1 <= root < mod):
raise ValueError()
outvec = []
for i in range(len(invec)):
temp = 0
for (j, val) in enumerate(invec):
temp += val * pow(root, i * j, mod)
temp %= mod
outvec.append(temp)
return outvec
def inverse_transform(invec, root, mod):
outvec = transform(invec, reciprocal(root, mod), mod)
scaler = reciprocal(len(invec), mod)
return [(val * scaler % mod) for val in outvec]
def reciprocal(n, mod):
check_int(n)
check_int(mod)
if not (0 <= n < mod):
raise ValueError()
x, y = mod, n
a, b = 0, 1
while y != 0:
a, b = b, a - x // y * b
x, y = y, x % y
if x == 1:
return a % mod
else:
raise ValueError("Reciprocal does not exist")
def circular_convolve(vec0, vec1):
if not (0 < len(vec0) == len(vec1)):
raise ValueError()
if any((val < 0) for val in itertools.chain(vec0, vec1)):
raise ValueError()
maxval = max(val for val in itertools.chain(vec0, vec1))
minmod = maxval**2 * len(vec0) + 1
temp0, root, mod = find_params_and_transform(vec0, minmod)
temp1 = transform(vec1, root, mod)
temp2 = [(x * y % mod) for (x, y) in zip(temp0, temp1)]
return inverse_transform(temp2, root, mod)
vec0 = [24, 12, 28, 8, 0, 0, 0, 0]
vec1 = [4, 26, 29, 23, 0, 0, 0, 0]
print(circular_convolve(vec0, vec1))
def modulo(vec, prime):
return [x % prime for x in vec]
print(modulo(circular_convolve(vec0, vec1), 31))
输出:
[96, 672, 1120, 1660, 1296, 876, 184, 0]
[3, 21, 4, 17, 25, 8, 29, 0]
然而,当我将
minmod = maxval**2 * len(vec0) + 1
改为 minmod = maxval + 1
时,它停止工作:[14, 16, 13, 20, 25, 15, 20, 0]
[14, 16, 13, 20, 25, 15, 20, 0]
“为了使其正常工作,最小的minmod(上面链接中的N)应该是多少?”
p
中可以安全地扮演链接所称的M
的角色。如果我理解正确,他们花时间指定M
的唯一原因是他们不真正希望答案被模M
,因此他们希望M
足够大,以便最终输出可以被解释为普通整数。你想要输出模p
,因此没有必要避免找到“安全”的M。 - John Colemankn+1
的形式,其中n
是项目数。大于p
的最小素数N
应该没问题。 - John Colemanp>maxval**2 * len(vec0) + 1
处中断时,它打印出 [14, 16, 13, 20, 25, 15, 20, 0],而不是 [3, 21, 4, 17, 25, 8, 29, 0]。因此,一方面,“p 是质数,使得 p mod n == 1 且 p>max”似乎不足以作为停止准则,另一方面,我看到的所有文本都将其提及为停止准则。 - minmaxp
总会让你陷入麻烦(除非使用任意精度)。 - Spektre