【Go-多线程】Golang的channel实现消息任务的批量处理

【Go-多线程】Golang的channel实现消息的批量处理。当消息量特别大时,使用kafka之类的message queue是首选,但这是更加轻量的方案

channelx.go

//这个方案需要实现以下几点:
//1.消息聚合后处理(最大条数为BatchSize),核心:
//(1)带buffer的channel相当于一个FIFO的队列
//(2)多个常驻的goroutine来提高并发
//(3)goroutine之间是并行的,但每个goroutine内是串行的,所以对batch操作是不用加锁的。
//2.延迟处理(延迟时间为LingerTime)
//  注意:为什么使用time.Timer而不是time.After,是因为time.After在for select中使用时,会发生内存泄露。
//3.自定义错误处理
//4.并发处理

package channelx

import (
	"runtime"
	"sync"
	"time"
)

// Represents the aggregator
type Aggregator struct {
	option         AggregatorOption
	wg             *sync.WaitGroup
	quit           chan struct{}
	eventQueue     chan interface{}
	batchProcessor BatchProcessFunc
}

// Represents the aggregator option
type AggregatorOption struct {
	BatchSize         int
	Workers           int
	ChannelBufferSize int
	LingerTime        time.Duration
	ErrorHandler      ErrorHandlerFunc
	Logger            Logger
}

// the func to batch process items
type BatchProcessFunc func([]interface{}) error

// the func to set option for aggregator
type SetAggregatorOptionFunc func(option AggregatorOption) AggregatorOption

// the func to handle error
type ErrorHandlerFunc func(err error, items []interface{}, batchProcessFunc BatchProcessFunc, aggregator *Aggregator)

// Creates a new aggregator
func NewAggregator(batchProcessor BatchProcessFunc, optionFuncs ...SetAggregatorOptionFunc) *Aggregator {
	option := AggregatorOption{
		BatchSize:  8,
		Workers:    runtime.NumCPU(),
		LingerTime: 1 * time.Minute,
	}

	for _, optionFunc := range optionFuncs {
		option = optionFunc(option)
	}

	if option.ChannelBufferSize <= option.Workers {
		option.ChannelBufferSize = option.Workers
	}

	return &Aggregator{
		eventQueue:     make(chan interface{}, option.ChannelBufferSize),
		option:         option,
		quit:           make(chan struct{}),
		wg:             new(sync.WaitGroup),
		batchProcessor: batchProcessor,
	}
}

// Try enqueue an item, and it is non-blocked
func (agt *Aggregator) TryEnqueue(item interface{}) bool {
	select {
	case agt.eventQueue <- item:
		return true
	default:
		if agt.option.Logger != nil {
			agt.option.Logger.Warnc("Aggregator", nil, "Event queue is full and try reschedule")
		}

		runtime.Gosched()

		select {
		case agt.eventQueue <- item:
			return true
		default:
			if agt.option.Logger != nil {
				agt.option.Logger.Warnc("Aggregator", nil, "Event queue is still full and %+v is skipped.", item)
			}
			return false
		}
	}
}

// Enqueue an item, will be blocked if the queue is full
func (agt *Aggregator) Enqueue(item interface{}) {
	agt.eventQueue <- item
}

// Start the aggregator
func (agt *Aggregator) Start() {
	for i := 0; i < agt.option.Workers; i++ {
		index := i
		go agt.work(index)
	}
}

// Stop the aggregator
func (agt *Aggregator) Stop() {
	close(agt.quit)
	agt.wg.Wait()
}

// Stop the aggregator safely, the difference with Stop is it guarantees no item is missed during stop
func (agt *Aggregator) SafeStop() {
	if len(agt.eventQueue) == 0 {
		close(agt.quit)
	} else {
		ticker := time.NewTicker(50 * time.Millisecond)
		for range ticker.C {
			if len(agt.eventQueue) == 0 {
				close(agt.quit)
				break
			}
		}
		ticker.Stop()
	}
	agt.wg.Wait()
}

