Go简单实现B+Tree

接上一篇文章, 使用go代码简单实现B+Tree (https://www.cnblogs.com/hezifan/p/16258914.html)

package main

import (
	"fmt"
	"encoding/json"
	"reflect"
	"strings"
	"crypto/md5"
	"crypto/rand"
	"encoding/base64"
	"encoding/hex"
	"io"
	"sort"
	"os"
	"strconv"
	"regexp"
)

// 数据
type DataCfg struct{
	Id 		int `json:"id"`
	Name 	string `json:"name"`
	Age 	int `json:"age"`
}

// 数据类
type data struct {
	Key string
	DataCfg 
}

func NewData(key string, dataCfg DataCfg) *data {
	dataObj := &data{
		Key:     key,
		DataCfg: dataCfg,
	}
	return dataObj
}

// 获取数据行
func (data *data)ToString() (dataString string) {
	bytes, err := json.Marshal(data.DataCfg)
	if err != nil {
		fmt.Println(err)
		return
	} 
	return string(bytes)
}

// 获取索引值
func (data *data)getIndexVal() interface{} {
	key := strings.ToUpper(data.Key[:1]) + data.Key[1:]
	t := reflect.ValueOf(data).Elem()
	return t.FieldByName(key).Interface()
}

// ==============================================================index====================================================================

// 索引对象类
type index struct {
	indexVal 	int 		 // 索引值
	data 		DataCfg 	 // 索引指向的具体数据,在叶节点中该属性才有值
	left 		string 		 // 当前索引左边的节点指针
	right 		string 		 // 当前索引右边的节点指针
}

func NewIndex(indexVal int, left string, right string, data DataCfg) *index {
	indexObj := &index{
		indexVal: indexVal,
		left: left,
		right: right,
		data: data,
	}
	return indexObj
}

// 获取索引值
func (i *index)getIndexVal() int {
	return i.indexVal
}

// 获取当前索引左边的节点指针
func (i *index)getLeft() string {
	return i.left
}

// 获取当前索引左边的节点指针
func (i *index)getRight() string {
	return i.right
}

// 获取数据
func (i *index)getData() DataCfg {
	return i.data
}

// 修改当前索引左边的节点指针
func (i *index)updateLeft(left string) {
	i.left = left
}

// 修改当前索引左边的节点指针
func (i *index)updateRight(right string) {
	i.right = right
}

// 生成32位md5字串
func GetMd5String(s string) string {
	h := md5.New()
	h.Write([]byte(s))
	return hex.EncodeToString(h.Sum(nil))
}
 
// 生成Guid字串
func UniqueId() string {
	b := make([]byte, 48)
	if _, err := io.ReadFull(rand.Reader, b); err != nil {
		return ""
	}
	return GetMd5String(base64.URLEncoding.EncodeToString(b))
}

// 自定义切片排序
type NewIndexMap []*index

func (n NewIndexMap) Len() int {
    return len(n)
}

// 使用 > 从大到小排序,使用 < 从小到大排序
func (n NewIndexMap) Less(i, j int) bool {
    return n[i].indexVal < n[j].indexVal
}

func (n NewIndexMap) Swap(i, j int) {
    n[i], n[j] = n[j], n[i]
}

// ================================================================btNode==================================================================

// B+树节点
type btNode struct {
	id 			string		 // 节点指针(节点的内存地址)
	parent 		string	 	 // 父节点的指针
	isLeaf 		bool 	 	 // 是否是叶子节点
	indexNum 	int 		 // 当前树的颗粒度(阶), 索引数大于等这个值则分裂
	indexMap 	[]*index	 // 索引对象数组
	next 		string	 	 // 下一个兄弟节点的ID值(该属性仅针对叶子节点)
}

func NewBtNode(isLeaf bool, parent string) *btNode {
	btNodeObj := &btNode{
		id: 	UniqueId(),
		isLeaf:	isLeaf,
		parent: parent,
	}
	return btNodeObj
}

func (b *btNode)getID() string {
	return b.id
}

// 向树节点中添加新的索引对象,添加完成后需要按索引值升序排序
func (b *btNode)addIndex(i *index) {
	b.indexMap = append(b.indexMap, i)
	sort.Sort(NewIndexMap(b.indexMap))
	b.indexNum++
}

// 获取索引对象数组
func (b *btNode)getIndexMap() []*index {
	return b.indexMap
}

// 判断该节点是否已满,当前的索引对象树超过树的阶即为满
func (b *btNode)isFull(order int) bool {
	return b.indexNum >= order
}

// 删除当前节点已经分裂出去的索引
func (b *btNode)deleteMap(endIndex int) {
	b.indexMap = b.indexMap[0: endIndex]
	b.indexNum = len(b.indexMap)
}

// 修改父节点的指针
func (b *btNode)updateParent(id string) {
	b.parent = id
}

// 下一个兄弟节点的ID值
func (b *btNode)setNext(id string) {
	b.next = id
}

func dumpNode(selectNode *btNode, isExit bool) {
	fmt.Printf("id       =====> %v\n", selectNode.id)
	fmt.Printf("parent   =====> %v\n", selectNode.parent)
	fmt.Printf("isLeaf   =====> %v\n", selectNode.isLeaf)
	fmt.Printf("indexNum =====> %v\n", selectNode.indexNum)
	fmt.Printf("next     =====> %v\n", selectNode.next)
	fmt.Println("indexMap =====> ")
	for _, index := range selectNode.indexMap {
		fmt.Println("    index : ")
		fmt.Printf("        indexVal =====> %v\n", index.indexVal)
		fmt.Printf("        left     =====> %v\n", index.left)
		fmt.Printf("        right    =====> %v\n", index.right)
		fmt.Printf("        data     =====> %v\n", index.data)
	}
	if isExit {
		os.Exit(0)
	}
	fmt.Println("--------------------------------------------------------------------")
}

// ==============================================================BPlusTree====================================================================

// B+树
type BPlusTree struct {
	root 	string    	 		  // 根节点ID
	nodeMap map[string]*btNode 	  // 节点池
	order 	int 				  // B+树的阶
}

func NewBPlusTree(order int) *BPlusTree {
	BPlusTreeObj := &BPlusTree{
		order: order,
		nodeMap : make(map[string]*btNode),
	}
	return BPlusTreeObj
}

// 判断是否存在节点
func (bp *BPlusTree)isEmpty() bool {
	return bp.root == ""
}

// 以节点的id为key, 节点对象为value, 保存到节点池中.
func (bp *BPlusTree)storeNode(b *btNode) {
	bp.nodeMap[b.getID()] = b
}

// 获取指定节点
func (bp *BPlusTree)getNodeByID(nodeId string) *btNode {
	return bp.nodeMap[nodeId]
}

// 写入数据
func (bp *BPlusTree)insert(dataObj *data) {
	indexVal := dataObj.getIndexVal().(int)

	if bp.isEmpty() {
		// 树为空,直接创建一个根节点,此节点是叶节点.
		rootNodeObj := NewBtNode(true, "")
		rootNodeObj.addIndex(NewIndex(indexVal, "", "", dataObj.DataCfg))
		bp.storeNode(rootNodeObj)
		bp.root = rootNodeObj.getID()
	} else {
		tmpNode := bp.getNodeByID(bp.root)
		prevNode := tmpNode
		var (
			indexMap []*index
			currentIndexObj, indexObj *index
			i, iLen int
			left bool
		)
		for {
			if tmpNode == nil {
				break
			}
			prevNode = tmpNode
			indexMap = tmpNode.getIndexMap()
			iLen = len(indexMap)
			left = false
			i = 0
			for {
				if i >= iLen {
					break
				}
				indexObj = indexMap[i]
				if indexVal > indexObj.getIndexVal() {
					i++
				} else if indexVal == indexObj.getIndexVal() {
					return
				} else {
					left = true
					break
				}
			}
			if left {
				tmpNode = bp.getNodeByID(indexObj.getLeft())
			} else {
				i--
				currentIndexObj = indexMap[i]
				tmpNode = bp.getNodeByID(currentIndexObj.getRight())
			}
		}
		prevNode.addIndex(NewIndex(indexVal, "", "", dataObj.DataCfg))

		if prevNode.isFull(bp.order)  {
			bp.split(prevNode)
		}
	}
}

// 分裂节点
func (bp *BPlusTree)split(node *btNode) {
	var (
		middle, iLen, i,middleIndexValue				int
		pid, prevRight									string
		parent, newNode, sonLeftNode, sonRightNode 		*btNode
		indexMap										[]*index
		indexObj, currentIndexObj, prevIndexObj			*index
	)
	// 获取中间索引,创建新的索引
	middle = int(node.indexNum / 2)

	pid = node.parent
	// 分裂节点为根节点时,树高度+1,创建新节点作为根节点.
	if pid == "" {
		parent = NewBtNode(false, "")
		bp.storeNode(parent)
		pid = parent.getID()
		// 新节点作为根节点
		bp.root = pid
	}
	parent = bp.getNodeByID(pid)

	newNode = NewBtNode(node.isLeaf, pid)
	bp.storeNode(newNode)

	indexMap = node.getIndexMap()
	iLen = len(indexMap)
	for {
		if i >= iLen {
			break
		}
		indexObj = indexMap[i]
		indexVal := indexObj.getIndexVal()
		if newNode.isLeaf == true {
			if i >= middle {
				newNode.addIndex(NewIndex(indexVal, indexObj.getLeft(), indexObj.getRight(), indexObj.getData()))
			}
		} else {
			if i > middle {
				newNode.addIndex(NewIndex(indexVal, indexObj.getLeft(), indexObj.getRight(), indexObj.getData()))
				// 修改当前索引下节点的父节点
				sonLeftNode = bp.getNodeByID(indexObj.getLeft())
				sonLeftNode.updateParent(newNode.getID())
				sonRightNode = bp.getNodeByID(indexObj.getRight())
				sonRightNode.updateParent(newNode.getID())
			}
		}
		if i == middle {
			middleIndexValue = indexVal
		}
		i++
	}

	// 原节点的父节点更新为新的父节点(原节点为根节点时,会重新创建根节点,此时原节点的父节点是这个新的根节点)
	node.updateParent(pid)
	// 原节点分裂后,中间索引及之后的索引都被移动到了新节点,所以把移动的索引在原节点中删除
	node.deleteMap(middle);
	// B+树的叶子节点之间形成一个链表,在原节点分裂后,原节点的next指向新节点,新节点的next指向原节点的next
	if (node.isLeaf) {
		newNode.setNext(node.next);
		node.setNext(newNode.getID());
	}

	// 向分裂节点的父节点添加索引对象,该索引对象的索引值是分裂节点的中间索引值,指向的是新创建的树节点和原节点
	parent.addIndex(NewIndex(middleIndexValue, node.getID(), newNode.getID(), DataCfg{}));
	// 调整父节点索引的指针
	indexMap = parent.getIndexMap()
	iLen = len(indexMap)
	i = 0
	for {
		if i >= iLen {
			break
		}
		currentIndexObj = indexMap[i]
		if i > 0 {
			i--
			prevIndexObj = indexMap[i]
			prevRight = prevIndexObj.getRight()
			currentIndexObj.updateLeft(prevRight)
			i++
		} 
		i++
	}

	// 若分裂节点的父节点索引达到上限,需要分裂父节点
	if (parent.isFull(bp.order)) {
		bp.split(parent);
	}
}

// 索引单条查询
func (bp *BPlusTree)find(indexVal int) {
	var (
		tmpNode 						*btNode
		indexMap						[]*index
		currentIndexObj, nextIndexObj 	*index
		i, iLen 						int
		left							bool
	)
	tmpNode = bp.getNodeByID(bp.root)
	for {
		if tmpNode == nil {
			break
		}
		
		indexMap = tmpNode.getIndexMap()
		iLen = len(indexMap)
		i = 0
		left = false
		for {
			if i >= iLen {
				break
			}
			currentIndexObj = indexMap[i]
			if indexVal > currentIndexObj.getIndexVal() {
				i++
			} else if indexVal == currentIndexObj.getIndexVal() {
				if tmpNode.isLeaf == true {
					fmt.Println(currentIndexObj.getData())
					return
				} else {
					i++
				}
			} else {
				left = true
				break
			}
		}
		if left == true {
			tmpNode = bp.getNodeByID(currentIndexObj.getLeft())
		} else {
			i--
			nextIndexObj = indexMap[i]
			tmpNode = bp.getNodeByID(nextIndexObj.getRight())
		}
	}
	fmt.Printf("索引值 [%v] is not exists!\n", indexVal)
}

// 范围查询
func (bp *BPlusTree)rangeFind(start, end int) {
	var (
		tmpNode, startNode 			*btNode
		indexMap					[]*index
		indexObj, currentIndexObj	*index
		i, iLen 					int
		left, endfor				bool
		resultData					[]DataCfg
	)
	tmpNode = bp.getNodeByID(bp.root)
	startNode = tmpNode;
	for {
		if tmpNode == nil {
			break
		}
		startNode = tmpNode
		indexMap = tmpNode.getIndexMap()
		iLen = len(indexMap)
		i = 0
		left = false
		for {
			if i >= iLen {
				break
			}
			indexObj = indexMap[i]
			if start >= indexObj.getIndexVal() {
				i++
			} else {
				left = true
				break
			}
		}
		if left {
			tmpNode = bp.getNodeByID(indexObj.getLeft())
		} else {
			i--
			currentIndexObj = indexMap[i]
			tmpNode = bp.getNodeByID(currentIndexObj.getRight())
		}
	}

	//从定位到的节点,遍历叶节点链表,查询出范围内的记录
	for {
		if startNode == nil {
			break
		}
		indexMap = startNode.getIndexMap()
		iLen = len(indexMap)
		i = 0
		for {
			if i >= iLen {
				break
			}
			indexObj = indexMap[i]
			if indexObj.getIndexVal() > end {
				endfor = true
				break
			}

			if indexObj.getIndexVal() >= start {
				resultData = append(resultData, indexObj.getData())
			}
			i++
		}
		if endfor == true {
			break
		}
		startNode = bp.getNodeByID(startNode.next)
	}
	fmt.Println(resultData)
}

// 遍历当前所有节点
func (bp *BPlusTree)dumpNode() {
	fmt.Printf("root = %v\n", bp.root)
	fmt.Printf("order = %v\n", bp.order)
	fmt.Println("nodeMap : ")
	for id, node :=  range bp.nodeMap {
		fmt.Printf("   id       =====> %v\n", id)
		fmt.Printf("   parent   =====> %v\n", node.parent)
		fmt.Printf("   isLeaf   =====> %v\n", node.isLeaf)
		fmt.Printf("   indexNum =====> %v\n", node.indexNum)
		fmt.Printf("   next     =====> %v\n", node.next)
		fmt.Println("   indexMap =====> ")
		for _, index := range node.indexMap {
			fmt.Println("       index : ")
			fmt.Printf("           indexVal =====> %v\n", index.indexVal)
			fmt.Printf("           left     =====> %v\n", index.left)
			fmt.Printf("           right    =====> %v\n", index.right)
			fmt.Printf("           data     =====> %v\n", index.data)
		}
		fmt.Println("--------------------------------------------------------------------")
	}
}

var bPlusTreeObj *BPlusTree 

func load() {
	bPlusTreeObj = NewBPlusTree(3)
	mapList :=  []DataCfg{
		DataCfg{Id: 2, Name: "hezifan", Age: 18}, 
		DataCfg{Id: 1, Name: "wj", Age: 20}, 
		DataCfg{Id: 3, Name: "hzf", Age: 19}, 
		DataCfg{Id: 4, Name: "hzf", Age: 19}, 
		DataCfg{Id: 6, Name: "hzf", Age: 19},
		DataCfg{Id: 8, Name: "hzf", Age: 19},
		DataCfg{Id: 7, Name: "hzf", Age: 19},
		DataCfg{Id: 10, Name: "hzf", Age: 19},
		DataCfg{Id: 9, Name: "hzf", Age: 19},
		DataCfg{Id: 5, Name: "hzf", Age: 19},
	}

	var dataObj *data
	for _, dataCfg := range mapList {
		dataObj = NewData("id", dataCfg)
		bPlusTreeObj.insert(dataObj)
	}
}

func init() {
	load()
}

func main() {
	var (
		findMsg string 
	)
	// std := make([]byte,1024)
	find := make([]byte,1024)
	reg := regexp.MustCompile("\\s+")
	for {
	// 	fmt.Println("请输入查询方式 : 1, 单条查询; 2, 范围查询")
	// 	num, _ := os.Stdin.Read(std)
	// 	msg = reg.ReplaceAllString(string(std[:num]), "")
	// 	if msg == "1" {
			fmt.Println("已选择单条查询, 请输入查询值:")
			num, _ := os.Stdin.Read(find)
			findMsg = reg.ReplaceAllString(string(find[:num]), "")
			index, _ := strconv.Atoi(findMsg)
			bPlusTreeObj.find(index)
	// 	} else if msg == "2" {
	// 		// TODO...
	// 	} else {
	// 		fmt.Println("请正确输入查询方式!")
	// 	}
	}
	// bPlusTreeObj.rangeFind(6, 7)
}
posted @ 2022-05-17 10:06  吹_神  阅读(126)  评论(0)    收藏  举报