我们如何将一个Python上下文管理器与出现在其块中的变量"关联"起来?

11
据我所知,上下文管理器用于在Python中为对象定义初始化和终止代码片段(__enter____exit__)。然而,在PyMC3的教程中,他们展示了以下上下文管理器示例:
basic_model = pm.Model()

with basic_model:

    # Priors for unknown model parameters
    alpha = pm.Normal('alpha', mu=0, sd=10)
    beta = pm.Normal('beta', mu=0, sd=10, shape=2)
    sigma = pm.HalfNormal('sigma', sd=1)

    # Expected value of outcome
    mu = alpha + beta[0]*X1 + beta[1]*X2

    # Likelihood (sampling distribution) of observations
    Y_obs = pm.Normal('Y_obs', mu=mu, sd=sigma, observed=Y)

并提到这样做的目的是将变量alphabetasigmamuY_obs与基本模型basic_model关联起来。
我想了解这种机制是如何工作的。在我找到的上下文管理器的解释,我没有看到任何建议在上下文块中定义的变量或对象如何与上下文管理器“关联”。似乎该库(PyMC3)以某种方式可以访问“当前”上下文管理器,以便在幕后将每个新创建的语句与其关联起来。但是该库如何访问上下文管理器呢?

2
这可以通过实现 __enter__ 将信息推送到线程本地堆栈来完成。 - donkopotamus
@donkopotamus 是的 - dhke
3个回答

10

PyMC3通过在Context类中维护一个线程本地变量作为类变量来实现此功能。 ModelContext继承。

每次调用模型上的with时,当前模型都会被推送到线程特定的上下文堆栈上。因此,堆栈顶部始终引用最内层(最近)的用作上下文管理器的模型。

Context(因此Model)具有一个类方法{{link2:.get_context()}}来获取上下文堆栈的顶部。

Distribution在创建时调用Model.get_context()来将自己与最内层的模型关联起来。

简而言之:

  1. with modelmodel推入上下文堆栈。这意味着在with块内,type(model).contextsModel.contextsContext.contexts现在将model作为其最后(顶部)元素。
  2. Distribution.__init__()调用Model.get_context()(注意大写的M),返回上下文堆栈的顶部。在我们的例子中,这是model。上下文堆栈是线程本地的(每个线程一个),但它不是实例特定的。如果只有一个线程,则也只有一个上下文堆栈,而不管模型数量。
  3. 退出上下文管理器时,model从上下文堆栈中弹出。

非常感谢。但是有一个明显的循环,我不太明白:您从上下文堆栈中获取模型/上下文,该堆栈是从模型/上下文(get_context)获取的,而您又从上下文堆栈中获取了它,这是从模型/上下文获取的...那么分布式如何首先访问模型/上下文或上下文堆栈呢? - user118967
1
我可能需要强调一下get_context()是一个类方法,而上下文堆栈是一个线程本地的类变量。get_context()不是在模型实例上调用的,而是在Model类上调用的。 - dhke

4

我不知道在这种特定情况下它是如何工作的,但通常你会使用一些“幕后魔法”:

class Parent:
    def __init__(self):
        self.active_child = None

    def ContextManager(self):
        return Child(self)

    def Attribute(self):
        return self.active_child.Attribute()

class Child:
    def __init__(self,parent):
        self.parent = parent

    def __enter__(self):
        self.parent.active_child = self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.parent.active_child = None

    def Attribute(self):
        print("Called Attribute of child")

使用以下代码:
p = Parent()
with p.ContextManager():
    attr = p.Attribute()

将产生以下输出:
Called Attribute of child

2
谢谢,但我还是没明白。我的想法是要像这样做: with contextmanager : foo() 然后让foo()以某种方式访问contextmanager。在你的例子中,似乎是通过保留p并将其用作桥梁来实现这一点,这似乎不能达到相同的目标。 - user118967
1
@user118967 但这是发生在你的代码片段中。pm被用作桥梁。(你调用pm.Model()pm.Normal(...),然后将ModelNormal的返回值连接起来。) - MegaIng

1

在进入和退出上下文管理器块时,还可以检查堆栈中的locals()变量,并确定哪些变量已更改。

class VariablePostProcessor(object):
    """Context manager that applies a function to all newly defined variables in the context manager.

    with VariablePostProcessor(print):
        a = 1
        b = 3

    It uses the (name, id(obj)) of the variable & object to detect if a variable has been added.
    If a name is already binded before the block to an object, it will detect the assignment to this name
    in the context manager block only if the id of the object has changed.

    a = 1
    b = 2
    with VariablePostProcessor(print):
        a = 1
        b = 3
    # will only detect 'b' has newly defined variable/object. 'a' will not be detected as it points to the
    # same object 1
    """

    @staticmethod
    def variables():
        # get the locals 2 stack above
        # (0 is this function, 1 is the __init__/__exit__ level, 2 is the context manager level)
        return {(k, id(v)): v for k, v in inspect.stack()[2].frame.f_locals.items()}

    def __init__(self, post_process):
        self.post_process = post_process
        # save the current stack
        self.dct = self.variables()

    def __enter__(self):
        return

    def __exit__(self, type, value, traceback):
        # compare variables defined at __exist__ with variables defined at __enter__
        dct_exit, dct_enter = self.variables(), self.dct
        for (name, id_) in set(dct_exit).difference(dct_enter):
            self.post_process(name, dct_exit[(name, id_)])

典型用途可以是:

# let us define a Variable object that has a 'name' attribute that can be defined at initialisation time or later
class Variable:
    def __init__(self, name=None):
        self.name = name

# the following code
x = Variable('x')
y = Variable('y')
print(x.name, y.name)

# can be replaced by
with VariablePostProcessor(lambda name, obj: setattr(obj, "name", name)):
    x = Variable()
    y = Variable()
print(x.name, y.name)

# in such case, you can also define as a convenience
import functools
AutoRenamer = functools.partial(VariablePostProcessor, post_process=lambda name, obj: setattr(obj, "name", name))

# and rewrite the above code as
with AutoRenamer():
    x = Variable()
    y = Variable()
print(x.name, y.name)  # => x y

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接