《asyncio 系列》11. asyncio 的并发原语(锁、信号量、事件、条件)

楔子

使用多线程和多进程编写应用程序时,需要考虑非原子操作时的竞态条件,因为即使是并发增加整数这样简单的操作也可能导致微妙的、难以重现的 bug。而 asyncio 是单线程的(除非与多线程和 multiprocessing 进行交互),这是否意味着我们就可以不必考虑竞争条件呢?事实证明,事情并非那么简单。

虽然 asyncio 的单线程特性消除了多线程或 multiprocessing 应用程序中可能出现的某些并发错误,但并未完全消除,所以在某些情况下我们仍然需要这些构造。asyncio 的同步原语(synchronization primitives)可以帮助我们防止单线程并发模型特有的错误。

了解单线程并发错误

在前面关于 multiprocessing 和多线程的文章中,当处理在不同进程和线程之间共享的数据时,我们不得不考虑竞态条件。这是因为数据被一个线程修改时,也可能被另一个线程读取,从而导致状态不一致,引起数据损坏。这种损坏部分是由于某些操作是非原子性的,这意味着虽然它们看起来像一个操作,但其实它们在后台包含多个独立的操作。就比如整型变量进行自增,首先读取当前值,然后将其递增,最后再重新分配给变量,而这为其他线程和进程提供了足够的机会来获取处于不一致状态的数据。

但在单线程并发模型中,避免了由非原子性操作引起的竞态条件。因为在 asyncio 的单线程模型中,只会有一个线程在给定时间执行 Python 代码,并且遇见非阻塞 IO 之前不会切换。这意味着,即使一个操作是非原子性的,也将始终运行直到完成,而不会让其他协程读取不一致的状态信息。

为证明这一点,我们创建一个计数器,并启动多个任务,同时修改计数器。

import asyncio

counter: int = 0

async def increment():
    global counter
    await asyncio.sleep(0.01)
    counter += 1

async def main():
    global counter
    for _ in range(100):
        tasks = [asyncio.create_task(increment()) for _ in range(100)]
        await asyncio.gather(*tasks)
        assert counter == 100
        # 将 counter 重置为 0,重新开始循环
        counter = 0

asyncio.run(main())

在代码中创建了一个协程函数,它负责将全局计数器加 1,然后添加 1 毫秒的延迟来模拟慢速操作。在主协程中,创建了 100 个任务来递增计数器,然后通过 gather 并发执行。之后断言计数器可得到期望的值,因为我们运行了 100 个增量任务,所以它应该总是 100。运行这段程序,没有任何报错,所以最终 counter 的值总是 100。这说明虽然递增一个整数是非原子性操作,但对于单线程的 asyncio 来说没有任何影响。而如果运行多个线程而不是协程,应该会看到断言在执行的某个时刻失败。

那么这是否意味着已经通过单线程并发模型找到了一种完全避免竞态条件的方法呢?不幸的是,情况并非如此。虽然避免了单个非原子性操作可能导致的错误竞态条件,但以错误顺序执行的多个操作是可能导致其他问题的。

import asyncio

counter: int = 0

async def increment():
    global counter
    temp_counter = counter
    temp_counter += 1
    await asyncio.sleep(0.01)
    counter = temp_counter

async def main():
    global counter
    for _ in range(100):
        tasks = [asyncio.create_task(increment()) for _ in range(100)]
        await asyncio.gather(*tasks)
        assert counter == 100
        # 将 counter 重置为 0,重新开始循环
        counter = 0

asyncio.run(main())

协程不是直接递增计数器,而是首先将其读入一个临时变量,然后将临时计数器加 1。通过 await asyncio.sleep 来模拟一个缓慢的操作,暂停协程,然后才将它重新分配回全局计数器变量。运行上述代码,你应该会立即看到执行失败,出现断言错误,并且计数器只会被设置为 1。

每个协程首先读取计数器的值 0,将其存储给临时变量,临时变量自增 1,然后进入休眠状态。因为 temp_counter 是临时变量,所以每个任务里面的 temp_counter 都是 1,然后一旦休眠完成,再将其赋值给 counter。这意味着尽管运行了 100 个协程来增加计数器,但计数器永远只会是 1。注意,如果你删除 await 表达式,那么代码就是正确的,因为内部没有出现非阻塞 IO,每个任务都会一口气执行完,然后再执行下一个任务。

诚然,上面是一个简单化且有些不切实际的例子。为了更好地了解何时会发生这种情况,让我们创建一个稍微复杂一点的竞态条件。

