将递归的 Python 代码转换为非递归版本。

3
代码在不增加不同符号、n-符号和长度的情况下可以正常工作。例如,在我的计算机上,当 n_symbols=512,length=512,distinct=300 时,会出现以下错误: RecursionError: maximum recursion depth exceeded in comparison,如果我增加 lru_cache 值,则会出现溢出错误。
我希望有一个非递归版本的代码。
from functools import lru_cache
@lru_cache
def get_permutations_count(n_symbols, length, distinct, used=0):
    '''
     - n_symbols: number of symbols in the alphabet
     - length: the number of symbols in each sequence
     - distinct: the number of distinct symbols in each sequence
    '''
    if distinct < 0:
        return 0
    if length == 0:
        return 1 if distinct == 0 else 0
    else:
        return \
          get_permutations_count(n_symbols, length-1, distinct-0, used+0) * used + \
          get_permutations_count(n_symbols, length-1, distinct-1, used+1) * (n_symbols - used)

那么

get_permutations_count(n_symbols=300, length=300, distinct=270)

程序运行约0.5秒即可得出答案。

2729511887951350984580070745513114266766906881300774347439917775
7093985721949669285469996223829969654724957176705978029888262889
8157939885553971500652353177628564896814078569667364402373549268
5524290993833663948683375995196081654415976659499171897405039547
1546236260377859451955180752885715923847446106509971875543496023
2494854876774756172488117802642800540206851318332940739395445903
6305051887120804168979339693187702655904071331731936748927759927
3688881301614948043182289382736687065840703041231428800720854767
0713406956719647313048146023960093662879015837313428567467555885
3564982943420444850950866922223974844727296000000000000000000000
000000000000000000000000000000000000000000000000

如果我增加 lru_cache 值会发生“溢出错误”是什么意思? - Pychopath
你能否用闭合形式编写它,调用 math.factorialmath.combmath.perm?这样会非常快。否则,将递归函数转换为迭代函数需要在列表上推送和弹出值,而不是在堆栈上。 - Jerry101
这是在计算什么?函数名称包含“排列”一词,但似乎与排列无关。 - Pychopath
1个回答

4

这是我的:

def get_permutations_count_improved(n_symbols, length, distinct):
    if distinct > length or distinct > n_symbols:
        return 0
    ways = [1]
    for _ in range(length):
        ways = [used * (distinct - d) + new
               for d, used, new in zip(range(distinct+1), [*ways, 0], [0, *ways])]
    return ways[distinct] * comb(n_symbols, distinct) * factorial(distinct)

一些参数集的速度比较:

n_symbols length distinct   yours    mine
   300      300    270      0.62 s   0.012 s (~51 times faster)
   512      512    300        -      0.035 s
  1024     1024    600        -      0.22 s
  3000     3000   2700        -      6.0 s

在我的最后一行中,你可以看到我将总结果分为三个因素:
- `comb(n_symbols, distinct)` 用于选择从 n_symbols 符号中选择哪些 distinct 符号实际使用。这本质上消除了 n_symbols 参数,也可以认为是补偿设置 `n_symbols = distinct`。 - `factorial(distinct)` 用于确定符号首次使用的顺序。这样就可以消除您递归中的 `* (n_symbols - used)`。 - `ways[distinct]` 是使用恰好 distinct 个不同符号构建长度为 length 的序列的方式数量,其中固定了它们首次使用的顺序。
更容易理解的方法可能是将 ways 表视为二维表格:`ways[length][distinct]`。但为了更节省内存,我逐行计算并仅保留最新的行。

基准测试和一些正确性检查 (在线尝试!):

from timeit import timeit
from functools import lru_cache
from math import comb, factorial

@lru_cache
def get_permutations_count(n_symbols, length, distinct, used=0):
    '''
     - n_symbols: number of symbols in the alphabet
     - length: the number of symbols in each sequence
     - distinct: the number of distinct symbols in each sequence
    '''
    if distinct < 0:
        return 0
    if length == 0:
        return 1 if distinct == 0 else 0
    else:
        return \
          get_permutations_count(n_symbols, length-1, distinct-0, used+0) * used + \
          get_permutations_count(n_symbols, length-1, distinct-1, used+1) * (n_symbols - used)

def get_permutations_count_improved(n_symbols, length, distinct):
    if distinct > length or distinct > n_symbols:
        return 0
    ways = [1]
    for _ in range(length):
        ways = [used * (distinct - d) + new
               for d, used, new in zip(range(distinct+1), [*ways, 0], [0, *ways])]
    return ways[distinct] * comb(n_symbols, distinct) * factorial(distinct)

funcs = get_permutations_count, get_permutations_count_improved

# Check correctness
stop = 20
for a in range(stop):
    for b in range(stop):
        for c in range(stop):
            expect = get_permutations_count(a, b, c)
            result = get_permutations_count_improved(a, b, c)
            assert result == expect, (a, b, c, expect, result)

# Benchmark
n_symbols, length, distinct = 300, 300, 270
#n_symbols, length, distinct = 512, 512, 300
#n_symbols, length, distinct = 1024, 1024, 600
#n_symbols, length, distinct = 3000, 3000, 2700
for func in funcs[0:] * 3:
    funcs[0].cache_clear()
    t = timeit(lambda: func(n_symbols, length, distinct), number=1)
    print('%.3f seconds ' % t, func.__name__)

谢谢。我会给它一个赏金。 - Eftekhari
@Bob,你是在这个问题中写的解决方案,对吧?感谢你的贡献,部分原因是因为它使得任务相当容易理解,比你之前写的那个问题要容易些 :-) - Pychopath
是的,是我@Pychopath,我甚至提供了一个迭代器在这里,但那里的想法是给出完整的计数表,供其他方法使用,我认识到你的实现有两个不同之处。你没有遍历零,构建三角形矩阵。你添加了new而不是new * d,然后你把乘法留到最后执行。这很好,因为所有中间值都更小,因此更快。 - Bob
@Pychopath,你可以帮忙回答这个问题吗?>> https://dev59.com/_sPra4cB1Zd3GeqPnLsh - Eftekhari
@Eftekhari 我会尝试。顺便问一下,你是出于兴趣还是需要用于某个项目?这似乎是不寻常的事情。 - Pychopath
显示剩余2条评论

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接