揭祕 Go 併發利器 WaitGroup

在 Go 語言的併發編程世界中,WaitGroup 是一個至關重要的工具,它爲開發者提供了一種簡單而有效的方式來管理和同步多個協程的執行。本文將深入揭祕 WaitGroup 的實現原理、注意事項、使用示例。

什麼是 WaitGroup

WaitGroup 是 Go 標準庫中 sync 包提供的一種同步原語,用於等待一組(可能是併發的)操作完成。它的主要作用是讓主協程(即調用 WaitGroup 相關方法的協程)能夠等待其他協程完成任務後再繼續執行,確保所有併發操作都按預期完成。

package main

import (
    "fmt"
    "sync"
)

func worker(id int, wg *sync.WaitGroup) {
    defer wg.Done()
    fmt.Printf("Worker %d starting\n", id)
    // 模擬一些工作
    for i := 0; i < 5; i++ {
        fmt.Printf("Worker %d working... step %d\n", id, i)
    }
    fmt.Printf("Worker %d done\n", id)
}

func main() {
    var wg sync.WaitGroup

    // 設置等待組計數器爲 3,表示有三個協程需要等待
    wg.Add(3)

    // 啓動三個協程
    go worker(1, &wg)
    go worker(2, &wg)
    go worker(3, &wg)

    // 等待所有協程完成
    wg.Wait()
    fmt.Println("All workers are done.")
}

WaitGroup 的核心方法

WaitGroup 對外提供了 Add、Done、Wait 三個方法,這三個方法需要搭配使用。

Add 方法

func main() {
    var wg sync.WaitGroup

    // 設置等待組計數器爲 3,表示有三個協程需要等待
    wg.Add(3)
    
    // ...
}

Done 方法

Wait 方法

func main() {
    var wg sync.WaitGroup

    // 設置等待組計數器爲 3,表示有三個協程需要等待
    wg.Add(3)

    // 啓動三個協程
    go worker(1, &wg)
    go worker(2, &wg)
    go worker(3, &wg)

    // 等待所有協程完成
    wg.Wait()
    fmt.Println("All workers are done.")
}

WaitGroup 的實現原理

WaitGroup 結構體

type WaitGroup struct {
    noCopy noCopy

    state atomic.Uint64 // high 32 bits are counter, low 32 bits are waiter count.
    sema  uint32
}

WaitGroup 是一個結構體,裏面有 state 和 sema 兩個核心字段,其中:

func main() {
    var wg sync.WaitGroup
    wg.Add(3)
    go worker(1, &wg)
    go worker(2, &wg)
    go worker(3, &wg)
    wg.Wait()
}

當 worker 沒有執行完時,state 的內存模型如下圖所示:

state 的一些操作:

int32(state >> 32) // 取高 32 位的值
uint32(state) // 取低 32 位的值

wg.state.Add(uint64(delta) << 32) // 高 32 位加減
wg.state.CompareAndSwap(state, state+1) // 低 32 位操作

sema 的初始值是 0,調用 Wait 方法會進行 P 操作,如果此時沒有 V 操作,Wait 方法就會阻塞,然後子協程執行完會調用 Done 方法,當 state 高 32 位爲 0 時,就會進行 V 操作,這時 Wait 方法就會被喚醒繼續執行。

Add 方法源碼

func (wg *WaitGroup) Add(delta int) {
    // 省略 race 相關代碼
    
    // 1. 原子更新 state 的值
    state := wg.state.Add(uint64(delta) << 32)
    
    // 2. 通過位操作獲取子協程計數器和主協程計數器的值
    v := int32(state >> 32)
    w := uint32(state)
    
    // 省略 race 相關代碼
    
    // 3. v < 0:拋出 panic , v 的值不可以是負值
    if v < 0 {
       panic("sync: negative WaitGroup counter")
    }
    
    // 4. Wait 和 Add() 不能同時被調用,否則會拋出 panic
    // 4.1 w != 0 說明 Wait 方法已經被調用但是還沒返回
    // 4.2 delta > 0 && v == int32(delta) 說明調用了 Add() 方法
    if w != 0 && delta > 0 && v == int32(delta) {
       panic("sync: WaitGroup misuse: Add called concurrently with Wait")
    }
    
    // 5. 執行到此處可能有兩種情況
    // 5.1 Wait 方法還沒被調用(w == 0),此時調用 Add() 或者 Done() 都直接返回,不需要進行 V 操作
    // 5.2 Wait 方法已經被調用(w != 0),此時只能調用 Done(), 若果 v > 0 說明子協程沒有全部執行完,可以直接返回,不需要進行 V 操作
    if v > 0 || w == 0 {
       return
    }
    
    // 6. 執行到這裏說明 v == 0 && w != 0,所有子協程都已經執行完,
    // v == 0 時調用 Wait() 並不會更改state,再次檢查 state 防止有併發調用 Add,
    if wg.state.Load() != state {
       panic("sync: WaitGroup misuse: Add called concurrently with Wait")
    }
    
    // 7. 執行 V 操作,釋放所有因調用 Wait 而阻塞的協程
    wg.state.Store(0)
    for ; w != 0; w-- {
       runtime_Semrelease(&wg.sema, false, 0)
    }
}

