将整数写成k次幂不同整数的和

3

给定一个整数 n(n≥1)和一个数字 k,返回将 n 写成 k 次方不同整数的所有可能方式。 例如,如果 n = 100 并且 k = 2:

100 = 1**2 + 3**2 + 4**2 + 5**2 + 7**2
    = 6**2 + 8**2
    = 10**2

或者如果 k = 3:

100 = 1**3 + 2**3 + 3**3 + 4**3

所以program(100,2)返回类似于[(2, [1, 3, 4, 5, 7]), (2, [6, 8]), (2, [10])]的东西, 而program(100,3)返回[(3, [1, 2, 3, 4])]

只要输入n很小或k很大(>=3),一切都正常运作。 我的方法是首先获取所有整数的列表,其k次方为<= n。

def possible_powers(n,k):
    arr,i = [],1
    while i**k <= n:
        arr.append(i)
        i += 1
    return arr

然后(这里犯了一个错误),我创建了这个列表的所有子集(作为列表):
def subsets(L):
    subsets = [[]]
    for i in L:
        subsets += [subset+[i] for subset in subsets]
    return subsets

最后,我循环遍历了所有这些子集,将每个元素提高到k次方并相加,只选择那些总和为n的子集。
def kth_power_sum(arr,k):
    return sum([i**k for i in arr])

def solutions(n,k):
    return [(k,i) for i in subsets(possible_powers(n,k)) if kth_power_sum(i,k) == n]

我知道问题出在子集创建上,但我不知道如何进行优化。比如说,如果我尝试solutions(1000,2),它会创建一个很大的集合,占用超过4GB的内存。我的猜测是筛选出一些子集,但除非我有一个非常高效的筛选方法,否则这并不能帮助太多。

非常感谢任何帮助。如果有什么不清楚的地方,或者我在发布此内容时犯了错误,请告诉我。


1
我现在没有时间深入研究这个问题,但你可以看一下这个链接:https://dev59.com/m3M_5IYBdhLWcg3wMgEF - Swifty
itertools 文档 中,你可以学习到如何懒惰地计算函数 powerset:每次只有一个子集被存储在内存中。 - Jorge Luis
感谢 @Swifty 的建议。我浏览了那个帖子,但我不认为我的 powerset 函数有问题(我错了吗?),但事实上,我创建了整个 powerset,对于 subsets(1000,2) 来说,这是一个 2**31 长度的列表(不太优化)。 - Reggie Floarde
我认为这里的解决方法是使用一些动态规划算法。 - Ohad Sharet
看看这个答案:找到所有可能的数字组合以达到给定的总和 ,将其修改为使用幂似乎非常直接。 - Jorge Luis
显示剩余4条评论
5个回答

5
如果您将其实现为递归生成器,就不需要存储大量的值(甚至不需要结果):
def powersum(n,k,b=1):
    bk = b**k
    while bk < n:
        for bases in powersum(n-bk,k,b+1):
            yield [b]+bases
        b += 1
        bk = b**k
    if bk == n :
        yield [b]
        
print(*powersum(100,2)) # [1, 3, 4, 5, 7] [6, 8] [10]
print(*powersum(100,3)) # [1, 2, 3, 4]
print(sum(1 for _ in powersum(1000,2))) # 1269 solutions
print(sum(1 for _ in powersum(2000,2))) # 27526 solutions (6 seconds)     

请注意,这仍然具有指数时间复杂度,因此对于稍大的n值,速度会慢得多。
print(sum(1 for _ in powersum(2200,2))) # 44930 solutions (12 seconds)
print(sum(1 for _ in powersum(2500,2))) # 91021 solutions (25 seconds)
print(sum(1 for _ in powersum(2800,2))) # 175625 solutions (55 seconds)
print(sum(1 for _ in powersum(3100,2))) # 325067 solutions (110 seconds)

[编辑] 供参考,这是Kelly Bundy的缓存版本,运行速度更快。在此发布以防他的演示链接失效:

from functools import cache
from time import time

@cache                  
def powersum(n,k,b=1):
    bk = b**k
    res = []
    while bk < n:
        for bases in powersum(n-bk,k,b+1):
            res.append([b]+bases)
        b += 1
        bk = b**k
    if bk == n :
        res.append([b])
    return res

