你能否仅修补一个带闭包的嵌套函数,还是必须重复整个外部函数?

68

我们使用的第三方库中包含一个相当长的函数,其中嵌套了一个函数。我们使用该库触发了该函数中的一个错误,我们非常希望解决这个问题。

不幸的是,库维护者修复问题的速度有些缓慢,但我们不想分叉该库。我们也不能等到他们修复问题再发布我们的产品。

我们更愿意使用 monkey-patching 来解决这个问题,因为这比修补源代码更容易跟踪。然而,如果只是替换内部函数就足够了,那么重复一个非常大的函数似乎有些过度,这使得其他人更难以看出我们究竟改变了什么。难道我们只能使用静态修补程序来修复这个库吗?

内部函数依赖于对变量的封闭;一个虚构的例子如下:

def outerfunction(*args):
    def innerfunction(val):
        return someformat.format(val)

    someformat = 'Foo: {}'
    for arg in args:
        yield innerfunction(arg)

我们希望只替换innerfunction()的实现,而不是外层函数。实际上,外层函数要长得多。当然,我们将重用闭合变量并保持函数签名。

4个回答

69
是的,即使内部函数使用了闭包,您也可以替换它。但是您需要跳过一些障碍。请注意以下几点:
1. 您需要将替换函数创建为嵌套函数,以确保Python创建相同的闭包。如果原始函数对名称foo和bar进行闭包,则需要定义具有相同名称的闭合的替换函数。更重要的是,您需要按照相同的顺序使用这些名称;闭包按索引引用。
2. 猴子补丁始终很脆弱,并且可能会因实现更改而中断。这也不例外。每当更改已修补程序库的版本时,请重新测试猴子补丁。
代码对象
为了理解这将如何工作,我首先将解释Python如何处理嵌套函数。Python使用代码对象根据需要生成函数对象。每个代码对象都有一个关联的常量序列,并且嵌套函数的代码对象存储在该序列中:
>>> def outerfunction(*args):
...     def innerfunction(val):
...         return someformat.format(val)
...     someformat = 'Foo: {}'
...     for arg in args:
...         yield innerfunction(arg)
... 
>>> outerfunction.__code__
<code object outerfunction at 0x105b27ab0, file "<stdin>", line 1>
>>> outerfunction.__code__.co_consts
(None, <code object innerfunction at 0x10f136ed0, file "<stdin>", line 2>, 'outerfunction.<locals>.innerfunction', 'Foo: {}')
co_consts序列是一个不可变对象,即一个元组,因此我们不能仅仅交换内部代码对象。稍后我将展示如何生成一个新的函数对象,只替换那个代码对象。

如何处理闭包

接下来,我们需要涉及闭包。在编译时,Python 确定了以下内容:
a) someformat 不是 innerfunction 中的局部名称,而
b) 它与 outerfunction 中的相同名称有关。
Python 不仅生成字节码以产生正确的名称查找,而且嵌套和外部函数的代码对象都被注释为记录要关闭 someformat
>>> outerfunction.__code__.co_cellvars
('someformat',)
>>> outerfunction.__code__.co_consts[1].co_freevars
('someformat',)

你需要确保替换的内部代码对象只列出相同名称的自由变量,并按照相同顺序进行列出。
闭包是在运行时创建的;生成它们的字节码是外部函数的一部分:
>>> import dis
>>> dis.dis(outerfunction)
  2           0 LOAD_CLOSURE             0 (someformat)
              2 BUILD_TUPLE              1
              4 LOAD_CONST               1 (<code object innerfunction at 0x10f136ed0, file "<stdin>", line 2>)
              6 LOAD_CONST               2 ('outerfunction.<locals>.innerfunction')
              8 MAKE_FUNCTION            8 (closure)
             10 STORE_FAST               1 (innerfunction)

# ... rest of disassembly omitted ...

LOAD_CLOSURE 字节码在这里为 someformat 变量创建了一个闭包;Python 会按照内部函数中首次使用的顺序创建所有使用的闭包。这是一个重要的事实,以后需要记住。函数本身通过位置查找这些闭包:

