在Python Asyncio中限制异步函数的速率

18
我有一个"awaitables"列表,想将其传递给asyncio.AbstractEventLoop,但需要对第三方API的请求进行限制。我希望避免等待将"future"传递给循环,因为在此期间我会阻塞我的循环。我有哪些选择?SemaphoresThreadPools将限制并发运行的数量,但这不是我的问题。我需要将请求限制为每秒100个,但完成请求所需的时间并不重要。
这是使用标准库的非常简洁的(非)工作示例,展示了问题。这应该以100 /秒的速度限制,但实际上限制为116.651 /秒。什么是在asyncio中调度异步请求的最佳方法?
工作代码:
import asyncio
from threading import Lock

class PTBNL:

    def __init__(self):
        self._req_id_seq = 0
        self._futures = {}
        self._results = {}
        self.token_bucket = TokenBucket()
        self.token_bucket.set_rate(100)

    def run(self, *awaitables):

        loop = asyncio.get_event_loop()

        if not awaitables:
            loop.run_forever()
        elif len(awaitables) == 1:
            return loop.run_until_complete(*awaitables)
        else:
            future = asyncio.gather(*awaitables)
            return loop.run_until_complete(future)

    def sleep(self, secs) -> True:

        self.run(asyncio.sleep(secs))
        return True

    def get_req_id(self) -> int:

        new_id = self._req_id_seq
        self._req_id_seq += 1
        return new_id

    def start_req(self, key):

        loop = asyncio.get_event_loop()
        future = loop.create_future()
        self._futures[key] = future
        return future

    def end_req(self, key, result=None):

        future = self._futures.pop(key, None)
        if future:
            if result is None:
                result = self._results.pop(key, [])
            if not future.done():
                future.set_result(result)

    def req_data(self, req_id, obj):
        # Do Some Work Here
        self.req_data_end(req_id)
        pass

    def req_data_end(self, req_id):
        print(req_id, " has ended")
        self.end_req(req_id)

    async def req_data_async(self, obj):

        req_id = self.get_req_id()
        future = self.start_req(req_id)

        self.req_data(req_id, obj)

        await future
        return future.result()

    async def req_data_batch_async(self, contracts):

        futures = []
        FLAG = False

        for contract in contracts:
            req_id = self.get_req_id()
            future = self.start_req(req_id)
            futures.append(future)

            nap = self.token_bucket.consume(1)

            if FLAG is False:
                FLAG = True
                start = asyncio.get_event_loop().time()

            asyncio.get_event_loop().call_later(nap, self.req_data, req_id, contract)

        await asyncio.gather(*futures)
        elapsed = asyncio.get_event_loop().time() - start

        return futures, len(contracts)/elapsed

class TokenBucket:

    def __init__(self):
        self.tokens = 0
        self.rate = 0
        self.last = asyncio.get_event_loop().time()
        self.lock = Lock()

    def set_rate(self, rate):
        with self.lock:
            self.rate = rate
            self.tokens = self.rate

    def consume(self, tokens):
        with self.lock:
            if not self.rate:
                return 0

            now = asyncio.get_event_loop().time()
            lapse = now - self.last
            self.last = now
            self.tokens += lapse * self.rate

            if self.tokens > self.rate:
                self.tokens = self.rate

            self.tokens -= tokens

            if self.tokens >= 0:
                return 0
            else:
                return -self.tokens / self.rate


if __name__ == '__main__':

    asyncio.get_event_loop().set_debug(True)
    app = PTBNL()

    objs = [obj for obj in range(500)]

    l,t = app.run(app.req_data_batch_async(objs))

    print(l)
    print(t)

编辑:我在这里添加了一个使用信号量的TrottleTestApp简单示例,但仍无法对执行进行限制:

import asyncio
import time


