3-11 三叉搜索树(子节点为二叉搜索树的Trie)

三叉搜索树(Ternary Search Tree)

三叉搜索树(Ternary Search Tree,简称 TST)是一种结合了 Trie(前缀树)和二叉搜索树(Binary Search Tree,BST)特点的树形数据结构。它的每个节点存储一个字符(Character),并拥有三个子节点指针:左子节点(Left Child)指向字符更小的节点,中间子节点(Middle Child)指向下一个字符的节点,右子节点(Right Child)指向字符更大的节点。

与标准 Trie 的每个节点维护 26 个子节点指针不同,TST 每个节点只有 3 个指针,在空间利用率上具有明显优势。TST 特别适合存储字符串集合,广泛应用于拼写检查(Spell Checking)、自动补全(Autocomplete)和最近邻搜索(Near-Neighbor Search)等场景。

下面是一棵包含 "cat", "car", "care", "dog" 四个单词的三叉搜索树:

            (root)
              c
            / | \
           a  T  d
          /|  |   \
         r T  t    o
         |   [E]    \
         e    T      g
        [E]   T     [E]
               T

更精确的结构表示:

                c
              / | \
             a     d
           / |      \
          r  t[E]    o
         / \         |
        T   e        g[E]
            [E]

其中各节点的含义如下:

  • 根节点存储字符 'c',其左子树中字符小于 'c',右子树中字符大于 'c'(如 'd'),中间子节点沿 'a' 继续匹配
  • t[E] 表示字符 't' 是单词 "cat" 的结尾
  • e[E] 表示字符 'e' 是单词 "care" 的结尾
  • g[E] 表示字符 'g' 是单词 "dog" 的结尾

可以看到 "cat""car" 共享前缀 "ca""car""care" 共享前缀 "car"——与 Trie 类似,TST 同样利用公共前缀节省空间,但每个节点的指针数从 26 减少到了 3。


节点定义

三叉搜索树的节点需要存储一个字符、一个标记是否为单词结尾的布尔值,以及三个子节点指针(左、中、右)。

C++ 节点定义

struct TSTNode {
    char data;              // character stored in this node
    bool isEndOfWord;       // true if this node ends a word
    TSTNode* left;          // child for characters less than data
    TSTNode* middle;        // child for the next character (equal)
    TSTNode* right;         // child for characters greater than data

    TSTNode(char ch) : data(ch), isEndOfWord(false),
                       left(nullptr), middle(nullptr), right(nullptr) {}
};

C++ 使用 struct 配合构造函数初始化,data 存储当前字符,三个指针 leftmiddleright 分别指向小于、等于、大于当前字符的子树。

C 节点定义

#include <stdbool.h>

typedef struct TSTNode {
    char data;                  // character stored in this node
    bool isEndOfWord;           // true if this node ends a word
    struct TSTNode* left;       // child for characters less than data
    struct TSTNode* middle;     // child for the next character (equal)
    struct TSTNode* right;      // child for characters greater than data
} TSTNode;

// Create and initialize a new TST node with given character
TSTNode* createNode(char ch) {
    TSTNode* node = (TSTNode*)malloc(sizeof(TSTNode));
    node->data = ch;
    node->isEndOfWord = false;
    node->left = NULL;
    node->middle = NULL;
    node->right = NULL;
    return node;
}

C 语言使用 typedef 定义结构体,并通过 malloc 动态分配内存。createNode 函数负责初始化所有字段。

Python 节点定义

class TSTNode:
    def __init__(self, ch):
        self.data = ch              # character stored in this node
        self.is_end_of_word = False  # true if this node ends a word
        self.left = None            # child for characters less than data
        self.middle = None          # child for the next character (equal)
        self.right = None           # child for characters greater than data

Python 版本使用类定义,代码简洁直观。三个子节点指针默认为 None,在需要时才创建子节点。

Go 节点定义

type TSTNode struct {
    data   byte        // character stored in this node
    isEnd  bool        // true if this node ends a word
    left   *TSTNode    // child for characters less than data
    middle *TSTNode    // child for the next character (equal)
    right  *TSTNode    // child for characters greater than data
}

Go 使用结构体定义节点,databyte 类型存储单个字符,isEnd 标记是否为单词结尾,三个指针分别指向小于、等于、大于当前字符的子树。


插入操作

插入(Insert)操作的核心思路是:从根节点出发,将当前字符与节点中的字符进行比较:

  1. 当前字符小于节点字符 — 走左子节点
  2. 当前字符大于节点字符 — 走右子节点
  3. 当前字符等于节点字符 — 走中间子节点,处理下一个字符

如果需要走的子节点不存在,则创建新节点。当处理完所有字符后,将最后一个节点标记为单词结尾。

例如,依次插入 "cat", "car", "care", "dog", "do", "bat"

  1. 插入 "cat":创建 c(m) -> a(m) -> t 路径,标记 t 为单词结尾
  2. 插入 "car":复用 c -> a,从 a 的中间子节点 t 处,因为 'r' < 't',创建 t 的左子节点 r,标记 r 为单词结尾
  3. 插入 "care":复用 c -> a -> r,从 r 的中间子节点创建 e,标记 e 为单词结尾
  4. 插入 "dog":因为 'd' > 'c',走 c 的右子节点,创建 d(m) -> o(m) -> g,标记 g 为单词结尾
  5. 插入 "do":复用 d -> o,标记 o 为单词结尾
  6. 插入 "bat":因为 'b' < 'c',走 c 的左子节点,创建 b(m) -> a(m) -> t,标记 t 为单词结尾

C++ 插入实现

#include <iostream>
#include <string>
using namespace std;

struct TSTNode {
    char data;
    bool isEndOfWord;
    TSTNode* left;
    TSTNode* middle;
    TSTNode* right;