>>> dis.dis(outerfunction.__code__.co_consts[1])
  3           0 LOAD_DEREF               0 (someformat)
              2 LOAD_METHOD              0 (format)
              4 LOAD_FAST                0 (val)
              6 CALL_METHOD              1
              8 RETURN_VALUE
LOAD_DEREF 操作码在这里选择了位置为 0 的闭包,以便访问 someformat 闭包。
理论上,您可以在内部函数中使用完全不同的闭包名称,但为了调试方便,最好使用相同的名称。如果使用相同的名称,则验证替换函数是否正确插入变得更加容易,因为您可以比较 co_freevars 元组。

replace_inner_function()

现在是交换技巧时间了。函数与 Python 中的任何其他对象一样,是特定类型的实例。该类型通常不公开,但是 type() 调用仍然会返回它。代码对象也适用于此,并且两种类型都有文档说明。

>>> type(outerfunction)
<type 'function'>
>>> print(type(outerfunction).__doc__)
Create a function object.

  code
    a code object
  globals
    the globals dictionary
  name
    a string that overrides the name from the code object
  argdefs
    a tuple that specifies the default argument values
  closure
    a tuple that supplies the bindings for free variables
>>> type(outerfunction.__code__)
<type 'code'>
>>> print(type(outerfunction.__code__).__doc__)
code(argcount, posonlyargcount, kwonlyargcount, nlocals, stacksize,
      flags, codestring, constants, names, varnames, filename, name,
      firstlineno, lnotab[, freevars[, cellvars]])

Create a code object.  Not for the faint of heart.

(Python的确切参数计数和文档字符串因各个版本而异; Python 3.0添加了kwonlyargcount参数,截至Python 3.8,posonlyargcount已被添加。)
我们将使用这些类型对象来生成一个新的code对象并更新常量,然后使用更新后的code对象生成一个新的函数对象;以下功能与Python版本2.7到3.8兼容。
def replace_inner_function(outer, new_inner):
    """Replace a nested function code object used by outer with new_inner

    The replacement new_inner must use the same name and must at most use the
    same closures as the original.

    """
    if hasattr(new_inner, '__code__'):
        # support both functions and code objects
        new_inner = new_inner.__code__

    # find original code object so we can validate the closures match
    ocode = outer.__code__
    function, code = type(outer), type(ocode)
    iname = new_inner.co_name
    orig_inner = next(
        const for const in ocode.co_consts
        if isinstance(const, code) and const.co_name == iname)

    # you can ignore later closures, but since they are matched by position
    # the new sequence must match the start of the old.
    assert (orig_inner.co_freevars[:len(new_inner.co_freevars)] ==
            new_inner.co_freevars), 'New closures must match originals'

    # replace the code object for the inner function
    new_consts = tuple(
        new_inner if const is orig_inner else const
        for const in outer.__code__.co_consts)

    # create a new code object with the new constants
    try:
        # Python 3.8 added code.replace(), so much more convenient!
        ncode = ocode.replace(co_consts=new_consts)
    except AttributeError:
        # older Python versions, argument counts vary so we need to check
        # for specifics.
        args = [
            ocode.co_argcount, ocode.co_nlocals, ocode.co_stacksize,
            ocode.co_flags, ocode.co_code,
            new_consts,  # replacing the constants
            ocode.co_names, ocode.co_varnames, ocode.co_filename,
            ocode.co_name, ocode.co_firstlineno, ocode.co_lnotab,
            ocode.co_freevars, ocode.co_cellvars,
        ]
        if hasattr(ocode, 'co_kwonlyargcount'):
            # Python 3+, insert after co_argcount
            args.insert(1, ocode.co_kwonlyargcount)
        # Python 3.8 adds co_posonlyargcount, but also has code.replace(), used above
        ncode = code(*args)

    # and a new function object using the updated code object
    return function(
        ncode, outer.__globals__, outer.__name__,
        outer.__defaults__, outer.__closure__
    )

