sync.WaitGroup:协同等待,编排一组 goroutine

楔子

WaitGroup 是用来做任务编排的一个并发原语,它解决的是 "并发 - 等待" 的问题。比如我们要完成一个大的任务,需要使用并行的 goroutine 执行三个小任务,只有这三个小任务都完成,我们才能去执行后面的任务。

这个时候使用使用 WaitGroup 就非常合适了,而且名字也很形象:等待组,它可以对一组 goroutine 进行编排,保证这一组 goroutine 都执行完毕之后程序再往下执行。那么 WaitGroup 是怎么做到的呢?以及用法如何呢?我们下面来看一看。

WaitGroup 的基本用法

首先 WaitGroup 内部有一个计数器,然后围绕着计数器提供了三个方法:

func (wg *WaitGroup) Add(delta int)
func (wg *WaitGroup) Done()
func (wg *WaitGroup) Wait()

我们分别看下这三个方法:

  • Add:给 WaitGroup 内部计数器的值增加 delta
  • Done:给 WaitGroup 内部计数器的值减一
  • Wait:如果计数器的值不为 0,那么此方法会阻塞,如果为 0,程序往下执行

想必你此时已经猜到 WaitGroup 该怎么用了,进入协程时 Add(1),执行结束时 Done() 一下即可。

package main

import (
	"sync"
	"time"
)

type Counter struct {
	sync.Mutex
	count int64
}

func (c *Counter) Incr() {
	c.Lock()
	defer c.Unlock()
	c.count++
}

func (c *Counter) Count() int64 {
	c.Lock()
	defer c.Unlock()
	return c.count
}

func main() {

	var wg sync.WaitGroup
	var count Counter

	for i := 0; i < 10; i++ {
		// 开启 10 个协程,但是在开始的位置 Add(1) 会使得计数器的值加 1
		// 结束的位置 Done() 一下会使得计数器的值减 1
		go func() {
			wg.Add(1)
			count.Incr()
			time.Sleep(time.Second)
			wg.Done()
		}()
	}
	// 如果计数器的值不为 0,那么会一直阻塞
	// 因此我们一定要保证协程执行完毕之后计数器的值为 0,否则程序就会卡死在这里了
	wg.Wait()
	// 如果我们能够知道开启的协程的数量,那么也可以不把 Add(1) 写在协程中
	// 比如我们上面开启了 10 个协程,那么可以在进入 for 循环之前写上 wg.Add(10),此时计数器的值为 10
	// 然后在协程内部只调用 wg.Done() 即可,当所有协程都执行完之后,计数器的值会减去 10,最终变成 0
}

因此使用起来还是比较简单的,核心是一定要确保任务执行完之后计数器的值为 0,否则程序就会卡死。很常见的错误就是 Add 之后忘记 Done,尤其是逻辑比较长的时候,写到最后很容易把 wg.Done() 给忘记了,因此建议通过 defer 来保证。

以上就是我们使用 WaitGroup 编排这类任务的常用方式,而 "这类任务" 指的就是需要启动多个 goroutine 执行子任务,并且主 goroutine 需要等待指定的子 goroutine 都完成后才继续执行。

了解完基本用法之后,下面我们来看看底层实现,因为单从语法的使用层面上讲,确实没有什么难度,毕竟 Go 本身用起来就很简单。所以我们还需要剖析底层实现,当然这就不一定简单了。

WaitGroup 的实现

首先来看一下 WaitGroup 的数据结构,结构体只包含了两个成员:

  • noCopy:辅助字段,一个空结构体(struct {})的别名,主要就是辅助 vet 工具检测该 WaitGroup 实例是否发生了值拷贝,和 Mutex、RWMutex 一样,WaitGroup 在传递时也必须传递指针
  • state1:一个具有复合意义的字段,包含 WaitGroup 的计数器的值、调用 Wait 方法阻塞时的 waiter 数和信号量
type WaitGroup struct {
    noCopy noCopy
    state1 [3]uint32
}

