Python装饰器方法带有可变数量的位置参数和可选参数

6

我正在使用SQLalchemy编写我的第一个Python(3.4)应用程序。我有几个方法都具有非常相似的模式。它们接受一个可选参数session,默认为None。如果传递了session,则函数将使用该会话,否则它将打开并使用新会话。例如,请考虑以下方法:

def _stocks(self, session=None):
    """Return a list of all stocks in database."""
    newsession = False
    if not session:
        newsession = True
        session = self.db.Session()
    stocks = [stock.ticker for stock in session.query(Stock).all()]
    if newsession:
        session.close()
    return stocks

因此,作为一个新手想要学习Python的全部功能,我认为现在是学习一些有关Python装饰器的完美时机。经过大量阅读,如这个博客系列这个出色的SO答案后,我编写了以下装饰器:

from functools import wraps

def session_manager(func):
    """
    Manage creation of session for given function.

    If a session is passed to the decorated function, it is simply
    passed through, otherwise a new session is created.  Finally after
    execution of decorated function, the new session (if created) is
    closed/
    """
    @wraps(func)
    def inner(that, session=None, *args, **kwargs):
        newsession = False
        if not session:
            newsession = True
            session = that.db.Session()
        func(that, session, *args, **kwargs)
        if newsession:
            session.close()
        return func(that, session, *args, **kwargs)
    return inner

看起来它很有效。原始方法现在被简化为:

@session_manager
def _stocks(self, session=None):
    """Return a list of all stocks in database."""
    return [stock.ticker for stock in session.query(Stock).all()]

然而,当我将装饰器应用于除可选的session参数之外还带有一些位置参数的函数时,就会出现错误。因此,尝试编写以下代码:

@session_manager
def stock_exists(self, ticker, session=None):
    """
    Check for existence of stock in database.

    Args:
        ticker (str): Ticker symbol for a given company's stock.
        session (obj, optional): Database session to use.  If not
            provided, opens, uses and closes a new session.

    Returns:
        bool: True if stock is in database, False otherwise.
    """
    return bool(session.query(Stock)
                .filter_by(ticker=ticker)
                .count()
                )

运行 print(client.manager.stock_exists('AAPL')) 报错 AttributeError,错误信息如下:

Traceback (most recent call last):
  File "C:\Code\development\Pynance\pynance.py", line 33, in <module>
    print(client.manager.stock_exists('GPX'))
  File "C:\Code\development\Pynance\pynance\decorators.py", line 24, in inner
    func(that, session, *args, **kwargs)
  File "C:\Code\development\Pynance\pynance\database\database.py", line 186, in stock_exists
    .count()
AttributeError: 'NoneType' object has no attribute 'query'
[Finished in 0.7s]

根据回溯信息,我猜想我搞错了参数的顺序,但我无法弄清楚如何正确排序。 我有一些要装饰的函数,除了session之外还可以接受0-3个参数。 请问有人能指出我方法中的错误吗?


session作为一个命名参数传递-- func(stuff, session=session)。另外,你为什么要调用两次func?最后,这似乎应该有一个上下文管理器来管理db.session - jwilner
感谢@jwilner!我两次调用“func”只是对语法的误解。我将“func”调用更改为“result = func()”,然后返回结果。是的,关于“db.session”的上下文管理器也是正确的。我试图删除一堆代码以更好地隔离我的问题。 - Christopher Pearson
2个回答

3

变更

def inner(that, session=None, *args, **kwargs):

为了

def inner(that, *args, session=None, **kwargs):

并且

return func(that, session, *args, **kwargs)

to

return func(that, *args, session=session, **kwargs)

它可以工作:

def session_manager(func):

    def inner(that, *args, session=None, **kwargs):
        if not session:
            session = object()
        return func(that, *args, session=session, **kwargs)

    return inner


class A():

    @session_manager
    def _stocks(self, session=None):
        print(session)
        return True

    @session_manager
    def stock_exists(self, ticker, session=None):
        print(ticker, session)
        return True

a = A()
a._stocks()
a.stock_exists('ticker')

输出:

$ python3 test.py
<object object at 0x7f4197810070>
ticker <object object at 0x7f4197810070>

当您使用def inner(that, session=None, *args, **kwargs)时,任何第二个位置参数(计算self)都被视为session参数。因此,当您调用manager.stock_exists('AAPL')时,session会获得值AAPL


谢谢,这真的很有帮助。我感觉只是弄混了参数的顺序。 - Christopher Pearson

1

我注意到的第一件事是你调用了装饰函数两次。

@wraps(func)
    def inner(that, session=None, *args, **kwargs):
        newsession = False
        if not session:
            newsession = True
            session = that.db.Session()
        #calling first time
        func(that, session, *args, **kwargs)
        if newsession:
            session.close()
        #calling second time
        return func(that, session, *args, **kwargs)
    return inner

在第二个调用会话已经关闭。此外,您不需要在装饰器函数中显式接受thatsession参数,它们已经在argskwargs中了。看一下这个解决方案:
@wraps(func)
def inner(*args, **kwargs):
    session = None
    if not 'session' in kwargs:
        session = that.db.Session()
        kwargs['session'] = session
    result = func(*args, **kwargs)
    if session:
        session.close()
    return result
return inner

您可能还希望将会话关闭代码放在finally块中,这样即使装饰函数抛出异常,也可以确保关闭它。


如果在inner()func()中没有显式传递that,那么会出现NameError: name 'that' is not defined的错误。 - Christopher Pearson

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接