在关系型数据库中存储树形结构之闭包表(Spring Boot 3 + Kotlin)

依赖

  • PostgreSQL 15
  • Spring Boot 3.0.4
  • Spring Data JPA

ID字段因偷懒使用了UUID,可能导致性能问题

实体类

节点

import jakarta.persistence.*
import org.hibernate.annotations.JdbcTypeCode
import org.hibernate.type.SqlTypes
import java.util.Objects

@Entity
@Table(name = "t_node")
data class Node(
    //节点ID,主键
    @Id
    @Column(name = "node_id", length = 100)
    var nodeId: String="",

    //节点名称
    @Column(name="node_name",nullable = false, length = 100)
    var nodeName:String="",

    //节点信息
    @JdbcTypeCode(SqlTypes.JSON)
    @Column(name="metadata",nullable = false,columnDefinition = "jsonb")
    var metadata: Map<String, Any> =HashMap()
){
    override fun equals(other: Any?): Boolean {
        return nodeId==(other as Node).nodeId
    }

    override fun hashCode(): Int {
        return Objects.hash(nodeId)
    }
}

路径

import jakarta.persistence.*

@Entity
@Table(name = "t_tree_path",uniqueConstraints=[UniqueConstraint(columnNames=["anc","des"])])
//主键: (anc,des)
data class TreePath(
    //逻辑主键
    @Id
    var id: String = "",

    //祖先节点编号
    @Column(name = "anc", length = 100)
    var anc: String="",

    //子孙节点编号
    @Column(name = "des", length = 100)
    var des: String="",

    //距离
    @Column(name = "distance", nullable = false)
    var distance: Long = 0
)

import com.example.*
import java.util.*


data class Tree(
    //节点
    var node: Node,
    //子树
    var children: MutableList<Tree> = ArrayList()
) {
    //重写toString方法以方便打印与调试
    override fun toString(): String {
        val indentCount = 2
        val sb = StringBuffer()
        val stack: Stack<Pair<Tree, Int>> = Stack()
        //记录树与缩进树
        val rootTree = Pair(this, 0)
        stack.push(rootTree)
        while (stack.isNotEmpty()) {
            val tree = stack.pop()
            val subTreeList = tree.first.children.map { Pair(it, tree.second + indentCount) }
            for (subTree in subTreeList) {
                stack.push(subTree)
            }
            sb.append(" ".repeat(tree.second))
            sb.append("nodeId=${tree.first.node.nodeId}\n")
        }
        return sb.toString()
    }
}

数据访问对象

NodeDao

import org.springframework.data.jpa.repository.JpaRepository
import org.springframework.stereotype.Repository
import com.example.Node


@Repository
public interface NodeDao : JpaRepository<Node, String> {


}

TreeDao

import org.springframework.data.jpa.repository.JpaRepository
import org.springframework.data.jpa.repository.Modifying
import org.springframework.data.jpa.repository.Query
import org.springframework.data.repository.query.Param
import org.springframework.stereotype.Repository
import java.util.UUID
import java.util.Optional
import com.example.*



@Repository
interface TreePathDao : JpaRepository<TreePath, String> {


    fun findByAncAndDes(anc: String,des:String): Optional<TreePath>

    //删除子树
    @Modifying
    @Query(
        """DELETE 
            FROM TreePath t 
            WHERE t.des IN (SELECT t2.des FROM TreePath t2 WHERE t2.anc=:nodeId)
        """, nativeQuery = false
    )
    fun deleteTree(@Param("nodeId") nodeId: String): Int


    //分离子树
    @Modifying
    @Query(
        """DELETE FROM  
                TreePath t 
                WHERE 
                EXISTS (SELECT 1 FROM TreePath d WHERE d.anc=:nodeId and t.des=d.des) 
                    AND 
                EXISTS   (SELECT 1 FROM TreePath a WHERE a.des=:nodeId AND a.anc!=a.des and t.anc=a.anc)
        """, nativeQuery = false
    )
    fun detach(@Param("nodeId") nodeId: String): Int


    //嫁接子树
    @Modifying
    @Query(
        """INSERT INTO  t_tree_path(id,anc,des,distance)
                SELECT 
                    gen_random_uuid()\:\:text id,
                    super.anc, 
                    sub.des, 
                    super.distance + sub.distance +1 distance
                FROM 
                    t_tree_path super 
                    CROSS JOIN 
                    t_tree_path sub 
                WHERE 
                    super.des=:desNodeId AND sub.anc=:srcNodeId
        """, nativeQuery = true
    )
    fun graft(@Param("srcNodeId") srcNodeId: String, @Param("desNodeId") desNodeId: String): Int

