3-11 三叉搜索树(子节点为二叉搜索树的Trie)
三叉搜索树(Ternary Search Tree)
三叉搜索树(Ternary Search Tree,简称 TST)是一种结合了 Trie(前缀树)和二叉搜索树(Binary Search Tree,BST)特点的树形数据结构。它的每个节点存储一个字符(Character),并拥有三个子节点指针:左子节点(Left Child)指向字符更小的节点,中间子节点(Middle Child)指向下一个字符的节点,右子节点(Right Child)指向字符更大的节点。
与标准 Trie 的每个节点维护 26 个子节点指针不同,TST 每个节点只有 3 个指针,在空间利用率上具有明显优势。TST 特别适合存储字符串集合,广泛应用于拼写检查(Spell Checking)、自动补全(Autocomplete)和最近邻搜索(Near-Neighbor Search)等场景。
下面是一棵包含 "cat", "car", "care", "dog" 四个单词的三叉搜索树:
(root)
c
/ | \
a T d
/| | \
r T t o
| [E] \
e T g
[E] T [E]
T
更精确的结构表示:
c
/ | \
a d
/ | \
r t[E] o
/ \ |
T e g[E]
[E]
其中各节点的含义如下:
- 根节点存储字符
'c',其左子树中字符小于'c',右子树中字符大于'c'(如'd'),中间子节点沿'a'继续匹配 t[E]表示字符't'是单词"cat"的结尾e[E]表示字符'e'是单词"care"的结尾g[E]表示字符'g'是单词"dog"的结尾
可以看到 "cat" 和 "car" 共享前缀 "ca","car" 和 "care" 共享前缀 "car"——与 Trie 类似,TST 同样利用公共前缀节省空间,但每个节点的指针数从 26 减少到了 3。
节点定义
三叉搜索树的节点需要存储一个字符、一个标记是否为单词结尾的布尔值,以及三个子节点指针(左、中、右)。
C++ 节点定义
struct TSTNode {
char data; // character stored in this node
bool isEndOfWord; // true if this node ends a word
TSTNode* left; // child for characters less than data
TSTNode* middle; // child for the next character (equal)
TSTNode* right; // child for characters greater than data
TSTNode(char ch) : data(ch), isEndOfWord(false),
left(nullptr), middle(nullptr), right(nullptr) {}
};
C++ 使用 struct 配合构造函数初始化,data 存储当前字符,三个指针 left、middle、right 分别指向小于、等于、大于当前字符的子树。
C 节点定义
#include <stdbool.h>
typedef struct TSTNode {
char data; // character stored in this node
bool isEndOfWord; // true if this node ends a word
struct TSTNode* left; // child for characters less than data
struct TSTNode* middle; // child for the next character (equal)
struct TSTNode* right; // child for characters greater than data
} TSTNode;
// Create and initialize a new TST node with given character
TSTNode* createNode(char ch) {
TSTNode* node = (TSTNode*)malloc(sizeof(TSTNode));
node->data = ch;
node->isEndOfWord = false;
node->left = NULL;
node->middle = NULL;
node->right = NULL;
return node;
}
C 语言使用 typedef 定义结构体,并通过 malloc 动态分配内存。createNode 函数负责初始化所有字段。
Python 节点定义
class TSTNode:
def __init__(self, ch):
self.data = ch # character stored in this node
self.is_end_of_word = False # true if this node ends a word
self.left = None # child for characters less than data
self.middle = None # child for the next character (equal)
self.right = None # child for characters greater than data
Python 版本使用类定义,代码简洁直观。三个子节点指针默认为 None,在需要时才创建子节点。
Go 节点定义
type TSTNode struct {
data byte // character stored in this node
isEnd bool // true if this node ends a word
left *TSTNode // child for characters less than data
middle *TSTNode // child for the next character (equal)
right *TSTNode // child for characters greater than data
}
Go 使用结构体定义节点,data 为 byte 类型存储单个字符,isEnd 标记是否为单词结尾,三个指针分别指向小于、等于、大于当前字符的子树。
插入操作
插入(Insert)操作的核心思路是:从根节点出发,将当前字符与节点中的字符进行比较:
- 当前字符小于节点字符 — 走左子节点
- 当前字符大于节点字符 — 走右子节点
- 当前字符等于节点字符 — 走中间子节点,处理下一个字符
如果需要走的子节点不存在,则创建新节点。当处理完所有字符后,将最后一个节点标记为单词结尾。
例如,依次插入 "cat", "car", "care", "dog", "do", "bat":
- 插入
"cat":创建 c(m) -> a(m) -> t 路径,标记 t 为单词结尾 - 插入
"car":复用 c -> a,从 a 的中间子节点 t 处,因为'r' < 't',创建 t 的左子节点 r,标记 r 为单词结尾 - 插入
"care":复用 c -> a -> r,从 r 的中间子节点创建 e,标记 e 为单词结尾 - 插入
"dog":因为'd' > 'c',走 c 的右子节点,创建 d(m) -> o(m) -> g,标记 g 为单词结尾 - 插入
"do":复用 d -> o,标记 o 为单词结尾 - 插入
"bat":因为'b' < 'c',走 c 的左子节点,创建 b(m) -> a(m) -> t,标记 t 为单词结尾
C++ 插入实现
#include <iostream>
#include <string>
using namespace std;
struct TSTNode {
char data;
bool isEndOfWord;
TSTNode* left;
TSTNode* middle;
TSTNode* right;
TSTNode(char ch) : data(ch), isEndOfWord(false),
left(nullptr), middle(nullptr), right(nullptr) {}
};
// Insert a word into the TST (recursive helper)
TSTNode* insert(TSTNode* node, const string& word, int index) {
char ch = word[index];
// Create a new node if current position is empty
if (node == nullptr)
node = new TSTNode(ch);
// Navigate based on character comparison
if (ch < node->data) {
node->left = insert(node->left, word, index);
} else if (ch > node->data) {
node->right = insert(node->right, word, index);
} else {
// Characters match, move to next character via middle child
if (index < (int)word.length() - 1) {
node->middle = insert(node->middle, word, index + 1);
} else {
node->isEndOfWord = true; // mark end of word
}
}
return node;
}
int main() {
TSTNode* root = nullptr;
// Insert words one by one
string words[] = {"cat", "car", "care", "dog", "do", "bat"};
for (const string& w : words) {
root = insert(root, w, 0);
cout << "Inserted: " << w << endl;
}
return 0;
}
C++ 使用递归实现插入。insert 函数接收当前节点、待插入字符串和当前字符索引。比较逻辑决定走向左、中、右哪个子树。当索引到达字符串末尾时,标记当前节点为单词结尾。
C 插入实现
#include <stdio.h>
#include <stdlib.h>
#include <stdbool.h>
#include <string.h>
typedef struct TSTNode {
char data;
bool isEndOfWord;
struct TSTNode* left;
struct TSTNode* middle;
struct TSTNode* right;
} TSTNode;
// Create and initialize a new TST node with given character
TSTNode* createNode(char ch) {
TSTNode* node = (TSTNode*)malloc(sizeof(TSTNode));
node->data = ch;
node->isEndOfWord = false;
node->left = NULL;
node->middle = NULL;
node->right = NULL;
return node;
}
// Insert a word into the TST (recursive helper)
TSTNode* insert(TSTNode* node, const char* word, int index) {
char ch = word[index];
// Create a new node if current position is empty
if (node == NULL)
node = createNode(ch);
// Navigate based on character comparison
if (ch < node->data) {
node->left = insert(node->left, word, index);
} else if (ch > node->data) {
node->right = insert(node->right, word, index);
} else {
// Characters match, move to next character via middle child
if (index < (int)strlen(word) - 1) {
node->middle = insert(node->middle, word, index + 1);
} else {
node->isEndOfWord = true; // mark end of word
}
}
return node;
}
int main() {
TSTNode* root = NULL;
// Insert words one by one
const char* words[] = {"cat", "car", "care", "dog", "do", "bat"};
int n = sizeof(words) / sizeof(words[0]);
for (int i = 0; i < n; i++) {
root = insert(root, words[i], 0);
printf("Inserted: %s\n", words[i]);
}
return 0;
}
C 语言版本与 C++ 逻辑相同,使用 const char* 代替 string,NULL 代替 nullptr,printf 代替 cout。
Python 插入实现
class TSTNode:
def __init__(self, ch):
self.data = ch
self.is_end_of_word = False
self.left = None
self.middle = None
self.right = None
def insert(node, word, index=0):
ch = word[index]
# Create a new node if current position is empty
if node is None:
node = TSTNode(ch)
# Navigate based on character comparison
if ch < node.data:
node.left = insert(node.left, word, index)
elif ch > node.data:
node.right = insert(node.right, word, index)
else:
# Characters match, move to next character via middle child
if index < len(word) - 1:
node.middle = insert(node.middle, word, index + 1)
else:
node.is_end_of_word = True # mark end of word
return node
if __name__ == "__main__":
root = None
# Insert words one by one
words = ["cat", "car", "care", "dog", "do", "bat"]
for w in words:
root = insert(root, w)
print(f"Inserted: {w}")
Python 版本使用递归函数实现,index 参数默认为 0。逻辑与 C++/C 完全一致,但代码更加简洁。
运行该程序将输出:
Inserted: cat
Inserted: car
Inserted: care
Inserted: dog
Inserted: do
Inserted: bat
Go 插入实现
package main
import "fmt"
type TSTNode struct {
data byte
isEnd bool
left *TSTNode
middle *TSTNode
right *TSTNode
}
func insert(node *TSTNode, word string, index int) *TSTNode {
ch := word[index]
// Create a new node if current position is empty
if node == nil {
node = &TSTNode{data: ch}
}
// Navigate based on character comparison
if ch < node.data {
node.left = insert(node.left, word, index)
} else if ch > node.data {
node.right = insert(node.right, word, index)
} else {
// Characters match, move to next character via middle child
if index < len(word)-1 {
node.middle = insert(node.middle, word, index+1)
} else {
node.isEnd = true // mark end of word
}
}
return node
}
func main() {
var root *TSTNode
// Insert words one by one
words := []string{"cat", "car", "care", "dog", "do", "bat"}
for _, w := range words {
root = insert(root, w, 0)
fmt.Printf("Inserted: %s\n", w)
}
}
Go 版本使用递归实现插入。insert 函数返回节点指针,当节点为 nil 时创建新节点。byte 类型用于字符比较,决定走向左、中、右哪个子树。
运行该程序将输出
Inserted: cat
Inserted: car
Inserted: care
Inserted: dog
Inserted: do
Inserted: bat
搜索操作
搜索(Search)操作与插入操作的比较逻辑完全相同:从根节点出发,将当前字符与节点字符进行比较,根据小于、等于、大于三种情况分别走向左、中、右子节点。如果处理完所有字符后,最终节点被标记为单词结尾,则搜索成功。
搜索的结果有三种情况:
- 找到完整单词 — 所有字符匹配,且最后一个节点的
isEndOfWord为true - 前缀存在但不是完整单词 — 所有字符匹配,但最后一个节点未被标记为单词结尾
- 未找到 — 在匹配过程中遇到空指针
C++ 搜索实现
// Search for a word in the TST
bool search(TSTNode* root, const string& word) {
TSTNode* curr = root;
int index = 0;
while (curr != nullptr && index < (int)word.length()) {
char ch = word[index];
if (ch < curr->data) {
curr = curr->left; // go left for smaller characters
} else if (ch > curr->data) {
curr = curr->right; // go right for larger characters
} else {
// Characters match
if (index == (int)word.length() - 1) {
// Reached the last character
return curr->isEndOfWord;
}
curr = curr->middle; // move to next character
index++;
}
}
return false;
}
// Example usage (assuming words already inserted):
// search(root, "cat") -> true
// search(root, "ca") -> false (prefix exists but not a complete word)
// search(root, "cow") -> false (not found)
C++ 使用迭代方式实现搜索,从根节点出发逐步比较字符。到达最后一个字符时检查 isEndOfWord 标记。
C 搜索实现
// Search for a word in the TST
bool search(TSTNode* root, const char* word) {
TSTNode* curr = root;
int index = 0;
int len = strlen(word);
while (curr != NULL && index < len) {
char ch = word[index];
if (ch < curr->data) {
curr = curr->left; // go left for smaller characters
} else if (ch > curr->data) {
curr = curr->right; // go right for larger characters
} else {
// Characters match
if (index == len - 1) {
// Reached the last character
return curr->isEndOfWord;
}
curr = curr->middle; // move to next character
index++;
}
}
return false;
}
C 语言版本使用 strlen 获取字符串长度,逻辑与 C++ 迭代版本一致。
Python 搜索实现
def search(node, word):
curr = node
index = 0
while curr is not None and index < len(word):
ch = word[index]
if ch < curr.data:
curr = curr.left # go left for smaller characters
elif ch > curr.data:
curr = curr.right # go right for larger characters
else:
# Characters match
if index == len(word) - 1:
# Reached the last character
return curr.is_end_of_word
curr = curr.middle # move to next character
index += 1
return False
Python 版本使用迭代循环,逻辑清晰直观。
运行搜索测试将输出:
Search "cat": Found
Search "ca": Not Found (prefix exists, but not a complete word)
Search "cow": Not Found
Go 搜索实现
func search(root *TSTNode, word string) bool {
curr := root
index := 0
for curr != nil && index < len(word) {
ch := word[index]
if ch < curr.data {
curr = curr.left // go left for smaller characters
} else if ch > curr.data {
curr = curr.right // go right for larger characters
} else {
// Characters match
if index == len(word)-1 {
// Reached the last character
return curr.isEnd
}
curr = curr.middle // move to next character
index++
}
}
return false
}
Go 版本使用迭代循环实现搜索,从根节点出发逐步比较字符。到达最后一个字符时检查 isEnd 标记,判断是否为完整单词。
前缀搜索
前缀搜索(Prefix Search)是三叉搜索树最重要的应用之一。它的目标是:给定一个前缀字符串,找到树中所有以该前缀开头的完整单词。
前缀搜索分两步完成:
- 定位前缀节点 — 沿着搜索路径找到前缀最后一个字符对应的节点
- 收集所有后缀 — 从该节点的中间子节点出发,沿着所有中间路径收集完整单词
C++ 前缀搜索实现
#include <iostream>
#include <string>
#include <vector>
using namespace std;
// Helper: collect all words from a subtree
void collect(TSTNode* node, string prefix, vector<string>& results) {
if (node == nullptr) return;
// Traverse left subtree first (smaller characters)
collect(node->left, prefix, results);
// If current node is end of word, add to results
if (node->isEndOfWord)
results.push_back(prefix + node->data);
// Traverse middle subtree (continuing the word)
collect(node->middle, prefix + node->data, results);
// Traverse right subtree (larger characters)
collect(node->right, prefix, results);
}
// Find the node corresponding to the last character of prefix
TSTNode* findPrefixNode(TSTNode* root, const string& prefix) {
TSTNode* curr = root;
int index = 0;
while (curr != nullptr && index < (int)prefix.length()) {
char ch = prefix[index];
if (ch < curr->data) {
curr = curr->left;
} else if (ch > curr->data) {
curr = curr->right;
} else {
if (index == (int)prefix.length() - 1)
return curr; // found the last char of prefix
curr = curr->middle;
index++;
}
}
return nullptr; // prefix not found
}
// Prefix search: find all words starting with given prefix
vector<string> prefixSearch(TSTNode* root, const string& prefix) {
vector<string> results;
// Step 1: find the node for the last character of prefix
TSTNode* prefixNode = findPrefixNode(root, prefix);
if (prefixNode == nullptr) return results;
// Step 2: if prefix itself is a word, add it
if (prefixNode->isEndOfWord)
results.push_back(prefix);
// Step 3: collect all words from the middle subtree
collect(prefixNode->middle, prefix, results);
return results;
}
C++ 版本将前缀搜索拆分为两个函数:findPrefixNode 定位前缀末尾节点,collect 从该节点的中间子树收集所有完整单词。collect 函数按中序遍历(左-中-右)的顺序收集结果。
C 前缀搜索实现
#include <stdio.h>
#include <stdlib.h>
#include <stdbool.h>
#include <string.h>
typedef struct TSTNode {
char data;
bool isEndOfWord;
struct TSTNode* left;
struct TSTNode* middle;
struct TSTNode* right;
} TSTNode;
// (createNode and insert functions as defined earlier)
// Helper: collect all words from a subtree into an array
void collect(TSTNode* node, char* prefix, char results[][256], int* count) {
if (node == NULL) return;
// Traverse left subtree first
collect(node->left, prefix, results, count);
// If current node is end of word, add to results
if (node->isEndOfWord) {
sprintf(results[*count], "%s%c", prefix, node->data);
(*count)++;
}
// Traverse middle subtree (continuing the word)
int len = strlen(prefix);
prefix[len] = node->data;
prefix[len + 1] = '\0';
collect(node->middle, prefix, results, count);
prefix[len] = '\0'; // backtrack
// Traverse right subtree
collect(node->right, prefix, results, count);
}
// Find the node corresponding to the last character of prefix
TSTNode* findPrefixNode(TSTNode* root, const char* prefix) {
TSTNode* curr = root;
int index = 0;
int len = strlen(prefix);
while (curr != NULL && index < len) {
char ch = prefix[index];
if (ch < curr->data) {
curr = curr->left;
} else if (ch > curr->data) {
curr = curr->right;
} else {
if (index == len - 1)
return curr;
curr = curr->middle;
index++;
}
}
return NULL;
}
// Prefix search: find all words starting with given prefix
int prefixSearch(TSTNode* root, const char* prefix, char results[][256]) {
int count = 0;
TSTNode* prefixNode = findPrefixNode(root, prefix);
if (prefixNode == NULL) return 0;
// If prefix itself is a word, add it
if (prefixNode->isEndOfWord) {
strcpy(results[count], prefix);
count++;
}
// Collect all words from the middle subtree
char buffer[256] = "";
strcpy(buffer, prefix);
collect(prefixNode->middle, buffer, results, &count);
return count;
}
C 语言版本使用固定大小的二维字符数组存储结果。collect 函数通过修改 prefix 缓冲区并回溯(Backtracking)来实现路径拼接。
Python 前缀搜索实现
def _collect(node, prefix, results):
"""Helper: collect all words from a subtree."""
if node is None:
return
# Traverse left subtree first (smaller characters)
_collect(node.left, prefix, results)
# If current node is end of word, add to results
if node.is_end_of_word:
results.append(prefix + node.data)
# Traverse middle subtree (continuing the word)
_collect(node.middle, prefix + node.data, results)
# Traverse right subtree (larger characters)
_collect(node.right, prefix, results)
def _find_prefix_node(node, prefix):
"""Find the node corresponding to the last character of prefix."""
curr = node
index = 0
while curr is not None and index < len(prefix):
ch = prefix[index]
if ch < curr.data:
curr = curr.left
elif ch > curr.data:
curr = curr.right
else:
if index == len(prefix) - 1:
return curr
curr = curr.middle
index += 1
return None
def prefix_search(node, prefix):
"""Prefix search: find all words starting with given prefix."""
results = []
prefix_node = _find_prefix_node(node, prefix)
if prefix_node is None:
return results
# If prefix itself is a word, add it
if prefix_node.is_end_of_word:
results.append(prefix)
# Collect all words from the middle subtree
_collect(prefix_node.middle, prefix, results)
return results
Python 版本逻辑清晰,使用列表动态存储结果,字符串拼接通过 prefix + node.data 完成,无需手动回溯。
运行前缀搜索测试将输出:
Prefix search "ca": car, care, cat
Prefix search "do": do, dog
Prefix search "b": bat
Prefix search "z": (no results)
Go 前缀搜索实现
func collect(node *TSTNode, prefix string, results *[]string) {
if node == nil {
return
}
// Traverse left subtree first (smaller characters)
collect(node.left, prefix, results)
// If current node is end of word, add to results
if node.isEnd {
*results = append(*results, prefix+string(node.data))
}
// Traverse middle subtree (continuing the word)
collect(node.middle, prefix+string(node.data), results)
// Traverse right subtree (larger characters)
collect(node.right, prefix, results)
}
func findPrefixNode(root *TSTNode, prefix string) *TSTNode {
curr := root
index := 0
for curr != nil && index < len(prefix) {
ch := prefix[index]
if ch < curr.data {
curr = curr.left
} else if ch > curr.data {
curr = curr.right
} else {
if index == len(prefix)-1 {
return curr
}
curr = curr.middle
index++
}
}
return nil
}
func prefixSearch(root *TSTNode, prefix string) []string {
var results []string
prefixNode := findPrefixNode(root, prefix)
if prefixNode == nil {
return results
}
// If prefix itself is a word, add it
if prefixNode.isEnd {
results = append(results, prefix)
}
// Collect all words from the middle subtree
collect(prefixNode.middle, prefix, &results)
return results
}
Go 版本将前缀搜索拆分为三个函数:findPrefixNode 定位前缀末尾节点,collect 从该节点的中间子树收集所有完整单词,prefixSearch 组合两者完成前缀搜索。注意 byte 转 string 使用 string(node.data)。
遍历所有单词
遍历(Traversal)操作的目标是收集三叉搜索树中存储的所有单词。它的核心思路是:对每个节点,先遍历左子树,然后检查当前节点(沿中间路径收集完整单词),最后遍历右子树。本质上是一种变形的中序遍历(In-order Traversal)。
C++ 遍历实现
// Traverse the TST and collect all words
void traverse(TSTNode* node, string buffer, vector<string>& results) {
if (node == nullptr) return;
// Visit left subtree (smaller characters)
traverse(node->left, buffer, results);
// If current node marks end of word, save the word
if (node->isEndOfWord)
results.push_back(buffer + node->data);
// Visit middle subtree, appending current character
traverse(node->middle, buffer + node->data, results);
// Visit right subtree (larger characters)
traverse(node->right, buffer, results);
}
// Wrapper function to get all words sorted
vector<string> getAllWords(TSTNode* root) {
vector<string> results;
traverse(root, "", results);
return results;
}
C++ 使用递归遍历,buffer 参数累加沿中间路径的字符。当遇到标记为单词结尾的节点时,将 buffer + node->data 加入结果。
C 遍历实现
// Traverse the TST and collect all words
void traverse(TSTNode* node, char* buffer, int depth,
char results[][256], int* count) {
if (node == NULL) return;
// Visit left subtree (smaller characters)
traverse(node->left, buffer, depth, results, count);
// If current node marks end of word, save the word
if (node->isEndOfWord) {
buffer[depth] = node->data;
buffer[depth + 1] = '\0';
strcpy(results[*count], buffer);
(*count)++;
}
// Visit middle subtree, appending current character
buffer[depth] = node->data;
traverse(node->middle, buffer, depth + 1, results, count);
// Visit right subtree (larger characters)
traverse(node->right, buffer, depth, results, count);
}
C 语言版本使用 depth 参数追踪当前深度,buffer 在递归过程中原地修改。保存单词时需要手动添加 '\0' 终止符。
Python 遍历实现
def traverse(node, buffer="", results=None):
"""Traverse the TST and collect all words."""
if results is None:
results = []
if node is None:
return results
# Visit left subtree (smaller characters)
traverse(node.left, buffer, results)
# If current node marks end of word, save the word
if node.is_end_of_word:
results.append(buffer + node.data)
# Visit middle subtree, appending current character
traverse(node.middle, buffer + node.data, results)
# Visit right subtree (larger characters)
traverse(node.right, buffer, results)
return results
Python 版本最为简洁,使用默认参数 buffer="" 累积字符,结果列表通过参数传递。
运行遍历操作将输出:
All words: bat, car, care, cat, do, dog
Go 遍历实现
func traverse(node *TSTNode, buffer string, results *[]string) {
if node == nil {
return
}
// Visit left subtree (smaller characters)
traverse(node.left, buffer, results)
// If current node marks end of word, save the word
if node.isEnd {
*results = append(*results, buffer+string(node.data))
}
// Visit middle subtree, appending current character
traverse(node.middle, buffer+string(node.data), results)
// Visit right subtree (larger characters)
traverse(node.right, buffer, results)
}
func getAllWords(root *TSTNode) []string {
var results []string
traverse(root, "", &results)
return results
}
Go 版本使用递归实现中序遍历,buffer 参数累加沿中间路径的字符。当遇到标记为单词结尾的节点时,将 buffer + string(node.data) 加入结果切片。getAllWords 为包装函数,返回排序后的所有单词。
完整实现
下面给出三叉搜索树的完整实现,包括插入、搜索、前缀搜索和遍历所有单词功能。
C++ 完整实现
#include <iostream>
#include <string>
#include <vector>
using namespace std;
struct TSTNode {
char data;
bool isEndOfWord;
TSTNode* left;
TSTNode* middle;
TSTNode* right;
TSTNode(char ch) : data(ch), isEndOfWord(false),
left(nullptr), middle(nullptr), right(nullptr) {}
};
// Insert a word into the TST
TSTNode* insert(TSTNode* node, const string& word, int index = 0) {
char ch = word[index];
if (node == nullptr)
node = new TSTNode(ch);
if (ch < node->data) {
node->left = insert(node->left, word, index);
} else if (ch > node->data) {
node->right = insert(node->right, word, index);
} else {
if (index < (int)word.length() - 1)
node->middle = insert(node->middle, word, index + 1);
else
node->isEndOfWord = true;
}
return node;
}
// Search for a word in the TST
bool search(TSTNode* root, const string& word) {
TSTNode* curr = root;
int index = 0;
while (curr != nullptr && index < (int)word.length()) {
char ch = word[index];
if (ch < curr->data) curr = curr->left;
else if (ch > curr->data) curr = curr->right;
else {
if (index == (int)word.length() - 1)
return curr->isEndOfWord;
curr = curr->middle;
index++;
}
}
return false;
}
// Collect all words from a subtree
void collect(TSTNode* node, const string& prefix, vector<string>& results) {
if (node == nullptr) return;
collect(node->left, prefix, results);
if (node->isEndOfWord)
results.push_back(prefix + node->data);
collect(node->middle, prefix + node->data, results);
collect(node->right, prefix, results);
}
// Prefix search: find all words starting with given prefix
vector<string> prefixSearch(TSTNode* root, const string& prefix) {
vector<string> results;
TSTNode* curr = root;
int index = 0;
// Navigate to the last character of the prefix
while (curr != nullptr && index < (int)prefix.length()) {
char ch = prefix[index];
if (ch < curr->data) curr = curr->left;
else if (ch > curr->data) curr = curr->right;
else {
if (index == (int)prefix.length() - 1) {
if (curr->isEndOfWord)
results.push_back(prefix);
collect(curr->middle, prefix, results);
return results;
}
curr = curr->middle;
index++;
}
}
return results;
}
// Traverse and collect all words
void traverse(TSTNode* node, const string& buffer, vector<string>& results) {
if (node == nullptr) return;
traverse(node->left, buffer, results);
if (node->isEndOfWord)
results.push_back(buffer + node->data);
traverse(node->middle, buffer + node->data, results);
traverse(node->right, buffer, results);
}
int main() {
TSTNode* root = nullptr;
// Insert words
string words[] = {"cat", "car", "care", "dog", "do", "bat"};
for (const string& w : words)
root = insert(root, w);
// Search test
cout << "=== Search ===" << endl;
string queries[] = {"cat", "ca", "cow", "dog", "do", "bat"};
for (const string& q : queries) {
cout << "search(\"" << q << "\"): "
<< (search(root, q) ? "Found" : "Not Found") << endl;
}
// Prefix search test
cout << "\n=== Prefix Search ===" << endl;
string prefixes[] = {"ca", "do", "b", "z"};
for (const string& p : prefixes) {
vector<string> matches = prefixSearch(root, p);
cout << "prefixSearch(\"" << p << "\"): ";
for (size_t i = 0; i < matches.size(); i++) {
if (i > 0) cout << ", ";
cout << matches[i];
}
cout << endl;
}
// Traverse all words
cout << "\n=== All Words ===" << endl;
vector<string> allWords;
traverse(root, "", allWords);
cout << "All words: ";
for (size_t i = 0; i < allWords.size(); i++) {
if (i > 0) cout << ", ";
cout << allWords[i];
}
cout << endl;
return 0;
}
C 完整实现
#include <stdio.h>
#include <stdlib.h>
#include <stdbool.h>
#include <string.h>
typedef struct TSTNode {
char data;
bool isEndOfWord;
struct TSTNode* left;
struct TSTNode* middle;
struct TSTNode* right;
} TSTNode;
// Create a new node
TSTNode* createNode(char ch) {
TSTNode* node = (TSTNode*)malloc(sizeof(TSTNode));
node->data = ch;
node->isEndOfWord = false;
node->left = NULL;
node->middle = NULL;
node->right = NULL;
return node;
}
// Insert a word into the TST
TSTNode* insert(TSTNode* node, const char* word, int index) {
char ch = word[index];
if (node == NULL)
node = createNode(ch);
if (ch < node->data) {
node->left = insert(node->left, word, index);
} else if (ch > node->data) {
node->right = insert(node->right, word, index);
} else {
if (index < (int)strlen(word) - 1)
node->middle = insert(node->middle, word, index + 1);
else
node->isEndOfWord = true;
}
return node;
}
// Search for a word in the TST
bool search(TSTNode* root, const char* word) {
TSTNode* curr = root;
int index = 0;
int len = strlen(word);
while (curr != NULL && index < len) {
char ch = word[index];
if (ch < curr->data) curr = curr->left;
else if (ch > curr->data) curr = curr->right;
else {
if (index == len - 1)
return curr->isEndOfWord;
curr = curr->middle;
index++;
}
}
return false;
}
// Collect all words from a subtree
void collect(TSTNode* node, char* prefix, char results[][256], int* count) {
if (node == NULL) return;
collect(node->left, prefix, results, count);
if (node->isEndOfWord) {
sprintf(results[*count], "%s%c", prefix, node->data);
(*count)++;
}
int len = strlen(prefix);
prefix[len] = node->data;
prefix[len + 1] = '\0';
collect(node->middle, prefix, results, count);
prefix[len] = '\0';
collect(node->right, prefix, results, count);
}
// Prefix search
int prefixSearch(TSTNode* root, const char* prefix, char results[][256]) {
int count = 0;
TSTNode* curr = root;
int index = 0;
int len = strlen(prefix);
while (curr != NULL && index < len) {
char ch = prefix[index];
if (ch < curr->data) curr = curr->left;
else if (ch > curr->data) curr = curr->right;
else {
if (index == len - 1) {
if (curr->isEndOfWord) {
strcpy(results[count], prefix);
count++;
}
char buffer[256] = "";
strcpy(buffer, prefix);
collect(curr->middle, buffer, results, &count);
return count;
}
curr = curr->middle;
index++;
}
}
return count;
}
// Traverse and collect all words
void traverse(TSTNode* node, char* buffer, int depth,
char results[][256], int* count) {
if (node == NULL) return;
traverse(node->left, buffer, depth, results, count);
if (node->isEndOfWord) {
buffer[depth] = node->data;
buffer[depth + 1] = '\0';
strcpy(results[*count], buffer);
(*count)++;
}
buffer[depth] = node->data;
traverse(node->middle, buffer, depth + 1, results, count);
traverse(node->right, buffer, depth, results, count);
}
int main() {
TSTNode* root = NULL;
// Insert words
const char* words[] = {"cat", "car", "care", "dog", "do", "bat"};
int n = sizeof(words) / sizeof(words[0]);
for (int i = 0; i < n; i++)
root = insert(root, words[i], 0);
// Search test
printf("=== Search ===\n");
const char* queries[] = {"cat", "ca", "cow", "dog", "do", "bat"};
int qn = sizeof(queries) / sizeof(queries[0]);
for (int i = 0; i < qn; i++)
printf("search(\"%s\"): %s\n", queries[i],
search(root, queries[i]) ? "Found" : "Not Found");
// Prefix search test
printf("\n=== Prefix Search ===\n");
const char* prefixes[] = {"ca", "do", "b", "z"};
int pn = sizeof(prefixes) / sizeof(prefixes[0]);
for (int i = 0; i < pn; i++) {
char results[100][256];
int cnt = prefixSearch(root, prefixes[i], results);
printf("prefixSearch(\"%s\"): ", prefixes[i]);
for (int j = 0; j < cnt; j++) {
if (j > 0) printf(", ");
printf("%s", results[j]);
}
printf("\n");
}
// Traverse all words
printf("\n=== All Words ===\n");
char allResults[100][256];
int totalCount = 0;
char buffer[256] = "";
traverse(root, buffer, 0, allResults, &totalCount);
printf("All words: ");
for (int i = 0; i < totalCount; i++) {
if (i > 0) printf(", ");
printf("%s", allResults[i]);
}
printf("\n");
return 0;
}
Python 完整实现
class TSTNode:
def __init__(self, ch):
self.data = ch
self.is_end_of_word = False
self.left = None
self.middle = None
self.right = None
def insert(node, word, index=0):
"""Insert a word into the TST."""
ch = word[index]
if node is None:
node = TSTNode(ch)
if ch < node.data:
node.left = insert(node.left, word, index)
elif ch > node.data:
node.right = insert(node.right, word, index)
else:
if index < len(word) - 1:
node.middle = insert(node.middle, word, index + 1)
else:
node.is_end_of_word = True
return node
def search(node, word):
"""Search for a word in the TST."""
curr = node
index = 0
while curr is not None and index < len(word):
ch = word[index]
if ch < curr.data:
curr = curr.left
elif ch > curr.data:
curr = curr.right
else:
if index == len(word) - 1:
return curr.is_end_of_word
curr = curr.middle
index += 1
return False
def _collect(node, prefix, results):
"""Collect all words from a subtree."""
if node is None:
return
_collect(node.left, prefix, results)
if node.is_end_of_word:
results.append(prefix + node.data)
_collect(node.middle, prefix + node.data, results)
_collect(node.right, prefix, results)
def prefix_search(node, prefix):
"""Prefix search: find all words starting with given prefix."""
results = []
curr = node
index = 0
# Navigate to the last character of the prefix
while curr is not None and index < len(prefix):
ch = prefix[index]
if ch < curr.data:
curr = curr.left
elif ch > curr.data:
curr = curr.right
else:
if index == len(prefix) - 1:
if curr.is_end_of_word:
results.append(prefix)
_collect(curr.middle, prefix, results)
return results
curr = curr.middle
index += 1
return results
def traverse(node, buffer="", results=None):
"""Traverse and collect all words."""
if results is None:
results = []
if node is None:
return results
traverse(node.left, buffer, results)
if node.is_end_of_word:
results.append(buffer + node.data)
traverse(node.middle, buffer + node.data, results)
traverse(node.right, buffer, results)
return results
if __name__ == "__main__":
root = None
# Insert words
words = ["cat", "car", "care", "dog", "do", "bat"]
for w in words:
root = insert(root, w)
# Search test
print("=== Search ===")
for q in ["cat", "ca", "cow", "dog", "do", "bat"]:
status = "Found" if search(root, q) else "Not Found"
print(f'search("{q}"): {status}')
# Prefix search test
print("\n=== Prefix Search ===")
for p in ["ca", "do", "b", "z"]:
matches = prefix_search(root, p)
print(f'prefixSearch("{p}"): {", ".join(matches) if matches else "(none)"}')
# Traverse all words
print("\n=== All Words ===")
all_words = traverse(root)
print(f"All words: {', '.join(all_words)}")
运行该程序将输出:
=== Search ===
search("cat"): Found
search("ca"): Not Found
search("cow"): Not Found
search("dog"): Found
search("do"): Found
search("bat"): Found
=== Prefix Search ===
prefixSearch("ca"): car, care, cat
prefixSearch("do"): do, dog
prefixSearch("b"): bat
prefixSearch("z"): (none)
=== All Words ===
All words: bat, car, care, cat, do, dog
Go 完整实现
package main
import "fmt"
type TSTNode struct {
data byte
isEnd bool
left *TSTNode
middle *TSTNode
right *TSTNode
}
// Insert a word into the TST
func insert(node *TSTNode, word string, index int) *TSTNode {
ch := word[index]
if node == nil {
node = &TSTNode{data: ch}
}
if ch < node.data {
node.left = insert(node.left, word, index)
} else if ch > node.data {
node.right = insert(node.right, word, index)
} else {
if index < len(word)-1 {
node.middle = insert(node.middle, word, index+1)
} else {
node.isEnd = true
}
}
return node
}
// Search for a word in the TST
func search(root *TSTNode, word string) bool {
curr := root
index := 0
for curr != nil && index < len(word) {
ch := word[index]
if ch < curr.data {
curr = curr.left
} else if ch > curr.data {
curr = curr.right
} else {
if index == len(word)-1 {
return curr.isEnd
}
curr = curr.middle
index++
}
}
return false
}
// Collect all words from a subtree
func collect(node *TSTNode, prefix string, results *[]string) {
if node == nil {
return
}
collect(node.left, prefix, results)
if node.isEnd {
*results = append(*results, prefix+string(node.data))
}
collect(node.middle, prefix+string(node.data), results)
collect(node.right, prefix, results)
}
// Prefix search: find all words starting with given prefix
func prefixSearch(root *TSTNode, prefix string) []string {
var results []string
curr := root
index := 0
// Navigate to the last character of the prefix
for curr != nil && index < len(prefix) {
ch := prefix[index]
if ch < curr.data {
curr = curr.left
} else if ch > curr.data {
curr = curr.right
} else {
if index == len(prefix)-1 {
if curr.isEnd {
results = append(results, prefix)
}
collect(curr.middle, prefix, &results)
return results
}
curr = curr.middle
index++
}
}
return results
}
// Traverse and collect all words
func traverse(node *TSTNode, buffer string, results *[]string) {
if node == nil {
return
}
traverse(node.left, buffer, results)
if node.isEnd {
*results = append(*results, buffer+string(node.data))
}
traverse(node.middle, buffer+string(node.data), results)
traverse(node.right, buffer, results)
}
func main() {
var root *TSTNode
// Insert words
words := []string{"cat", "car", "care", "dog", "do", "bat"}
for _, w := range words {
root = insert(root, w, 0)
}
// Search test
fmt.Println("=== Search ===")
queries := []string{"cat", "ca", "cow", "dog", "do", "bat"}
for _, q := range queries {
status := "Not Found"
if search(root, q) {
status = "Found"
}
fmt.Printf("search(\"%s\"): %s\n", q, status)
}
// Prefix search test
fmt.Println("\n=== Prefix Search ===")
prefixes := []string{"ca", "do", "b", "z"}
for _, p := range prefixes {
matches := prefixSearch(root, p)
if len(matches) == 0 {
fmt.Printf("prefixSearch(\"%s\"): (none)\n", p)
} else {
fmt.Printf("prefixSearch(\"%s\"): %s\n", p, joinStrings(matches))
}
}
// Traverse all words
fmt.Println("\n=== All Words ===")
var allWords []string
traverse(root, "", &allWords)
fmt.Printf("All words: %s\n", joinStrings(allWords))
}
func joinStrings(ss []string) string {
result := ""
for i, s := range ss {
if i > 0 {
result += ", "
}
result += s
}
return result
}
Go 完整实现包含插入、搜索、前缀搜索和遍历所有单词功能。joinStrings 辅助函数用于将字符串切片用逗号连接,替代其他语言中的 join 方法。所有核心逻辑与 C++/Python 版本一致,使用指针操作树结构。
运行该程序将输出
=== Search ===
search("cat"): Found
search("ca"): Not Found
search("cow"): Not Found
search("dog"): Found
search("do"): Found
search("bat"): Found
=== Prefix Search ===
prefixSearch("ca"): car, care, cat
prefixSearch("do"): do, dog
prefixSearch("b"): bat
prefixSearch("z"): (none)
=== All Words ===
All words: bat, car, care, cat, do, dog
三叉搜索树的性质
空间效率
三叉搜索树最显著的优势在于空间效率(Space Efficiency)。标准 Trie 的每个节点需要维护 26 个子节点指针(对于小写字母),即使大部分指针为空也要分配空间。而 TST 每个节点只有 3 个指针(left, middle, right),存储开销大幅降低:
| 数据结构 | 每节点指针数 | 空闲指针 |
|---|---|---|
| 标准 Trie | 26 | 大量空闲 |
| 三叉搜索树 | 3 | 极少空闲 |
时间复杂度
| 操作 | 平均时间复杂度 | 最坏时间复杂度 |
|---|---|---|
| 插入(Insert) | O(L) | O(L * log N) |
| 搜索(Search) | O(L) | O(L * log N) |
| 前缀搜索(Prefix Search) | O(L + M) | O(L * log N + M) |
其中 L 为字符串长度,N 为树中存储的字符串数量,M 为匹配结果的总字符数。
当树较为平衡时,查找操作接近 O(L),其中 L 是待查找字符串的长度。最坏情况下(树退化为链表),时间复杂度退化为 O(L * log N)。
Trie vs TST vs BST 对比
| 特性 | Trie | 三叉搜索树 (TST) | 二叉搜索树 (BST) |
|---|---|---|---|
| 每节点指针数 | 26(固定) | 3 | 2(或 3 含 parent) |
| 查找复杂度 | O(L) | O(L) ~ O(L log N) | O(L log N) |
| 前缀搜索 | 天然支持 | 天然支持 | 不直接支持 |
| 空间开销 | 高 | 中等 | 低 |
| 公共前缀共享 | 完全共享 | 完全共享 | 不共享 |
| 适合字符集 | 固定小字符集 | 任意字符集 | 任意数据类型 |
三叉搜索树在 Trie 和 BST 之间取得了良好的平衡:既保留了 Trie 前缀共享和前缀搜索的能力,又具有接近 BST 的空间效率。尤其当字符集较大(如 Unicode)或字符串稀疏时,TST 的优势更加明显。
实际应用
三叉搜索树在工程实践中有广泛的应用场景:
- 拼写检查(Spell Checking) — 快速查找单词是否在字典中,也可用于模糊匹配
- 自动补全(Autocomplete) — 用户输入前缀后,通过前缀搜索实时返回候选词列表
- IP 路由表(IP Routing Table) — 利用前缀搜索进行最长前缀匹配
- 基因组序列分析(Genomic Sequence Analysis) — 在 DNA 序列中搜索模式串
- 全文搜索(Full-text Search) — 作为倒排索引的辅助数据结构

浙公网安备 33010602011771号