Golang WaitGroup实现原理

基本用法

sync.WaitGroup 是一个结构体类型,该类型对外只提供了三个方法, , Add(delta int)用来增加任务数量,Done()用来完成单个任务,Wait()用来等待所有任务完成,一共Add多少,Wait就需要等待多少次Done,在Done全部完成之前Wait()会阻塞当前线程。

1
2
3
4
5
6
7
8
9
10
11
12
func main () {
wg := sync.WaitGroup {}
wg.Add (10)
for i := 0; i < 10; i++ {
go func (JobSeq int) {
defer wg.Done ()
// do something
fmt.Printf ("the job: %d done.\n", JobSeq)
}(i)
}
wg.Wait ()
}

输出结果

1
2
3
4
5
6
7
8
9
10
the job: 9 done.
the job: 0 done.
the job: 2 done.
the job: 8 done.
the job: 4 done.
the job: 5 done.
the job: 7 done.
the job: 1 done.
the job: 3 done.
the job: 6 done.

如果不使用wg.Wait主协程可能会在子协程完成之前退出,造成子协程提前结束的情况

实际应用中,如果一个任务需要多次查询数据库,每个查询之前没有数据依赖,我们就可以开多个协程并发查询,每个协程查询完成以后执行一次Done,主协程可以通过Wait等待所有数据加载完毕

实现原理

1
2
3
4
type WaitGroup struct {
noCopy noCopy
state1 [3]uint32
}

WaitGroup结构体里只有两个字段noCopy, state1

noCopy 是一个特殊的结构体,在编译阶段,go vet 工具会检查 noCopy 字段,避免对象被复制
state1 字段是一个长度是3的uint32数组,该字段设计非常巧妙,同时包含了三个含义 worker计数器,waiter计数器,和信号量

state1读取逻辑如下

1
2
3
4
5
6
7
8
// state returns pointers to the state and sema fields stored within wg.state1.
func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
return (*uint64)(unsafe.Pointer(&wg.state1)), &wg.state1[2]
} else {
return (*uint64)(unsafe.Pointer(&wg.state1[1])), &wg.state1[0]
}
}

先看函数返回结果,statep是一个uint64的指针,因为state1[0]state1[1]是两个地址连续的uint32,所以一个64位的指针包含了两个数组元素

函数内容代码逻辑有点难理解,这里根据state1字段的对齐方式来选择计数器的存储位置,如果是64位对齐的话,wg.state1[0],wg.state1[1]能一次性读取到,如果是32位对齐的话wg.state1[1],wg.state1[2]能一次性读取到,因为我们需要把两个计数器返回到一个变量中去,为了优化cpu读取效率,需要做一次内存对齐

wg.Add

Add操作核心代码如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
func (wg *WaitGroup) Add(delta int) {
statep, semap := wg.state()
// 修改statep高32位,也就是worker计数器
state := atomic.AddUint64(statep, uint64(delta)<<32)
v := int32(state >> 32) // waiter计数器
w := uint32(state) // worker计数器
if v < 0 {
panic("sync: negative WaitGroup counter")
}
if w != 0 && delta > 0 && v == int32(delta) {
panic("sync: WaitGroup misuse: Add called concurrently with Wait")
}
// 2. 判断计数器
if v > 0 || w == 0 {
return
}

// 当 worker计数器降低到0时
// 重置 waiter计数器,并释放信号量
*statep = 0
for ; w != 0; w-- {
runtime_Semrelease(semap, false)
}
}

func (wg *WaitGroup) Done() {
wg.Add(-1)
}

wg.Done

1
2
3
4
// Done decrements the WaitGroup counter by one.
func (wg *WaitGroup) Done() {
wg.Add(-1)
}

wg.Wait

Wait逻辑是先修改计数器,然后等待信号量,核心代码如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
func (wg *WaitGroup) Wait() {
statep, semap := wg.state()
for {
state := atomic.LoadUint64(statep)
v := int32(state >> 32)
w := uint32(state)
if v == 0 {
return
}

// 增加waiter计数器
if atomic.CompareAndSwapUint64(statep, state, state+1) {
// 获取信号量
runtime_Semacquire(semap)
if *statep != 0 {
panic("sync: WaitGroup is reused before previous Wait has returned")
}

// 信号量获取成功
return
}
}
}