    @Query(
        """SELECT n FROM TreePath t 
            INNER JOIN Node n 
            ON t.anc=n.nodeId
            WHERE t.des=:nodeId AND t.distance=1
        """, nativeQuery = false
    )
    fun parent(@Param("nodeId") nodeId: String): Optional<Node>

    //查询祖先节点
    @Query(
        """SELECT n FROM TreePath t 
            INNER JOIN Node n 
            ON t.anc=n.nodeId 
            WHERE t.des=:nodeId AND t.anc!=:nodeId
        """, nativeQuery = false
    )
    fun ancestors(@Param("nodeId") nodeId: String): List<Node>

    @Query(
        """SELECT t FROM TreePath t 
            WHERE t.des=:nodeId AND t.anc!=:nodeId and t.distance>0
        """, nativeQuery = false
    )
    fun ancestorsPath(@Param("nodeId") nodeId: String): List<TreePath>

    //查询子孙节点
    @Query(
        """SELECT n FROM TreePath t 
            INNER JOIN Node n 
            ON t.des=n.nodeId 
            WHERE t.anc=:nodeId AND t.des!=:nodeId
        """, nativeQuery = false
    )
    fun descendants(@Param("nodeId") nodeId: String): List<Node>

    //查询子节点
    @Query(
        """SELECT n FROM TreePath t 
            INNER JOIN Node n 
            ON t.des=n.nodeId 
            WHERE t.anc=:nodeId AND t.distance=1
        """, nativeQuery = false
    )
    fun children(@Param("nodeId") nodeId: String): List<Node>

}

初始化树根


import org.springframework.boot.context.event.ApplicationReadyEvent
import org.springframework.context.event.EventListener
import org.springframework.stereotype.Component
import java.util.UUID
import jakarta.transaction.Transactional
import com.example.*


@Component
class EventHook(
    val nodeDao: NodeDao,
    val treePathDao: TreePathDao
) {

    @EventListener(ApplicationReadyEvent::class)
    @Transactional
    fun initTree() {
        // Tree init
        val rootNodeId="00000000-00000000-00000000-00000000"
        if (nodeDao.count() == 0L) {
            val metadata = HashMap<String, Any>()
            val root = Node(nodeId = rootNodeId, nodeName = "root", metadata = metadata)
            nodeDao.save(root)
            val treePath = TreePath(id= UUID.randomUUID().toString(),anc = root.nodeId, des = root.nodeId, distance = 0)
            treePathDao.save(treePath)
        }


    }

}

异常类

class NodeExistsException: RuntimeException {
    constructor(nodeId:String):super("Node(id=${nodeId} already exists")
}
class NodeNotExistsException: RuntimeException {
    constructor(nodeId:String):super("Node(id=${nodeId} not exists")
}

服务类

TreeService

import com.example.*


interface TreeService {

    //节点关系
    fun treePath(ancNodeId:String,desNodeId:String):Optional<TreePath>

    //创建节点
    fun createNode(node: Node, parentId: String): Optional<Node>

    //删除节点
    fun deleteNode(nodeId: String): Int

    //更新节点
    fun update(node: Node): Node

    //构造树
    fun buildTree(nodeId: String): Tree

    //移动树
    fun move(nodeId: String, newParentId: String): Int

    //获取孩子节点
    fun children(nodeId:String):List<Node>

    //获取子孙节点
    fun descendants(nodeId:String):List<Node>

    //根节点
    fun rootNode():Node

}

TreeServiceImpl

import org.springframework.stereotype.Service
import org.springframework.transaction.annotation.Isolation
import org.springframework.transaction.annotation.Transactional
import java.util.UUID
import java.util.*
import com.example.*


