golang开启多协程处理channel数据(实现协程守护,动态控制协程数)

work.go 中实现, 可根据实际需要微调。比如更改消费 channel中元素类型。

package work

import (
	"context"
	"errors"
	"fmt"
	"log"
	"sync"
	"sync/atomic"
	"time"
)

var (
	ErrWorkerStopped     = errors.New("worker stopped")
	ErrMessageChanClosed = errors.New("message channel closed")
)

type workerControl struct {
	stopChan chan struct{}
	done     chan struct{}
}

type WorkerFunc func(ctx context.Context, msg string) error

type WorkerManager struct {
	workerFunc        WorkerFunc
	messages          chan string
	restartLimit      int
	targetConcurrency int32 // atomic
	maxConcurrency    int

	mu              sync.Mutex
	workers         map[int]*workerControl
	nextID          int
	stopAll         chan struct{}
	workerExited    chan int
	supervisorClose chan struct{}
}

func NewWorkerManager(
	fn WorkerFunc,
	msgChan chan string,
	restartLimit int,
	maxConcurrency int,
) *WorkerManager {
	wm := &WorkerManager{
		workerFunc:      fn,
		messages:        msgChan,
		restartLimit:    restartLimit,
		maxConcurrency:  maxConcurrency,
		workers:         make(map[int]*workerControl),
		stopAll:         make(chan struct{}),
		workerExited:    make(chan int, 100),
		supervisorClose: make(chan struct{}),
	}

	go wm.supervise()
	return wm
}

func (wm *WorkerManager) SetConcurrency(target int) {
	if target < 0 {
		target = 0
	}
	if wm.maxConcurrency > 0 && target > wm.maxConcurrency {
		target = wm.maxConcurrency
	}

	atomic.StoreInt32(&wm.targetConcurrency, int32(target))

	wm.mu.Lock()
	defer wm.mu.Unlock()

	current := len(wm.workers)
	if target == current {
		return
	}

	if target > current {
		for i := 0; i < target-current; i++ {
			wm.spawnWorkerLocked()
		}
		return
	}

	// 安全缩容逻辑
	toStop := current - target
	idsToStop := make([]int, 0, toStop)
	for id := range wm.workers {
		if len(idsToStop) >= toStop {
			break
		}
		idsToStop = append(idsToStop, id)
	}

	// 异步停止worker
	go func(ids []int) {
		for _, id := range ids {
			wm.mu.Lock()
			wc, exists := wm.workers[id]
			if !exists {
				wm.mu.Unlock()
				continue
			}
			delete(wm.workers, id)
			wm.mu.Unlock()

			close(wc.stopChan)
			<-wc.done
			log.Printf("Worker %d stopped (scale down)", id)
		}
	}(idsToStop)
}

func (wm *WorkerManager) spawnWorkerLocked() {
	id := wm.nextID
	wm.nextID++
	wc := &workerControl{
		stopChan: make(chan struct{}),
		done:     make(chan struct{}),
	}
	wm.workers[id] = wc

	go func() {
		defer func() {
			if r := recover(); r != nil {
				log.Printf("Worker %d final panic: %v", id, r)
			}
			wm.workerExited <- id
			close(wc.done)
		}()
		wm.runWorker(id, wc)
	}()

	log.Printf("Worker %d started", id)
}

func (wm *WorkerManager) runWorker(id int, wc *workerControl) {
	restartCount := 0

	for {
		ctx, cancel := context.WithCancel(context.Background())

		// 监控停止信号
		stopSignal := make(chan struct{})
		go func() {
			select {
			case <-wc.stopChan:
				log.Printf("Worker %d received stop signal", id)
			case <-wm.stopAll:
				log.Printf("Worker %d received global stop", id)
			}
			close(stopSignal)
			cancel()
		}()

		// 消息处理循环
		func() {
			defer cancel()

			for {
				select {
				case msg, ok := <-wm.messages:
					if !ok {
						log.Printf("Worker %d: message channel closed", id)
						return
					}

					// 处理消息(带超时控制)
					processCtx, processCancel := context.WithCancel(ctx)
					errCh := make(chan error, 1)
					go func() {
						defer func() {
							if r := recover(); r != nil {
								errCh <- fmt.Errorf("panic: %v", r)
							}
						}()
						errCh <- wm.workerFunc(processCtx, msg)
					}()

					select {
					case err := <-errCh:
						processCancel()
						if err != nil {
							log.Printf("Worker %d process error: %v", id, err)
							return // 退出当前消息循环,可能触发重启
						}

					case <-stopSignal:
						processCancel()
						<-errCh // 等待处理结束
						return

					case <-ctx.Done():
						processCancel()
						<-errCh // 等待处理结束
						return
					}

				case <-stopSignal:
					return
				}
			}
		}()

		// 检查是否收到停止信号
		select {
		case <-stopSignal:
			return // 完全退出worker
		default:
		}

		// 处理重启逻辑
		restartCount++
		if wm.restartLimit > 0 && restartCount >= wm.restartLimit {
			log.Printf("Worker %d restart limit reached", id)
			return
		}

		log.Printf("Worker %d restarting...", id)
		select {
		case <-time.After(1 * time.Second):
		case <-wc.stopChan:
			return
		case <-wm.stopAll:
			return
		}
	}
}

