golang前缀树过滤敏感字

前缀树

package xxxx

const defaultMask = '*'

type (
	TrieOption func(trie *trieNode)

	Trie interface {
		Filter(text string) (string, []string, bool)
		FindKeywords(text string) []string
	}

	trieNode struct {
		node
		mask rune
	}

	scope struct {
		start int
		stop  int
	}
)

// 开始实现前缀树,把敏感字从数据库/配置读出来,塞进去
func NewTrie(words []string, opts ...TrieOption) Trie {
	n := new(trieNode)

	for _, opt := range opts {
		opt(n)
	}
	if n.mask == 0 {
		n.mask = defaultMask
	}
	for _, word := range words {
		n.add(word)
	}

	n.build()

	return n
}

func (n *trieNode) Filter(text string) (sentence string, keywords []string, found bool) {
	chars := []rune(text)
	if len(chars) == 0 {
		return text, nil, false
	}

	scopes := n.find(chars)
	keywords = n.collectKeywords(chars, scopes)

	for _, match := range scopes {
		//替换
		n.replaceWithAsterisk(chars, match.start, match.stop)
	}

	return string(chars), keywords, len(keywords) > 0
}

func (n *trieNode) FindKeywords(text string) []string {
	chars := []rune(text)
	if len(chars) == 0 {
		return nil
	}

	scopes := n.find(chars)
	return n.collectKeywords(chars, scopes)
}

func (n *trieNode) collectKeywords(chars []rune, scopes []scope) []string {
	set := make(map[string]struct{})
	for _, v := range scopes {
		set[string(chars[v.start:v.stop])] = struct{}{}
	}

	var i int
	keywords := make([]string, len(set))
	for k := range set {
		keywords[i] = k
		i++
	}

	return keywords
}

func (n *trieNode) replaceWithAsterisk(chars []rune, start, stop int) {
	for i := start; i < stop; i++ {
		chars[i] = n.mask
	}
}

// 自定义替换字符
func WithMask(mask rune) TrieOption {
	return func(n *trieNode) {
		n.mask = mask
	}
}

使用

trie := NewTrie([]string{
		"bc",
		"cd",
	})
	output, keywords, found := trie.Filter("abcd")
	fmt.Println(output)
    fmt.Println(keywords)
    fmt.Println(found)
//a***
//[bc cd]
//true

压缩前缀树

package main

import (
	"fmt"
)

const defaultMask = '*'

type (
	// Trie 定义前缀树接口
	Trie interface {
		Filter(text string) (string, []string, bool)
		FindKeywords(text string) []string
	}

	// compressedTrieNode 表示压缩前缀树的节点
	compressedTrieNode struct {
		children map[string]*compressedTrieNode
		isEnd    bool
		mask     rune
	}

	// scope 用于记录匹配到的敏感词在文本中的位置范围
	scope struct {
		start int
		stop  int
	}

	// TrieOption 用于配置前缀树的选项
	TrieOption func(trie *compressedTrieNode)
)

// NewCompressedTrie 创建一个新的压缩前缀树实例
func NewCompressedTrie(words []string, opts ...TrieOption) Trie {
	n := &compressedTrieNode{
		children: make(map[string]*compressedTrieNode),
		mask:     defaultMask,
	}

	// 应用选项
	for _, opt := range opts {
		opt(n)
	}

	// 添加所有敏感词到压缩前缀树
	for _, word := range words {
		if word != "" {
			n.add(word)
		}
	}

	return n
}

// add 将单词添加到压缩前缀树中
func (n *compressedTrieNode) add(word string) {
	runes := []rune(word)
	wLen := len(runes)

	// 遍历当前节点的所有子节点
	for prefix, child := range n.children {
		prefixRunes := []rune(prefix)
		// 找到公共前缀长度
		commonLen := n.findCommonPrefixRunes(prefixRunes, runes)

		if commonLen == 0 {
			// 没有公共前缀,继续查找下一个子节点
			continue
		}

		if commonLen == len(prefixRunes) && commonLen == wLen {
			// 完全匹配,标记为单词结束
			child.isEnd = true
			return
		}

		if commonLen == len(prefixRunes) {
			// 前缀完全匹配,递归添加剩余部分
			child.add(string(runes[commonLen:]))
			return
		}

		if commonLen < len(prefixRunes) {
			// 需要分裂当前节点
			splitNode := &compressedTrieNode{
				children: make(map[string]*compressedTrieNode),
				mask:     n.mask,
			}

			// 调整父子关系
			splitNode.children[string(prefixRunes[commonLen:])] = child
			splitNode.isEnd = child.isEnd
			delete(n.children, prefix)
			n.children[string(prefixRunes[:commonLen])] = splitNode

			if commonLen < wLen {
				// 添加剩余部分
				splitNode.add(string(runes[commonLen:]))
			} else {
				// 恰好是公共前缀长度
				splitNode.isEnd = true
			}
			return
		}
	}

	// 没有找到匹配的前缀,创建新的子节点
	n.children[word] = &compressedTrieNode{
		children: make(map[string]*compressedTrieNode),
		isEnd:    true,
		mask:     n.mask,
	}
}

// findCommonPrefixRunes 查找两个rune切片的公共前缀长度
func (n *compressedTrieNode) findCommonPrefixRunes(s1, s2 []rune) int {
	minLen := len(s1)
	if len(s2) < minLen {
		minLen = len(s2)
	}

	for i := 0; i < minLen; i++ {
		if s1[i] != s2[i] {
			return i
		}
	}

	return minLen
}

