加速循环或计算原始三元组的不同思路

8
def pythag_triples(n):
    i = 0
    start = time.time()
    for x in range(1, int(sqrt(n) + sqrt(n)) + 1, 2):
        for m in range(x+2,int(sqrt(n) + sqrt(n)) + 1, 2):
            if gcd(x, m) == 1:
                # q = x*m
                # l = (m**2 - x**2)/2
                c = (m**2 + x**2)/2
                # trips.append((q,l,c))
                if c < n:
                    i += 1
    end = time.time()
    return i, end-start
print(pythag_triples(3141592653589793))

我正在尝试使用“m”和“n”均为奇数且互质的概念来计算原始勾股三元组。我已经知道该函数在1000000以内的数中有效,但是当对更大的数进行计算时,花费的时间超过了24小时。有没有任何方法可以加速这个过程/而不是采用暴力解法。我正在尝试计算三元组的数量。

1
评论不适合进行长时间的讨论;该对话已经移至聊天室。 (https://chat.stackoverflow.com/rooms/240943/discussion-on-question-by-koder-speeding-up-the-loops-or-different-ideas-for-cou)。 - Stephen Rauch
@KellyBundy 我该如何接受一个答案? - Koder
FYI,我在我的Linux电脑上(一台Ryzen 5950x)使用单线程将其降至9.7秒。这是在整个过程中都使用了numba和我第二个答案的变体。我会在有时间的时候添加一些说明。不确定是将其添加到第二个答案中还是写一个第三个答案。你有什么想法? - Pierre D
3个回答

5

我们不需要对xm 进行双重循环,并反复检查它们是否互质,而是仅迭代 m(两者中的较大值),并直接应用欧拉函数或其自定义版本来计算与m互质的x值的数量。这为我们提供了一种更快速的方法(速度尚待更精确地量化):例如,对于n = 100_000_000,使用此方法只需要43毫秒,而使用原始代码需要30秒(加速700倍)。

x允许的最大值xmax小于m时(以满足不等式(m**2 + x**2)/2 <= n),就需要使用自定义版本。在这种情况下,不应该计算m的所有互质数,而只计算达到该上限的互质数。

def distinct_factors(n):
    # a variant of the well-known factorization, but that
    # yields only distinct factors, rather than all of them
    # (including possible repeats)
    last = None
    i = 2
    while i * i <= n:
        if n % i:
            i += 1
        else:
            n //= i
            if i != last:
                yield i
                last = i
    if n > 1 and n != last:
        yield n

def products_of(p_list, upto):
    for i, p in enumerate(p_list):
        if p > upto:
            break
        yield -p
        for q in products_of(p_list[i+1:], upto=upto // p):
            yield -p * q

def phi(n, upto=None):
    # Euler's totient or "phi" function
    if upto is not None and upto < n:
        # custom version: all co-primes of n up to the `upto` bound
        cnt = upto
        p_list = list(distinct_factors(n))
        for q in products_of(p_list, upto):
            cnt += upto // q if q > 0 else -(upto // -q)
        return cnt
    # standard formulation: all co-primes of n up to n-1
    cnt = n
    for p in distinct_factors(n):
        cnt *= (1 - 1/p)
    return int(cnt)

phi(n)是欧拉(Euler)的欧拉函数或ϕ(n)函数。

phi(n, upto=x)是一种自定义变体,仅计算给定值x以下的互质数。为了理解它,让我们使用一个例子:

>>> n = 3*3*3*5  # 135
>>> list(factors(n))
[3, 3, 3, 5]

>>> list(distinct_factors(n))
[3, 5]

# there are 72 integers between 1 and 135 that are co-primes of 135
>>> phi(n)
72

# ...but only 53 of them are no greater than 100:
# 100 - (100//3 + 100//5 - 100//(3*5)) 
>>> phi(n, upto=100)
53

在评估小于值为x的数字中与n互质的数字数量时,我们应该计算所有1到x的数字,然后减去任何n的不同因子的倍数。但是,如果只是简单地删除所有p_i的x // p_i,则会重复计算两个因子的倍数,因此我们需要“添加回来”。但是,在这样做时,会多次添加多倍数的三个因子的数字,所以我们也需要考虑这些数字。在例子n = 135中,我们删除了x // 3和x // 5,但那些既是3的因子又是5的因子(15的因子)的整数被重复计算了,因此我们需要将它们添加回来。对于更长的因子集,我们需要:
- 将x作为初始计数; - 减去每个因子p的倍数的数量; - “取消减少”(添加)任意2个因子的乘积的倍数数量; - “取消取消减少”(减去)任何3个因子的乘积的倍数数量; - 等等。
最初的答案是通过迭代所有不同因子的组合来实现的,但是这个答案中通过“products_of(p_list,upto)”生成器进行了大量优化,该生成器给出了给定p_list不同因子的所有子集的乘积,其乘积不大于upto。符号表示如何计算每个产品:正或负取决于子集大小是偶数还是奇数。
有了phi(n)和phi(n,upto),我们现在可以编写以下内容:
def pyth_m_counts(n):
    # yield tuples (m, count(x) where 0 < x < m and odd(x)
    # and odd(m) and coprime(x, m) and m**2 + x**2 <= 2*n).
    mmax = isqrt(2*n - 1)
    for m in range(3, mmax + 1, 2):
        # requirement: (m**2 + x**2) // 2 <= n
        # and both m and x are odd
        # (so (m**2 + x**2) // 2 == (m**2 + x**2) / 2)
        xmax = isqrt(2*n - m**2)
        cnt_m = phi(2*m, upto=xmax) if xmax < m else phi(2*m) // 2
        if cnt_m > 0:
            yield m, cnt_m

为什么要使用表达式phi(2*m) // 2呢?根据原帖,由于x(和m)必须都是奇数,我们需要移除所有偶数值。我们可以通过传递2*m(其中2是一个因子,将会“清除”所有x的偶数值),而无需修改phi()函数来做到这一点,然后再除以2来获得与m互质的奇数个数。对于phi(2*m, upto=xmax)也有类似但稍微微妙一些的考虑 - 我们将其留给读者作为练习...
示例运行:
>>> n = 300
>>> list(pyth_m_counts(n))
[(3, 1),
 (5, 2),
 (7, 3),
 (9, 3),
 (11, 5),
 (13, 6),
 (15, 4),
 (17, 8),
 (19, 8),
 (21, 3),
 (23, 4)]

这意味着,在OP的函数中,pythag_triples(300)将返回1个元组,其中m==3,2个元组,m==5等以此类推。事实上,我们修改该函数以验证:

def mod_pythag_triples(n):
    for x in range(1, int(sqrt(n) + sqrt(n)) + 1, 2):
        for m in range(x+2, int(sqrt(n) + sqrt(n)) + 1, 2):
            if gcd(x, m) == 1:
                c = (m**2 + x**2) // 2
                if c < n:
                    yield x, m

然后:

>>> n = 300
>>> list(pyth_m_counts(n)) == list(Counter(m for x, m in mod_pythag_triples(n)).items())
True

对于任何正值的n都是相同的。

现在来看实际计数函数:我们只需要将每个m的计数相加即可:

def pyth_triples_count(n):
    cnt = 0
    mmax = isqrt(2*n - 1)
    for m in range(3, mmax + 1, 2):
        # requirement: (m**2 + x**2) // 2 <= n
        # and both m and x are odd (so (m**2 + x**2) // 2 == (m**2 + x**2) / 2)
        xmax = isqrt(2*n - m**2)
        cnt += phi(2*m, upto=xmax) if xmax < m else phi(2*m) // 2
    return cnt

样例运行:

>>> pyth_triples_count(1_000_000)
159139

>>> pyth_triples_count(100_000_000)
15915492

>>> pyth_triples_count(1_000_000_000)
159154994

>>> big_n = 3_141_592_653_589_793
>>> pyth_triples_count(big_n)
500000000002845

速度:

%timeit pyth_triples_count(100_000_000)
42.8 ms ± 56.1 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

%timeit pyth_triples_count(1_000_000_000)
188 ms ± 571 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

%%time
pyth_triples_count(big_n)
CPU times: user 1h 42min 33s, sys: 480 ms, total: 1h 42min 33s
Wall time: 1h 42min 33s

注意:在同一台机器上,OP问题中的代码对于n = 100_000_000需要30秒;而这个版本对于这个n快了700倍。
另请参见我的其他答案以获取更快的解决方案。

3
这个新答案将big_n的总时间缩短到4分6秒
对我的初始答案进行分析,发现以下事实:
  • 总时间:1小时42分33秒
  • 分解数字所花费的时间:几乎占据了全部时间
相比之下,从3sqrt(2*N - 1)生成所有素数仅需38.5秒(使用Atkin筛法)。
因此,我决定尝试一种版本,其中我们将所有数字m生成为已知质数乘积。也就是说,生成器产生数字本身以及涉及的不同质因数。无需因式分解。 结果仍然是500_000_000_002_841,与@Koder发现的相差4。我还不知道问题来自哪里。编辑:在修正xmax边界后(isqrt(2 * N-m ** 2)而不是isqrt(2 * N-m ** 2-1),因为我们确实希望包括直角三角形的斜边等于N),我们现在得到了正确的结果。

质数生成器的代码包含在最后。基本上,我使用了Atkin筛法,并将其适应(没有花费太多时间)Python。我相当确定它可以加速(例如使用numpy,甚至可能使用numba)。

为了从质数生成整数(我们知道这是可能的,多亏了算术基本定理),我们只需遍历所有可能的积prod(p_i**k_i),其中p_i是第i个质数,k_i是任何非负整数。
最简单的公式是递归公式:
def gen_ints_from_primes(p_list, upto):
    if p_list and upto >= p_list[0]:
        p, *p_list = p_list
        pk = 1
        p_tup = tuple()
        while pk <= upto:
            for q, p_distinct in gen_ints_from_primes(p_list, upto=upto // pk):
                yield pk * q, p_tup + p_distinct
            pk *= p
            p_tup = (p, )
    else:
        yield 1, tuple()

不幸的是,我们很快就会遇到内存限制(和递归限制)。因此,这里有一个非递归版本,除了素数列表本身之外不使用额外的内存。基本上,当前值q(正在生成的整数)和列表中的索引是我们生成下一个整数所需的所有信息。当然,这些值未排序,但这并不重要,只要它们都被覆盖即可。

def rem_p(q, p, p_distinct):
    q0 = q
    while q % p == 0:
        q //= p
    if q != q0:
        if p_distinct[-1] != p:
            raise ValueError(f'rem({q}, {p}, ...{p_distinct[-4:]}): p expected at end of p_distinct if q % p == 0')
        p_distinct = p_distinct[:-1]
    return q, p_distinct

def add_p(q, p, p_distinct):
    if len(p_distinct) == 0 or p_distinct[-1] != p:
        p_distinct += (p, )
    q *= p
    return q, p_distinct

def gen_prod_primes(p, upto=None):
    if upto is None:
        upto = p[-1]
    if upto >= p[-1]:
        p = p + [upto + 1]  # sentinel
    
    q = 1
    i = 0
    p_distinct = tuple()
    
    while True:
        while q * p[i] <= upto:
            i += 1
        while q * p[i] > upto:
            yield q, p_distinct
            if i <= 0:
                return
            q, p_distinct = rem_p(q, p[i], p_distinct)
            i -= 1
        q, p_distinct = add_p(q, p[i], p_distinct)

例子-
>>> p_list = list(primes(20))
>>> p_list
[2, 3, 5, 7, 11, 13, 17, 19]

>>> sorted(gen_prod_primes(p_list, 20))
[(1, ()),
 (2, (2,)),
 (3, (3,)),
 (4, (2,)),
 (5, (5,)),
 (6, (2, 3)),
 (7, (7,)),
 (8, (2,)),
 (9, (3,)),
 (10, (2, 5)),
 (11, (11,)),
 (12, (2, 3)),
 (13, (13,)),
 (14, (2, 7)),
 (15, (3, 5)),
 (16, (2,)),
 (17, (17,)),
 (18, (2, 3)),
 (19, (19,)),
 (20, (2, 5))]

如您所见,我们不需要分解任何数字,因为它们很方便地与涉及的不同质数一起出现。

要仅获取奇数,请从素数列表中删除2

>>> sorted(gen_prod_primes(p_list[1:]), 20)
[(1, ()),
 (3, (3,)),
 (5, (5,)),
 (7, (7,)),
 (9, (3,)),
 (11, (11,)),
 (13, (13,)),
 (15, (3, 5)),
 (17, (17,)),
 (19, (19,))]

为了利用这个数字和因子的展示方式,我们需要对原始答案中给出的函数进行一些修改:
def phi(n, upto=None, p_list=None):
    # Euler's totient or "phi" function
    if upto is None or upto > n:
        upto = n
    if p_list is None:
        p_list = list(distinct_factors(n))
    if upto < n:
        # custom version: all co-primes of n up to the `upto` bound
        cnt = upto
        for q in products_of(p_list, upto):
            cnt += upto // q if q > 0 else -(upto // -q)
        return cnt
    # standard formulation: all co-primes of n up to n-1
    cnt = n
    for p in p_list:
        cnt = cnt * (p - 1) // p
    return cnt

通过以上内容,我们现在可以重新编写计数函数:

def pt_count_m(N):
    # yield tuples (m, count(x) where 0 < x < m and odd(x)
    # and odd(m) and coprime(x, m) and m**2 + x**2 <= 2*N))
    # in this version, m is generated from primes, and the values
    # are iterated through unordered.
    mmax = isqrt(2*N - 1)
    p_list = list(primes(mmax))[1:]  # skip 2
    for m, p_distinct in gen_prod_primes(p_list, upto=mmax):
        if m < 3:
            continue
        # requirement: (m**2 + x**2) // 2 <= N
        # note, both m and x are odd (so (m**2 + x**2) // 2 == (m**2 + x**2) / 2)
        xmax = isqrt(2*N - m*m)
        cnt_m = phi(m+1, upto=xmax, p_list=(2,) + tuple(p_distinct))
        if cnt_m > 0:
            yield m, cnt_m

def pt_count(N, progress=False):
    mmax = isqrt(2*N - 1)
    it = pt_count_m(N)
    if progress:
        it = tqdm(it, total=(mmax - 3 + 1) // 2)
    return sum(cnt_m for m, cnt_m in it)

现在:

%timeit pt_count(100_000_000)
31.1 ms ± 38.9 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

%timeit pt_count(1_000_000_000)
104 ms ± 299 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

# the speedup is still very moderate at that stage

# however:
%%time
big_n = 3_141_592_653_589_793
N = big_n
res = pt_count(N)

CPU times: user 4min 5s, sys: 662 ms, total: 4min 6s
Wall time: 4min 6s

>>> res
500000000002845

附录:阿特金筛法

正如承诺的那样,这是我版本的阿特金筛法。它肯定可以加速。

def primes(limit):
    # Generates prime numbers between 2 and n
    # Atkin's sieve -- see http://en.wikipedia.org/wiki/Prime_number
    sqrtLimit = isqrt(limit) + 1

    # initialize the sieve
    is_prime = [False, False, True, True, False] + [False for _ in range(5, limit + 1)]

    # put in candidate primes:
    # integers which have an odd number of
    # representations by certain quadratic forms
    for x in range(1, sqrtLimit):
        x2 = x * x
        for y in range(1, sqrtLimit):
            y2 = y*y
            n = 4 * x2 + y2
            if n <= limit and (n % 12 == 1 or n % 12 == 5): is_prime[n] ^= True
            n = 3 * x2 + y2
            if n <= limit and (n % 12 == 7): is_prime[n] ^= True
            n = 3*x2-y2
            if n <= limit and x > y and n % 12 == 11: is_prime[n] ^= True

    # eliminate composites by sieving
    for n in range(5, sqrtLimit):
        if is_prime[n]:
            sqN = n**2
            # n is prime, omit multiples of its square; this is sufficient because
            # composites which managed to get on the list cannot be square-free
            for i in range(1, int(limit/sqN) + 1):
                k = i * sqN # k ∈ {n², 2n², 3n², ..., limit}
                is_prime[k] = False
    for i, truth in enumerate(is_prime):
        if truth: yield i

刚刚添加了我的新代码,使用了你的旧代码和一种不同类型的筛法。 - Koder

1

感谢Pierre,我找到了一个更快的解决方案。

以下是我的新代码与Pierre的融合,供有需要的人使用。

def sieve_factors(n):
    s = [0] * (n+1)
    s[1] = 1
    for i in range(2, n+1, 2):
        s[i] = 2
    for i in range(3, n+1, 2):
        if s[i] == 0:
            s[i] = i
            for j in range(i, n + 1, i):
                if s[j] == 0:
                    s[j] = i
    return s
Q = sieve_factors(2*(isqrt(2 * 3141592653589793) + 1))


def findfactors(n):
    global Q
    yield Q[n]
    last = Q[n]
    while n > 1:
        if Q[n] != last and Q[n] != 1:
            last = Q[n]
            yield Q[n]
        n //= Q[n]


def products_of(p_list, upto):
    for i, p in enumerate(p_list):
        if p > upto:
            break
        yield -p
        for q in products_of(p_list[i+1:], upto=upto // p):
            yield -p * q


def phi(n, upto=None):
    if upto is not None and upto < n:
        cnt = upto
        p_list = list(findfactors(n))
        for q in products_of(p_list, upto):
            cnt += upto // q if q > 0 else -(upto // -q)
        return cnt
    cnt = n
    for p in findfactors(n):
        cnt *= (1 - 1/p)
    return int(cnt)

def countprimtrips(n):
    cnt = 0
    for m in range(3, int(sqrt(2*n)) + 1, 2):
        xmax = int(sqrt(2*n - m**2))
        cnt += phi(2*m, upto=xmax) if xmax < m else phi(2*m) // 2
    return cnt

print(countprimtrips(3141592653589793))

如上面的答案所提到的,大部分时间都花在了因数分解上,所以我采用了他的代码,并添加了一个筛法,筛选出所有小于x-max的数字,每个数字都是它们最低质因数的索引。它在4分44秒内找到了答案(284.6727148秒)。感谢Pierre的帮助。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接