    TSTNode(char ch) : data(ch), isEndOfWord(false),
                       left(nullptr), middle(nullptr), right(nullptr) {}
};

// Insert a word into the TST (recursive helper)
TSTNode* insert(TSTNode* node, const string& word, int index) {
    char ch = word[index];

    // Create a new node if current position is empty
    if (node == nullptr)
        node = new TSTNode(ch);

    // Navigate based on character comparison
    if (ch < node->data) {
        node->left = insert(node->left, word, index);
    } else if (ch > node->data) {
        node->right = insert(node->right, word, index);
    } else {
        // Characters match, move to next character via middle child
        if (index < (int)word.length() - 1) {
            node->middle = insert(node->middle, word, index + 1);
        } else {
            node->isEndOfWord = true;  // mark end of word
        }
    }
    return node;
}

int main() {
    TSTNode* root = nullptr;

    // Insert words one by one
    string words[] = {"cat", "car", "care", "dog", "do", "bat"};
    for (const string& w : words) {
        root = insert(root, w, 0);
        cout << "Inserted: " << w << endl;
    }

    return 0;
}

C++ 使用递归实现插入。insert 函数接收当前节点、待插入字符串和当前字符索引。比较逻辑决定走向左、中、右哪个子树。当索引到达字符串末尾时,标记当前节点为单词结尾。

C 插入实现

#include <stdio.h>
#include <stdlib.h>
#include <stdbool.h>
#include <string.h>

typedef struct TSTNode {
    char data;
    bool isEndOfWord;
    struct TSTNode* left;
    struct TSTNode* middle;
    struct TSTNode* right;
} TSTNode;

// Create and initialize a new TST node with given character
TSTNode* createNode(char ch) {
    TSTNode* node = (TSTNode*)malloc(sizeof(TSTNode));
    node->data = ch;
    node->isEndOfWord = false;
    node->left = NULL;
    node->middle = NULL;
    node->right = NULL;
    return node;
}

// Insert a word into the TST (recursive helper)
TSTNode* insert(TSTNode* node, const char* word, int index) {
    char ch = word[index];

    // Create a new node if current position is empty
    if (node == NULL)
        node = createNode(ch);

    // Navigate based on character comparison
    if (ch < node->data) {
        node->left = insert(node->left, word, index);
    } else if (ch > node->data) {
        node->right = insert(node->right, word, index);
    } else {
        // Characters match, move to next character via middle child
        if (index < (int)strlen(word) - 1) {
            node->middle = insert(node->middle, word, index + 1);
        } else {
            node->isEndOfWord = true;  // mark end of word
        }
    }
    return node;
}

int main() {
    TSTNode* root = NULL;

    // Insert words one by one
    const char* words[] = {"cat", "car", "care", "dog", "do", "bat"};
    int n = sizeof(words) / sizeof(words[0]);
    for (int i = 0; i < n; i++) {
        root = insert(root, words[i], 0);
        printf("Inserted: %s\n", words[i]);
    }

    return 0;
}

C 语言版本与 C++ 逻辑相同,使用 const char* 代替 stringNULL 代替 nullptrprintf 代替 cout

Python 插入实现

class TSTNode:
    def __init__(self, ch):
        self.data = ch
        self.is_end_of_word = False
        self.left = None
        self.middle = None
        self.right = None


def insert(node, word, index=0):
    ch = word[index]

    # Create a new node if current position is empty
    if node is None:
        node = TSTNode(ch)

    # Navigate based on character comparison
    if ch < node.data:
        node.left = insert(node.left, word, index)
    elif ch > node.data:
        node.right = insert(node.right, word, index)
    else:
        # Characters match, move to next character via middle child
        if index < len(word) - 1:
            node.middle = insert(node.middle, word, index + 1)
        else:
            node.is_end_of_word = True  # mark end of word
    return node


if __name__ == "__main__":
    root = None

    # Insert words one by one
    words = ["cat", "car", "care", "dog", "do", "bat"]
    for w in words:
        root = insert(root, w)
        print(f"Inserted: {w}")

Python 版本使用递归函数实现,index 参数默认为 0。逻辑与 C++/C 完全一致,但代码更加简洁。

运行该程序将输出:

Inserted: cat
Inserted: car
Inserted: care
Inserted: dog
Inserted: do
Inserted: bat

Go 插入实现

package main

import "fmt"

type TSTNode struct {
    data   byte
    isEnd  bool
    left   *TSTNode
    middle *TSTNode
    right  *TSTNode
}

func insert(node *TSTNode, word string, index int) *TSTNode {
    ch := word[index]

    // Create a new node if current position is empty
    if node == nil {
        node = &TSTNode{data: ch}
    }

    // Navigate based on character comparison
    if ch < node.data {
        node.left = insert(node.left, word, index)
    } else if ch > node.data {
        node.right = insert(node.right, word, index)
    } else {
        // Characters match, move to next character via middle child
        if index < len(word)-1 {
            node.middle = insert(node.middle, word, index+1)
        } else {
            node.isEnd = true // mark end of word
        }
    }
    return node
}

func main() {
    var root *TSTNode

    // Insert words one by one
    words := []string{"cat", "car", "care", "dog", "do", "bat"}
    for _, w := range words {
        root = insert(root, w, 0)
        fmt.Printf("Inserted: %s\n", w)
    }
}

Go 版本使用递归实现插入。insert 函数返回节点指针,当节点为 nil 时创建新节点。byte 类型用于字符比较,决定走向左、中、右哪个子树。

运行该程序将输出

Inserted: cat
Inserted: car
Inserted: care
Inserted: dog
Inserted: do
Inserted: bat

搜索操作