import asyncio

class MockSocket:

    def __init__(self):
        self.socket_closed = False

    async def send(self, msg: str):
        # 模拟向客户端缓慢发送消息
        if self.socket_closed:
            raise Exception("socket 已关闭")
        print(f"准备向客户端发送消息: {msg}")
        await asyncio.sleep(1)
        print(f"成功向客户端发送消息: {msg}")

    def close(self):
        self.socket_closed = True

usernames_to_sockets = {"satori": MockSocket(), "koishi": MockSocket(),
                         "marisa": MockSocket(), "scarlet": MockSocket()}

async def user_disconnect(username: str):
    # 断开用户连接,并将其从应用程序内存中删掉
    print(f"{username} 断开连接")
    socket = usernames_to_sockets.pop(username)
    socket.close()

async def message_all_users():
    # 同时向所有用户发送消息
    print(f"创建消息发送任务")
    messages = [socket.send(f"Hello {username}") for username, socket in usernames_to_sockets.items()]
    await asyncio.gather(*messages)

async def main():
    await asyncio.gather(message_all_users(), user_disconnect("marisa"))

asyncio.run(main())
"""
创建消息发送任务
marisa 断开连接
准备向客户端发送消息: Hello satori
准备向客户端发送消息: Hello koishi
准备向客户端发送消息: Hello scarlet
Traceback (most recent call last):
  ......
    raise Exception("socket 已关闭")
Exception: socket 已关闭
"""

我们实现了一个向连接的用户发送消息的套接字(MockSocket),每来一个用户就创建一个套接字,并用一个字典保存用户名到对应套接字的映射。当用户断开连接时,运行一个回调 user_disconnect,将用户从字典中删除并关闭套接字。

然后并发运行 message_all_users() 和 user_disconnect("marisa") 两个协程,可以理解为服务端创建了 4 个任务,准备向每个用户发消息。然后 "marisa" 用户断开连接,于是我们关闭给 "marisa" 用户发送消息的套接字,并将其从 usernames_to_sockets 字典中删除。完成后,message_all_users 恢复执行,并开始发送消息,但由于 "marisa" 的套接字已关闭,所以结果会看到一个异常,不会收到我们发送的消息。

这些是你在单线程并发模型中容易看到的错误类型,使用 await 到达一个挂起点,另一个协程运行并修改一些共享状态;当第一个协程通过意外的方式恢复时,就会发生修改冲突。多线程并发性 bug 和单线程并发性 bug 之间的关键区别在于,多线程应用程序中,在修改可变状态的任何地方都有可能出现竞态条件。而单线程并发模型中,只有遇到等待点(await point)才可能出现意料之外的结果。

注意:asyncio 属于单线程,严格意义上讲,出现的错误并不能称之为竞态条件。一个变量在被修改到一半时,又被另一个线程读取了,这导致第一个线程所做的修改可能会被覆盖掉,这叫做竞态条件。但 asyncio 中是不存在这种情况的,因为 asyncio 默认是单线程,如果没有遇见 await,那么它会将某个任务一次性全部执行完。

所以 asyncio 如果出现问题,那么一定是逻辑没有执行完,就通过 await 发生切换了。比如在上述代码的 message_all_users 中创建了 4 个套接字,准备给客户端发消息,但在发送之前发生了切换。而切换之后,在另一个任务中将套接字给关掉了,所以再切回来的时候,发送消息就会失败。

但这不属于竞态条件,这属于逻辑没按照顺序执行导致的问题。不过和竞态条件类似,都属于并发错误。

既然已经理解了单线程模型中的并发错误类型,那么让我们看看如何通过使用 asyncio 锁来避免它们的发生。

asyncio 锁的操作类似于 multiprocessing 和多线程模块中的锁,获取一个锁,在临界区内工作,完成后释放锁,让其他相关资源获取锁。主要区别在于,asyncio 锁是可等待对象,当它被阻塞时,会暂停协程的执行。这意味着如果一个协程在等待获取锁时被阻塞,其他代码可以运行。此外 asyncio 锁也是异步上下文管理器,使用它的首选方式是使用 async with 语法。

为了熟悉锁的工作原理,让我们看一个简单例子:一个锁在两个协程之间共享。

import asyncio
from asyncio import Lock

async def a(lock: Lock):
    print("协程 a 等待获取锁")
    async with lock:
        print("协程 a 成功获取了锁, 并进入临界区执行操作")
        await asyncio.sleep(2)
    print("协程 a 释放了锁")