class ThrottleTestApp:

    def __init__(self):
        self._req_id_seq = 0
        self._futures = {}
        self._results = {}
        self.sem = asyncio.Semaphore()

    async def allow_requests(self, sem):
        """Permit 100 requests per second; call 
           loop.create_task(allow_requests())
        at the beginning of the program to start this routine.  That call returns
        a task handle that can be canceled to end this routine.

        asyncio.Semaphore doesn't give us a great way to get at the value other
        than accessing sem._value.  We do that here, but creating a wrapper that
        adds a current_value method would make this cleaner"""

        while True:
            while sem._value < 100: sem.release()
            await asyncio.sleep(1)  # Or spread more evenly 
                                    # with a shorter sleep and 
                                    # increasing the value less

    async def do_request(self, req_id, obj):
        await self.sem.acquire()

        # this is the work for the request
        self.req_data(req_id, obj)

    def run(self, *awaitables):

        loop = asyncio.get_event_loop()

        if not awaitables:
            loop.run_forever()
        elif len(awaitables) == 1:
            return loop.run_until_complete(*awaitables)
        else:
            future = asyncio.gather(*awaitables)
            return loop.run_until_complete(future)

    def sleep(self, secs: [float]=0.02) -> True:

        self.run(asyncio.sleep(secs))
        return True

    def get_req_id(self) -> int:

        new_id = self._req_id_seq
        self._req_id_seq += 1
        return new_id

    def start_req(self, key):

        loop = asyncio.get_event_loop()
        future = loop.create_future()
        self._futures[key] = future
        return future

    def end_req(self, key, result=None):

        future = self._futures.pop(key, None)
        if future:
            if result is None:
                result = self._results.pop(key, [])
            if not future.done():
                future.set_result(result)

    def req_data(self, req_id, obj):
        # This is the method that "does" something
        self.req_data_end(req_id)
        pass

    def req_data_end(self, req_id):

        print(req_id, " has ended")
        self.end_req(req_id)

    async def req_data_batch_async(self, objs):

        futures = []
        FLAG = False

        for obj in objs:
            req_id = self.get_req_id()
            future = self.start_req(req_id)
            futures.append(future)

            if FLAG is False:
                FLAG = True
                start = time.time()

            self.do_request(req_id, obj)

        await asyncio.gather(*futures)
        elapsed = time.time() - start
        print("Roughly %s per second" % (len(objs)/elapsed))

        return futures


if __name__ == '__main__':

    asyncio.get_event_loop().set_debug(True)
    app = ThrottleTestApp()

    objs = [obj for obj in range(10000)]

    app.run(app.req_data_batch_async(objs))

你是想限制每秒正在进行的请求数量,还是限制在特定一秒内启动的请求数量呢?例如,如果你开始了100个需要3秒钟才能完成的请求,那么在接下来的2秒钟内,你是否可以再启动200个请求呢? - Aaron Schif
@AaronSchif 重要的不是它们何时启动,而是在任何1秒滚动窗口内,启动的数量不超过100个。 - Jared
3个回答

64
您可以通过实现“漏桶算法”来达到这个目的:

您可以通过实现漏桶算法来达到这个目的:

import asyncio
import contextlib
import collections
import time

from types import TracebackType
from typing import Dict, Optional, Type

try:  # Python 3.7
    base = contextlib.AbstractAsyncContextManager
    _current_task = asyncio.current_task
except AttributeError:
    base = object  # type: ignore
    _current_task = asyncio.Task.current_task  # type: ignore

class AsyncLeakyBucket(base):
    """A leaky bucket rate limiter.

    Allows up to max_rate / time_period acquisitions before blocking.

    time_period is measured in seconds; the default is 60.

    """
    def __init__(
        self,
        max_rate: float,
        time_period: float = 60,
        loop: Optional[asyncio.AbstractEventLoop] = None
    ) -> None:
        self._loop = loop
        self._max_level = max_rate
        self._rate_per_sec = max_rate / time_period
        self._level = 0.0
        self._last_check = 0.0
        # queue of waiting futures to signal capacity to
        self._waiters: Dict[asyncio.Task, asyncio.Future] = collections.OrderedDict()

    def _leak(self) -> None:
        """Drip out capacity from the bucket."""
        if self._level:
            # drip out enough level for the elapsed time since
            # we last checked
            elapsed = time.time() - self._last_check
            decrement = elapsed * self._rate_per_sec
            self._level = max(self._level - decrement, 0)
        self._last_check = time.time()

    def has_capacity(self, amount: float = 1) -> bool:
        """Check if there is enough space remaining in the bucket"""
        self._leak()
        requested = self._level + amount
        # if there are tasks waiting for capacity, signal to the first
        # there there may be some now (they won't wake up until this task
        # yields with an await)
        if requested < self._max_level:
            for fut in self._waiters.values():
                if not fut.done():
                    fut.set_result(True)
                    break
        return self._level + amount <= self._max_level

    async def acquire(self, amount: float = 1) -> None:
        """Acquire space in the bucket.

        If the bucket is full, block until there is space.

        """
        if amount > self._max_level:
            raise ValueError("Can't acquire more than the bucket capacity")

        loop = self._loop or asyncio.get_event_loop()
        task = _current_task(loop)
        assert task is not None
        while not self.has_capacity(amount):
            # wait for the next drip to have left the bucket
            # add a future to the _waiters map to be notified
            # 'early' if capacity has come up
            fut = loop.create_future()
            self._waiters[task] = fut
            try:
                await asyncio.wait_for(
                    asyncio.shield(fut),
                    1 / self._rate_per_sec * amount,
                    loop=loop
                )
            except asyncio.TimeoutError:
                pass
            fut.cancel()
        self._waiters.pop(task, None)

        self._level += amount

        return None

    async def __aenter__(self) -> None:
        await self.acquire()
        return None

    async def __aexit__(
        self,
        exc_type: Optional[Type[BaseException]],
        exc: Optional[BaseException],
        tb: Optional[TracebackType]
    ) -> None:
        return None