搜索(Search)操作与插入操作的比较逻辑完全相同:从根节点出发,将当前字符与节点字符进行比较,根据小于、等于、大于三种情况分别走向左、中、右子节点。如果处理完所有字符后,最终节点被标记为单词结尾,则搜索成功。

搜索的结果有三种情况:

  1. 找到完整单词 — 所有字符匹配,且最后一个节点的 isEndOfWordtrue
  2. 前缀存在但不是完整单词 — 所有字符匹配,但最后一个节点未被标记为单词结尾
  3. 未找到 — 在匹配过程中遇到空指针

C++ 搜索实现

// Search for a word in the TST
bool search(TSTNode* root, const string& word) {
    TSTNode* curr = root;
    int index = 0;

    while (curr != nullptr && index < (int)word.length()) {
        char ch = word[index];

        if (ch < curr->data) {
            curr = curr->left;          // go left for smaller characters
        } else if (ch > curr->data) {
            curr = curr->right;         // go right for larger characters
        } else {
            // Characters match
            if (index == (int)word.length() - 1) {
                // Reached the last character
                return curr->isEndOfWord;
            }
            curr = curr->middle;        // move to next character
            index++;
        }
    }
    return false;
}

// Example usage (assuming words already inserted):
// search(root, "cat")  -> true
// search(root, "ca")   -> false (prefix exists but not a complete word)
// search(root, "cow")  -> false (not found)

C++ 使用迭代方式实现搜索,从根节点出发逐步比较字符。到达最后一个字符时检查 isEndOfWord 标记。

C 搜索实现

// Search for a word in the TST
bool search(TSTNode* root, const char* word) {
    TSTNode* curr = root;
    int index = 0;
    int len = strlen(word);

    while (curr != NULL && index < len) {
        char ch = word[index];

        if (ch < curr->data) {
            curr = curr->left;          // go left for smaller characters
        } else if (ch > curr->data) {
            curr = curr->right;         // go right for larger characters
        } else {
            // Characters match
            if (index == len - 1) {
                // Reached the last character
                return curr->isEndOfWord;
            }
            curr = curr->middle;        // move to next character
            index++;
        }
    }
    return false;
}

C 语言版本使用 strlen 获取字符串长度,逻辑与 C++ 迭代版本一致。

Python 搜索实现

def search(node, word):
    curr = node
    index = 0

    while curr is not None and index < len(word):
        ch = word[index]

        if ch < curr.data:
            curr = curr.left            # go left for smaller characters
        elif ch > curr.data:
            curr = curr.right           # go right for larger characters
        else:
            # Characters match
            if index == len(word) - 1:
                # Reached the last character
                return curr.is_end_of_word
            curr = curr.middle          # move to next character
            index += 1

    return False

Python 版本使用迭代循环,逻辑清晰直观。

运行搜索测试将输出:

Search "cat": Found
Search "ca": Not Found (prefix exists, but not a complete word)
Search "cow": Not Found

Go 搜索实现

func search(root *TSTNode, word string) bool {
    curr := root
    index := 0

    for curr != nil && index < len(word) {
        ch := word[index]

        if ch < curr.data {
            curr = curr.left // go left for smaller characters
        } else if ch > curr.data {
            curr = curr.right // go right for larger characters
        } else {
            // Characters match
            if index == len(word)-1 {
                // Reached the last character
                return curr.isEnd
            }
            curr = curr.middle // move to next character
            index++
        }
    }
    return false
}

Go 版本使用迭代循环实现搜索,从根节点出发逐步比较字符。到达最后一个字符时检查 isEnd 标记,判断是否为完整单词。


前缀搜索

前缀搜索(Prefix Search)是三叉搜索树最重要的应用之一。它的目标是:给定一个前缀字符串,找到树中所有以该前缀开头的完整单词。

前缀搜索分两步完成:

  1. 定位前缀节点 — 沿着搜索路径找到前缀最后一个字符对应的节点
  2. 收集所有后缀 — 从该节点的中间子节点出发,沿着所有中间路径收集完整单词

C++ 前缀搜索实现

#include <iostream>
#include <string>
#include <vector>
using namespace std;

// Helper: collect all words from a subtree
void collect(TSTNode* node, string prefix, vector<string>& results) {
    if (node == nullptr) return;

    // Traverse left subtree first (smaller characters)
    collect(node->left, prefix, results);

    // If current node is end of word, add to results
    if (node->isEndOfWord)
        results.push_back(prefix + node->data);

    // Traverse middle subtree (continuing the word)
    collect(node->middle, prefix + node->data, results);

    // Traverse right subtree (larger characters)
    collect(node->right, prefix, results);
}

// Find the node corresponding to the last character of prefix
TSTNode* findPrefixNode(TSTNode* root, const string& prefix) {
    TSTNode* curr = root;
    int index = 0;

    while (curr != nullptr && index < (int)prefix.length()) {
        char ch = prefix[index];

        if (ch < curr->data) {
            curr = curr->left;
        } else if (ch > curr->data) {
            curr = curr->right;
        } else {
            if (index == (int)prefix.length() - 1)
                return curr;            // found the last char of prefix
            curr = curr->middle;
            index++;
        }
    }
    return nullptr;                     // prefix not found
}

// Prefix search: find all words starting with given prefix
vector<string> prefixSearch(TSTNode* root, const string& prefix) {
    vector<string> results;

    // Step 1: find the node for the last character of prefix
    TSTNode* prefixNode = findPrefixNode(root, prefix);
    if (prefixNode == nullptr) return results;

    // Step 2: if prefix itself is a word, add it
    if (prefixNode->isEndOfWord)
        results.push_back(prefix);

    // Step 3: collect all words from the middle subtree
    collect(prefixNode->middle, prefix, results);

    return results;
}