注意:尽管缓存会消耗一些内存,但它远远不会达到所有可能组合的完整幂集的大小。


不错的回答。在6秒内完成2000.2是非常令人印象深刻的。 - Amit
你可以通过返回列表而不是迭代器并使用 functools.cache 进行装饰来加速它。 - Kelly Bundy
我曾尝试使用lru_cache和列表输出,但好像并没有产生实质性的差异(可能是我没有正确使用,或者参数模式重用不够)。无论如何,对于这个特定的答案,我的目标是尽可能地减少内存使用。 - Alain T.
你用了什么限制?使用@cache,我在0.4秒内得到2000,2和3秒内得到3100,2。演示 - Kelly Bundy
我正在使用Python 3.7,忘记在@lru_cache()中指定一个数字,因此我没有看到任何改进并得出了错误的结论。将它设置为一百万(而不是默认的128)确实提供了预期的性能提升。 - Alain T.
好的 :-). 或者使用 lru_cache(None)。一些统计数据:对于 2000,2,未缓存的函数被调用了 8,944,901 次。当缓存时,它只被调用了 31,840 次,并且 powersum.cache_info() 报告了 CacheInfo(hits=182585, misses=31840, maxsize=None, currsize=31840) - Kelly Bundy

1

@Reggie-florade和@Ohad-sharet,你们可以让代码运行得更快。请看下面的代码。我的解决方案与@Ohad-sharet非常相似,但运行速度明显更快。我的函数名是"kth_power_and_n"。

from itertools import combinations
import time
import math
import pandas as pd

def my_solutions_one_linner(n,k):
    subsets = []
    kth_root = int(n**(1/k))# kth_root of n is the upprt bound
    relevant_list = range(1,kth_root+1)
    ans =  [[list(comb) for comb in combinations(relevant_list, comb_size) if sum([x**k for x in comb]) == n] for comb_size in range(1,kth_root+1)]
    return [lst for lst in ans if lst!=[]]  

def kth_power_and_n(n, k):
    UpperLimit = int(math.floor(n**(1/k)))
    Matrix = list(map(lambda x: x**k,range(1,UpperLimit+1)))
    Final_Solution_List = []
    for i in range(1,len(Matrix)+1):
        Generated_Combinations = combinations(Matrix,i)
        for eachCombination in Generated_Combinations:
            if sum(eachCombination) == n:
                Final_Solution_List.append(sorted(list(map(lambda x: int(round(x**(1/k))), eachCombination))))
    return Final_Solution_List

Solution_Dict = {"n":[],"k":[],"time taken 1":[],"time taken 2":[],"Solution Found 1":[],"Solution Found 2":[]}

                
for n in [50,100,250,500]:
    for k in [2,3,4]:
        ## Previous Solution time taken
        Start = time.time()
        Solution = my_solutions_one_linner(n,k)
        Time_Taken = time.time()-Start
        Solution_Dict["n"].append(n)
        Solution_Dict["k"].append(k)
        Solution_Dict["time taken 1"].append(Time_Taken)        
        if Solution:
            Solution_Dict["Solution Found 1"].append(True)
        else:
            Solution_Dict["Solution Found 1"].append(False)

        ## Significantly faster code
        Start = time.time()
        Solution = kth_power_and_n(n,k)
        Time_Taken = time.time()-Start
        Solution_Dict["time taken 2"].append(Time_Taken)        
        if Solution:
            Solution_Dict["Solution Found 2"].append(True)
        else:
            Solution_Dict["Solution Found 2"].append(False)


DF_One_Liner = pd.DataFrame(data=Solution_Dict)
DF_One_Liner["Time Gain Ratio"] = DF_One_Liner["time taken 1"]/DF_One_Liner["time taken 2"]
print(DF_One_Liner)

以下是两种方法的时间比较。
      n  k  time taken 1  time taken 2  Solution Found 1  Solution Found 2  Time Gain Ratio
0    50  2      0.000170      0.000047              True              True         3.627551
1    50  3      0.000012      0.000007             False             False         1.758621
2    50  4      0.000007      0.000005             False             False         1.333333
3   100  2      0.001503      0.000185              True              True         8.142119
4   100  3      0.000019      0.000011              True              True         1.723404
5   100  4      0.000010      0.000006             False             False         1.833333
6   250  2      0.067290      0.006028              True              True        11.163476
7   250  3      0.000073      0.000018             False             False         4.135135
8   250  4      0.000010      0.000006             False             False         1.720000
9   500  2     11.933419      0.870651              True              True        13.706314
10  500  3      0.000185      0.000036             False             False         5.132450
11  500  4      0.000019      0.000009             False             False         2.135135