然后我们重点说一下这个 state1 成员,它包含了 WaitGroup 的计数器的值、调用 Wait 方法阻塞时的 waiter 数和信号量。但是有一点需要注意,对于不同的处理器该字段的值也会有区别。

  • 如果是 64 位机器,那么 state1[0] 表示调用 Wait 方法阻塞时的 waiter 数,可以理解为有几个 goroutine 调用 wg.Wait() 阻塞了,那么 waiter 数就是几;state1[1] 表示 WaitGroup 计数器的值;state1[2] 表示信号量
  • 如果是 32 位机器,那么 state1[0] 表示信号量;state1[1] 表示调用 Wait 方法阻塞时的 waiter 数;state1[2] 表示 WaitGroup 计数器的值

WaitGroup 有一个方法 state,专门用来获取上面的信息,我们来看一下。

func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
    // 如果地址(转成十进制整数)模上 8 等于 0,那么说明地址是 8 字节对齐,也就是 64 位机器
    // 否则是 32 位机器,因为 Go 只能运行在 64 和 32 位机器上
    if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
        /* 但是两个 return 可能会有点难理解,首先它们都返回了两个元素,并且第二个值是信号量(指针)
         * 但第一个值是什么?我们注意到 state1 数组的元素是 uint32 类型,而这里转成了 *uint64
         * 很明显,"调用 Wait 方法阻塞的 waiter 数" 和 "WaitGroup 计数器的值" 这两个 uint32 整数
         * 被组合成了一个 uint64 整数,然后返回其指针
         */
        return (*uint64)(unsafe.Pointer(&wg.state1)), &wg.state1[2]
    } else {
        return (*uint64)(unsafe.Pointer(&wg.state1[1])), &wg.state1[0]
    }
}

关于上面的 return,我们实际演示一下,这里只以 64 位机器为例。首先我们来假设一下,如果 state1 数组为 [1, 2, 3],那么你认为 state 方法会返回什么?不用想,肯定是 return 两个值,第二个值是 3,但第一个值是什么呢。

uint32 整数 1 转成二进制:00000000 00000000 00000000 00000001
uint32 整数 2 转成二进制:00000000 00000000 00000000 00000010
拼接成 uint64:00000000 00000000 00000000 00000010 00000000 00000000 00000000 00000001

注意:state1[1] 占高位、state1[0] 占低位,因为数组从左往右地址是增大的,而 Go 里面高位存储在高地址中。

package main

import "fmt"

func main() {
    fmt.Println(0b00000000_00000000_00000000_00000010_00000000_00000000_00000000_00000001)
    /* 8589934593 */
}

我们看到结果是 8589934593,那么实际情况是不是这样呢?

package main

import (
    "fmt"
    "unsafe"
)

func main() {
    var state1 = [3]uint32{1, 2, 3}
    fmt.Println(*(*uint64)(unsafe.Pointer(&state1)))
    /* 8589934593  */
}

我们看到结果是一样的,至于根据结果进行逆运算也是非常容易的。

  • state1[1] 放在高位,那么计算的时候直接右移 32 个位即可
  • state1[0] 放在低位,那么计算的时候只需要和后 16 个位为 1(其它位为 0)的整数进行与运算即可
package main

import (
    "fmt"
)

func main() {
    var num = 8589934593
    fmt.Printf("state[1]:%d,state[0]:%d", num>>32, num&0xFFFF)
    /* state[1]:2,state[0]:1 */
}

以上我们就解释了 state 方法里面的 return 到底是怎么回事,当然有点跑题了,不过也还好。

回归正题,我们来看一下 Add、Done、Wait 方法的实现,删除掉了一些用于 race 检查和异常检查的部分,这里我们重点关注 Add、Done、Wait 这些方法本身的实现。

Add 方法

Add 方法,我们知道主要负责操作计数器的值,也就是 state1[1]。我们可以给计数器增加一个 delta,内部会通过原子操作将这个值加上去。此外需要注意的是,delta 也可以是一个负数,Done 方法内部就是通过 Add(-1) 实现的。

