多模式匹配的Trie实现

业务场景

这种需求一般用于敏感词过滤等场景, 输入是大文本, 需要快速判断是否存在匹配的模式串(敏感词), 或者在其中找出所有匹配的模式串. 对于模式串数量不超过5000的场景, 直接用暴力查找速度也能接受, 对于更大规模的模式串, 需要对匹配进行优化.

实现原理

带Fail Next回溯的Trie树结构是常见的实现方法, 算法原理可以自行查找"多模式匹配算法". 在实际使用中, 对于中文的模式串, 因为中文字库很大, 数万的级别, 如果使用单个中文文字作为每个节点的子节点数组, 那么这个数组尺寸会非常大, 同时这个Trie树的深度很小, 最长的中文词字数不过19. 这样造成了很多空间浪费. 在这个实现中, 将字符串统一使用十六进制数组表示, 这样每个节点的子节点数组大小只有16, 同时最大深度变成114, 虽然在计算Fail Next时需要花费更多时间, 但是在空间效率上提升了很多.

模式清洗

对于输入的模式串, 统一转换为byte[], 再转换为十六进制的 int[]

Trie树构造

遍历所有的模式串, 将int[]添加入Trie树, 每个int对应其中的一个node, 将byte[]值写入最后一个节点(叶子节点).

Fail Next构造

Next的定义: 当匹配下一个节点失败时, 模式串应该跳到哪个节点继续匹配.

初始值: Root的Next为空, 第一层的Next都为Root

计算某节点的Next: 取此节点的父节点的Next为Node,

  • 若Node中编号index的子节点存在, 则此子节点就是Next
  • 若不存在, 那么再将Node的Next设为Node, 继续刚才的逻辑
  • 若Node的Next为空, 则以此Node为Next (此时这个Node应当为Root)

对整个Trie树的next赋值必须以广度遍历的方式进行, 因为每一个next的计算, 要基于上层已经设置的next.

文本查找

对输入的文本, 也需要转换为十六进制int[]进行查找. 在每一步, 无论是匹配成功, 还是匹配失败, 都要查看当前节点的next, 以及next的next, 是否是叶子节点, 否则会错过被大模式串包围的小模式串.

代码实现

TrieNode

public class TrieNode {
    private byte[] value;
    private int freq;
    private TrieNode parent;
    private TrieNode next;
    private TrieNode[] children;
}

TrieMatch

/**
 * Efficient multi-pattern matching approach with Trie algorithm
 *
 * Code example:
 * ```
 * TrieMatch trie = new TrieMatch();
 * trie.initialize("webdict_with_freq.txt");
 * Set<String> results = trie.match("any UTF-8 string");
 * ```
 */
public class TrieMatch {
    private static final Logger logger = LoggerFactory.getLogger(TrieMatch.class);
    /** Trie size */
    private int size;
    /** Trie depth */
    private int depth;
    /** Trie root node */
    private TrieNode root;
    /** Queue for span traversal */
    private Queue<TrieNode> queue;

    public TrieMatch() {
        root = new TrieNode();
        queue = new ArrayDeque<>();
    }

    public TrieNode getRoot() { return root; }
    public int getSize() { return size; }
    public int getDepth() { return depth; }

    public Set<String> match(String content) {
        byte[] bytes = content.getBytes(StandardCharsets.UTF_8);
        int[] hex = bytesToHex(bytes);
        Set<byte[]> arrays = match(hex);
        Set<String> output = new LinkedHashSet<>();
        for (byte[] array : arrays) {
            if (array == null) {
                continue;
            }
            String string = new String(array, StandardCharsets.UTF_8);
            output.add(string);
        }
        return output;
    }

    /**
     * Traverse the Trie tree to find all matched nodes.
     */
    public Set<byte[]> match(int[] hex) {
        Set<byte[]> output = new LinkedHashSet<>();
        TrieNode node = root;
        for (int i = 0; i < hex.length;) {

            if (node.getChildren() != null) {
                TrieNode forward = node.getChildren()[hex[i]];
                if (forward != null) {
                    if (forward.getValue() != null) {
                        output.add(forward.getValue());
                    }
                    TrieNode possible = node.getNext();
                    while (possible != null && possible.getValue() != null) {
                        output.add(possible.getValue());
                        possible = possible.getNext();
                    }
                    node = forward;
                    i++;
                    continue;
                }
            }
            // Move to 'next' node when unmatched
            node = node.getNext();
            if (node == null) {
                node = root;
                i++;
            } else {
                TrieNode possible = node;
                while (possible != null && possible.getValue() != null) {
                    output.add(possible.getValue());
                    possible = possible.getNext();
                }
            }
        }
        return output;
    }