请注意,我们会在 opportunistically 的时候从存储桶中泄漏容量,没有必要运行单独的异步任务来降低水平;相反,当测试剩余容量足够时,容量会被泄漏出去。
请注意,等待容量的任务被保存在有序字典中,当可能再次有剩余容量时,第一个仍在等待的任务会提前唤醒。
您可以将其用作上下文管理器;当存储桶已满时尝试获取它会阻塞,直到再次释放足够的容量为止:
bucket = AsyncLeakyBucket(100)

# ...

async with bucket:
    # only reached once the bucket is no longer full

或者您可以直接调用acquire()

await bucket.acquire()  # blocks until there is space in the bucket

或者您可以先测试是否有空格:
if bucket.has_capacity():
    # reject a request due to rate limiting

请注意,您可以通过增加或减少将请求“滴入”桶中的数量来将某些请求视为“更重”或“更轻”:
await bucket.acquire(10)
if bucket.has_capacity(0.5):

然而需要小心的是:当混合大滴和小滴时,当处于最大速率或接近最大速率时,小滴往往会在大滴之前失控,因为在足够多的自由容量可以容纳一个较小的滴落之前,容纳一个较大的滴落需要更多的空间。

演示:

>>> import asyncio, time
>>> bucket = AsyncLeakyBucket(5, 10)
>>> async def task(id):
...     await asyncio.sleep(id * 0.01)
...     async with bucket:
...         print(f'{id:>2d}: Drip! {time.time() - ref:>5.2f}')
...
>>> ref = time.time()
>>> tasks = [task(i) for i in range(15)]
>>> result = asyncio.run(asyncio.wait(tasks))
 0: Drip!  0.00
 1: Drip!  0.02
 2: Drip!  0.02
 3: Drip!  0.03
 4: Drip!  0.04
 5: Drip!  2.05
 6: Drip!  4.06
 7: Drip!  6.06
 8: Drip!  8.06
 9: Drip! 10.07
10: Drip! 12.07
11: Drip! 14.08
12: Drip! 16.08
13: Drip! 18.08
14: Drip! 20.09

开始时桶会迅速填满,导致剩余的任务更加均匀地分散;每2秒钟就会释放足够的容量来处理另一个任务。

最大突发大小等于最大速率值,在上述演示中设置为5。如果不想允许突发,将最大速率设置为1,时间段设置为滴水之间的最小时间:

>>> bucket = AsyncLeakyBucket(1, 1.5)  # no bursts, drip every 1.5 seconds
>>> async def task():
...     async with bucket:
...         print(f'Drip! {time.time() - ref:>5.2f}')
...
>>> ref = time.time()
>>> tasks = [task() for _ in range(5)]
>>> result = asyncio.run(asyncio.wait(tasks))
Drip!  0.00
Drip!  1.50
Drip!  3.01
Drip!  4.51
Drip!  6.02

我已经将这个打包成Python项目:https://github.com/mjpieters/aiolimiter

(注:本段话为原文,无需翻译)

5

