实现阻塞读且并发安全的map

实现阻塞读且并发安全的 map

需要实现以下接口

type sp interface {
    // 存入k-v, 此方法不会阻塞, 时刻都可以立即执行并返回.
    Put(key string, val interface{})  
    // 读取一个 k-v, 如果 key 不存在则阻塞, 等待 key 存在或者超时.
    Get(key string, timeout time.Duration) (interface{}, bool)
}

看到这个接口要求, 第一个反应就是用 channel 来进行信息的传递, 大概的思路差不多就是: 一个 Put 操作会向所有监听当前 key 的 Get 操作的 channel 中发送数据. 好了, 有了一个大概的思路, 我们开始实现这个接口.

首先定义 map 结构体和 entry 的结构体:

type Map struct {
	mux     sync.Mutex        // 保护 hashmap
	hashmap map[string]*entry // hashmap
}

type entry struct {
	mux sync.Mutex      // 用来保护 entry
	valid bool          // 结果是否有效, true 有效, false 无效
	value interface{}   // 结果
	chans []chan Result // Get 操作监听的通道
}

type Result struct {
	val  interface{} // 读取的结果
	read bool        // 读到了: true; 没读到: false
}

这里使用了读写锁来保护 Map 和 entry, 用来保证并发安全. chans 就是 Get 操作监听的通道, 这里使用了一个 valid 标志位来判断这个值是否有效, 为什么这么设计后文会讲.

我们先考虑 Put 的实现方法:

  • Put 首先如果 key 对应的 entry 不存在, 那么就需要 new 一个实例. 如果已经存在了, 需要将 value 和 valid 更新.
  • 接着将这个值发送给所有的 Get 监听通道, 发送一个关闭一个, 最后清空监听的通道.

来看看具体的实现:

func (this *Map) Put(key string, value interface{}) {
	go func() {
		// 1
		this.mux.Lock()
		en, exist := this.hashmap[key]
		if !exist {
			en = &entry{
				mux:   sync.Mutex{},
				chans: make([]chan Result, 0),
			}
			this.hashmap[key] = en
		}
		this.mux.Unlock()

		// 2
		en.mux.Lock()
		en.value = value // 赋值
		en.valid = true  // 结果有效
		// 向阻塞接收方发送消息
		for _, ch := range en.chans {
			ch <- Result{val: value, read: true} // 返回读取的信息
			close(ch)                            // 从发送方将 channel 全部关闭
		}
		en.chans = make([]chan Result, 0) // 发送完消息后清除掉所有的 channel
		en.mux.Unlock()
	}()
}

我们接着考虑 Get 的实现方法:

  • 首先 Get 方法需要查询 hashmap 中是否存在 key 对应的 entry, 如果不存在那么 new 一个实例.
  • 接着查看 entry 中的值是否有效, 如果有效那么直接返回.
  • 如果无效, 那么就将自己监听结果的 channel 注册, 然后监听两个 channel (timer 和 监听结果的) 即可.

来看看具体实现:

func (this *Map) Get(key string, timeout time.Duration) (interface{}, bool) {
	// 1
	this.mux.Lock()
	en, exist := this.hashmap[key]
	if !exist {
		en = &entry{
			mux:   sync.Mutex{},
			chans: make([]chan Result, 0),
		}
		this.hashmap[key] = en
	}
	this.mux.Unlock()

	// 2
	en.mux.Lock()
	if en.valid {
		en.mux.Unlock()
		return en.value, true
	}

	// 3
	readChan := make(chan Result, 1)      // 设置缓冲为1的 channel
	en.chans = append(en.chans, readChan) // 增加一个异步读取队列
	en.mux.Unlock()

        // 4
	timer := time.NewTimer(timeout)
	select {
	case <-timer.C:
		return nil, false
	case result := <-readChan:
		return result.val, result.read
	}
}

测试代码:

func NewSP() SP {
	return &Map{
		mux:     sync.RWMutex{},
		hashmap: make(map[string]*entry),
	}
}

func TestMap(t *testing.T) {
	testCases := []func(){
		func() {
			wg := &sync.WaitGroup{}
			sp := NewSP()
			wg.Add(1)
			go func() {
				sp.Put("key1", "value1")
				time.Sleep(3 * time.Second)
				sp.Put("key2", "value2")
				wg.Done()
			}()
			wg.Add(1)
			go func() {
				res1, ok1 := sp.Get("key1", 1*time.Second)
				if !ok1 {
					t.Fail()
				}
				fmt.Println("res1: ", res1)
				res2, ok2 := sp.Get("key2", 1*time.Second)
				if ok2 {
					t.Fail()
				}
				fmt.Println("res2: ", res2)
				res3, ok3 := sp.Get("key2", 5*time.Second)
				if !ok3 {
					t.Fail()
				}
				fmt.Println("res3: ", res3)
				res4, ok4 := sp.Get("key1", 5*time.Second)
				if !ok4 {
					t.Fail()
				}
				fmt.Println("res4: ", res4)
				wg.Done()
			}()
			wg.Wait()
		},
	}

	for _, f := range testCases {
		f()
	}
}

测试结果:

res1:  value1
res2:  <nil>
res3:  value2
res4:  value1
posted @ 2023-03-03 10:02  kohn  阅读(218)  评论(0)    收藏  举报