async def b(lock: Lock):
    print("协程 b 等待获取锁")
    async with lock:
        print("协程 b 成功获取了锁, 并进入临界区执行操作")
        await asyncio.sleep(2)
    print("协程 b 释放了锁")

async def main():
    lock = Lock()
    await asyncio.gather(a(lock), b(lock))

asyncio.run(main())
"""
协程 a 等待获取锁
协程 a 成功获取了锁, 并进入临界区执行操作
协程 b 等待获取锁
协程 a 释放了锁
协程 b 成功获取了锁, 并进入临界区执行操作
协程 b 释放了锁
"""

比较简单,然后这里使用了 async with 语法。如果愿意,可以像下面的代码这样,调用锁的 acquire 和 release 方法:

await lock.acquire()
try:
    ...
finally:
    lock.release()  

也就是说,最好的做法是尽可能使用 async with 语法。

然后需要注意的一件重要事情是,锁是在主协程内部创建,然后作为参数传递的,这是正确的做法。但你也可能会将其设为全局变量,从而避免每次都对其进行传递:

import asyncio
from asyncio import Lock

async def a():
    print("协程 a 等待获取锁")
    async with lock:
        print("协程 a 成功获取了锁, 并进入临界区执行操作")
        await asyncio.sleep(2)
    print("协程 a 释放了锁")

async def b():
    print("协程 b 等待获取锁")
    async with lock:
        print("协程 b 成功获取了锁, 并进入临界区执行操作")
        await asyncio.sleep(2)
    print("协程 b 释放了锁")

lock = Lock()

async def main():
    await asyncio.gather(a(), b())

asyncio.run(main())

如果这样做,很快会看到崩溃的发生,并报告多个事件循环的错误:

RuntimeError: ..... attached to a different loop

为什么只移动了锁定义,就会发生这种情况呢?这是 asyncio 库的一个令人困惑的地方,而且这种现象也不是锁特有的,asyncio 中的大多数对象都提供一个可选的 loop 参数,允许你指定要运行的特定事件循环。当未提供此参数时,asyncio 尝试获取当前正在运行的事件循环,如果没有,则创建一个新的事件循环。在上例中,创建一个锁的同时会创建一个新的事件循环,因为当脚本第一次运行时,还没有事件循环。然后 asyncio.run(main()) 会创建第二个事件循环,试图使用锁时,这两个独立的事件循环就会混合在一起会导致崩溃。

这种行为非常棘手,以至于在 Python 3.10 中将移除 loop 参数,这种令人困惑的行为也会消失。但在 3.10 之前,在使用全局 asyncio 变量时需要认真考虑这些情况。

现在让我们修复之前的套接字问题,当我们试图向过早关闭套接字的用户发送消息时会出现异常。解决这个问题的思路是在两个地方使用锁:当用户断开连接时,以及当我们向用户发送消息时。这样,如果在发送消息时连接断开,我们将等到所有消息都完成后才最终关闭套接字。

import asyncio

class MockSocket:

    def __init__(self):
        self.socket_closed = False

    async def send(self, msg: str):
        # 模拟向客户端缓慢发送消息
        if self.socket_closed:
            raise Exception("socket 已关闭")
        print(f"准备向客户端发送消息: {msg}")
        await asyncio.sleep(1)
        print(f"成功向客户端发送消息: {msg}")

    def close(self):
        self.socket_closed = True

usernames_to_sockets = {"satori": MockSocket(), "koishi": MockSocket(),
                         "marisa": MockSocket(), "scarlet": MockSocket()}

async def user_disconnect(username: str, lock: asyncio.Lock):
    print(f"{username} 断开连接")
    async with lock:
        socket = usernames_to_sockets.pop(username)
        socket.close()
    print(f"{username} 已断开, 并从字典中移除")

async def message_all_users(lock: asyncio.Lock):
    # 同时向所有用户发送消息
    print(f"创建消息发送任务")
    async with lock:
        messages = [socket.send(f"Hello {username}")
                    for username, socket in usernames_to_sockets.items()]
        await asyncio.gather(*messages)

async def main():
    lock = asyncio.Lock()
    await asyncio.gather(message_all_users(lock), user_disconnect("marisa", lock))

asyncio.run(main())
"""
创建消息发送任务
marisa 断开连接
准备向客户端发送消息: Hello satori
准备向客户端发送消息: Hello koishi
准备向客户端发送消息: Hello marisa
准备向客户端发送消息: Hello scarlet
成功向客户端发送消息: Hello satori
成功向客户端发送消息: Hello koishi
成功向客户端发送消息: Hello marisa
成功向客户端发送消息: Hello scarlet
marisa 已断开, 并从字典中移除
"""

