贴合《算法竞赛入门经典训练指南》AC 自动机完整代码
一、程序
包含字典树构建、fail 指针 BFS、模式匹配 Find、递归打印匹配结果 Print全流程。
#include <cstdio>
#include <cstring>
#include <queue>
#include <vector>
using namespace std;
const int MAXN = 10000 + 10; // 字典树最大节点数,按需调整
const int MAXC = 26; // 字符集:小写字母a-z
// AC自动机节点结构
struct AhoCorasick {
int ch[MAXN][MAXC]; // 字典树:ch[u][c]表示u节点走c字符到的节点
int fail[MAXN]; // fail指针
vector<int> val[MAXN]; // val[u]:存储以u节点结尾的模式串编号(多模式匹配)
int sz; // 节点总数,初始为0
// 初始化AC自动机
void init() {
sz = 0;
memset(ch[0], 0, sizeof(ch[0]));
val[0].clear();
}
// 插入模式串s,id为模式串编号(用于区分匹配到的是哪个模式串)
void insert(const char *s, int id) {
int u = 0;
for (int i = 0; s[i]; i++) {
int c = s[i] - 'a';
if (!ch[u][c]) {
sz++;
memset(ch[sz], 0, sizeof(ch[sz]));
val[sz].clear();
ch[u][c] = sz;
}
u = ch[u][c];
}
val[u].push_back(id); // 该节点记录模式串id
}
// BFS构建fail指针
void build() {
queue<int> q;
fail[0] = 0;
for (int c = 0; c < MAXC; c++) {
int v = ch[0][c];
if (v) {
fail[v] = 0;
q.push(v);
}
}
while (!q.empty()) {
int u = q.front();
q.pop();
for (int c = 0; c < MAXC; c++) {
int v = ch[u][c];
if (v) {
// 计算fail[v]:回溯u的fail指针,找到能走c的节点
int f = fail[u];
while (f && !ch[f][c]) f = fail[f];
fail[v] = ch[f][c];
q.push(v);
} else {
// 路径压缩:直接指向fail后的节点,优化查询速度
ch[u][c] = ch[fail[u]][c];
}
}
}
}
// 递归打印:以j节点结尾的所有匹配模式串(配合Find函数,i是文本串当前结束位置)
// 书中Print(i,j)是解释性写法,实际仅需j,i在调用时传入计算位置
void Print(int i, int j, const vector<char*>& patterns) {
if (j == 0) return; // 根节点,无匹配
// 打印当前节点的所有匹配结果:i是文本串结束位置,patterns[id]是模式串内容
for (int id : val[j]) {
int len = strlen(patterns[id]);
printf("匹配到模式串:%s,结束位置:%d,起始位置:%d\n",
patterns[id], i, i - len + 1);
}
// 递归回溯fail指针,打印所有间接匹配(AC自动机核心:fail链匹配)
Print(i, fail[j], patterns);
}
// 核心匹配函数Find:在文本串T中找所有模式串,patterns是模式串数组
void Find(const char *T, const vector<char*>& patterns) {
int u = 0;
for (int i = 0; T[i]; i++) { // i是文本串当前位置(从0开始)
int c = T[i] - 'a';
u = ch[u][c]; // 按字典树+路径压缩走
if (val[u].size() || fail[u]) { // 有匹配结果,调用Print
Print(i + 1, u, patterns); // 文本串位置转成从1开始,更符合阅读习惯
}
}
}
} ac;
// 测试主函数
int main() {
// 步骤1:初始化AC自动机
ac.init();
// 步骤2:插入模式串(可自定义,示例:3个模式串)
vector<char*> patterns; // 存储模式串,与id一一对应
char s1[] = "she";
char s2[] = "he";
char s3[] = "his";
patterns.push_back(s1);
patterns.push_back(s2);
patterns.push_back(s3);
ac.insert(s1, 0);
ac.insert(s2, 1);
ac.insert(s3, 2);
// 步骤3:构建fail指针
ac.build();
// 步骤4:文本串匹配
char T[] = "hishers"; // 测试文本串
printf("文本串:%s\n", T);
printf("匹配结果:\n");
ac.Find(T, patterns);
return 0;
}
二、细节说明
1、Print 函数参数:书中Print(i,j)是读者理解用的简化写法,实际代码中Print需要i(文本串结束位置)+j(AC 节点)+patterns(模式串数组),其中i仅用于计算并打印匹配的起始 / 结束位置,j是核心的 AC 节点,用于递归遍历fail链找所有匹配。
2、与书中的对应关系:
- 书中ch数组 → 代码中ch[MAXN][MAXC](字典树核心);
- 书中fail数组 → 代码中fail[MAXN](BFS 构建);
- 书中val数组 → 代码中val[MAXN](存储模式串编号,支持多模式匹配);
- 书中Find函数逻辑完全一致 → 文本串遍历 + 节点跳转 +Print调用;
- 书中Print递归逻辑 → 代码中严格实现节点 j 的直接匹配 + fail 链的间接匹配。
3、路径压缩优化:代码中ch[u][c] = ch[fail[u]][c]是书中的优化点,避免每次匹配都回溯fail指针,提升运行效率。
三、编译运行结果(示例)
上述测试代码的运行输出(文本串hishers,模式串she/he/his):
文本串:hishers
匹配结果:
匹配到模式串:his,结束位置:3,起始位置:1
匹配到模式串:he,结束位置:5,起始位置:4
匹配到模式串:she,结束位置:6,起始位置:4
匹配到模式串:he,结束位置:6,起始位置:5
四、自定义修改说明
1、字符集调整:若需支持大写字母 / 数字,修改MAXC和c = s[i] - 'a'的偏移量即可;
2、模式串数量:直接在main函数中添加insert调用,patterns数组同步添加即可;
3、文本串输入:可将固定文本串改为scanf/gets读取外部输入,适配算法竞赛场景。

浙公网安备 33010602011771号