func (agt *Aggregator) work(index int) {
	defer func() {
		if r := recover(); r != nil {
			if agt.option.Logger != nil {
				agt.option.Logger.Errorc("Aggregator", nil, "recover worker as bad thing happens %+v", r)
			}

			agt.work(index)
		}
	}()

	agt.wg.Add(1)
	defer agt.wg.Done()

	batch := make([]interface{}, 0, agt.option.BatchSize)
	lingerTimer := time.NewTimer(0)
	if !lingerTimer.Stop() {
		<-lingerTimer.C
	}
	defer lingerTimer.Stop()

loop:
	for {
		select {
		case req := <-agt.eventQueue:
			batch = append(batch, req)

			batchSize := len(batch)
			if batchSize < agt.option.BatchSize {
				if batchSize == 1 {
					lingerTimer.Reset(agt.option.LingerTime)
				}
				break
			}

			agt.batchProcess(batch)

			if !lingerTimer.Stop() {
				<-lingerTimer.C
			}
			batch = make([]interface{}, 0, agt.option.BatchSize)
		case <-lingerTimer.C:
			if len(batch) == 0 {
				break
			}

			agt.batchProcess(batch)
			batch = make([]interface{}, 0, agt.option.BatchSize)
		case <-agt.quit:
			if len(batch) != 0 {
				agt.batchProcess(batch)
			}

			break loop
		}
	}
}

func (agt *Aggregator) batchProcess(items []interface{}) {
	agt.wg.Add(1)
	defer agt.wg.Done()
	if err := agt.batchProcessor(items); err != nil {
		if agt.option.Logger != nil {
			agt.option.Logger.Errorc("Aggregator", err, "error happens")
		}

		if agt.option.ErrorHandler != nil {
			go agt.option.ErrorHandler(err, items, agt.batchProcessor, agt)
		} else if agt.option.Logger != nil {
			agt.option.Logger.Errorc("Aggregator", err, "error happens in batchProcess and is skipped")
		}
	} else if agt.option.Logger != nil {
		agt.option.Logger.Infoc("Aggregator", "%d items have been sent.", len(items))
	}
}

测试

//aggregator库的使用示例
package channelx

import (
	"errors"
	"sync"
	"testing"
	"time"
)

func TestAggregator_Basic(t *testing.T) {
	batchProcess := func(items []interface{}) error {
		t.Logf("handler %d items", len(items))
		return nil
	}

	aggregator := NewAggregator(batchProcess)

	aggregator.Start()

	for i := 0; i < 1000; i++ {
		aggregator.TryEnqueue(i)
	}

	aggregator.SafeStop()
}

func TestAggregator_Complex(t *testing.T) {
	wg := &sync.WaitGroup{}
	wg.Add(100)

	batchProcess := func(items []interface{}) error {
		defer wg.Add(-len(items))
		time.Sleep(20 * time.Millisecond)
		if len(items) != 4 {
			return errors.New("len(items) != 4")
		}
		return nil
	}

	errorHandler := func(err error, items []interface{}, batchProcessFunc BatchProcessFunc, aggregator *Aggregator) {
		if err == nil {
			t.FailNow()
		}
		t.Logf("Receive error, item size is %d", len(items))
	}

	aggregator := NewAggregator(batchProcess, func(option AggregatorOption) AggregatorOption {
		option.BatchSize = 4
		option.Workers = 2
		option.LingerTime = 8 * time.Millisecond
		option.Logger = NewConsoleLogger()
		option.ErrorHandler = errorHandler
		return option
	})

	aggregator.Start()

	for i := 0; i < 100; i++ {
		for !aggregator.TryEnqueue(i) {
			time.Sleep(10 * time.Millisecond)
		}
	}

	aggregator.SafeStop()
	wg.Wait()
}

func TestAggregator_LingerTimeOut(t *testing.T) {
	wg := &sync.WaitGroup{}
	wg.Add(100)

	batchProcess := func(items []interface{}) error {
		defer wg.Add(-len(items))
		if len(items) != 4 {
			t.Log("linger time out")
		}
		return nil
	}

	aggregator := NewAggregator(batchProcess, func(option AggregatorOption) AggregatorOption {
		option.BatchSize = 4
		option.Workers = 1
		option.LingerTime = 150 * time.Millisecond
		option.Logger = NewConsoleLogger()
		return option
	})

	aggregator.Start()

	for i := 0; i < 100; i++ {
		aggregator.TryEnqueue(i)
		time.Sleep(100 * time.Millisecond)
	}

	aggregator.SafeStop()
	wg.Wait()
}
posted @ 2024-06-05 17:42  朝阳1  阅读(112)  评论(0)    收藏  举报