C++ 版本将前缀搜索拆分为两个函数:findPrefixNode 定位前缀末尾节点,collect 从该节点的中间子树收集所有完整单词。collect 函数按中序遍历(左-中-右)的顺序收集结果。

C 前缀搜索实现

#include <stdio.h>
#include <stdlib.h>
#include <stdbool.h>
#include <string.h>

typedef struct TSTNode {
    char data;
    bool isEndOfWord;
    struct TSTNode* left;
    struct TSTNode* middle;
    struct TSTNode* right;
} TSTNode;

// (createNode and insert functions as defined earlier)

// Helper: collect all words from a subtree into an array
void collect(TSTNode* node, char* prefix, char results[][256], int* count) {
    if (node == NULL) return;

    // Traverse left subtree first
    collect(node->left, prefix, results, count);

    // If current node is end of word, add to results
    if (node->isEndOfWord) {
        sprintf(results[*count], "%s%c", prefix, node->data);
        (*count)++;
    }

    // Traverse middle subtree (continuing the word)
    int len = strlen(prefix);
    prefix[len] = node->data;
    prefix[len + 1] = '\0';
    collect(node->middle, prefix, results, count);
    prefix[len] = '\0';  // backtrack

    // Traverse right subtree
    collect(node->right, prefix, results, count);
}

// Find the node corresponding to the last character of prefix
TSTNode* findPrefixNode(TSTNode* root, const char* prefix) {
    TSTNode* curr = root;
    int index = 0;
    int len = strlen(prefix);

    while (curr != NULL && index < len) {
        char ch = prefix[index];

        if (ch < curr->data) {
            curr = curr->left;
        } else if (ch > curr->data) {
            curr = curr->right;
        } else {
            if (index == len - 1)
                return curr;
            curr = curr->middle;
            index++;
        }
    }
    return NULL;
}

// Prefix search: find all words starting with given prefix
int prefixSearch(TSTNode* root, const char* prefix, char results[][256]) {
    int count = 0;

    TSTNode* prefixNode = findPrefixNode(root, prefix);
    if (prefixNode == NULL) return 0;

    // If prefix itself is a word, add it
    if (prefixNode->isEndOfWord) {
        strcpy(results[count], prefix);
        count++;
    }

    // Collect all words from the middle subtree
    char buffer[256] = "";
    strcpy(buffer, prefix);
    collect(prefixNode->middle, buffer, results, &count);

    return count;
}

C 语言版本使用固定大小的二维字符数组存储结果。collect 函数通过修改 prefix 缓冲区并回溯(Backtracking)来实现路径拼接。

Python 前缀搜索实现

def _collect(node, prefix, results):
    """Helper: collect all words from a subtree."""
    if node is None:
        return

    # Traverse left subtree first (smaller characters)
    _collect(node.left, prefix, results)

    # If current node is end of word, add to results
    if node.is_end_of_word:
        results.append(prefix + node.data)

    # Traverse middle subtree (continuing the word)
    _collect(node.middle, prefix + node.data, results)

    # Traverse right subtree (larger characters)
    _collect(node.right, prefix, results)


def _find_prefix_node(node, prefix):
    """Find the node corresponding to the last character of prefix."""
    curr = node
    index = 0

    while curr is not None and index < len(prefix):
        ch = prefix[index]

        if ch < curr.data:
            curr = curr.left
        elif ch > curr.data:
            curr = curr.right
        else:
            if index == len(prefix) - 1:
                return curr
            curr = curr.middle
            index += 1

    return None


def prefix_search(node, prefix):
    """Prefix search: find all words starting with given prefix."""
    results = []

    prefix_node = _find_prefix_node(node, prefix)
    if prefix_node is None:
        return results

    # If prefix itself is a word, add it
    if prefix_node.is_end_of_word:
        results.append(prefix)

    # Collect all words from the middle subtree
    _collect(prefix_node.middle, prefix, results)

    return results

Python 版本逻辑清晰,使用列表动态存储结果,字符串拼接通过 prefix + node.data 完成,无需手动回溯。

运行前缀搜索测试将输出:

Prefix search "ca": car, care, cat
Prefix search "do": do, dog
Prefix search "b": bat
Prefix search "z": (no results)

Go 前缀搜索实现

func collect(node *TSTNode, prefix string, results *[]string) {
    if node == nil {
        return
    }

    // Traverse left subtree first (smaller characters)
    collect(node.left, prefix, results)

    // If current node is end of word, add to results
    if node.isEnd {
        *results = append(*results, prefix+string(node.data))
    }

    // Traverse middle subtree (continuing the word)
    collect(node.middle, prefix+string(node.data), results)

    // Traverse right subtree (larger characters)
    collect(node.right, prefix, results)
}

func findPrefixNode(root *TSTNode, prefix string) *TSTNode {
    curr := root
    index := 0

    for curr != nil && index < len(prefix) {
        ch := prefix[index]
        if ch < curr.data {
            curr = curr.left
        } else if ch > curr.data {
            curr = curr.right
        } else {
            if index == len(prefix)-1 {
                return curr
            }
            curr = curr.middle
            index++
        }
    }
    return nil
}

func prefixSearch(root *TSTNode, prefix string) []string {
    var results []string

    prefixNode := findPrefixNode(root, prefix)
    if prefixNode == nil {
        return results
    }

    // If prefix itself is a word, add it
    if prefixNode.isEnd {
        results = append(results, prefix)
    }

    // Collect all words from the middle subtree
    collect(prefixNode.middle, prefix, &results)

    return results
}

Go 版本将前缀搜索拆分为三个函数:findPrefixNode 定位前缀末尾节点,collect 从该节点的中间子树收集所有完整单词,prefixSearch 组合两者完成前缀搜索。注意 bytestring 使用 string(node.data)


遍历所有单词

