Go channel 計數信號量

Go 併發設計的一個慣用法就是將帶緩衝 channel 用作計數信號量(counting semaphore)。

帶緩衝 channel 中的當前數據個數代表的是當前同時處於活動狀態(處理業務)的 goroutine 的數量,而帶緩衝 channel 的容量(capacity)就代表了允許同時處於活動狀態的 goroutine 的最大數量。

向帶緩衝 channel 的一個發送操作表示獲取一個信號量,而從 channel 的一個接收操作則表示釋放一個信號量。

計數信號量經常被使用於限制最大併發數。

e.g.

package main

import (
 "log"
 "sync"
 "time"
)

func main() {
 active := make(chan struct{}, 3)
 jobs := make(chan int, 10)
 go func() {
  for i := 0; i < 8; i++ {
   jobs <- i + 1
  }
  close(jobs)
 }()

 var wg sync.WaitGroup
 for j := range jobs {
  wg.Add(1)
  active <- struct{}{}
  go func(j int) {
   defer func() { <-active }()
   log.Printf("handle job: %d\n", j)
   time.Sleep(2 * time.Second)
   wg.Done()
  }(j)
 }
 wg.Wait()
}

上面的示例創建了一組 goroutines 來處理 job,同一時間允許的最多 3 個 goroutine 處於活動狀態。爲達成這一目標,我們看到示例使用了一個容量 (capacity) 爲 3 的帶緩衝 channel:active 作爲計數信號量,這意味着允許同時處於活動狀態的最大 goroutine 數量爲 3。

運行一下該示例:

2024/07/14 23:15:17 handle job: 3
2024/07/14 23:15:17 handle job: 8
2024/07/14 23:15:17 handle job: 6
2024/07/14 23:15:19 handle job: 1
2024/07/14 23:15:19 handle job: 4
2024/07/14 23:15:19 handle job: 7
2024/07/14 23:15:21 handle job: 2
2024/07/14 23:15:21 handle job: 5

從示例運行結果中的時間戳我們可以看到:雖然我們創建了很多 goroutine,但由於計數信號量的存在,同一時間內處理活動狀態 (正在處理 job) 的 goroutine 的數量最多爲 3 個。

e.g.

package main

import (
 "log"
 "math/rand"
 "time"
)

type Customer struct{ id int }
type Bar chan Customer

func (bar Bar) ServeCustomer(c Customer) {
 log.Print("++ 顧客#", c.id, "開始飲酒")
 time.Sleep(time.Second * time.Duration(3+rand.Intn(16)))
 log.Print("-- 顧客#", c.id, "離開酒吧")
 <-bar // 離開酒吧,騰出位子
}

func main() {
 rand.Seed(time.Now().UnixNano())

 bar24x7 := make(Bar, 10) // 最多同時服務10位顧客
 for customerId := 0; ; customerId++ {
  time.Sleep(time.Second * 2)
  customer := Customer{customerId}
  bar24x7 <- customer // 等待進入酒吧
  go bar24x7.ServeCustomer(customer)
 }
}

Go 在它的擴展包中也提供了信號量庫 semaphore (golang.org/x/sync/semaphore)。

func NewWeighted(n int64) *Weighted
func (s *Weighted) Acquire(ctx context.Context, n int64) error
func (s *Weighted) Release(n int64)
func (s *Weighted) TryAcquire(n int64) bool

e.g.

package main

import (
 "context"
 "log"
 "sync"
 "time"

 "golang.org/x/sync/semaphore"
)

func main() {
 sema := semaphore.NewWeighted(3)
 jobs := make(chan int, 10)
 go func() {
  for i := 0; i < 8; i++ {
   jobs <- i + 1
  }
  close(jobs)
 }()
 ctx := context.Background()
 var wg sync.WaitGroup
 for j := range jobs {
  wg.Add(1)
  sema.Acquire(ctx, 1)
  go func(j int) {
   defer sema.Release(1)
   log.Printf("handle job: %d\n", j)
   time.Sleep(2 * time.Second)
   wg.Done()
  }(j)
 }
 wg.Wait()
}

e.g.

創建和 CPU 核數一樣多的 Worker,讓它們去處理一個 4 倍數量的整數 slice。每個 Worker 一次只能處理一個整數,處理完之後,才能處理下一個。當然,這個問題的解決方案有很多種,這一次我們使用信號量,代碼如下:

package main

import (
 "context"
 "fmt"
 "log"
 "runtime"
 "time"

 "golang.org/x/sync/semaphore"
)

var (
 maxWorkers = runtime.GOMAXPROCS(0)                    // worker數量
 sema       = semaphore.NewWeighted(int64(maxWorkers)) // 信號量
 task       = make([]int, maxWorkers*4)                // 任務數,是worker的四倍
)

func main() {
 ctx := context.Background()

 for i := range task {
  // 如果沒有worker可用,會阻塞在這裏,直到某個worker被釋放
  sema.Acquire(ctx, 1)

  // 啓動worker goroutine
  go func(i int) {
   defer sema.Release(1)
   time.Sleep(100 * time.Millisecond) // 模擬一個耗時操作
   task[i] = i + 1
  }(i)
 }

 // 請求所有的worker,這樣能確保前面的worker都執行完
 if err := sema.Acquire(ctx, int64(maxWorkers)); err != nil {
  log.Printf("獲取所有的worker失敗: %v", err)
 }

 fmt.Println(task)
}

如果在實際應用中,你想等所有的 Worker 都執行完,就可以獲取最大計數值的信號量。

相比 channel 信號量的實現看起來非常簡單,而且也能應對大部分的信號量的場景,爲什麼官方擴展庫的信號量的實現不採用這種方法呢?官方的實現方式有這樣一個功能:它可以一次請求多個資源,這是通過 channel 實現的信號量所不具備的。

本文由 Readfog 進行 AMP 轉碼,版權歸原作者所有。
來源https://mp.weixin.qq.com/s/bQIf-7gxH8KflZneKUPz_Q