非常令人印象深刻,你认为你能解释一下为什么你的代码运行得更快吗? - Ohad Sharet

1

你的问题在于在subsets中保留了太多无关的组合,我在这里写了一些代码,只保存相关的列表,这样应该就可以解决问题了, (我将其分成多行以便更容易理解)

def my_solutions(n,k):
    subsets = []
    kth_root = int(n**(1/k))# kth_root of n is the upper bound
    relevant_list = range(1,kth_root+1)
    for comb_size in range(1,kth_root+1):
        for comb in combinations(relevant_list, comb_size):
            sum_comb = sum([x**k for x in comb])
            if (sum_comb == n):
                subsets.append(list(comb))
    return subsets

这里是相同的代码,但是用一行写成以最大化效率。

def my_solutions_one_linner(n,k):
    subsets = []
    kth_root = int(n**(1/k))# kth_root of n is the upprt bound
    relevant_list = range(1,kth_root+1)
    ans =  [[list(comb) for comb in combinations(relevant_list, comb_size) if sum([x**k for x in comb]) == n] for comb_size in range(1,kth_root+1)]
    return [lst for lst in ans if lst!=[]]

你认为最优解是使用map吗?还是你有其他的想法?编辑:我刚刚尝试了使用map,结果稍微差了一些,所以如果你知道更好的解决方法,我想知道 :) - Ohad Sharet
谢谢你的提示,那个很不错 :) - Ohad Sharet
1
我刚试了一下,但不知为何它也稍微慢了一些(虽然我本以为会更快),无论如何,感谢你的提示,我正在尝试编写更高效的Python代码,非常感谢任何帮助 :) - Ohad Sharet
我发布了一个更快的代码。 - Amit

1

我的解决方案基于这个答案,该答案是针对问题查找所有可能的数字组合以达到给定总和。链接问题中的目标是通过将列表numbers中的元素相加来达到值target。在我们的情况下,numbers将是从1target**(1/k)的整数,然后再将其提高到k次方。

这是另一个答案中的函数。 唯一的修改是我们知道numbers是按递增顺序排列的。 这意味着我们可以更早地停止迭代。 如果numbers[0]太大,那么接下来的幂也会很大。

def subset_sum(numbers: list[int], target: int, partial: list[int]=[], partial_sum: int=0):
    if partial_sum == target:
        yield partial
        return
    if numbers and (partial_sum + numbers[0]) > target:
        return
    for i, n in enumerate(numbers):
        remaining = numbers[i + 1:]
        yield from subset_sum(remaining, target, partial + [n], partial_sum + n)

现在我们来正确地调用这个函数。这里的重点是subset_sum需要幂,但我们的解决方案将使用未进行幂运算的数字。为了达到这个目的,让我们定义一个字典,在这个字典中,键是幂,而值是我们答案中的数字。然后,我们只需要给subset_sum一个包含我们字典键的列表就可以了。
def solutions(n, k):
    kth_root = next(i for i in itertools.count() if i**k>n) - 1
    powers_dict = {i**k: i for i in range(1, kth_root+1)}
    powers = list(powers_dict.keys())
    return [[powers_dict[i] for i in answer] for answer in subset_sum(powers, n)]

在我的电脑上,这段代码完成需要9秒钟 solutions(2000, 2)


0

好的,只是为了好玩,我尝试编写了一个递归函数来完成这个任务;以下是结果(power_sum(2000,2)在我的机器上大约10秒钟后得出答案,所以虽然它可能不是最优的,但仍然不错)。

from math import inf
def power_sums(total, p, max_allowed=inf):
    nmax = min(int(total**(1/p)), max_allowed)
    for i in range(nmax, 0, -1):
        if (ip := i**p) == total:
            yield [i]
        else:
            for x in power_sums(total - ip, p, i-1):
                yield [i] + x

哦,刚才我发现了Alain T.的答案...我应该学会阅读 :) - Swifty

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接