另一种解决方案——使用有界信号量——由我的同事、导师和朋友提出,如下所示:

import asyncio


class AsyncLeakyBucket(object):

    def __init__(self, max_tasks: float, time_period: float = 60, loop: asyncio.events=None):
        self._delay_time = time_period / max_tasks
        self._sem = asyncio.BoundedSemaphore(max_tasks)
        self._loop = loop or asyncio.get_event_loop()
        self._loop.create_task(self._leak_sem())

    async def _leak_sem(self):
        """
        Background task that leaks semaphore releases based on the desired rate of tasks per time_period
        """
        while True:
            await asyncio.sleep(self._delay_time)
            try:
                self._sem.release()
            except ValueError:
                pass

    async def __aenter__(self) -> None:
        await self._sem.acquire()

    async def __aexit__(self, exc_type, exc, tb) -> None:
        pass

仍然可以使用与@Martijn的回答中相同的async with bucket代码


1
请注意,这仍然允许爆发,最多可同时处理 max_tasks 个任务。这是因为当低于设定速率时,信号量自由地让任务获取它。如果任务持有信号量的时间均匀(所有任务持有信号量锁的时间相等),那么您可能会得到一系列爆发。 - Martijn Pieters

0
一个简单的解决方案来管理每秒最大请求数和最大同时连接到API的数量,这是我用于Interactive Brokers API的。
import asyncio
import datetime as dt
import random


async def send_request(num):
    print(f"Request  {num:>2} at {dt.datetime.now()}")
    await asyncio.sleep(random.choice([0.1, 0.2]))
    print(f"Response {num:>2} at {dt.datetime.now()}")


def requests_per_second(request_datetimes):
    rps = 0
    if len(request_datetimes) > 0:
        rps = 1 / (dt.datetime.now() - request_datetimes[-1]).total_seconds()
    return rps


async def rate_limited_gather(*args, rate_limit=50, max_connections=10):
    """Manage max requests per second and max open connections for an API"""
    awaitables = []
    request_datetimes = []
    loop = asyncio.get_event_loop()
    connections = 0
    for arg in args:
        while (
            requests_per_second(request_datetimes) > rate_limit or connections >= max_connections
        ):
            await asyncio.sleep(1 / rate_limit)
            connections = sum([not a.done() for a in awaitables])
        print(f"Requests per second: {requests_per_second(request_datetimes)}")
        request_datetimes.append(dt.datetime.now())
        awaitables.append(loop.create_task(arg))
        connections = sum([not a.done() for a in awaitables])
    await asyncio.gather(*awaitables, return_exceptions=True)


if __name__ == "__main__":
    loop = asyncio.get_event_loop()
    loop.run_until_complete(rate_limited_gather(*[send_request(x) for x in range(10)]))

示例输出:

Requests per second: 0
Request   0 at 2023-03-11 10:34:49.348671
Requests per second: 49.696849219759464
Request   1 at 2023-03-11 10:34:49.368800
Requests per second: 49.69931911932807
Request   2 at 2023-03-11 10:34:49.388930
Requests per second: 49.69931911932807
Request   3 at 2023-03-11 10:34:49.409057
Requests per second: 49.72403162448411
Request   4 at 2023-03-11 10:34:49.429170
Response  0 at 2023-03-11 10:34:49.449260
Requests per second: 49.691910157026435
Request   5 at 2023-03-11 10:34:49.449298
Response  1 at 2023-03-11 10:34:49.469389
Requests per second: 49.68450340338848
Request   6 at 2023-03-11 10:34:49.469436
Response  2 at 2023-03-11 10:34:49.489529
Requests per second: 49.67956679417755
Request   7 at 2023-03-11 10:34:49.489566
Requests per second: 49.73392350922564
Request   8 at 2023-03-11 10:34:49.509682
Requests per second: 49.7116723006562
Request   9 at 2023-03-11 10:34:49.529858
Response  6 at 2023-03-11 10:34:49.569973
Response  7 at 2023-03-11 10:34:49.590072
Response  3 at 2023-03-11 10:34:49.609170
Response  4 at 2023-03-11 10:34:49.629267
Response  5 at 2023-03-11 10:34:49.650361
Response  8 at 2023-03-11 10:34:49.710456
Response  9 at 2023-03-11 10:34:49.730560

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接