// findCommonPrefix 查找两个字符串的公共前缀长度
// 注意:这个方法主要用于ASCII字符串的比较,对于包含中文等Unicode字符的字符串,应使用findCommonPrefixRunes
func (n *compressedTrieNode) findCommonPrefix(s1, s2 string) int {
	s1Runes := []rune(s1)
	s2Runes := []rune(s2)
	return n.findCommonPrefixRunes(s1Runes, s2Runes)
}

// Filter 过滤文本中的敏感词,返回过滤后的文本、敏感词列表和是否包含敏感词
func (n *compressedTrieNode) Filter(text string) (sentence string, keywords []string, found bool) {
	chars := []rune(text)
	if len(chars) == 0 {
		return text, nil, false
	}

	scopes := n.find(chars)
	keywords = n.collectKeywords(chars, scopes)

	// 替换敏感词
	for _, match := range scopes {
		n.replaceWithMask(chars, match.start, match.stop)
	}

	return string(chars), keywords, len(keywords) > 0
}

// FindKeywords 查找文本中的所有敏感词
func (n *compressedTrieNode) FindKeywords(text string) []string {
	chars := []rune(text)
	if len(chars) == 0 {
		return nil
	}

	scopes := n.find(chars)
	return n.collectKeywords(chars, scopes)
}

// find 在文本中查找所有敏感词的位置
func (n *compressedTrieNode) find(chars []rune) []scope {
	var scopes []scope
	length := len(chars)

	for i := 0; i < length; i++ {

		foundScope := n.searchFromPosition(chars, i)
		if foundScope.stop > i {
			scopes = append(scopes, foundScope)
		}
	}

	return scopes
}

// searchFromPosition 从指定位置开始搜索敏感词
func (n *compressedTrieNode) searchFromPosition(chars []rune, start int) scope {
	current := n
	length := len(chars)
	currentPos := start
	lastMatchEnd := start

	for currentPos < length {
		matched := false

		for prefix, child := range current.children {
			prefixRunes := []rune(prefix)
			prefixLen := len(prefixRunes)

			// 检查剩余字符是否足够匹配前缀
			if currentPos+prefixLen > length {
				continue
			}

			// 检查是否匹配前缀
			match := true
			for i := 0; i < prefixLen; i++ {
				if prefixRunes[i] != chars[currentPos+i] {
					match = false
					break
				}
			}

			if match {
				// 找到匹配的前缀
				current = child
				currentPos += prefixLen
				matched = true

				// 如果是单词结束,更新最后匹配位置
				if child.isEnd {
					lastMatchEnd = currentPos
				}
				break
			}
		}

		if !matched {
			break
		}
	}

	return scope{start: start, stop: lastMatchEnd}
}

// collectKeywords 收集所有匹配的敏感词
func (n *compressedTrieNode) collectKeywords(chars []rune, scopes []scope) []string {
	set := make(map[string]struct{})
	for _, v := range scopes {
		if v.stop > v.start {
			set[string(chars[v.start:v.stop])] = struct{}{}
		}
	}

	// 将集合转换为切片
	keywords := make([]string, 0, len(set))
	for k := range set {
		keywords = append(keywords, k)
	}

	return keywords
}

// replaceWithMask 用掩码字符替换指定范围内的字符
func (n *compressedTrieNode) replaceWithMask(chars []rune, start, stop int) {
	for i := start; i < stop; i++ {
		chars[i] = n.mask
	}
}

// WithMask 设置自定义的掩码字符
func WithMask(mask rune) TrieOption {
	return func(n *compressedTrieNode) {
		n.mask = mask
	}
}

func main() {
	// 创建敏感词列表
	sensitiveWords := []string{
		"不良",
		"不良信息",
		"敏感词",
		"过滤",
		"测试",
	}

	// 1. 创建默认的压缩前缀树实例
	trie := NewCompressedTrie(sensitiveWords)

	// 测试文本
	testText := "这是一段包含敏感词和不良信息的测试文本,需要进行过滤处理。"

	// 2. 过滤敏感词
	filteredText, keywordsFound, hasSensitiveWord := trie.Filter(testText)

	// 输出结果
	fmt.Println("原始文本:", testText)
	fmt.Println("过滤后文本:", filteredText)
	fmt.Println("包含敏感词:", hasSensitiveWord)
	fmt.Println("检测到的敏感词:", keywordsFound)

	// 3. 仅查找敏感词不进行替换
	onlyKeywords := trie.FindKeywords(testText)
	fmt.Println("仅查找敏感词结果:", onlyKeywords)

	// 4. 使用自定义掩码字符
	customTrie := NewCompressedTrie(sensitiveWords, WithMask('x'))
	customFilteredText, _, _ := customTrie.Filter(testText)
	fmt.Println("使用自定义掩码过滤后文本:", customFilteredText)

	// 5. 测试没有敏感词的文本
	cleanText := "这是一段正常的文本内容。"
	cleanFiltered, cleanKeywords, hasCleanSensitive := trie.Filter(cleanText)
	fmt.Println("\n正常文本过滤结果:")
	fmt.Println("原始文本:", cleanText)
	fmt.Println("过滤后文本:", cleanFiltered)
	fmt.Println("包含敏感词:", hasCleanSensitive)
	fmt.Println("检测到的敏感词:", cleanKeywords)
}
posted @ 2024-06-04 11:44  朝阳1  阅读(34)  评论(0)    收藏  举报