你可能不需要经常在 asyncio 代码中使用锁,因为它的单线程特性避免了许多并发问题。即使发生竞态条件,有时也可重构代码(如使用不可变对象),以防止在协程挂起时修改状态。当你不能以这种方式重构时,可以强制修改锁,使其按所需的同步顺序发生。既然已经理解了避免锁的并发性错误的概念,让我们看看如何在 asyncio 应用程序中使用同步来实现新功能。

使用信号量限制并发性

应用程序需要使用的资源通常是有限的,比如数据库并发的连接数可能有限,CPU 核数也是有限的。我们不想使其超负荷运行,或者根据 API 当前的订阅策略,我们使用的 API 只允许少量的并发请求。因此我们可能会考虑设定多大的负载来访问该 API,从而测试该 API 对分布式拒绝服务攻击的应变能力。

信号量是一种可在这些情况下帮助我们完成任务的结构,它的作用很像锁,但可以被多次获取。信号量内部有一个可以设定的初始值,每次获取(acquire)时,内部的值会减 1,释放(release)时则加 1。如果这个值为 0,那么再次获取信号量时会发生阻塞。所以如果和上面的锁做对比的话,可以把锁看作是内部初始值为 1 的信号量。

import asyncio
from asyncio import Semaphore

async def operation(sem: Semaphore):
    print("等待获取信号量")
    async with sem:
        print("信号量已获取")
        await asyncio.sleep(2)
    print("信号量已释放")

async def main():
    sem = Semaphore(2)
    await asyncio.gather(*[operation(sem) for _ in range(4)])

asyncio.run(main())
"""
等待获取信号量
信号量已获取
等待获取信号量
信号量已获取
等待获取信号量
等待获取信号量
信号量已释放
信号量已释放
信号量已获取
信号量已获取
信号量已释放
信号量已释放
"""

由于信号量在阻塞之前只允许被获取 2 次,所以前两个任务可以成功获取锁,而其他两个任务需要等待前两个任务释放信号量。一旦前两个任务中的工作完成,并释放了信号量,其他两个任务就可以获取信号量并开始运行。

让我们采用这种模式,并将其应用于现实世界的案例。假设你正在为一家充满斗志但资金拮据的初创公司工作,而你刚与第三方 REST API 供应商合作。他们的合同对于无限制地查询来说特别昂贵,但提供了一个只允许 10 个并发请求的收费计划,这将更加经济实惠。如果你同时发出超过 10 个请求,API 将返回状态码 429(请求过多)。如果收到状态码 429,你可发送一组请求并重试,但这样的效率很低,并会给供应商的服务器带来额外负载。他们的管理员可能会发现这种行为,并发出警告。更好的方法是创建一个限制为 10 的信号量,然后在你发出 API 请求时先获取信号量,这样就能确保在任何给定时间都最多只有 10 个正在运行的请求。

让我们看看如何使用 aiohttp 库来实现这一点,将向示例 API 发出 1000 个请求,但使用信号量将并发请求总数限制为 10 个。注意,aiohttp 也有我们可以调整的连接限制参数,默认情况下它一次只允许 100 个连接。通过调整此限制参数,可实现与以下代码相同的效果。

import asyncio
from asyncio import Semaphore
from aiohttp import ClientSession

async def get_status(url: str,
                     session: ClientSession,
                     sem: Semaphore):
    print(f"等待获取信号量")
    async with sem:
        print("信号量已获取, 正在发送请求")
        response = await session.get(url)
        print("请求已完成")
        return response.status

async def main():
    sem = Semaphore(10)
    async with ClientSession() as session:
        tasks = [get_status("http://www.baidu.com", session, sem) for _ in range(1000)]
        await asyncio.gather(*tasks)

asyncio.run(main())

每次请求完成时,信号量就会被释放,意味着一个被阻塞且正在等待信号量的任务可以开始了。这表示在给定时间内最多只能运行 10 个请求,当一个请求完成时,可以开始一个新请求。

这解决了并发运行的请求过多的问题,但上面的代码是突发进行的,这意味着它可能同时突发 10 个请求,从而造成潜在的流量峰值。如果担心正在调用的 API 出现负载峰值,上面的方法可能不是最佳选择。如果你只需要在某个时间单位内突发一定数量的请求,则需要将其与流量重塑(trafc-shaping,也称为流量整形)算法的实现一起使用,例如"漏桶(leaky bucket)算法" 或 "令牌桶(token bucket)算法"。

