golang开启多协程处理channel数据(实现协程守护,动态控制协程数)
work.go 中实现, 可根据实际需要微调。比如更改消费 channel中元素类型。
package work
import (
"context"
"errors"
"fmt"
"log"
"sync"
"sync/atomic"
"time"
)
var (
ErrWorkerStopped = errors.New("worker stopped")
ErrMessageChanClosed = errors.New("message channel closed")
)
type workerControl struct {
stopChan chan struct{}
done chan struct{}
}
type WorkerFunc func(ctx context.Context, msg string) error
type WorkerManager struct {
workerFunc WorkerFunc
messages chan string
restartLimit int
targetConcurrency int32 // atomic
maxConcurrency int
mu sync.Mutex
workers map[int]*workerControl
nextID int
stopAll chan struct{}
workerExited chan int
supervisorClose chan struct{}
}
func NewWorkerManager(
fn WorkerFunc,
msgChan chan string,
restartLimit int,
maxConcurrency int,
) *WorkerManager {
wm := &WorkerManager{
workerFunc: fn,
messages: msgChan,
restartLimit: restartLimit,
maxConcurrency: maxConcurrency,
workers: make(map[int]*workerControl),
stopAll: make(chan struct{}),
workerExited: make(chan int, 100),
supervisorClose: make(chan struct{}),
}
go wm.supervise()
return wm
}
func (wm *WorkerManager) SetConcurrency(target int) {
if target < 0 {
target = 0
}
if wm.maxConcurrency > 0 && target > wm.maxConcurrency {
target = wm.maxConcurrency
}
atomic.StoreInt32(&wm.targetConcurrency, int32(target))
wm.mu.Lock()
defer wm.mu.Unlock()
current := len(wm.workers)
if target == current {
return
}
if target > current {
for i := 0; i < target-current; i++ {
wm.spawnWorkerLocked()
}
return
}
// 安全缩容逻辑
toStop := current - target
idsToStop := make([]int, 0, toStop)
for id := range wm.workers {
if len(idsToStop) >= toStop {
break
}
idsToStop = append(idsToStop, id)
}
// 异步停止worker
go func(ids []int) {
for _, id := range ids {
wm.mu.Lock()
wc, exists := wm.workers[id]
if !exists {
wm.mu.Unlock()
continue
}
delete(wm.workers, id)
wm.mu.Unlock()
close(wc.stopChan)
<-wc.done
log.Printf("Worker %d stopped (scale down)", id)
}
}(idsToStop)
}
func (wm *WorkerManager) spawnWorkerLocked() {
id := wm.nextID
wm.nextID++
wc := &workerControl{
stopChan: make(chan struct{}),
done: make(chan struct{}),
}
wm.workers[id] = wc
go func() {
defer func() {
if r := recover(); r != nil {
log.Printf("Worker %d final panic: %v", id, r)
}
wm.workerExited <- id
close(wc.done)
}()
wm.runWorker(id, wc)
}()
log.Printf("Worker %d started", id)
}
func (wm *WorkerManager) runWorker(id int, wc *workerControl) {
restartCount := 0
for {
ctx, cancel := context.WithCancel(context.Background())
// 监控停止信号
stopSignal := make(chan struct{})
go func() {
select {
case <-wc.stopChan:
log.Printf("Worker %d received stop signal", id)
case <-wm.stopAll:
log.Printf("Worker %d received global stop", id)
}
close(stopSignal)
cancel()
}()
// 消息处理循环
func() {
defer cancel()
for {
select {
case msg, ok := <-wm.messages:
if !ok {
log.Printf("Worker %d: message channel closed", id)
return
}
// 处理消息(带超时控制)
processCtx, processCancel := context.WithCancel(ctx)
errCh := make(chan error, 1)
go func() {
defer func() {
if r := recover(); r != nil {
errCh <- fmt.Errorf("panic: %v", r)
}
}()
errCh <- wm.workerFunc(processCtx, msg)
}()
select {
case err := <-errCh:
processCancel()
if err != nil {
log.Printf("Worker %d process error: %v", id, err)
return // 退出当前消息循环,可能触发重启
}
case <-stopSignal:
processCancel()
<-errCh // 等待处理结束
return
case <-ctx.Done():
processCancel()
<-errCh // 等待处理结束
return
}
case <-stopSignal:
return
}
}
}()
// 检查是否收到停止信号
select {
case <-stopSignal:
return // 完全退出worker
default:
}
// 处理重启逻辑
restartCount++
if wm.restartLimit > 0 && restartCount >= wm.restartLimit {
log.Printf("Worker %d restart limit reached", id)
return
}
log.Printf("Worker %d restarting...", id)
select {
case <-time.After(1 * time.Second):
case <-wc.stopChan:
return
case <-wm.stopAll:
return
}
}
}
func (wm *WorkerManager) supervise() {
defer func() {
if r := recover(); r != nil {
log.Printf("supervisor panic: %v", r)
}
log.Println("Supervisor exited")
}()
for {
select {
case id := <-wm.workerExited:
wm.mu.Lock()
delete(wm.workers, id)
target := int(atomic.LoadInt32(&wm.targetConcurrency))
if len(wm.workers) < target {
wm.spawnWorkerLocked()
}
wm.mu.Unlock()
case <-wm.supervisorClose:
return
}
}
}
func (wm *WorkerManager) Stop() {
// 1. 关闭监控协程
close(wm.supervisorClose)
// 2. 关闭所有worker
wm.mu.Lock()
workers := make(map[int]*workerControl, len(wm.workers))
for id, wc := range wm.workers {
workers[id] = wc
}
wm.workers = make(map[int]*workerControl)
wm.mu.Unlock()
// 3. 发送停止信号
close(wm.stopAll)
// 4. 等待所有worker退出
timeout := time.After(5 * time.Second)
for id, wc := range workers {
select {
case <-wc.done:
log.Printf("Worker %d stopped", id)
case <-timeout:
log.Printf("Timeout waiting for worker %d", id)
return
}
}
}
func (wm *WorkerManager) ActiveWorkers() int {
wm.mu.Lock()
defer wm.mu.Unlock()
return len(wm.workers)
}
测试代码:
package main
import (
"context"
"fmt"
"log"
"math/rand"
"os"
"os/signal"
"syscall"
"time"
"work" // 假设worker包路径为work
)
func main() {
// 1. 初始化消息通道和工作管理器
msgChan := make(chan string, 100)
totalMessages := 20
processedCount := atomic.Int32{}
// 定义worker函数(模拟10%的随机失败)
workerFunc := func(ctx context.Context, msg string) error {
// 模拟10%的随机panic
if rand.Intn(10) == 0 {
panic("random worker panic")
}
// 模拟处理时间(100-300ms随机)
delay := time.Duration(100+rand.Intn(200)) * time.Millisecond
select {
case <-time.After(delay):
log.Printf("Processed: %s", msg)
processedCount.Add(1)
return nil
case <-ctx.Done():
return ctx.Err()
}
}
// 创建WorkerManager(最大重启3次,最大并发5)
wm := work.NewWorkerManager(workerFunc, msgChan, 3, 5)
// 2. 启动初始worker(3个)
wm.SetConcurrency(3)
log.Println("Started with 3 workers")
// 3. 生产消息的goroutine
go func() {
for i := 0; i < totalMessages; i++ {
msg := fmt.Sprintf("msg-%d", i)
select {
case msgChan <- msg:
log.Printf("Produced: %s", msg)
case <-time.After(100 * time.Millisecond):
log.Println("Message queue full, waiting...")
i-- // 重试当前消息
}
time.Sleep(50 * time.Millisecond)
}
log.Println("Finished producing messages")
// close(msgChan) // 测试时不自动关闭,由Stop处理
}()
// 4. 动态调整并发数的演示
go func() {
time.Sleep(1 * time.Second)
wm.SetConcurrency(5) // 扩容到5个worker
log.Println("Scaled up to 5 workers")
time.Sleep(2 * time.Second)
wm.SetConcurrency(2) // 缩容到2个worker
log.Println("Scaled down to 2 workers")
}()
// 5. 监控协程
go func() {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
log.Printf(
"[Monitor] Active workers: %d, Processed: %d/%d",
wm.ActiveWorkers(),
processedCount.Load(),
totalMessages,
)
}
}
}()
// 6. 优雅关闭处理
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
<-sigChan
log.Println("\nShutting down...")
// 7. 停止worker并等待
start := time.Now()
wm.Stop()
close(msgChan) // 确保通道关闭
log.Printf("Final processed count: %d/%d", processedCount.Load(), totalMessages)
log.Printf("Shutdown completed in %v", time.Since(start))
}

浙公网安备 33010602011771号