上述函数验证了新创建的内部函数(可以被传入作为代码对象或者函数)确实会使用和原始函数相同的闭包。它接着会创建新的代码和函数对象,以匹配旧的outer函数对象,但是将嵌套的函数(通过名称定位)替换为您的猴子补丁。
让我们试一下吧
为了证明上述内容可行,我们将innerfunction替换为一个使每个格式化值增加2的函数:
>>> def create_inner():
...     someformat = None  # the actual value doesn't matter
...     def innerfunction(val):
...         return someformat.format(val + 2)
...     return innerfunction
... 
>>> new_inner = create_inner()

新的内部函数也被创建为嵌套函数;这很重要,因为它确保Python将使用正确的字节码来查找someformat闭包。我使用了一个return语句来提取函数对象,但你也可以查看create_inner.__code__.co_consts来获取代码对象。
现在我们可以修补原始外部函数,仅交换内部函数。
>>> new_outer = replace_inner_function(outerfunction, new_inner)
>>> list(outerfunction(6, 7, 8))
['Foo: 6', 'Foo: 7', 'Foo: 8']
>>> list(new_outer(6, 7, 8))
['Foo: 8', 'Foo: 9', 'Foo: 10']

原始函数回显了原始值,而新返回的值增加了2。

您甚至可以创建使用更少闭包的新替换内部函数:

>>> def demo_outer():
...     closure1 = 'foo'
...     closure2 = 'bar'
...     def demo_inner():
...         print(closure1, closure2)
...     demo_inner()
...
>>> def create_demo_inner():
...     closure1 = None
...     def demo_inner():
...         print(closure1)
...
>>> replace_inner_function(demo_outer, create_demo_inner.__code__.co_consts[1])()
foo

简而言之

所以,为了完整地说明:

  1. 创建一个猴子补丁内部函数,作为一个嵌套函数,闭包的顺序与原来相同。
  2. 使用上述的replace_inner_function()函数来生成一个新的外部函数。
  3. 对原始的外部函数进行猴子补丁,使用步骤2中生成的新外部函数。

23

Martijn的回答很好,但有一个缺点是最好消除:

您要确保替换的内部代码对象仅列出与自由变量相同的名称,并以相同的顺序列出。

对于正常情况,这不是特别困难的约束条件,但依赖于未定义的行为(如名称排序)并且当事情出错时可能会产生非常严重的错误和可能甚至导致硬崩溃,这并不令人愉快。

我的方法也有其缺点,但在大多数情况下,我认为上述缺点会激励使用它。据我所知,它应该更可移植。

基本方法是使用inspect.getsource加载源代码,修改然后评估它。这是在AST级别完成的,以保持顺序。

以下是代码:

import ast
import inspect
import sys

class AstReplaceInner(ast.NodeTransformer):
    def __init__(self, replacement):
        self.replacement = replacement

    def visit_FunctionDef(self, node):
        if node.name == self.replacement.name:
            # Prevent the replacement AST from messing
            # with the outer AST's line numbers
            return ast.copy_location(self.replacement, node)

        self.generic_visit(node)
        return node

def ast_replace_inner(outer, inner, name=None):
    if name is None:
        name = inner.__name__

    outer_ast = ast.parse(inspect.getsource(outer))
    inner_ast = ast.parse(inspect.getsource(inner))

    # Fix the source lines for the outer AST
    outer_ast = ast.increment_lineno(outer_ast, inspect.getsourcelines(outer)[1] - 1)

    # outer_ast should be a module so it can be evaluated;
    # inner_ast should be a function so we strip the module node
    inner_ast = inner_ast.body[0]

    # Replace the function
    inner_ast.name = name
    modified_ast = AstReplaceInner(inner_ast).visit(outer_ast)

    # Evaluate the modified AST in the original module's scope
    compiled = compile(modified_ast, inspect.getsourcefile(outer), "exec")
    outer_globals = outer.__globals__ if sys.version_info >= (3,) else outer.func_globals
    exec_scope = {}

    exec(compiled, outer_globals, exec_scope)
    return exec_scope.popitem()[1]

一个快速的概述。 AstReplaceInner 是一个继承自 ast.NodeTransformer 的类,它允许你通过将某些节点映射到其他节点来修改 ASTs。在这种情况下,它接受一个 replacement 节点,并在名称匹配时用其替换 ast.FunctionDef 节点。