關鍵邏輯已經添加註釋了,總結一下 Add 主要操作:

  1. 更新 state,通常 Add 是增加操作,Done 是減少操作;

  2. 校驗 state,防止有 v < 0 或者錯誤調用 Add 函數的情況;

  3. 如果 state 符合預期,執行 V 操作,釋放所有因調用 Wait 而阻塞的協程。

Done 方法源碼

func (wg *WaitGroup) Done() {
    wg.Add(-1)
}

Done() 的底層實現是調用了 Add() 方法,參數是 -1 代表完成一個任務。

Wait 方法源碼

func (wg *WaitGroup) Wait() {
    // 省略 race 相關代碼
    
    for {
       state := wg.state.Load()
       v := int32(state >> 32)
       w := uint32(state)
       if v == 0 {
          // 1. Counter is 0, no need to wait.
          // 省略 race 相關代碼
          return
       }
       
       // 2. Increment waiters count.
       if wg.state.CompareAndSwap(state, state+1) {
          // 省略 race 相關代碼
          
          // 3. P 操作,阻塞當前進程
          runtime_Semacquire(&wg.sema)
          if wg.state.Load() != 0 {
             // 4. 當前協程已經被喚醒,此時應該 v == 0,Wait 沒有返回前不可以複用 WaitGroup
             panic("sync: WaitGroup is reused before previous Wait has returned")
          }
          // 省略 race 相關代碼
          return
       }
    }
}

Wait 的主要邏輯:

  1. 判斷 v 的值,如果 v == 0 , Wait 可以直接返回,v == 0 說明沒有需要等待的子協程;

  2. 使用 CompareAndSwap 進行 state + 1 操作,如果執行成功進行下面步驟,如果不成功開啓新一輪 Wait 邏輯;

  3. 如果 state + 1 操作成功後,需要進行 P 操作,阻塞當前進程;

  4. 當前協程已經被喚醒,再次校驗 state 的值,此時應該 state == 0,Wait 沒有返回前不可以複用 WaitGroup。

WaitGroup 的注意事項

正確使用 add 和 done 方法:

確保在啓動協程之前正確地調用 add 方法來設置等待的協程數量,並且在協程完成任務後及時調用 done 方法。避免忘記調用 done 方法導致程序永遠阻塞在 wait 上,或者超量調用 done 方法導致計數器變爲負數而引發 panic。

package main

import (
    "fmt"
    "sync"
)

func worker(id int, wg *sync.WaitGroup) {
    defer func(){
        // 調用了兩次 Done() 方法
        wg.Done()
        wg.Done()
    }()
        
    
    fmt.Printf("Worker %d done\n", id)
}

func main() {
    var wg sync.WaitGroup
    wg.Add(1)

    go worker(1, &wg)

    wg.Wait()
    fmt.Println("All workers are done.")
}

$ go run main.go 
Worker 1 done

panic: sync: negative WaitGroup counter

goroutine 18 [running]:
sync.(*WaitGroup).Add(0xc00007e020, 0xffffffffffffffff)
        /usr/local/go-1.13.5/src/sync/waitgroup.go:74 +0x139
sync.(*WaitGroup).Done(...)
        /usr/local/go-1.13.5/src/sync/waitgroup.go:99
main.worker.func1(0xc00007e020)
        /box/main.go:11 +0x4c
main.worker(0x1, 0xc00007e020)
        /box/main.go:16 +0xf2
created by main.main
        /box/main.go:22 +0x78
panic: sync: WaitGroup is reused before previous Wait has returned

goroutine 1 [running]:
sync.(*WaitGroup).Wait(0xc00007e020)
        /usr/local/go-1.13.5/src/sync/waitgroup.go:132 +0xad
