在线程中,我们有一些称为“线程上下文”的内容,在其中可以保存一些数据(状态),以便在特定线程中访问。 在asyncio中,我需要在当前执行路径中保存一些状态,以便所有后续的协程都可以访问它。 解决方案是什么?注意:我知道每个协程函数在asyncio中都会实例化为一个执行路径,但由于某种原因我无法将状态保存在函数属性中。(虽然这种方法也不是很好)
从Python 3.7开始,您可以使用contextvars.ContextVar。
在下面的示例中,我声明了request_id并在some_outer_coroutine中设置了值,然后在some_inner_coroutine中访问它。
import asyncio
import contextvars
# declare context var
request_id = contextvars.ContextVar('Id of request.')
async def some_inner_coroutine():
# get value
print('Processed inner coroutine of request: {}'.format(request_id.get()))
async def some_outer_coroutine(req_id):
# set value
request_id.set(req_id)
await some_inner_coroutine()
# get value
print('Processed outer coroutine of request: {}'.format(request_id.get()))
async def main():
tasks = []
for req_id in range(1, 5):
tasks.append(asyncio.create_task(some_outer_coroutine(req_id)))
await asyncio.gather(*tasks)
if __name__ == '__main__':
asyncio.run(main())
输出:
Processed inner coroutine of request: 1
Processed outer coroutine of request: 1
Processed inner coroutine of request: 2
Processed outer coroutine of request: 2
Processed inner coroutine of request: 3
Processed outer coroutine of request: 3
Processed inner coroutine of request: 4
Processed outer coroutine of request: 4
还有https://github.com/azazel75/metapensiero.asyncio.tasklocal,但您必须注意,任务通常由库内部创建,并且也由asyncio使用ensure_future(a_coroutine)
创建,没有实际的方法来跟踪这些新任务并初始化它们的本地变量(也许是从创建它们的任务的本地变量继承)。 (“hack”的方法是设置一个loop.set_task_factory()
函数,其中包含执行此操作的内容,希望所有代码都使用loop.create_task()
来创建任务,但这并不总是正确的...)
另一个问题是,如果您的某些代码在Future回调中执行,则Task.current_task()
函数将始终返回None
,该函数被用于选择正确的本地变量副本以提供服务。
我个人认为contextvars API太底层了。Google已经在https://github.com/google/etils中开发了一些小包装器,以获得更好的API:
tl;dr; 使用edc.ContextVar[T]
注释数据类字段,使字段成为上下文相关。它支持所有dataclasses.field
功能(例如default_factory
),因此每个线程/asyncio任务都有自己的版本:
from etils import edc
@edc.dataclass
@dataclasses.dataclass
class Context:
thread_id: edc.ContextVar[int] = dataclasses.field(default_factory=threading.get_native_id)
# Local stack: each thread will use its own instance of the stack
stack: edc.ContextVar[list[str]] = dataclasses.field(default_factory=list)
# Global context object
context = Context(thread_id=0)
使用示例:
def worker():
# Inside each thread, the worker use its own context
assert context.thread_id != 0
context.stack.append(1)
time.sleep(1)
assert len(context.stack) == 1 # Other workers do not modify the local stack
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
for _ in range(10):
executor.submit(worker)
这适用于asyncio和线程。
请参见文档。
curio
似乎已经拥有了,可以在https://github.com/dabeaz/curio/pull/85找到。 - Dima Tisnek