package main
import (
"errors"
"math/rand"
)
type CompareResult int
// TODO 目前尚未实现并发安全,待办
const (
// 跳表节点key的比较回调函数返回值 -1:left<right;0:left==right;1:left>right
LEFT_LT_RIGHT CompareResult = -1
LEFT_EQ_RIGHT CompareResult = 0
LEFT_GT_RIGHT CompareResult = 1
SKIP_LIST_MAX_LEVEL = 32 // 跳表首节点最高层数
SKIP_LIST_INIT_LEVEL = 3 // 跳表初始化时首节点的可用层级数
P = 0.25 // Redis的跳表使用的概率
)
// 检查键的合法性
func NilCheck(element interface{}) error {
if element == nil {
return errors.New("element is nil")
}
return nil
}
type ISkipList interface {
IsEmpty() bool
Size() int
Get(key interface{}) (interface{}, error)
Set(key interface{}, value interface{}) (interface{}, error)
Remove(key interface{}) (interface{}, error)
}
type SkipList struct {
compareCallBack func(left interface{}, right interface{}) CompareResult
header skipListNode // 跳表首节点,不存放任何值
size int // 跳表元素个数
level int // 当前跳表层级
}
type skipListNode struct {
key interface{}
value interface{}
next []*skipListNode // Next元素个数最多32(参考redis)
}
func NewSkipList(compareCallBack func(left interface{}, right interface{}) CompareResult) (ISkipList, error) {
if err := NilCheck(compareCallBack); err != nil {
return nil, err
}
skipList := new(SkipList)
skipList.compareCallBack = compareCallBack //设置回调函数
skipList.header.next = make([]*skipListNode, SKIP_LIST_MAX_LEVEL) // 设置跳表层级索引指针
return skipList, nil
}
func (sl *SkipList) IsEmpty() bool {
return sl.size == 0
}
func (sl *SkipList) Size() int {
return sl.size
}
func (sl *SkipList) Get(key interface{}) (interface{}, error) {
if err := NilCheck(sl.compareCallBack); err != nil {
return nil, err
}
if err := NilCheck(key); err != nil {
return nil, err
}
node := sl.header
for level := sl.level - 1; level >= 0; level-- {
for node.next != nil {
if node.next[level] == nil {
break
}
nodeKey := node.next[level].key
if sl.compareCallBack(key, nodeKey) == LEFT_GT_RIGHT {
node = *node.next[level]
} else if sl.compareCallBack(key, nodeKey) == LEFT_EQ_RIGHT {
return node.next[level].value, nil
} else {
break
}
}
}
return nil, nil
}
func (sl *SkipList) Set(key interface{}, value interface{}) (interface{}, error) {
if err := NilCheck(sl.compareCallBack); err != nil {
return nil, err
}
if err := NilCheck(key); err != nil {
return nil, err
}
node := sl.header
prevNodeArr := make([]skipListNode, sl.level) // 用来保存前驱节点的数组
for level := sl.level - 1; level >= 0; level-- {
for node.next != nil {
if node.next[level] == nil {
break
}
nodeKey := node.next[level].key
if sl.compareCallBack(key, nodeKey) == LEFT_GT_RIGHT {
node = *node.next[level]
} else if sl.compareCallBack(key, nodeKey) == LEFT_EQ_RIGHT {
// 节点存在,覆盖掉并返回原来的值就ok
tempValue := node.next[level].value
node.next[level].value = value
return tempValue, nil
} else {
break
}
}
// 保存前驱节点
prevNodeArr[level] = node
}
// 节点不存在,新创建一个节点
newNode := new(skipListNode)
// 节点元素层数定义随机,利用抛硬币来实现
newNode.key = key
newNode.value = value
newLevel := randomLevel()
newNode.next = make([]*skipListNode, newLevel)
// 设置前驱和后继
for i := 0; i < newLevel; i++ {
if i >= sl.level {
// 如果level增加了
sl.header.next[i] = newNode
} else {
newNode.next[i] = prevNodeArr[i].next[i]
prevNodeArr[i].next[i] = newNode
}
}
// 节点数量和跳表层级修改
sl.size++
if newLevel > sl.level {
sl.level = newLevel
}
return value, nil
}
func (sl *SkipList) Remove(key interface{}) (interface{}, error) {
if err := NilCheck(sl.compareCallBack); err != nil {
return nil, err
}
if err := NilCheck(key); err != nil {
return nil, err
}
node := sl.header
prevNodeArr := make([]skipListNode, sl.level) // 用来保存前驱节点的数组
exists := false
for level := sl.level - 1; level >= 0; level-- {
for node.next != nil {
if node.next[level] == nil {
break
}
nodeKey := node.next[level].key
if sl.compareCallBack(key, nodeKey) == LEFT_GT_RIGHT {
node = *node.next[level]
} else if sl.compareCallBack(key, nodeKey) == LEFT_EQ_RIGHT {
// 节点存在
exists = true
} else {
break
}
}
// 保存前驱节点
prevNodeArr[level] = node
}
if !exists {
return nil, nil
}
// 节点存在,删除此节点
deletedNode := node.next[0]
// 设置前驱和后继
for i := 0; i < sl.level; i++ {
prevNodeArr[i].next[i] = deletedNode.next[i]
}
// 节点数量和跳表层级修改
sl.size--
for i := sl.level - 1; i >= 0; i-- {
if sl.header.next[i] != nil {
sl.level = i + 1
break
}
}
return deletedNode.value, nil
}
func randomLevel() int {
level := 1
for rand.Float64() < P && level < SKIP_LIST_MAX_LEVEL {
level++
}
return level
}
func intCallBack(left interface{}, right interface{}) CompareResult {
if left.(int) > right.(int) {
return LEFT_GT_RIGHT
} else if left.(int) == right.(int) {
return LEFT_EQ_RIGHT
} else {
return LEFT_LT_RIGHT
}
}
func main() {
testSkipList, _ := NewSkipList(intCallBack)
testSkipList.Set(1, 12)
testSkipList.Set(2, 23)
testSkipList.Set(4, 34)
testSkipList.Set(5, 45)
testSkipList.Set(7, 45)
testSkipList.Set(9, 45)
testSkipList.Set(-1, 45)
testSkipList.Set(100, 45)
testSkipList.Set(34, 45)
testSkipList.Set(101, 45)
testSkipList.Set(200, 45)
}