ast_replace_inner 是我们真正关心的函数,它接受两个函数和可选的名称参数。该名称参数用于允许使用不同名称的另一个函数替换内部函数。

ASTs 被解析:

    outer_ast = ast.parse(inspect.getsource(outer))
    inner_ast = ast.parse(inspect.getsource(inner))

转换已完成:

    modified_ast = AstReplaceInner(inner_ast).visit(outer_ast)

代码被评估并提取函数:

    exec(compiled, outer_globals, exec_scope)
    return exec_scope.popitem()[1]

这里有一个使用示例。假设这段旧代码在 buggy.py 文件中:

def outerfunction():
    numerator = 10.0

    def innerfunction(denominator):
        return denominator / numerator

    return innerfunction
你想用新的函数替换 innerfunction
def innerfunction(denominator):
    return numerator / denominator
您写道:
import buggy

def innerfunction(denominator):
    return numerator / denominator

buggy.outerfunction = ast_replace_inner(buggy.outerfunction, innerfunction)

或者,您可以写成:

def divide(denominator):
    return numerator / denominator

buggy.outerfunction = ast_replace_inner(buggy.outerfunction, divide, "innerfunction")
该技术的主要缺点是需要inspect.getsource能够在目标和替换代码上均起作用。如果目标是“内置”(用C编写)或在分发前编译为字节码,则无法实现此功能。请注意,如果目标是内置的,则Martijn的技术也无法工作。
另一个主要的缺点是内部函数的行号完全混乱。如果内部函数很小,则这不是一个大问题,但如果您有一个大的内部函数,则应该考虑这一点。
其他缺点来自于如果函数对象未以相同方式指定。例如,您无法修补以下内容:
def outerfunction():
    numerator = 10.0

    innerfunction = lambda denominator: denominator / numerator

    return innerfunction

需要相同的方式;对于不同的AST转换,需要进行调整。

您应该决定哪种权衡对于您特定的情况最有意义。


这会如何与回溯交互?由于源代码已重新生成,且AST未跟踪文件名,我怀疑修补版本中的源代码行将不正确,特别是当行数发生变化时。目前没有测试的机会,但肯定是要考虑的一个方面。 - Martijn Pieters
现在我可以测试这个功能了;内部和外部函数的源代码所有偏移都被重置为从1开始。这意味着所有的回溯将包括完全错误的源行。此外,内部函数的偏移量是针对原始文件使用的。使用调试器也是同样的情况。外部函数的偏移量可以被修复,但是内部函数的代码需要重新构建以首先指向正确的文件名。 - Martijn Pieters
3
@MartijnPieters 我已经修复了外部函数的行号,这很简单,但正如你所说,没有明显的方法来修复内部函数的行号。我将编辑答案以提到这个缺点。 - Veedrac

4

我需要这个功能,但是需要在一个类和Python2/3中使用。所以我对@MartijnPieters的解决方案进行了一些扩展。

import types, inspect, six

def replace_inner_function(outer, new_inner, class_class=None):
    """Replace a nested function code object used by outer with new_inner

    The replacement new_inner must use the same name and must at most use the
    same closures as the original.

    """
    if hasattr(new_inner, '__code__'):
        # support both functions and code objects
        new_inner = new_inner.__code__

    # find original code object so we can validate the closures match
    ocode = outer.__code__

    iname = new_inner.co_name
    orig_inner = next(
        const for const in ocode.co_consts
        if isinstance(const, types.CodeType) and const.co_name == iname)
    # you can ignore later closures, but since they are matched by position
    # the new sequence must match the start of the old.
    assert (orig_inner.co_freevars[:len(new_inner.co_freevars)] ==
            new_inner.co_freevars), 'New closures must match originals'
    # replace the code object for the inner function
    new_consts = tuple(
        new_inner if const is orig_inner else const
        for const in outer.__code__.co_consts)

    if six.PY3:
        new_code = types.CodeType(ocode.co_argcount, ocode.co_kwonlyargcount, ocode.co_nlocals, ocode.co_stacksize,
             ocode.co_flags, ocode.co_code, new_consts, ocode.co_names,
             ocode.co_varnames, ocode.co_filename, ocode.co_name,
             ocode.co_firstlineno, ocode.co_lnotab, ocode.co_freevars,
             ocode.co_cellvars)
    else:
    # create a new function object with the new constants
        new_code = types.CodeType(ocode.co_argcount, ocode.co_nlocals, ocode.co_stacksize,
             ocode.co_flags, ocode.co_code, new_consts, ocode.co_names,
             ocode.co_varnames, ocode.co_filename, ocode.co_name,
             ocode.co_firstlineno, ocode.co_lnotab, ocode.co_freevars,
             ocode.co_cellvars)

    new_function= types.FunctionType(new_code, outer.__globals__, 
                                     outer.__name__, outer.__defaults__,
                                     outer.__closure__)

    if hasattr(outer, '__self__'):
        if outer.__self__ is None:
            if six.PY3:
                return types.MethodType(new_function, outer.__self__, class_class)
            else:
                return types.MethodType(new_function, outer.__self__, outer.im_class)
        else:
            return types.MethodType(new_function, outer.__self__, outer.__self__.__class__)

    return new_function