有界信号量

如果总是在 async with 块中使用信号量,由于每个 acquire 都会自动与一个 release 配对,那么调用 release 的次数和调用 acquire 的次数一定是相等的。然而如果需要进行更细粒度的控制,我们可能会手动调用 acquire 和 release,但如果调用 release 的次数多于 acquire 会怎么样?

import asyncio
from asyncio import Semaphore

async def delay(sem: Semaphore):
    async with sem:
        await asyncio.sleep(1)

async def main():
    sem = Semaphore(2)
    loop = asyncio.get_running_loop()
    start = loop.time()
    await asyncio.gather(*[delay(sem) for _ in range(4)])
    end = loop.time()
    print(f"耗时: {end - start}")

    # 调用两次 release,此时信号量内部的值会增加 2
    sem.release()
    sem.release()
    start = loop.time()
    await asyncio.gather(*[delay(sem) for _ in range(4)])
    end = loop.time()
    print(f"耗时: {end - start}")

asyncio.run(main())
"""
耗时: 2.0027941250000003
耗时: 1.001541209
"""

信号量内部有一个初始值,调用 acquire 时,该值减 1,调用 release 时,该值加 1。当值为 0 时,说明并发达到限制,获取信号量时会发生阻塞,直到某个协程调用 release 方法让值大于 0。但这个调用是不受限制的,只要调用 release,信号量内部的值就会增加,因此在某些时候会产生意想不到的结果。

对于动态修改并发量的需求来说,到时很适合。

为处理这种情况,asyncio 提供了一个 BoundedSemaphore,即有界信号量。它的行为与我们一直在使用的信号量完全一样,主要区别在于有界信号量调用 release 会抛出一个 ValueError: BoundedSemaphore release too many times 异常,从而改变可用的信号量限制数。

import asyncio
from asyncio import BoundedSemaphore

async def delay(sem: BoundedSemaphore):
    async with sem:
        await asyncio.sleep(1)

async def main():
    sem = BoundedSemaphore()
    # 像信号量、锁等,调用 acquire 时需要使用 await 表达式
    # 因为在 acquire 内部会创建一个 future 并添加到双端队列 _waiters 中,然后 await future
    # 当其它任务执行完之后退出时,会调用 _waiters 里的第一个 future 的 set_result
    # 从而唤醒某个任务
    await sem.acquire()
    # 而调用 release 则不需要 await,它就是一个普通的函数
    sem.release()
    sem.release()

asyncio.run(main())
"""
    raise ValueError('BoundedSemaphore released too many times')
ValueError: BoundedSemaphore released too many times
"""

运行代码时,对 release 的第二次调用将抛出一个 ValueError,告诉我们 release 方法调用太多次。那么有界信号量的内部是怎么做的呢?很简单,看一下源码就知道了。

以上就是有界信号量的全部源码,可以看到它继承自信号量,并重写了 release 方法。首先信号量内部有一个 self._value 属性,表示信号量的初始值(不指定默认为 1,此时等价于 Lock),调用 acquire 减 1,调用 release 加 1。而有界信号量里面还有一个 self._bound_value 也表示初始值,如果调用 release 时发现 self._value 大于等于 self._bound_value,那么就知道如果继续加 1 就会超过最初设定的值,于是报错。

非常简单,有界信号量就是在 release 方法中多做了一层检测罢了,而具体操作还是调用的信号量的方法。

现在我们已经了解了如何使用信号量来限制并发性,这对需要在应用程序中限制并发性的情况很有帮助。但 asyncio 同步原语不仅允许限制并发性,还允许在发生某些事情时通知任务,接下来让我们看看如何使用 Event 同步原语来实现这一点。

使用事件来通知任务

有时,我们可能需要等待一些外部事件发生才能继续运行程序,比如等待缓冲区填满,等待设备连接到应用程序,或者等待一些初始化完成。而事件对象提供了一种机制,可帮助我们在空闲时等待特定事件的发生。

在后台,Event 类跟踪一个标志,该标志指示事件是否已经发生。可通过两个方法(set 和 clear)来控制这个标志,set 方法将这个内部标志设置为 True,并通知所有等待事件发生的任务。clear 方法将这个内部标志设置为 False,等待该事件的任何任务都将被阻塞。