@Service
class TreeServiceImpl(
    var nodeDao: NodeDao,
    var treePathDao: TreePathDao
) : TreeService {


    override fun treePath(ancNodeId: String, desNodeId: String): Optional<TreePath> {
        return treePathDao.findByAncAndDes(anc = ancNodeId, des = desNodeId)
    }


    @Transactional
    override fun createNode(node: Node, parentId: String): Optional<Node> {
        if (nodeDao.findById(node.nodeId).isEmpty) {
            val ancPathList = treePathDao.ancestorsPath(parentId)
            val newPathList=ArrayList<TreePath>()
            var path:TreePath
            for(ancPath in ancPathList){
                path=TreePath(id= genId(),anc=ancPath.anc,des=node.nodeId, distance = ancPath.distance+1)
                newPathList.add(path)
            }
            //self
            path=TreePath(id= genId(),anc=node.nodeId,des=node.nodeId, distance = 0)
            newPathList.add(path)
            //parent
            path=TreePath(id= genId(),anc=parentId,des=node.nodeId, distance = 1)
            newPathList.add(path)
            //save
            treePathDao.saveAll(newPathList)
            return Optional.of(nodeDao.save(node))
        } else {
            throw NodeExistsException(node.nodeId)
        }
    }

    @Transactional
    override fun deleteNode(nodeId: String): Int {
        val node=nodeDao.findById(nodeId)
        if(!node.isEmpty){
            val desList = treePathDao.descendants(nodeId)
            val r1 = treePathDao.deleteTree(nodeId)
            nodeDao.deleteAll(desList)
            nodeDao.deleteById(nodeId)
            return r1 + desList.size+1
        }else{
            throw NodeNotExistsException(nodeId)
        }

    }

    @Transactional
    override fun update(node: Node): Node {
        if (!nodeDao.findById(node.nodeId).isEmpty) {
            return nodeDao.save(node)
        } else {
            throw NodeNotExistsException(node.nodeId)
        }
    }

    //查询树
    override fun buildTree(nodeId: String): Tree {

        val node = nodeDao.findById(nodeId)
        if (!node.isEmpty) {
            val queue: Queue<Tree> = LinkedList()
            val rootTree = Tree(node = node.get())
            queue.offer(rootTree)
            while (queue.isNotEmpty()) {
                val tree: Tree = queue.poll()
                val subTreeList = treePathDao.children(tree.node.nodeId).map { Tree(node = it) }
                tree.children.addAll(subTreeList)
                for (c in subTreeList) {
                    queue.offer(c)
                }
            }
            return rootTree

        } else {
            throw NodeNotExistsException(nodeId)
        }
    }

    //移动子树
    @Transactional(readOnly = false, isolation = Isolation.SERIALIZABLE)
    override fun move(nodeId: String, newParentId: String): Int {
        if (nodeId == newParentId) {
            return 0
        }
        val newParent = nodeDao.findById(newParentId)
        val node = nodeDao.findById(nodeId)
        if (node.isEmpty) {
            throw NodeNotExistsException(nodeId)
        }
        if (newParent.isEmpty) {
            throw NodeNotExistsException(newParentId)
        }
        val desList = treePathDao.descendants(nodeId = nodeId)
        if (desList.contains(newParent.get())) {
            throw RuntimeException("newParent cannot be sub node")
        }
        val r1 = treePathDao.detach(nodeId)
        val r2 = treePathDao.graft(nodeId, newParentId)
        return r1 + r2
    }

    //查询孩子节点
    override fun children(nodeId: String): List<Node> {
        return treePathDao.children(nodeId)
    }

    //查询所有子节点
    override fun descendants(nodeId: String): List<Node> {
        return treePathDao.descendants(nodeId = nodeId)
    }

    //获取根节点
    override fun rootNode(): Node {
        return nodeDao.findById("00000000-00000000-00000000-00000000").get()
    }

}

测试类

import org.junit.jupiter.api.Test
import org.springframework.beans.factory.annotation.Autowired
import org.springframework.boot.test.context.SpringBootTest
import java.util.*
import kotlin.collections.HashMap
import com.example.*


@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.NONE)
class TreeTests {

    @Autowired
    private  val nodeService:NodeService?=null

    @Autowired
    private  val treeService:TreeService?=null

    @Test
    @Throws(Exception::class)
    fun test() {
        val rootNode=treeService!!.rootNode()
        val a=nodeService!!.createNode(Node(nodeId = "a", nodeName = "a", metadata = HashMap()),parentId=rootNode.nodeId).get()
        val a1=nodeService!!.createNode(Node(nodeId = "a1", nodeName = "a1",metadata=HashMap()),parentId=a.nodeId).get()
        val a2=nodeService!!.createNode(Node(nodeId = "a2", nodeName = "a2",metadata=HashMap()),parentId=a.nodeId).get()

        val a21=nodeService!!.createNode(Node(nodeId = "a21", nodeName = "a21",metadata=HashMap()),parentId=a2.nodeId).get()
        val a22=nodeService!!.createNode(Node(nodeId = "a22", nodeName = "a22",metadata=HashMap()),parentId=a2.nodeId).get()

        val b=nodeService!!.createNode(Node(nodeId = "b", nodeName = "b",metadata=HashMap()),parentId=rootNode.nodeId).get()

//        val r=treeService.move(nodeId = a2.nodeId, newParentId = b.nodeId)
//        nodeService.deleteNode("a")

        val tree=treeService.genTree(rootNode.nodeId)
        println(tree)
    }


}
println(tree)输出
nodeId=00000000-00000000-00000000-00000000
  nodeId=b
  nodeId=a
    nodeId=a2
      nodeId=a22
      nodeId=a21
    nodeId=a1
posted @ 2023-04-06 11:55  Y'Shaarj  阅读(302)  评论(0)    收藏  举报