遍历(Traversal)操作的目标是收集三叉搜索树中存储的所有单词。它的核心思路是:对每个节点,先遍历左子树,然后检查当前节点(沿中间路径收集完整单词),最后遍历右子树。本质上是一种变形的中序遍历(In-order Traversal)。

C++ 遍历实现

// Traverse the TST and collect all words
void traverse(TSTNode* node, string buffer, vector<string>& results) {
    if (node == nullptr) return;

    // Visit left subtree (smaller characters)
    traverse(node->left, buffer, results);

    // If current node marks end of word, save the word
    if (node->isEndOfWord)
        results.push_back(buffer + node->data);

    // Visit middle subtree, appending current character
    traverse(node->middle, buffer + node->data, results);

    // Visit right subtree (larger characters)
    traverse(node->right, buffer, results);
}

// Wrapper function to get all words sorted
vector<string> getAllWords(TSTNode* root) {
    vector<string> results;
    traverse(root, "", results);
    return results;
}

C++ 使用递归遍历,buffer 参数累加沿中间路径的字符。当遇到标记为单词结尾的节点时,将 buffer + node->data 加入结果。

C 遍历实现

// Traverse the TST and collect all words
void traverse(TSTNode* node, char* buffer, int depth,
              char results[][256], int* count) {
    if (node == NULL) return;

    // Visit left subtree (smaller characters)
    traverse(node->left, buffer, depth, results, count);

    // If current node marks end of word, save the word
    if (node->isEndOfWord) {
        buffer[depth] = node->data;
        buffer[depth + 1] = '\0';
        strcpy(results[*count], buffer);
        (*count)++;
    }

    // Visit middle subtree, appending current character
    buffer[depth] = node->data;
    traverse(node->middle, buffer, depth + 1, results, count);

    // Visit right subtree (larger characters)
    traverse(node->right, buffer, depth, results, count);
}

C 语言版本使用 depth 参数追踪当前深度,buffer 在递归过程中原地修改。保存单词时需要手动添加 '\0' 终止符。

Python 遍历实现

def traverse(node, buffer="", results=None):
    """Traverse the TST and collect all words."""
    if results is None:
        results = []
    if node is None:
        return results

    # Visit left subtree (smaller characters)
    traverse(node.left, buffer, results)

    # If current node marks end of word, save the word
    if node.is_end_of_word:
        results.append(buffer + node.data)

    # Visit middle subtree, appending current character
    traverse(node.middle, buffer + node.data, results)

    # Visit right subtree (larger characters)
    traverse(node.right, buffer, results)

    return results

Python 版本最为简洁,使用默认参数 buffer="" 累积字符,结果列表通过参数传递。

运行遍历操作将输出:

All words: bat, car, care, cat, do, dog

Go 遍历实现

func traverse(node *TSTNode, buffer string, results *[]string) {
    if node == nil {
        return
    }

    // Visit left subtree (smaller characters)
    traverse(node.left, buffer, results)

    // If current node marks end of word, save the word
    if node.isEnd {
        *results = append(*results, buffer+string(node.data))
    }

    // Visit middle subtree, appending current character
    traverse(node.middle, buffer+string(node.data), results)

    // Visit right subtree (larger characters)
    traverse(node.right, buffer, results)
}

func getAllWords(root *TSTNode) []string {
    var results []string
    traverse(root, "", &results)
    return results
}

Go 版本使用递归实现中序遍历,buffer 参数累加沿中间路径的字符。当遇到标记为单词结尾的节点时,将 buffer + string(node.data) 加入结果切片。getAllWords 为包装函数,返回排序后的所有单词。


完整实现

下面给出三叉搜索树的完整实现,包括插入、搜索、前缀搜索和遍历所有单词功能。

C++ 完整实现

#include <iostream>
#include <string>
#include <vector>
using namespace std;

struct TSTNode {
    char data;
    bool isEndOfWord;
    TSTNode* left;
    TSTNode* middle;
    TSTNode* right;

    TSTNode(char ch) : data(ch), isEndOfWord(false),
                       left(nullptr), middle(nullptr), right(nullptr) {}
};

// Insert a word into the TST
TSTNode* insert(TSTNode* node, const string& word, int index = 0) {
    char ch = word[index];
    if (node == nullptr)
        node = new TSTNode(ch);

    if (ch < node->data) {
        node->left = insert(node->left, word, index);
    } else if (ch > node->data) {
        node->right = insert(node->right, word, index);
    } else {
        if (index < (int)word.length() - 1)
            node->middle = insert(node->middle, word, index + 1);
        else
            node->isEndOfWord = true;
    }
    return node;
}

// Search for a word in the TST
bool search(TSTNode* root, const string& word) {
    TSTNode* curr = root;
    int index = 0;
    while (curr != nullptr && index < (int)word.length()) {
        char ch = word[index];
        if (ch < curr->data)        curr = curr->left;
        else if (ch > curr->data)   curr = curr->right;
        else {
            if (index == (int)word.length() - 1)
                return curr->isEndOfWord;
            curr = curr->middle;
            index++;
        }
    }
    return false;
}

// Collect all words from a subtree
void collect(TSTNode* node, const string& prefix, vector<string>& results) {
    if (node == nullptr) return;
    collect(node->left, prefix, results);
    if (node->isEndOfWord)
        results.push_back(prefix + node->data);
    collect(node->middle, prefix + node->data, results);
    collect(node->right, prefix, results);
}