现在,这应该适用于函数、绑定类方法和未绑定的类方法。 (class_class参数仅对python3中的未绑定方法需要) 感谢@MartijnPieters完成了大部分工作!我从来没想过会做到这一点 ;)


不应该修补限定方法。方法对象具有 __func__ 属性,在传递之前正确地解包方法即可,也无需担心类(否则可以从 type(method.__self__) 中获取)。您需要最终得到一个函数对象以添加到类中,仅修补该函数一次,而不是反复修补方法。 - Martijn Pieters
如果您还需要支持classmethod对象,请使用ClassObject.__dict__['classmethod_name']获取未绑定的classmethod对象,访问__func__属性,对其进行修补,然后在将结果重新包装为classmethod()之前放回类中。 - Martijn Pieters
此外,当__self__None时,您也不需要传入类型,因此在Python 3中使用MethodType(new_function, None, None)。实际上,我不确定在Py3中何时会出现outer.__self__ is None为真的情况(classmethod对象在绑定时将__self__设置为类,而函数对象不再像在Python 2中那样绑定到类)。 - Martijn Pieters

0

我喜欢其他答案的想法。这里是另一种重新编译源代码的方法。

这种方法的好处是相对容易编码,而且对于查看代码的人来说,很容易知道哪些代码行发生了变化。缺点是:它非常低级,并且如果源代码即使有微小的更改,比如添加一个空行,它就会变得非常脆弱。

所以,如果buggy.py是有问题的源代码:

def outerfunction(*args):
    def innerfunction(val):
        return someformat.format(val)

    someformat = 'Foo: {}'
    for arg in args:
        yield innerfunction(arg)

我们想要将内部函数实现替换为return someformat.format(val * 2) - 即将发送的值加倍。我们可以这样做:

from _pytest._code import Code
from buggy import outerfunction


def change_function_source_code(func):
    context = getattr(func, "__globals__", {})
    code = Code.from_function(func)
    source = code.source()

    # this part can be made more generic
    # this is where you put the code changes you want
    # for simplicity, I just replaced the third line
    new_body = source.lines[0:2] + \
               ["        return someformat.format(val * 2)"] + \
               source.lines[4:]

    compiled = compile("\n".join(new_body), str(code.path), "exec")
    exec(compiled, context)

    return context[func.__name__]


def test_buggy_function_before_after():
    # before the change
    assert list(outerfunction(6, 7, 8)) == ['Foo: 6', 'Foo: 7', 'Foo: 8']

    # do the source code change in our local module only
    new_outerfunction = change_function_source_code(outerfunction)

    # to change the behaviour of the old function for everyone:
    #   import buggy
    #   change_function_source_code(buggy.outerfunction)

    # after the change - new function changed behaviour, while the old did not
    assert list(new_outerfunction(6, 7, 8)) == ['Foo: 12', 'Foo: 14', 'Foo: 16']
    assert list(outerfunction(6, 7, 8)) == ['Foo: 6', 'Foo: 7', 'Foo: 8']

这要求源代码是可用的。 - Martijn Pieters

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接