使用这两种方法,可管理内部状态,但是我们如何阻塞直到事件发生呢?Event 类有一个名为 wait 的协程方法,当等待这个协程时,将发生阻塞,直到有人调用事件对象上的 set 方法。一旦发生这种情况,任何额外的等待调用都不会阻塞并且立即返回。如果在调用 set 后调用 clear,则调用 wait 将再次开始阻塞,直到我们再次调用 set。

import asyncio
from asyncio import Event

def trigger_event(event: Event):
    event.set()

async def do_work_on_event(event: Event):
    print("等待事件发生")
    await event.wait()  # 如果标志位不为 True,那么此处会阻塞
    # 一旦事件发生,wait 将不再阻塞,我们可以继续运行程序
    print("event 标志位被设置为 True,开始执行逻辑")
    await asyncio.sleep(1)
    print("执行结束")
    # 重置事件,后续 await event.wait() 会再次阻塞
    event.clear()

async def main():
    # Event 实例化之后,标志位默认为 False
    event = asyncio.Event()
    # 5 秒后调用 trigger_event,在里面会执行 event.set()
    asyncio.get_running_loop().call_later(5, trigger_event, event)
    await asyncio.gather(do_work_on_event(event), do_work_on_event(event))

asyncio.run(main())
"""
等待事件发生
等待事件发生
event 标志位被设置为 True,开始执行逻辑
event 标志位被设置为 True,开始执行逻辑
执行结束
执行结束
"""

在代码中,创建了一个协程方法 do_work_on_event,这个协程接收一个事件,并首先调用它的 wait 方法。这将一直阻塞,直到有人调用事件的 set 方法来指示事件已经发生。还创建了一个简单的方法 trigger_event,用于设置给定的事件。在主协程中,创建了一个事件对象,并使用 call_later 在 5 秒后触发事件,然后创建两个并发任务。因此我们会看到两个 do_work_on_event 任务闲置 5 秒,直到事件触发。

这向我们展示了基础用法,等待一个事件将阻塞一个或多个协程,直到触发一个事件,之后它们可以继续运行。接下来,让我们看一个更真实的示例。假设正在构建一个 API 来接收来自客户端的上传文件,由于网络延迟和缓冲,文件上传可能需要些时间才能完成。有了这个约束,我们希望 API 有一个可以阻塞的协程,直到文件完全上传完成。然后,这个协程的调用者可以等待所有数据载入,并对数据做想做的任何操作。

import asyncio
from asyncio import StreamReader, StreamWriter

class FileUpload:

    def __init__(self, reader: StreamReader,
                 writer: StreamWriter):
        self._reader = reader
        self._writer = writer
        self._finished_event = asyncio.Event()
        self._buffer = b""
        self._upload_task = None

    def listen_for_upload(self):
        # 创建一个任务来监听上传,并将其附加到缓冲区。
        self._upload_task = asyncio.create_task(self._accept_upload())

    async def _accept_upload(self):
        while data := await self._reader.read(1024):
            self._buffer += data
        self._finished_event.set()
        self._writer.close()
        await self._writer.wait_closed()

    async def get_contents(self):
        # 阻塞,直到完成的事件被设置,然后返回缓冲区的内容
        await self._finished_event.wait()
        return self._buffer

当客户端连接时,将创建一个 FileUpload 对象,并调用 listen_for_uploads,会不断接收客户端上传的文件数据。然后通过 get_contents 方法获取结果即可,调用该方法时会阻塞在 wait 处,当客户端将文件上传完毕,将标志位设置为 True(事件发生)时,再解除阻塞,并返回文件内容。

class FileServer:

    def __init__(self, host: str, port: int):
        self.host = host
        self.port = port

    async def start_server(self):
        server = await asyncio.start_server(self._client_connect,
                                            self.host, self.port)
        await server.serve_forever()

    async def dump_contents_on_complete(self, upload: FileUpload):
        file_contents = await upload.get_contents()
        print(file_contents)

    def _client_connect(self, reader: StreamReader, writer: StreamWriter):
        upload = FileUpload(reader, writer)
        upload.listen_for_upload()
        asyncio.create_task(self.dump_contents_on_complete(upload))

async def main():
    server = FileServer("localhost", 9999)
    await server.start_server()

asyncio.run(main())

在代码中我们创建了一个 FileServer,每次客户端连接到服务器时,都会执行 self._client_connect 方法。在里面创建一个 FileUpload 类的实例,调用它的 listen_for_upload 方法监听已连接客户端的上传动作。同时为 self.dump_contents_on_complete 协程方法创建一个任务,在里面调用 upload.get_contents 协程获取客户端上传的文件内容(仅在上传完成后返回),并打印在控制台上。