// Prefix search: find all words starting with given prefix
vector<string> prefixSearch(TSTNode* root, const string& prefix) {
    vector<string> results;
    TSTNode* curr = root;
    int index = 0;

    // Navigate to the last character of the prefix
    while (curr != nullptr && index < (int)prefix.length()) {
        char ch = prefix[index];
        if (ch < curr->data)        curr = curr->left;
        else if (ch > curr->data)   curr = curr->right;
        else {
            if (index == (int)prefix.length() - 1) {
                if (curr->isEndOfWord)
                    results.push_back(prefix);
                collect(curr->middle, prefix, results);
                return results;
            }
            curr = curr->middle;
            index++;
        }
    }
    return results;
}

// Traverse and collect all words
void traverse(TSTNode* node, const string& buffer, vector<string>& results) {
    if (node == nullptr) return;
    traverse(node->left, buffer, results);
    if (node->isEndOfWord)
        results.push_back(buffer + node->data);
    traverse(node->middle, buffer + node->data, results);
    traverse(node->right, buffer, results);
}

int main() {
    TSTNode* root = nullptr;

    // Insert words
    string words[] = {"cat", "car", "care", "dog", "do", "bat"};
    for (const string& w : words)
        root = insert(root, w);

    // Search test
    cout << "=== Search ===" << endl;
    string queries[] = {"cat", "ca", "cow", "dog", "do", "bat"};
    for (const string& q : queries) {
        cout << "search(\"" << q << "\"): "
             << (search(root, q) ? "Found" : "Not Found") << endl;
    }

    // Prefix search test
    cout << "\n=== Prefix Search ===" << endl;
    string prefixes[] = {"ca", "do", "b", "z"};
    for (const string& p : prefixes) {
        vector<string> matches = prefixSearch(root, p);
        cout << "prefixSearch(\"" << p << "\"): ";
        for (size_t i = 0; i < matches.size(); i++) {
            if (i > 0) cout << ", ";
            cout << matches[i];
        }
        cout << endl;
    }

    // Traverse all words
    cout << "\n=== All Words ===" << endl;
    vector<string> allWords;
    traverse(root, "", allWords);
    cout << "All words: ";
    for (size_t i = 0; i < allWords.size(); i++) {
        if (i > 0) cout << ", ";
        cout << allWords[i];
    }
    cout << endl;

    return 0;
}

C 完整实现

#include <stdio.h>
#include <stdlib.h>
#include <stdbool.h>
#include <string.h>

typedef struct TSTNode {
    char data;
    bool isEndOfWord;
    struct TSTNode* left;
    struct TSTNode* middle;
    struct TSTNode* right;
} TSTNode;

// Create a new node
TSTNode* createNode(char ch) {
    TSTNode* node = (TSTNode*)malloc(sizeof(TSTNode));
    node->data = ch;
    node->isEndOfWord = false;
    node->left = NULL;
    node->middle = NULL;
    node->right = NULL;
    return node;
}

// Insert a word into the TST
TSTNode* insert(TSTNode* node, const char* word, int index) {
    char ch = word[index];
    if (node == NULL)
        node = createNode(ch);

    if (ch < node->data) {
        node->left = insert(node->left, word, index);
    } else if (ch > node->data) {
        node->right = insert(node->right, word, index);
    } else {
        if (index < (int)strlen(word) - 1)
            node->middle = insert(node->middle, word, index + 1);
        else
            node->isEndOfWord = true;
    }
    return node;
}

// Search for a word in the TST
bool search(TSTNode* root, const char* word) {
    TSTNode* curr = root;
    int index = 0;
    int len = strlen(word);
    while (curr != NULL && index < len) {
        char ch = word[index];
        if (ch < curr->data)        curr = curr->left;
        else if (ch > curr->data)   curr = curr->right;
        else {
            if (index == len - 1)
                return curr->isEndOfWord;
            curr = curr->middle;
            index++;
        }
    }
    return false;
}

// Collect all words from a subtree
void collect(TSTNode* node, char* prefix, char results[][256], int* count) {
    if (node == NULL) return;
    collect(node->left, prefix, results, count);
    if (node->isEndOfWord) {
        sprintf(results[*count], "%s%c", prefix, node->data);
        (*count)++;
    }
    int len = strlen(prefix);
    prefix[len] = node->data;
    prefix[len + 1] = '\0';
    collect(node->middle, prefix, results, count);
    prefix[len] = '\0';
    collect(node->right, prefix, results, count);
}

// Prefix search
int prefixSearch(TSTNode* root, const char* prefix, char results[][256]) {
    int count = 0;
    TSTNode* curr = root;
    int index = 0;
    int len = strlen(prefix);

    while (curr != NULL && index < len) {
        char ch = prefix[index];
        if (ch < curr->data)        curr = curr->left;
        else if (ch > curr->data)   curr = curr->right;
        else {
            if (index == len - 1) {
                if (curr->isEndOfWord) {
                    strcpy(results[count], prefix);
                    count++;
                }
                char buffer[256] = "";
                strcpy(buffer, prefix);
                collect(curr->middle, buffer, results, &count);
                return count;
            }
            curr = curr->middle;
            index++;
        }
    }
    return count;
}

// Traverse and collect all words
void traverse(TSTNode* node, char* buffer, int depth,
              char results[][256], int* count) {
    if (node == NULL) return;
    traverse(node->left, buffer, depth, results, count);
    if (node->isEndOfWord) {
        buffer[depth] = node->data;
        buffer[depth + 1] = '\0';
        strcpy(results[*count], buffer);
        (*count)++;
    }
    buffer[depth] = node->data;
    traverse(node->middle, buffer, depth + 1, results, count);
    traverse(node->right, buffer, depth, results, count);
}