    public void print() {
        queue.clear();
        queue.add(root);
        TrieNode node;
        while ((node = queue.poll()) != null) {
            logger.debug(node.toString());
            if (node.getChildren() != null) {
                TrieNode[] children = node.getChildren();
                for (int i = 0; i < children.length; i++) {
                    if (children[i] != null) {
                        queue.add(children[i]);
                    }
                }
            }
        }
    }

    public void initialize(String filepath) {
        try {
            BufferedReader reader = new BufferedReader(new FileReader(filepath));
            String line;
            while ((line = reader.readLine()) != null) {
                String[] array = line.split("\\s+");
                if (array.length != 2) {
                    logger.debug("Error: " + line);
                    continue;
                }
                int freq = Integer.parseInt(array[1]);
                append(array[0], freq);
            }
            reader.close();

            queue.clear();
            queue.add(root);
            TrieNode node;
            while ((node = queue.poll()) != null) {
                fillNext(node);
            }

        } catch (IOException e) {
            logger.debug(e.getMessage());
        }
    }

    private void append(String word, int freq) {
        byte[] bytes = word.getBytes(StandardCharsets.UTF_8);
        int[] hex = bytesToHex(bytes);
        append(hex, freq);
    }

    private void append(int[] hex, int freq) {
        if (hex.length > depth) { depth = hex.length; }
        TrieNode parent = root;
        for (int i = 0; i < hex.length; i++) {
            int index = hex[i];
            if (index > 16) {
                logger.debug("Error: index exceeds 16");
                continue;
            }
            if (parent.getChildren() == null) {
                parent.setChildren(new TrieNode[16]);
            }
            TrieNode pos = parent.getChildren()[index];
            if (pos == null) {
                size++;
                pos = new TrieNode();
                pos.setParent(parent);
                parent.getChildren()[index] = pos;
            }
            if (i == hex.length - 1) {
                pos.setValue(hexToBytes(hex));
                pos.setFreq(freq);
            }
            parent = pos;
        }
    }

    private void fillNext(TrieNode node) {
        if (node.getChildren() != null) {
            TrieNode[] children = node.getChildren();
            for (int i = 0; i < children.length; i++) {
                if (children[i] != null) {
                    TrieNode next = getNext(node, i);
                    children[i].setNext(next);
                }
            }
            for (int i = 0; i < children.length; i++) {
                if (children[i] != null) {
                    queue.add(children[i]);
                }
            }
        }
    }

    /**
     * Definition of 'next': When failed matching this node, the patten should continue from which one
     * Initialize: root.next = null, [direct descendent].next = root
     * Calculate: Set node = parent.next (at the moment parent.next should have been set)
     * - if node.children[index] exists, then this child is the next
     * - if not, then set node = node.next, continue above searching
     * - if node.next is null, it should have reach the root, just return this node
     */
    private TrieNode getNext(TrieNode node, int index) {
        if (node.getNext() == null) { // This should be root
            return node;
        }
        node = node.getNext();
        if (node.getChildren() != null) {
            TrieNode next = node.getChildren()[index];
            if (next != null) {
                return next;
            } else {
                return getNext(node, index);
            }
        } else {
            return getNext(node, index);
        }
    }

    private int[] bytesToHex(byte[] bytes) {
        int[] ints = new int[bytes.length * 2];
        for (int i = 0; i < bytes.length; i++) {
            ints[i * 2 + 1] = bytes[i] & 0x0f;
            ints[i * 2] = (bytes[i] >> 4) & 0x0f;
        }
        return ints;
    }

    private byte[] hexToBytes(int[] hex) {
        byte[] bytes = new byte[hex.length / 2];
        for (int i = 0; i < bytes.length; i++) {
            int a = (hex[i * 2 ] << 4) + hex[i * 2 + 1];
            bytes[i] = (byte)a;
        }
        return bytes;
    }
}

  

 

posted on 2020-04-28 16:08  Milton  阅读(499)  评论(0)    收藏  举报

导航