我们测试一下,当前有一个 words.txt 文件,里面内容如下:

Hello World
Hello Cruel World
Hello Beautiful World

我们将该文件上传给服务端:

import socket

client = socket.socket()
client.connect(("localhost", 9999))

with open("words.txt", "rb") as f:
    client.sendfile(f)

执行之后,查看服务端的输出:

结果一切正常,并且只有在上传完毕后才会打印。但需要注意,事件的一个缺点是它们触发的频率可能比协程响应的频率要高。假设我们在一种 producer、consumer 工作流中使用单个事件来唤醒多个任务,如果任务需要运行很长时间,那么可能永远看不到事件的变化。

如果我们只是想在特定事件发生时发出警报,那么可以使用事件来完成这样的工作。但如果我们需要将等待事件与对共享资源(如数据库连接)的独占访问结合起来,该怎么办呢?条件可以帮助我们解决这些问题。

条件

事件对于简单的通知很有用,但如果对于更复杂的用例呢?想象一下,需要访问一个共享资源,需要对某个事件进行锁定,或者需要等待一组更复杂的事实,然后才能继续或只唤醒特定数量的任务(而不是所有任务)。那么"条件"在这些情况下将提供很大帮助,另外条件是到目前为止我们遇到的最复杂的同步原语,因此,你可能不需要经常使用它。

"条件"将锁和事件的各个方面结合到一个同步原语中,有效地包装了两者的行为。我们首先获取条件锁,让协程独占访问共享资源,从而能够安全地更改需要的任何状态。然后使用 wait 或 wait_for 协程等待特定事件发生,那么这些协程释放锁并阻塞,直到事件发生。一旦事件发生,它就会重新获得锁,从而提供独占访问权限。

import asyncio
from asyncio import Condition

async def do_work(cond: Condition):
    print("等待条件锁")
    async with cond:
        print("已获得锁,然后立即释放")
        # 等待事件触发,一旦成功,重新获取条件锁
        await cond.wait()
        print("再次获得条件锁,继续执行逻辑")
        await asyncio.sleep(1)
    # 退出 async with 语句块后,释放条件锁。
    print("工作结束,释放锁")

async def fire_event(cond: Condition):
    await asyncio.sleep(5)
    print("准备获取条件锁")
    async with cond:
        # 通知所有任务:事件已经发生
        cond.notify_all()
    print("通知完毕,释放锁")

async def main():
    cond = Condition()
    asyncio.create_task(fire_event(cond))
    await asyncio.gather(do_work(cond), do_work(cond))

asyncio.run(main())
"""
等待条件锁
已获得锁,然后立即释放
等待条件锁
已获得锁,然后立即释放
准备获取条件锁
通知完毕,释放锁
再次获得条件锁,继续执行逻辑
工作结束,释放锁
再次获得条件锁,继续执行逻辑
工作结束,释放锁
"""

解释一下,当调用 await cond.acquire() 时,会获取锁,其它任务再次调用时会阻塞,调用 cond.release() 的时候会释放锁,其它获取锁的任务解除阻塞。这两个方法可以组合起来,用 async with cond 实现,比较简单,和 Lock 没有什么区别。

然后 cond 有一个 wait 协程方法,它必须在获取锁之后才能调用,调用之后会立即将锁释放掉,并陷入阻塞。所以上面代码打印的前 4 行输出是:

"""
等待条件锁
已获得锁,然后立即释放
等待条件锁
已获得锁,然后立即释放
"""

调用 await cond.wait() 的任务什么时候才能解除阻塞呢?答案是通过 cond.notify_all() 唤醒,调用 notify_all() 之前仍然需要先获取锁。所以 5 秒后,fire_event 写成里面会打印两行输出:

"""
准备获取条件锁
通知完毕,释放锁
"""

await cond.wait() 同样会在内部创建一个 future 并添加到双端队列里面,然后 await future,当调用 cond.notify_all() 的时候,会将里面的所有 future 都调用 set_result 方法,让阻塞的任务解除阻塞。

但需要注意:cond.wait() 解除阻塞之后,会再次获取锁,也就是外层的 finally 逻辑。之所以这么做的原因有两个:

  • 在调用 cond.wait() 的时候已经将锁释放了,而 async with cond: 结束之后会再释放一次,这就造成二次释放了。所以当内部的 future 解除阻塞之后,返回之前要再获取锁。当然这个原因很牵强,真正的原因是第二个;
  • 如果所有阻塞的任务同时解除阻塞,那么先执行哪一个呢?显然这有可能造成并发错误,所以要再度获取锁。因此虽然所有任务都被唤醒了,但只是从 cond.wait() 里面的 await fut 这一步唤醒了,之后它们会去 await self.acquire(),但只能有一个任务能拿到锁,没有拿到的任务会继续阻塞,对外表现就是仍然阻塞在 await cond.wait() 这里;