int main() {
    TSTNode* root = NULL;

    // Insert words
    const char* words[] = {"cat", "car", "care", "dog", "do", "bat"};
    int n = sizeof(words) / sizeof(words[0]);
    for (int i = 0; i < n; i++)
        root = insert(root, words[i], 0);

    // Search test
    printf("=== Search ===\n");
    const char* queries[] = {"cat", "ca", "cow", "dog", "do", "bat"};
    int qn = sizeof(queries) / sizeof(queries[0]);
    for (int i = 0; i < qn; i++)
        printf("search(\"%s\"): %s\n", queries[i],
               search(root, queries[i]) ? "Found" : "Not Found");

    // Prefix search test
    printf("\n=== Prefix Search ===\n");
    const char* prefixes[] = {"ca", "do", "b", "z"};
    int pn = sizeof(prefixes) / sizeof(prefixes[0]);
    for (int i = 0; i < pn; i++) {
        char results[100][256];
        int cnt = prefixSearch(root, prefixes[i], results);
        printf("prefixSearch(\"%s\"): ", prefixes[i]);
        for (int j = 0; j < cnt; j++) {
            if (j > 0) printf(", ");
            printf("%s", results[j]);
        }
        printf("\n");
    }

    // Traverse all words
    printf("\n=== All Words ===\n");
    char allResults[100][256];
    int totalCount = 0;
    char buffer[256] = "";
    traverse(root, buffer, 0, allResults, &totalCount);
    printf("All words: ");
    for (int i = 0; i < totalCount; i++) {
        if (i > 0) printf(", ");
        printf("%s", allResults[i]);
    }
    printf("\n");

    return 0;
}

Python 完整实现

class TSTNode:
    def __init__(self, ch):
        self.data = ch
        self.is_end_of_word = False
        self.left = None
        self.middle = None
        self.right = None


def insert(node, word, index=0):
    """Insert a word into the TST."""
    ch = word[index]
    if node is None:
        node = TSTNode(ch)

    if ch < node.data:
        node.left = insert(node.left, word, index)
    elif ch > node.data:
        node.right = insert(node.right, word, index)
    else:
        if index < len(word) - 1:
            node.middle = insert(node.middle, word, index + 1)
        else:
            node.is_end_of_word = True
    return node


def search(node, word):
    """Search for a word in the TST."""
    curr = node
    index = 0
    while curr is not None and index < len(word):
        ch = word[index]
        if ch < curr.data:
            curr = curr.left
        elif ch > curr.data:
            curr = curr.right
        else:
            if index == len(word) - 1:
                return curr.is_end_of_word
            curr = curr.middle
            index += 1
    return False


def _collect(node, prefix, results):
    """Collect all words from a subtree."""
    if node is None:
        return
    _collect(node.left, prefix, results)
    if node.is_end_of_word:
        results.append(prefix + node.data)
    _collect(node.middle, prefix + node.data, results)
    _collect(node.right, prefix, results)


def prefix_search(node, prefix):
    """Prefix search: find all words starting with given prefix."""
    results = []
    curr = node
    index = 0

    # Navigate to the last character of the prefix
    while curr is not None and index < len(prefix):
        ch = prefix[index]
        if ch < curr.data:
            curr = curr.left
        elif ch > curr.data:
            curr = curr.right
        else:
            if index == len(prefix) - 1:
                if curr.is_end_of_word:
                    results.append(prefix)
                _collect(curr.middle, prefix, results)
                return results
            curr = curr.middle
            index += 1

    return results


def traverse(node, buffer="", results=None):
    """Traverse and collect all words."""
    if results is None:
        results = []
    if node is None:
        return results
    traverse(node.left, buffer, results)
    if node.is_end_of_word:
        results.append(buffer + node.data)
    traverse(node.middle, buffer + node.data, results)
    traverse(node.right, buffer, results)
    return results


if __name__ == "__main__":
    root = None

    # Insert words
    words = ["cat", "car", "care", "dog", "do", "bat"]
    for w in words:
        root = insert(root, w)

    # Search test
    print("=== Search ===")
    for q in ["cat", "ca", "cow", "dog", "do", "bat"]:
        status = "Found" if search(root, q) else "Not Found"
        print(f'search("{q}"): {status}')

    # Prefix search test
    print("\n=== Prefix Search ===")
    for p in ["ca", "do", "b", "z"]:
        matches = prefix_search(root, p)
        print(f'prefixSearch("{p}"): {", ".join(matches) if matches else "(none)"}')

    # Traverse all words
    print("\n=== All Words ===")
    all_words = traverse(root)
    print(f"All words: {', '.join(all_words)}")

运行该程序将输出:

=== Search ===
search("cat"): Found
search("ca"): Not Found
search("cow"): Not Found
search("dog"): Found
search("do"): Found
search("bat"): Found

=== Prefix Search ===
prefixSearch("ca"): car, care, cat
prefixSearch("do"): do, dog
prefixSearch("b"): bat
prefixSearch("z"): (none)

=== All Words ===
All words: bat, car, care, cat, do, dog

Go 完整实现

package main

import "fmt"

type TSTNode struct {
    data   byte
    isEnd  bool
    left   *TSTNode
    middle *TSTNode
    right  *TSTNode
}

// Insert a word into the TST
func insert(node *TSTNode, word string, index int) *TSTNode {
    ch := word[index]
    if node == nil {
        node = &TSTNode{data: ch}
    }

    if ch < node.data {
        node.left = insert(node.left, word, index)
    } else if ch > node.data {
        node.right = insert(node.right, word, index)
    } else {
        if index < len(word)-1 {
            node.middle = insert(node.middle, word, index+1)
        } else {
            node.isEnd = true
        }
    }
    return node
}

// Search for a word in the TST
func search(root *TSTNode, word string) bool {
    curr := root
    index := 0
    for curr != nil && index < len(word) {
        ch := word[index]
        if ch < curr.data {
            curr = curr.left
        } else if ch > curr.data {
            curr = curr.right
        } else {
            if index == len(word)-1 {
                return curr.isEnd
            }
            curr = curr.middle
            index++
        }
    }
    return false
}