func (wm *WorkerManager) supervise() {
	defer func() {
		if r := recover(); r != nil {
			log.Printf("supervisor panic: %v", r)
		}
		log.Println("Supervisor exited")
	}()

	for {
		select {
		case id := <-wm.workerExited:
			wm.mu.Lock()
			delete(wm.workers, id)
			target := int(atomic.LoadInt32(&wm.targetConcurrency))
			if len(wm.workers) < target {
				wm.spawnWorkerLocked()
			}
			wm.mu.Unlock()

		case <-wm.supervisorClose:
			return
		}
	}
}

func (wm *WorkerManager) Stop() {
	// 1. 关闭监控协程
	close(wm.supervisorClose)

	// 2. 关闭所有worker
	wm.mu.Lock()
	workers := make(map[int]*workerControl, len(wm.workers))
	for id, wc := range wm.workers {
		workers[id] = wc
	}
	wm.workers = make(map[int]*workerControl)
	wm.mu.Unlock()

	// 3. 发送停止信号
	close(wm.stopAll)

	// 4. 等待所有worker退出
	timeout := time.After(5 * time.Second)
	for id, wc := range workers {
		select {
		case <-wc.done:
			log.Printf("Worker %d stopped", id)
		case <-timeout:
			log.Printf("Timeout waiting for worker %d", id)
			return
		}
	}
}

func (wm *WorkerManager) ActiveWorkers() int {
	wm.mu.Lock()
	defer wm.mu.Unlock()
	return len(wm.workers)
}

测试代码:

package main

import (
	"context"
	"fmt"
	"log"
	"math/rand"
	"os"
	"os/signal"
	"syscall"
	"time"
	"work" // 假设worker包路径为work
)

func main() {
	// 1. 初始化消息通道和工作管理器
	msgChan := make(chan string, 100)
	totalMessages := 20
	processedCount := atomic.Int32{}

	// 定义worker函数(模拟10%的随机失败)
	workerFunc := func(ctx context.Context, msg string) error {
		// 模拟10%的随机panic
		if rand.Intn(10) == 0 {
			panic("random worker panic")
		}

		// 模拟处理时间(100-300ms随机)
		delay := time.Duration(100+rand.Intn(200)) * time.Millisecond
		select {
		case <-time.After(delay):
			log.Printf("Processed: %s", msg)
			processedCount.Add(1)
			return nil
		case <-ctx.Done():
			return ctx.Err()
		}
	}

	// 创建WorkerManager(最大重启3次,最大并发5)
	wm := work.NewWorkerManager(workerFunc, msgChan, 3, 5)

	// 2. 启动初始worker(3个)
	wm.SetConcurrency(3)
	log.Println("Started with 3 workers")

	// 3. 生产消息的goroutine
	go func() {
		for i := 0; i < totalMessages; i++ {
			msg := fmt.Sprintf("msg-%d", i)
			select {
			case msgChan <- msg:
				log.Printf("Produced: %s", msg)
			case <-time.After(100 * time.Millisecond):
				log.Println("Message queue full, waiting...")
				i-- // 重试当前消息
			}
			time.Sleep(50 * time.Millisecond)
		}
		log.Println("Finished producing messages")
		// close(msgChan) // 测试时不自动关闭,由Stop处理
	}()

	// 4. 动态调整并发数的演示
	go func() {
		time.Sleep(1 * time.Second)
		wm.SetConcurrency(5) // 扩容到5个worker
		log.Println("Scaled up to 5 workers")

		time.Sleep(2 * time.Second)
		wm.SetConcurrency(2) // 缩容到2个worker
		log.Println("Scaled down to 2 workers")
	}()

	// 5. 监控协程
	go func() {
		ticker := time.NewTicker(1 * time.Second)
		defer ticker.Stop()
		
		for {
			select {
			case <-ticker.C:
				log.Printf(
					"[Monitor] Active workers: %d, Processed: %d/%d", 
					wm.ActiveWorkers(), 
					processedCount.Load(), 
					totalMessages,
				)
			}
		}
	}()

	// 6. 优雅关闭处理
	sigChan := make(chan os.Signal, 1)
	signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
	<-sigChan

	log.Println("\nShutting down...")
	
	// 7. 停止worker并等待
	start := time.Now()
	wm.Stop()
	close(msgChan) // 确保通道关闭
	
	log.Printf("Final processed count: %d/%d", processedCount.Load(), totalMessages)
	log.Printf("Shutdown completed in %v", time.Since(start))
}
posted @ 2025-07-08 19:32  熊先生不开玩笑  阅读(17)  评论(0)    收藏  举报