至于 notify_all 的源码也很简单,我们说过,这些同步原语实现等待过程就是在内部创建一个 future,然后 await future。而唤醒的过程的就是往 future 里面塞一个值(比如 None、False),由于可能有很多任务在同时等待,所以 future 会有很多个,这些 future 在创建之后会被扔到一个双端队列里面。所谓唤醒,就是从队列里面获取 future,调用它的 set_result。

唤醒任务,会从双端队列中的第一个 future 对应的任务开始唤醒

以上我们就聊了聊一些源码相关的细节,比较简单,这里这补充一个协程方法 wait_for。该方法接收一个返回布尔值的无参函数,会一直阻塞到函数返回 True 为止,因此当有一个共享资源与一些依赖于某些状态的协程变为 True 时,使用它将是一个很好的选择。

import asyncio
from enum import Enum

class ConnectionState(Enum):
    WAIT_INIT = 0  # 等待初始化
    INITIALIZING = 1  # 正在初始化
    INITIALIZED = 2  # 初始化已完成

class Connection:

    def __init__(self):
        self._state = ConnectionState.WAIT_INIT
        self._condition = asyncio.Condition()

    async def initialize(self):
        print("连接正在开始初始化")
        await self._change_state(ConnectionState.INITIALIZING)
        await asyncio.sleep(3)
        print("连接初始化已完成")
        await self._change_state(ConnectionState.INITIALIZED)

    async def _change_state(self, state: ConnectionState):
        async with self._condition:
            print(f"连接状态由 {self._state} 变为 {state}")
            self._state = state
            self._condition.notify_all()

    def _is_initialized(self):
        if self._state is not ConnectionState.INITIALIZED:
            print("无法获取连接: 因为它尚未初始化完成")
            return False
        print("可以获取连接: 因为已初始化完成")
        return True
    
    async def execute(self, query: str):
        async with self._condition:
            print(f"等待连接初始化完成后执行查询")
            # 当 self._is_initialized 方法返回 True 的时候解除阻塞
            await self._condition.wait_for(self._is_initialized)
            print(f"执行查询: {query}")
            await asyncio.sleep(3)
    
async def main():
    connection = Connection()
    query_one = asyncio.create_task(connection.execute("SELECT * FROM t1"))
    query_two = asyncio.create_task(connection.execute("SELECT * FROM t2"))
    asyncio.create_task(connection.initialize())
    await query_one
    await query_two

asyncio.run(main())
"""
等待连接初始化完成后执行查询
无法获取连接: 因为它尚未初始化完成
等待连接初始化完成后执行查询
无法获取连接: 因为它尚未初始化完成
连接正在开始初始化
连接状态由 ConnectionState.WAIT_INIT 变为 ConnectionState.INITIALIZING
无法获取连接: 因为它尚未初始化完成
无法获取连接: 因为它尚未初始化完成
连接初始化已完成
连接状态由 ConnectionState.INITIALIZING 变为 ConnectionState.INITIALIZED
可以获取连接: 因为已初始化完成
执行查询: SELECT * FROM t1
可以获取连接: 因为已初始化完成
执行查询: SELECT * FROM t2
"""

应该不难理解,区别就是 cond.wait() 和 cond.wait_for(func),当调用 cond.notify_all() 的时候,wait() 会立即解除阻塞,而 wait_for(func) 则是会调用 func,当返回 True 时才解除阻塞。

小结

在本篇文章中,我们学习了以下内容:

  • 了解了单线程并发错误,以及它们与多线程和 multiprocessing 中的并发错误有何不同。
  • 使用 asyncio 锁来防止并发错误,并实现同步协程。由于 asyncio 的单线程特性,这种情况较少发生,在等待期间共享状态会发生变化时,可能需要它们。
  • 使用信号量来控制对有限资源的访问,并限制并发性,这在流量整形中很有帮助。
  • 在某些事情发生时使用事件来触发动作,例如初始化或唤醒 worker 任务。
  • 使用"条件"来等待操作,从而获得对共享资源的访问。
posted @ 2023-05-12 15:55  古明地盆  阅读(833)  评论(0编辑  收藏  举报