func (wg *WaitGroup) Add(delta int) {
    statep, semap := wg.state()
    // 给计数器增加 delta,所以 delta 左移 32 位加上去即可
    state := atomic.AddUint64(statep, uint64(delta)<<32)
    // 获取计数器(state[1])的值,我们说它位于高位,所以右移 32 位即可
    v := int32(state >> 32)
    // 获取调用 Wait 方法阻塞的 waiter 数(state[0]),我们之前是和 0xFFFF 进行与运算
    // 但是显然 Go 内部的做法更简单,直接转成 uint32 即可,超过 32 位的部分会截断
    w := uint32(state)
    // 计数器的值一定大于等于 0
    // 如果计数器大于 0,或者没有阻塞的 goroutine,那么直接返回
    if v > 0 || w == 0 {
        return
    }
    // 如果计数器的值 v 是 0 并且 waiter 的数量 w 不是 0,那么 state 的值就是 waiter 的数量
    // 此时应该将 waiter 的数量设置为 0,因为计数器的值为 0 了,所以应该唤醒所有的 waiter
    // 这个 waiter 就是调用 wg.Wait() 阻塞的 goroutine,由于 v 是 0,w 也要设置为 0,所以将 *statep 直接设置为 0 即可
    *statep = 0
    for ; w != 0; w-- {
        runtime_Semrelease(semap, false, 0)
    }
}

整个逻辑还是很简单的,就是将计数器的值增加 delta。然后判断计数器的值 v 是否等于 0,如果 v 是 0、并且 w 不是 0,那么唤醒所有的 waiter 即可。

Done 方法

Done 的逻辑很简单,我们说它内部是通过 Add 实现的,可以看一下代码。

// Done 方法实际就是计数器减 1
func (wg *WaitGroup) Done() {
    wg.Add(-1)
}

Wait 方法

Wait 方法的逻辑是:不断检查 state 的值,如果其中的计数器的值变为了 0,那么说明所有的任务都已完成,调用者不必再等待,直接返回。如果计数值大于 0,说明此时还有任务没完成,那么调用者就变成了等待者,需要加入 waiter 队列,并且阻塞住自己。

func (wg *WaitGroup) Wait() {
    statep, semap := wg.state()

    for {
        state := atomic.LoadUint64(statep)
        // 当前计数器的值
        v := int32(state >> 32)
        // waiter 的数量,就是调用 wg.Done() 而阻塞的 goroutine 的数量
        w := uint32(state)
        if v == 0 {
            // 如果计数值为 0,调用这个方法的 goroutine 不必再等待,直接返回即可
            return
        }
        // 否则把 waiter 数量加 1
        // 期间可能有并发调用 Wait 的情况,所以最外层使用了一个 for 循环
        if atomic.CompareAndSwapUint64(statep, state, state+1) {
            // 阻塞休眠等待,而唤醒则通过 wg.Add,当计数器更新之后值变为 0,那么会唤醒 waiter
            runtime_Semacquire(semap)
            return
        }
    }
}

以上就是这三个方法的底层实现,其实也挺简单的。

使用 WaitGroup 时的常见错误

使用 WaitGroup 也是会有很多坑的,下面来看一下。

计数器的值为负数

WaitGroup 的计数器的值必须大于等于 0,我们在更改这个计数值的时候,WaitGroup 会检查计数器的值,如果小于 0,则引发 panic。

一般情况下,有两种情况会导致计数器设置为负数。

第一种情况:调用 Add 的时候传递一个负数,当然,如果你能保证当前的计数器的值加上这个负数后还是大于等于 0 的话,也没有问题,否则就会导致 panic;

func main() {
    var wg sync.WaitGroup
    wg.Add(10)
    wg.Add(-10) // 将 -10 作为参数调用 Add,计数器的值被设置为0
    wg.Add(-1)  // 将 -1 作为参数调用 Add,此时会变成 -1,会引发 panic
}

第二种情况:调用 Done 方法的次数过多,超过了 WaitGroup 的计数值;