// Collect all words from a subtree
func collect(node *TSTNode, prefix string, results *[]string) {
    if node == nil {
        return
    }
    collect(node.left, prefix, results)
    if node.isEnd {
        *results = append(*results, prefix+string(node.data))
    }
    collect(node.middle, prefix+string(node.data), results)
    collect(node.right, prefix, results)
}

// Prefix search: find all words starting with given prefix
func prefixSearch(root *TSTNode, prefix string) []string {
    var results []string
    curr := root
    index := 0

    // Navigate to the last character of the prefix
    for curr != nil && index < len(prefix) {
        ch := prefix[index]
        if ch < curr.data {
            curr = curr.left
        } else if ch > curr.data {
            curr = curr.right
        } else {
            if index == len(prefix)-1 {
                if curr.isEnd {
                    results = append(results, prefix)
                }
                collect(curr.middle, prefix, &results)
                return results
            }
            curr = curr.middle
            index++
        }
    }
    return results
}

// Traverse and collect all words
func traverse(node *TSTNode, buffer string, results *[]string) {
    if node == nil {
        return
    }
    traverse(node.left, buffer, results)
    if node.isEnd {
        *results = append(*results, buffer+string(node.data))
    }
    traverse(node.middle, buffer+string(node.data), results)
    traverse(node.right, buffer, results)
}

func main() {
    var root *TSTNode

    // Insert words
    words := []string{"cat", "car", "care", "dog", "do", "bat"}
    for _, w := range words {
        root = insert(root, w, 0)
    }

    // Search test
    fmt.Println("=== Search ===")
    queries := []string{"cat", "ca", "cow", "dog", "do", "bat"}
    for _, q := range queries {
        status := "Not Found"
        if search(root, q) {
            status = "Found"
        }
        fmt.Printf("search(\"%s\"): %s\n", q, status)
    }

    // Prefix search test
    fmt.Println("\n=== Prefix Search ===")
    prefixes := []string{"ca", "do", "b", "z"}
    for _, p := range prefixes {
        matches := prefixSearch(root, p)
        if len(matches) == 0 {
            fmt.Printf("prefixSearch(\"%s\"): (none)\n", p)
        } else {
            fmt.Printf("prefixSearch(\"%s\"): %s\n", p, joinStrings(matches))
        }
    }

    // Traverse all words
    fmt.Println("\n=== All Words ===")
    var allWords []string
    traverse(root, "", &allWords)
    fmt.Printf("All words: %s\n", joinStrings(allWords))
}

func joinStrings(ss []string) string {
    result := ""
    for i, s := range ss {
        if i > 0 {
            result += ", "
        }
        result += s
    }
    return result
}

Go 完整实现包含插入、搜索、前缀搜索和遍历所有单词功能。joinStrings 辅助函数用于将字符串切片用逗号连接,替代其他语言中的 join 方法。所有核心逻辑与 C++/Python 版本一致,使用指针操作树结构。

运行该程序将输出

=== Search ===
search("cat"): Found
search("ca"): Not Found
search("cow"): Not Found
search("dog"): Found
search("do"): Found
search("bat"): Found

=== Prefix Search ===
prefixSearch("ca"): car, care, cat
prefixSearch("do"): do, dog
prefixSearch("b"): bat
prefixSearch("z"): (none)

=== All Words ===
All words: bat, car, care, cat, do, dog

三叉搜索树的性质

空间效率

三叉搜索树最显著的优势在于空间效率(Space Efficiency)。标准 Trie 的每个节点需要维护 26 个子节点指针(对于小写字母),即使大部分指针为空也要分配空间。而 TST 每个节点只有 3 个指针(left, middle, right),存储开销大幅降低:

数据结构 每节点指针数 空闲指针
标准 Trie 26 大量空闲
三叉搜索树 3 极少空闲

时间复杂度

操作 平均时间复杂度 最坏时间复杂度
插入(Insert) O(L) O(L * log N)
搜索(Search) O(L) O(L * log N)
前缀搜索(Prefix Search) O(L + M) O(L * log N + M)

其中 L 为字符串长度,N 为树中存储的字符串数量,M 为匹配结果的总字符数。

当树较为平衡时,查找操作接近 O(L),其中 L 是待查找字符串的长度。最坏情况下(树退化为链表),时间复杂度退化为 O(L * log N)。

Trie vs TST vs BST 对比

特性 Trie 三叉搜索树 (TST) 二叉搜索树 (BST)
每节点指针数 26(固定) 3 2(或 3 含 parent)
查找复杂度 O(L) O(L) ~ O(L log N) O(L log N)
前缀搜索 天然支持 天然支持 不直接支持
空间开销 中等
公共前缀共享 完全共享 完全共享 不共享
适合字符集 固定小字符集 任意字符集 任意数据类型

三叉搜索树在 Trie 和 BST 之间取得了良好的平衡:既保留了 Trie 前缀共享和前缀搜索的能力,又具有接近 BST 的空间效率。尤其当字符集较大(如 Unicode)或字符串稀疏时,TST 的优势更加明显。

实际应用

三叉搜索树在工程实践中有广泛的应用场景:

  • 拼写检查(Spell Checking) — 快速查找单词是否在字典中,也可用于模糊匹配
  • 自动补全(Autocomplete) — 用户输入前缀后,通过前缀搜索实时返回候选词列表
  • IP 路由表(IP Routing Table) — 利用前缀搜索进行最长前缀匹配
  • 基因组序列分析(Genomic Sequence Analysis) — 在 DNA 序列中搜索模式串
  • 全文搜索(Full-text Search) — 作为倒排索引的辅助数据结构
posted @ 2026-04-16 20:52  游翔  阅读(22)  评论(0)    收藏  举报