为什么使用functools.lru_cache会破坏这个函数?

11
考虑下面的函数,它返回一个元素集合的所有唯一排列组合:

Consider the following function, which returns all the unique permutations of a set of elements:

def get_permutations(elements):
    if len(elements) == 0:
        yield ()
    else:
        unique_elements = set(elements)
        for first_element in unique_elements:
            remaining_elements = list(elements)
            remaining_elements.remove(first_element)
            for subpermutation in get_permutations(tuple(remaining_elements)):
                yield (first_element,) + subpermutation

for permutation in get_permutations((1, 1, 2)):
    print(permutation)

这打印出来

(1, 1, 2)
(1, 2, 1)
(2, 1, 1)

正如预期的那样。 但是,当我添加lru_cache修饰符来缓存该函数时:

import functools

@functools.lru_cache(maxsize=None)
def get_permutations(elements):
    if len(elements) == 0:
        yield ()
    else:
        unique_elements = set(elements)
        for first_element in unique_elements:
            remaining_elements = list(elements)
            remaining_elements.remove(first_element)
            for subpermutation in get_permutations(tuple(remaining_elements)):
                yield (first_element,) + subpermutation

for permutation in get_permutations((1, 1, 2)):
    print(permutation)
它会打印出以下内容:
(1, 1, 2)
为什么只打印第一个排列?
1个回答

24

lru.cache 函数会缓存你的函数的返回值。你的函数返回了一个生成器。生成器有状态并且可以被耗尽(也就是说,到达生成器的末尾并且没有更多的项目被 yield)。与未装饰版本的函数不同,每次使用给定的一组参数调用该函数时,LRU 缓存会为你提供完全相同的生成器对象。这样做是很好的,因为这就是它的作用!

但是,你要缓存的一些生成器可能被使用多次,并在第二次和后续使用时部分或完全被耗尽。(它们甚至可能同时出现多次)

为了解释你得到的结果,考虑当elements的长度为 0 并且您第一次 yield ()... 的时候会发生什么... 下一次调用此生成器时,它已经处于末尾,并且根本不会 yield 任何内容。因此,你的子排列循环什么也不做,并且从它那里不再产生任何东西。由于这是递归的“最基本情况”,因此对程序产生期望值具有关键作用,丢失它会破坏程序产生期望值的能力。

对于 (1,) 的生成器也被使用了两次,这在它达到 () 之前就破坏了第三个结果。

为了了解正在发生什么,将 print(elements) 添加为函数的第一行(并在主 for 循环中添加某种标记以区分它们)。 然后比较 memoized 版本与原始版本的输出。

看起来你可能希望有一种方法来缓存生成器的结果。在这种情况下,你需要将其编写成一个返回所有项目列表的函数(而不是逐个 yield 一个项目),然后对其进行 memoize。


4
易懂好记,通俗易懂。 - Denny Weinberg

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接