贴合《算法竞赛入门经典训练指南》AC 自动机完整代码及参考
一、程序
包含字典树构建、fail 指针 BFS、模式匹配 Find、递归打印匹配结果 Print全流程。
#include <cstdio>
#include <cstring>
#include <queue>
#include <vector>
using namespace std;
const int MAXN = 10000 + 10; // 字典树最大节点数,按需调整
const int MAXC = 26; // 字符集:小写字母a-z
// AC自动机节点结构
struct AhoCorasick {
int ch[MAXN][MAXC]; // 字典树:ch[u][c]表示u节点走c字符到的节点
int fail[MAXN]; // fail指针
vector<int> val[MAXN]; // val[u]:存储以u节点结尾的模式串编号(多模式匹配)
int sz; // 节点总数,初始为0
// 初始化AC自动机
void init() {
sz = 0;
memset(ch[0], 0, sizeof(ch[0]));
val[0].clear();
}
// 插入模式串s,id为模式串编号(用于区分匹配到的是哪个模式串)
void insert(const char *s, int id) {
int u = 0;
for (int i = 0; s[i]; i++) {
int c = s[i] - 'a';
if (!ch[u][c]) {
sz++;
memset(ch[sz], 0, sizeof(ch[sz]));
val[sz].clear();
ch[u][c] = sz;
}
u = ch[u][c];
}
val[u].push_back(id); // 该节点记录模式串id
}
// BFS构建fail指针
void build() {
queue<int> q;
fail[0] = 0;
for (int c = 0; c < MAXC; c++) {
int v = ch[0][c];
if (v) {
fail[v] = 0;
q.push(v);
}
}
while (!q.empty()) {
int u = q.front();
q.pop();
for (int c = 0; c < MAXC; c++) {
int v = ch[u][c];
if (v) {
// 计算fail[v]:回溯u的fail指针,找到能走c的节点
int f = fail[u];
while (f && !ch[f][c]) f = fail[f];
fail[v] = ch[f][c];
q.push(v);
} else {
// 路径压缩:直接指向fail后的节点,优化查询速度
ch[u][c] = ch[fail[u]][c];
}
}
}
}
// 递归打印:以j节点结尾的所有匹配模式串(配合Find函数,i是文本串当前结束位置)
// 书中Print(i,j)是解释性写法,实际仅需j,i在调用时传入计算位置
void Print(int i, int j, const vector<char*>& patterns) {
if (j == 0) return; // 根节点,无匹配
// 打印当前节点的所有匹配结果:i是文本串结束位置,patterns[id]是模式串内容
for (int id : val[j]) {
int len = strlen(patterns[id]);
printf("匹配到模式串:%s,结束位置:%d,起始位置:%d\n",
patterns[id], i, i - len + 1);
}
// 递归回溯fail指针,打印所有间接匹配(AC自动机核心:fail链匹配)
Print(i, fail[j], patterns);
}
// 核心匹配函数Find:在文本串T中找所有模式串,patterns是模式串数组
void Find(const char *T, const vector<char*>& patterns) {
int u = 0;
for (int i = 0; T[i]; i++) { // i是文本串当前位置(从0开始)
int c = T[i] - 'a';
u = ch[u][c]; // 按字典树+路径压缩走
if (val[u].size() || fail[u]) { // 有匹配结果,调用Print
Print(i + 1, u, patterns); // 文本串位置转成从1开始,更符合阅读习惯
}
}
}
} ac;
// 测试主函数
int main() {
// 步骤1:初始化AC自动机
ac.init();
// 步骤2:插入模式串(可自定义,示例:3个模式串)
vector<char*> patterns; // 存储模式串,与id一一对应
char s1[] = "she";
char s2[] = "he";
char s3[] = "his";
patterns.push_back(s1);
patterns.push_back(s2);
patterns.push_back(s3);
ac.insert(s1, 0);
ac.insert(s2, 1);
ac.insert(s3, 2);
// 步骤3:构建fail指针
ac.build();
// 步骤4:文本串匹配
char T[] = "hishers"; // 测试文本串
printf("文本串:%s\n", T);
printf("匹配结果:\n");
ac.Find(T, patterns);
return 0;
}
二、细节说明
1、Print 函数参数:书中Print(i,j)是读者理解用的简化写法,实际代码中Print需要i(文本串结束位置)+j(AC 节点)+patterns(模式串数组),其中i仅用于计算并打印匹配的起始 / 结束位置,j是核心的 AC 节点,用于递归遍历fail链找所有匹配。
2、与书中的对应关系:
- 书中ch数组 → 代码中ch[MAXN][MAXC](字典树核心);
- 书中fail数组 → 代码中fail[MAXN](BFS 构建);
- 书中val数组 → 代码中val[MAXN](存储模式串编号,支持多模式匹配);
- 书中Find函数逻辑完全一致 → 文本串遍历 + 节点跳转 +Print调用;
- 书中Print递归逻辑 → 代码中严格实现节点 j 的直接匹配 + fail 链的间接匹配。
3、路径压缩优化:代码中ch[u][c] = ch[fail[u]][c]是书中的优化点,避免每次匹配都回溯fail指针,提升运行效率。
三、编译运行结果(示例)
上述测试代码的运行输出(文本串hishers,模式串she/he/his):
文本串:hishers
匹配结果:
匹配到模式串:his,结束位置:3,起始位置:1
匹配到模式串:he,结束位置:5,起始位置:4
匹配到模式串:she,结束位置:6,起始位置:4
匹配到模式串:he,结束位置:6,起始位置:5
四、自定义修改说明
1、字符集调整:若需支持大写字母 / 数字,修改MAXC和c = s[i] - 'a'的偏移量即可;
2、模式串数量:直接在main函数中添加insert调用,patterns数组同步添加即可;
3、文本串输入:可将固定文本串改为scanf/gets读取外部输入,适配算法竞赛场景。
五、扩展
核心步骤 1:AC 自动机 fail 指针的 BFS 构建(逐行注释 + 解析)
先给出完整的带逐行注释的代码,再补充关键逻辑的设计思路,代码基于 AC 自动机的经典数组实现(适合新手理解,避免指针的复杂度)。
前置准备:AC 自动机的节点结构体(先定义,否则 BFS 代码无意义)
#include <cstdio>
#include <cstring>
#include <queue>
#include <vector>
using namespace std;
const int MAXN = 10000 + 10; // 字典树最大节点数,按需调整
const int MAXC = 26; // 字符集:小写字母a-z
// AC自动机节点结构
struct AhoCorasick {
int ch[MAXN][MAXC]; // 字典树:ch[u][c]表示u节点走c字符到的节点
int fail[MAXN]; // fail指针
vector<int> val[MAXN]; // val[u]:存储以u节点结尾的模式串编号(多模式匹配)
int sz; // 节点总数,初始为0
// 初始化AC自动机
void init() {
sz = 0;
memset(ch[0], 0, sizeof(ch[0]));
val[0].clear();
}
// 插入模式串s,id为模式串编号(用于区分匹配到的是哪个模式串)
void insert(const char *s, int id) {
int u = 0;
for (int i = 0; s[i]; i++) {
int c = s[i] - 'a';
if (!ch[u][c]) {
sz++;
memset(ch[sz], 0, sizeof(ch[sz]));
val[sz].clear();
ch[u][c] = sz;
}
u = ch[u][c];
}
val[u].push_back(id); // 该节点记录模式串id
}
// BFS构建fail指针
void build() {
queue<int> q;
fail[0] = 0;
for (int c = 0; c < MAXC; c++) {
int v = ch[0][c];
if (v) {
fail[v] = 0;
q.push(v);
}
}
while (!q.empty()) {
int u = q.front();
q.pop();
for (int c = 0; c < MAXC; c++) {
int v = ch[u][c];
if (v) {
// 计算fail[v]:回溯u的fail指针,找到能走c的节点
int f = fail[u];
while (f && !ch[f][c]) f = fail[f];
fail[v] = ch[f][c];
q.push(v);
} else {
// 路径压缩:直接指向fail后的节点,优化查询速度
ch[u][c] = ch[fail[u]][c];
}
}
}
}
// 递归打印:以j节点结尾的所有匹配模式串(配合Find函数,i是文本串当前结束位置)
// 书中Print(i,j)是解释性写法,实际仅需j,i在调用时传入计算位置
void Print(int i, int j, const vector<char*>& patterns) {
if (j == 0) return; // 根节点,无匹配
// 打印当前节点的所有匹配结果:i是文本串结束位置,patterns[id]是模式串内容
for (int id : val[j]) {
int len = strlen(patterns[id]);
printf("匹配到模式串:%s,结束位置:%d,起始位置:%d\n",
patterns[id], i, i - len + 1);
}
// 递归回溯fail指针,打印所有间接匹配(AC自动机核心:fail链匹配)
Print(i, fail[j], patterns);
}
// 核心匹配函数Find:在文本串T中找所有模式串,patterns是模式串数组
void Find(const char *T, const vector<char*>& patterns) {
int u = 0;
for (int i = 0; T[i]; i++) { // i是文本串当前位置(从0开始)
int c = T[i] - 'a';
u = ch[u][c]; // 按字典树+路径压缩走
if (val[u].size() || fail[u]) { // 有匹配结果,调用Print
Print(i + 1, u, patterns); // 文本串位置转成从1开始,更符合阅读习惯
}
}
}
} ac;
// 测试主函数
int main() {
// 步骤1:初始化AC自动机
ac.init();
// 步骤2:插入模式串(可自定义,示例:3个模式串)
vector<char*> patterns; // 存储模式串,与id一一对应
char s1[] = "she";
char s2[] = "he";
char s3[] = "his";
patterns.push_back(s1);
patterns.push_back(s2);
patterns.push_back(s3);
ac.insert(s1, 0);
ac.insert(s2, 1);
ac.insert(s3, 2);
// 步骤3:构建fail指针
ac.build();
// 步骤4:文本串匹配
char T[] = "hishers"; // 测试文本串
printf("文本串:%s\n", T);
printf("匹配结果:\n");
ac.Find(T, patterns);
return 0;
}
编译运行结果(示例)
上述测试代码的运行输出(文本串hishers,模式串she/he/his):
plaintext
文本串:hishers
匹配结果:
匹配到模式串:his,结束位置:3,起始位置:1
匹配到模式串:he,结束位置:5,起始位置:4
匹配到模式串:she,结束位置:6,起始位置:4
匹配到模式串:he,结束位置:6,起始位置:5
前置准备:AC 自动机的节点结构体(先定义,否则 BFS 代码无意义)
// AC自动机节点结构体,数组实现(size根据模式串数量/长度定,一般开1e5+10足够)
const int MAXN = 100010;
struct Node {
int next[26]; // 每个节点的26个小写字母子节点,存子节点的下标
int fail; // 该节点的失败指针,存指向的节点下标
int end; // end>0表示该节点是某个模式串的结尾,end的值可记录模式串编号/长度
Node() { // 构造函数,初始化节点
memset(next, -1, sizeof(next)); // 子节点初始化为-1(表示无此子节点)
fail = 0; // 失败指针默认指向根节点(根节点下标为0)
end = 0; // 初始不是任何模式串的结尾
}
} tree[MAXN];
int cnt = 0; // 节点计数器,根节点为0,新节点从1开始创建
queue<int> q; // BFS队列,存节点下标,用于构建fail指针
核心代码:fail 指针的 BFS 构建(逐行注释 + 解析)
// 构建AC自动机的fail指针,根节点下标为0
void build_fail() {
// 第一步:初始化根节点的直接子节点(第一层节点)的fail指针,并加入队列
// 根节点的next[0~25]遍历,i对应a~z的偏移(0=a,1=b...25=z)
for (int i = 0; i < 26; i++) {
// 如果根节点有i这个子节点(下标不为-1)
if (tree[0].next[i] != -1) {
tree[tree[0].next[i]].fail = 0; // 第一层节点的fail指针直接指向根节点
q.push(tree[0].next[i]); // 该节点入队,作为BFS的起始层
} else {
tree[0].next[i] = 0; // 根节点无此子节点,直接指向自己(优化后续匹配)
}
}
// 第二步:BFS遍历所有节点,逐层设置fail指针(核心循环)
while (!q.empty()) {
int u = q.front(); // 取出队首节点u(当前处理的父节点)
q.pop(); // 队首出队
// 遍历u的所有26个子节点,i对应a~z的偏移
for (int i = 0; i < 26; i++) {
// 情况1:u有i这个子节点(子节点下标为v)
if (tree[u].next[i] != -1) {
int v = tree[u].next[i]; // 记录u的i子节点的下标v
// 关键:v的fail指针 = u的fail指针 所指向节点 的 i子节点
// 理解:u的失败指针是f,那么v的失败指针就是f沿着i走的节点(最接近的后缀匹配)
tree[v].fail = tree[tree[u].fail].next[i];
q.push(v); // 子节点v入队,后续处理它的子节点
}
// 情况2:u无i这个子节点(优化:路径压缩,直接指向fail指针的i子节点)
else {
// 核心优化:u没有i子节点,就把u的next[i]指向 u->fail->next[i]
// 匹配时无需回溯fail指针,直接通过next[i]跳转,提升匹配效率
tree[u].next[i] = tree[tree[u].fail].next[i];
}
}
}
}
核心步骤 2:Print 递归(匹配成功后递归输出所有匹配的模式串,逐行注释)
AC 自动机匹配到某个节点时,该节点可能是多个模式串的结尾(比如模式串abc和bc,匹配到abc的 c 节点时,也能匹配到bc的 c 节点),因此需要从当前节点递归向上遍历 fail 指针,输出所有匹配的模式串,这就是 Print 递归的核心作用。
前置:假设模式串存入数组string pattern[1005](pattern [1] 是第一个模式串,pattern [2] 第二个...),tree[node].end记录模式串的编号(如tree[v].end=3表示该节点是第 3 个模式串的结尾)。
核心代码:Print 递归函数(逐行注释 + 解析)
// 递归打印:从node节点开始,向上遍历fail指针,输出所有匹配的模式串
// pos:当前匹配到的文本串的位置(用于输出匹配的结束位置,可选)
void Print(int node, int pos) {
// 递归终止条件:node为根节点(0),无更多后缀需要匹配
if (node == 0) {
return;
}
// 如果当前节点是某个模式串的结尾(end>0,end是模式串编号)
if (tree[node].end != 0) {
// 输出匹配信息:模式串编号、模式串内容、匹配的结束位置(pos)、匹配的起始位置(pos - 模式串长度 + 1)
// pattern[tree[node].end].size()是当前匹配的模式串的长度
cout << "匹配到模式串" << tree[node].end << ":" << pattern[tree[node].end]
<< ",结束位置:" << pos
<< ",起始位置:" << pos - pattern[tree[node].end].size() + 1 << endl;
}
// 递归核心:继续遍历当前节点的fail指针指向的节点,查找更长的后缀(或更短的)匹配
// 因为fail指针指向的是当前节点的最长有效后缀,其可能也是某个模式串的结尾
Print(tree[node].fail, pos);
}
递归调用场景(匹配时的调用)
在 AC 自动机的匹配函数中,遍历文本串的每个字符,跳转到对应节点now后,直接调用Print(now, i)(i是文本串当前字符的下标,从 1 开始更符合习惯),示例如下:
// 匹配函数:text是待匹配的文本串
void match(string text) {
int now = 0; // 从根节点开始匹配
for (int i = 0; i < text.size(); i++) {
// 跳转到对应节点:text[i]-'a'是字符偏移,now始终是当前匹配的节点
now = tree[now].next[text[i] - 'a'];
// 调用Print递归,输出所有匹配的模式串,i+1是文本串的实际位置(从1开始)
Print(now, i + 1);
}
}
循环版 Print(替代递归,避免栈溢出)
如果模式串很长,fail 指针链会很长,递归可能导致栈溢出,补充循环版的 Print(逐行注释,功能和递归完全一致),供你选择:
// 循环版Print:避免递归栈溢出,功能和递归版一致
void Print_Cycle(int node, int pos) {
// 循环遍历fail指针链,直到根节点(0)
while (node != 0) {
// 如果是模式串结尾,输出匹配信息
if (tree[node].end != 0) {
cout << "匹配到模式串" << tree[node].end << ":" << pattern[tree[node].end]
<< ",结束位置:" << pos
<< ",起始位置:" << pos - pattern[tree[node].end].size() + 1 << endl;
}
// 跳转到fail指针指向的节点,继续遍历
node = tree[node].fail;
}
}
最简单的字符串匹配模拟书上那个 “模板覆盖” 逻辑
#include <iostream>
#include <map>
#include <string>
using namespace std;
int main() {
map<string, int> pat; // 模板串 → 编号
// 插入第一个模板:abc → 编号 1
pat["abc"] = 1;
cout << "插入模板1后:abc -> " << pat["abc"] << endl;
// 插入第二个**相同**模板:abc → 编号 2
pat["abc"] = 2;
cout << "插入模板2后:abc -> " << pat["abc"] << endl;
// 匹配时,只会看到最后一次的值
cout << "最终匹配到的模板编号:" << pat["abc"] << endl;
return 0;
}
// 插入模板 s,编号 id
void insert(string &s, int id) {
int u = 0;
for (char ch : s) {
int c = ch - 'a';
if (!tr[u][c]) tr[u][c] = ++tot;
u = tr[u][c];
}
// 关键:同一个节点,重复赋值!
end[u] = id;
}
#include <iostream>
#include <cstring>
#include <queue>
#include <vector>
#include <algorithm>
using namespace std;
const int MAXN = 10005; // 模板总长度
const int MAXC = 26;
int tr[MAXN][MAXC];
int fail[MAXN];
int end_id[MAXN]; // 每个节点对应的模板编号
int cnt[MAXN]; // 统计每个模板出现次数
int tot;
// 初始化
void init() {
tot = 0;
memset(tr, 0, sizeof(tr));
memset(fail, 0, sizeof(fail));
memset(end_id, -1, sizeof(end_id));
memset(cnt, 0, sizeof(cnt));
}
// 插入模板 s,编号 id
void insert(string &s, int id) {
int u = 0;
for (char ch : s) {
int c = ch - 'a';
if (!tr[u][c]) tr[u][c] = ++tot;
u = tr[u][c];
}
// ======================
// 这里就是「覆盖」发生点!
// 如果重复串,end_id[u] 会被覆盖
// ======================
end_id[u] = id;
}
// 构建 fail 指针
void build() {
queue<int> q;
for (int i = 0; i < MAXC; i++)
if (tr[0][i]) q.push(tr[0][i]);
while (!q.empty()) {
int u = q.front(); q.pop();
for (int i = 0; i < MAXC; i++) {
int v = tr[u][i];
if (v) {
fail[v] = tr[fail[u]][i];
q.push(v);
} else {
tr[u][i] = tr[fail[u]][i];
}
}
}
}
// 匹配文本 t
void query(string &t) {
int u = 0;
for (char ch : t) {
int c = ch - 'a';
u = tr[u][c];
// 沿 fail 跳,统计所有匹配
for (int p = u; p; p = fail[p]) {
if (end_id[p] != -1)
cnt[end_id[p]]++;
}
}
}
// ==============================
// 去重:相同模板只保留第一个
// ==============================
vector<string> unique_patterns(vector<string> &pats) {
sort(pats.begin(), pats.end());
pats.erase(unique(pats.begin(), pats.end()), pats.end());
return pats;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0);
int n;
string text;
// 示例输入
cin >> n;
vector<string> pats(n);
for (int i = 0; i < n; i++) cin >> pats[i];
cin >> text;
init();
// 关键:先去重,避免覆盖!
auto uniq_pats = unique_patterns(pats);
// 插入去重后的模板
for (int i = 0; i < uniq_pats.size(); i++)
insert(uniq_pats[i], i);
build();
query(text);
// 输出每个模板出现次数
for (int i = 0; i < uniq_pats.size(); i++)
cout << uniq_pats[i] << ": " << cnt[i] << endl;
return 0;
}

浙公网安备 33010602011771号