main.main()
        /box/main.go:24 +0x86

Exited with error status 2

合理複用 WaitGroup:

WaitGroup 對象可以在所有協程完成後重用。但是在重用時,要確保之前的 wait 方法已經返回,否則可能會出現不可預期的行爲。

package main

import (
    "fmt"
    "sync"
)

func worker(id int, wg *sync.WaitGroup) {
    defer func(){
        // 調用了兩次 Done() 方法
        wg.Done()
        go func(){
            wg.Add(1)
        }()
    }()
        
    
    fmt.Printf("Worker %d done\n", id)
}

func main() {
    var wg sync.WaitGroup
    wg.Add(1)

    go worker(1, &wg)

    wg.Wait()
    fmt.Println("All workers are done.")
}

$ go run main.go
Worker 1 done

panic: sync: WaitGroup is reused before previous Wait has returned

goroutine 1 [running]:
sync.(*WaitGroup).Wait(0xc000016060)
        /usr/local/go-1.13.5/src/sync/waitgroup.go:132 +0xad
main.main()
        /box/main.go:27 +0x86

Exited with error status 2

不要複製 WaitGroup:

WaitGroup 實例是不期望被複制的,如果複製後需要當做不同的實例看待,如果錯誤的使用了複製後的實例,可能造成協程泄漏:

package main

import (
    "fmt"
    "sync"
)

func main() {
    var wg sync.WaitGroup
    wg.Add(1)

    go func(){
        // 複製一個 wg 
        wg := wg
        defer wg.Done()
        
        fmt.Printf("Worker done\n")
    }()

    wg.Wait()
    fmt.Println("All workers are done.")
}

& go run main.go 
Worker done

fatal error: all goroutines are asleep - deadlock!

goroutine 1 [semacquire]:
sync.runtime_Semacquire(0xc000016068)
        /usr/local/go-1.13.5/src/runtime/sema.go:56 +0x42
sync.(*WaitGroup).Wait(0xc000016060)
        /usr/local/go-1.13.5/src/sync/waitgroup.go:130 +0x64
main.main()
        /box/main.go:20 +0x7d

Exited with error status 2

WaitGroup 的示例

並行計算

假設我們需要計算一個大型數組中每個元素的平方值,可以將數組分成多個部分,每個部分由一個協程來處理。使用 WaitGroup 可以確保所有協程都完成計算後再彙總結果。

package main

import (
    "fmt"
    "sync"
)

func square(wg *sync.WaitGroup, slice []int, result chan<- int) {
    defer wg.done()
    for _, v := range slice {
        result <- v * v
    }
}

func main() {
    var wg sync.WaitGroup
    slice := []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
    result := make(chan int, len(slice))

    // 將數組分成 3 個部分,每個部分由一個協程處理
    wg.Add(3)
    go square(&wg, slice[:3], result)
    go square(&wg, slice[3:6], result)
    go square(&wg, slice[6:], result)

    // 等待所有協程完成
    wg.Wait()
    close(result)

    // 彙總結果
    var total int
    for v := range result {
        total += v
    }
    fmt.Println(total)
}

併發請求

在開發中,當需要同時處理多個 RPC 請求時,可以使用 WaitGroup 來確保所有請求都處理完成後再返回響應。

package main

import (
    "fmt"
    "sync"
)

func main() {
    var wg sync.WaitGroup

    wg.Add(3)
    go func() {
       defer wg.Done()
       fmt.Println("call rpc1 ...")
    }()
    go func() {
       defer wg.Done()
       fmt.Println("call rpc2 ...")
    }()
    go func() {
       defer wg.Done()
       fmt.Println("call rpc3 ...")
    }()

    // 等待所有請求處理完成
    wg.Wait()
    fmt.Println("所有請求處理完成")
}

總結

WaitGroup 是 Go 語言中非常強大的併發控制工具,它能夠幫助開發者輕鬆地管理和同步多個協程的執行,確保併發操作的正確執行順序。通過正確地使用 WaitGroup,開發者可以編寫出高效、可靠的併發程序,充分發揮 Go 語言的併發優勢。在實際應用中,我們需要深入理解 WaitGroup 的工作原理和使用方法,避免常見的錯誤,以確保程序的正確性和性能。

本文由 Readfog 進行 AMP 轉碼,版權歸原作者所有。
來源https://mp.weixin.qq.com/s/i7RL6wrfNit9Rlp9b8kAGw