func main() {
    var wg sync.WaitGroup
    wg.Add(1)
    wg.Done()
    wg.Done()  // panic
}

使用 WaitGroup 的正确姿势是,预先确定好 WaitGroup 的计数值,然后调用相同次数的 Done 完成相应的任务。比如在 WaitGroup 变量声明之后,就立即设置它的计数,但这要求你必须事先知道组里面到底有多个 goroutine。或者在 goroutine 启动之前增加 1,然后在 goroutine 退出之后调用 Done。

如果你没有遵循这些规则,就很可能会导致 Done 方法调用的次数和计数值不一致,进而造成死锁(Done 调用次数比计数值少)或者 panic(Done 调用次数比计数值多)。

Add 之前先 Wait

在使用 WaitGroup 的时候,要遵循一个原则:等所有的 Add 方法调用之后再调用 Wait,否则就可能导致 panic 或者不期望的结果。

package main

import (
    "fmt"
    "sync"
    "time"
)

func doSomething(wg *sync.WaitGroup) {
    time.Sleep(1000)  // 故意 sleep 一下
    wg.Add(1)
    fmt.Println("do something")
    wg.Done()
}
func main() {
    var wg sync.WaitGroup
    go doSomething(&wg)
    go doSomething(&wg)
    go doSomething(&wg)
    wg.Wait()
}

程序执行之后会发现什么也没有打印,原因就在于子协程执行 wg.Add(1) 的时候,主协程就已经执行了 wg.Wait()。而计数器初始是为 0 的,所以主协程不会阻塞,然后程序退出了。因此我们要确保子协程中的 Add 一定要在 Wait 之前执行。

Add 和 Wait 同时调用

首先 WaitGroup 是可重用的,只要 WaitGroup 内部计数器的值恢复到零值的状态,那么它就可以被看作是新创建的 WaitGroup。之后我们可以继续把该 WaitGroup 当成是新创建的来用,然后继续调用 wg.Add,但是这一步一定要在上一轮 wg.Wait 结束之后进行。光说可能有点难理解,我们举个栗子:

package main

import (
    "sync"
    "time"
)

func main() {
    var wg sync.WaitGroup
    go func() {
        time.Sleep(time.Millisecond)
        wg.Add(1)
        wg.Done() // 计数器减 1
        wg.Add(1) // 计数值加 1
    }()
    wg.Wait() // 主 goroutine 等待,有可能和第二个 wg.Add(1) 并发执行
}

当调用完 wg.Done() 之后,阻塞在 wg.Wait() 处的协程会被唤醒,然后这个 wg 就可以看成是新创建的 wg,因为此时内部的成员的值为零值,和新创建一个 wg 没有区别。但重点是:执行 wg.Wait() 的同时(唤醒阻塞 goroutine),子协程内部也会执行 wg.Add(1),如果这两者并发执行,那么就会 panic。

不过个人测试了几次均没有出现 panic,因为 wg.Wait() 总是在子协程的第二个 wg.Add(1) 执行之前先执行完,但如果真的不幸、这两者同时执行了,那么就会造成 panic。所以如果想要重用 WaitGroup,那么一定要等到上一轮的 wg.Wait() 执行完毕之后,再执行 wg.Add。

Docker 源码里面就犯过两次这种错误.

小结

WaitGroup 的使用场景还是很明确的,就是编排一组 goroutine。尽管使用 WaitGroup 也容易踩坑,但只要记住以下五点,便可以避免。

  • 不重用 WaitGroup,新建一个 WaitGroup 不会带来多大的资源开销,重用反而更容易出错
  • 保证所有的 Add 方法调用都在 Wait 之前
  • 不传递负数给 Add 方法,只通过 Done 来给计数值减 1
  • 不做多余的 Done 方法调用,保证 Add 的计数值和 Done 方法调用的数量是一样的
  • 不遗漏 Done 方法的调用,否则会导致 Wait 阻塞而无法返回
posted @ 2020-03-01 22:25  古明地盆  阅读(5554)  评论(0编辑  收藏  举报