贴合《算法竞赛入门经典训练指南》AC 自动机完整代码及参考

参考:详解AC自动机的原理和代码实现

一、程序

包含字典树构建、fail 指针 BFS、模式匹配 Find、递归打印匹配结果 Print全流程。

#include <cstdio>
#include <cstring>
#include <queue>
#include <vector>
using namespace std;

const int MAXN = 10000 + 10;  // 字典树最大节点数,按需调整
const int MAXC = 26;          // 字符集:小写字母a-z

// AC自动机节点结构
struct AhoCorasick {
    int ch[MAXN][MAXC];       // 字典树:ch[u][c]表示u节点走c字符到的节点
    int fail[MAXN];           // fail指针
    vector<int> val[MAXN];    // val[u]:存储以u节点结尾的模式串编号(多模式匹配)
    int sz;                   // 节点总数,初始为0

    // 初始化AC自动机
    void init() {
        sz = 0;
        memset(ch[0], 0, sizeof(ch[0]));
        val[0].clear();
    }

    // 插入模式串s,id为模式串编号(用于区分匹配到的是哪个模式串)
    void insert(const char *s, int id) {
        int u = 0;
        for (int i = 0; s[i]; i++) {
            int c = s[i] - 'a';
            if (!ch[u][c]) {
                sz++;
                memset(ch[sz], 0, sizeof(ch[sz]));
                val[sz].clear();
                ch[u][c] = sz;
            }
            u = ch[u][c];
        }
        val[u].push_back(id);  // 该节点记录模式串id
    }

    // BFS构建fail指针
    void build() {
        queue<int> q;
        fail[0] = 0;
        for (int c = 0; c < MAXC; c++) {
            int v = ch[0][c];
            if (v) {
                fail[v] = 0;
                q.push(v);
            }
        }
        while (!q.empty()) {
            int u = q.front();
            q.pop();
            for (int c = 0; c < MAXC; c++) {
                int v = ch[u][c];
                if (v) {
                    // 计算fail[v]:回溯u的fail指针,找到能走c的节点
                    int f = fail[u];
                    while (f && !ch[f][c]) f = fail[f];
                    fail[v] = ch[f][c];
                    q.push(v);
                } else {
                    // 路径压缩:直接指向fail后的节点,优化查询速度
                    ch[u][c] = ch[fail[u]][c];
                }
            }
        }
    }

    // 递归打印:以j节点结尾的所有匹配模式串(配合Find函数,i是文本串当前结束位置)
    // 书中Print(i,j)是解释性写法,实际仅需j,i在调用时传入计算位置
    void Print(int i, int j, const vector<char*>& patterns) {
        if (j == 0) return;  // 根节点,无匹配
        // 打印当前节点的所有匹配结果:i是文本串结束位置,patterns[id]是模式串内容
        for (int id : val[j]) {
            int len = strlen(patterns[id]);
            printf("匹配到模式串:%s,结束位置:%d,起始位置:%d\n", 
                   patterns[id], i, i - len + 1);
        }
        // 递归回溯fail指针,打印所有间接匹配(AC自动机核心:fail链匹配)
        Print(i, fail[j], patterns);
    }

    // 核心匹配函数Find:在文本串T中找所有模式串,patterns是模式串数组
    void Find(const char *T, const vector<char*>& patterns) {
        int u = 0;
        for (int i = 0; T[i]; i++) {  // i是文本串当前位置(从0开始)
            int c = T[i] - 'a';
            u = ch[u][c];  // 按字典树+路径压缩走
            if (val[u].size() || fail[u]) {  // 有匹配结果,调用Print
                Print(i + 1, u, patterns);  // 文本串位置转成从1开始,更符合阅读习惯
            }
        }
    }
} ac;

// 测试主函数
int main() {
    // 步骤1:初始化AC自动机
    ac.init();

    // 步骤2:插入模式串(可自定义,示例:3个模式串)
    vector<char*> patterns;  // 存储模式串,与id一一对应
    char s1[] = "she";
    char s2[] = "he";
    char s3[] = "his";
    patterns.push_back(s1);
    patterns.push_back(s2);
    patterns.push_back(s3);
    ac.insert(s1, 0);
    ac.insert(s2, 1);
    ac.insert(s3, 2);

    // 步骤3:构建fail指针
    ac.build();

    // 步骤4:文本串匹配
    char T[] = "hishers";  // 测试文本串
    printf("文本串:%s\n", T);
    printf("匹配结果:\n");
    ac.Find(T, patterns);

    return 0;
}

二、细节说明

1、Print 函数参数:书中Print(i,j)是读者理解用的简化写法,实际代码中Print需要i(文本串结束位置)+j(AC 节点)+patterns(模式串数组),其中i仅用于计算并打印匹配的起始 / 结束位置,j是核心的 AC 节点,用于递归遍历fail链找所有匹配。
2、与书中的对应关系:

  • 书中ch数组 → 代码中ch[MAXN][MAXC](字典树核心);
  • 书中fail数组 → 代码中fail[MAXN](BFS 构建);
  • 书中val数组 → 代码中val[MAXN](存储模式串编号,支持多模式匹配);
  • 书中Find函数逻辑完全一致 → 文本串遍历 + 节点跳转 +Print调用;
  • 书中Print递归逻辑 → 代码中严格实现节点 j 的直接匹配 + fail 链的间接匹配。

3、路径压缩优化:代码中ch[u][c] = ch[fail[u]][c]是书中的优化点,避免每次匹配都回溯fail指针,提升运行效率。

三、编译运行结果(示例)

上述测试代码的运行输出(文本串hishers,模式串she/he/his):

文本串:hishers
匹配结果:
匹配到模式串:his,结束位置:3,起始位置:1
匹配到模式串:he,结束位置:5,起始位置:4
匹配到模式串:she,结束位置:6,起始位置:4
匹配到模式串:he,结束位置:6,起始位置:5

四、自定义修改说明

1、字符集调整:若需支持大写字母 / 数字,修改MAXC和c = s[i] - 'a'的偏移量即可;
2、模式串数量:直接在main函数中添加insert调用,patterns数组同步添加即可;
3、文本串输入:可将固定文本串改为scanf/gets读取外部输入,适配算法竞赛场景。

五、扩展

核心步骤 1:AC 自动机 fail 指针的 BFS 构建(逐行注释 + 解析)
先给出完整的带逐行注释的代码,再补充关键逻辑的设计思路,代码基于 AC 自动机的经典数组实现(适合新手理解,避免指针的复杂度)。
前置准备:AC 自动机的节点结构体(先定义,否则 BFS 代码无意义)

#include <cstdio>
#include <cstring>
#include <queue>
#include <vector>
using namespace std;

const int MAXN = 10000 + 10;  // 字典树最大节点数,按需调整
const int MAXC = 26;          // 字符集:小写字母a-z

// AC自动机节点结构
struct AhoCorasick {
    int ch[MAXN][MAXC];       // 字典树:ch[u][c]表示u节点走c字符到的节点
    int fail[MAXN];           // fail指针
    vector<int> val[MAXN];    // val[u]:存储以u节点结尾的模式串编号(多模式匹配)
    int sz;                   // 节点总数,初始为0

    // 初始化AC自动机
    void init() {
        sz = 0;
        memset(ch[0], 0, sizeof(ch[0]));
        val[0].clear();
    }

    // 插入模式串s,id为模式串编号(用于区分匹配到的是哪个模式串)
    void insert(const char *s, int id) {
        int u = 0;
        for (int i = 0; s[i]; i++) {
            int c = s[i] - 'a';
            if (!ch[u][c]) {
                sz++;
                memset(ch[sz], 0, sizeof(ch[sz]));
                val[sz].clear();
                ch[u][c] = sz;
            }
            u = ch[u][c];
        }
        val[u].push_back(id);  // 该节点记录模式串id
    }

    // BFS构建fail指针
    void build() {
        queue<int> q;
        fail[0] = 0;
        for (int c = 0; c < MAXC; c++) {
            int v = ch[0][c];
            if (v) {
                fail[v] = 0;
                q.push(v);
            }
        }
        while (!q.empty()) {
            int u = q.front();
            q.pop();
            for (int c = 0; c < MAXC; c++) {
                int v = ch[u][c];
                if (v) {
                    // 计算fail[v]:回溯u的fail指针,找到能走c的节点
                    int f = fail[u];
                    while (f && !ch[f][c]) f = fail[f];
                    fail[v] = ch[f][c];
                    q.push(v);
                } else {
                    // 路径压缩:直接指向fail后的节点,优化查询速度
                    ch[u][c] = ch[fail[u]][c];
                }
            }
        }
    }

    // 递归打印:以j节点结尾的所有匹配模式串(配合Find函数,i是文本串当前结束位置)
    // 书中Print(i,j)是解释性写法,实际仅需j,i在调用时传入计算位置
    void Print(int i, int j, const vector<char*>& patterns) {
        if (j == 0) return;  // 根节点,无匹配
        // 打印当前节点的所有匹配结果:i是文本串结束位置,patterns[id]是模式串内容
        for (int id : val[j]) {
            int len = strlen(patterns[id]);
            printf("匹配到模式串:%s,结束位置:%d,起始位置:%d\n", 
                   patterns[id], i, i - len + 1);
        }
        // 递归回溯fail指针,打印所有间接匹配(AC自动机核心:fail链匹配)
        Print(i, fail[j], patterns);
    }

    // 核心匹配函数Find:在文本串T中找所有模式串,patterns是模式串数组
    void Find(const char *T, const vector<char*>& patterns) {
        int u = 0;
        for (int i = 0; T[i]; i++) {  // i是文本串当前位置(从0开始)
            int c = T[i] - 'a';
            u = ch[u][c];  // 按字典树+路径压缩走
            if (val[u].size() || fail[u]) {  // 有匹配结果,调用Print
                Print(i + 1, u, patterns);  // 文本串位置转成从1开始,更符合阅读习惯
            }
        }
    }
} ac;

// 测试主函数
int main() {
    // 步骤1:初始化AC自动机
    ac.init();

    // 步骤2:插入模式串(可自定义,示例:3个模式串)
    vector<char*> patterns;  // 存储模式串,与id一一对应
    char s1[] = "she";
    char s2[] = "he";
    char s3[] = "his";
    patterns.push_back(s1);
    patterns.push_back(s2);
    patterns.push_back(s3);
    ac.insert(s1, 0);
    ac.insert(s2, 1);
    ac.insert(s3, 2);

    // 步骤3:构建fail指针
    ac.build();

    // 步骤4:文本串匹配
    char T[] = "hishers";  // 测试文本串
    printf("文本串:%s\n", T);
    printf("匹配结果:\n");
    ac.Find(T, patterns);

    return 0;
}

编译运行结果(示例)
上述测试代码的运行输出(文本串hishers,模式串she/he/his):
plaintext
文本串:hishers
匹配结果:
匹配到模式串:his,结束位置:3,起始位置:1
匹配到模式串:he,结束位置:5,起始位置:4
匹配到模式串:she,结束位置:6,起始位置:4
匹配到模式串:he,结束位置:6,起始位置:5

前置准备:AC 自动机的节点结构体(先定义,否则 BFS 代码无意义)

// AC自动机节点结构体,数组实现(size根据模式串数量/长度定,一般开1e5+10足够)
const int MAXN = 100010;
struct Node {
    int next[26];  // 每个节点的26个小写字母子节点,存子节点的下标
    int fail;      // 该节点的失败指针,存指向的节点下标
    int end;       // end>0表示该节点是某个模式串的结尾,end的值可记录模式串编号/长度
    Node() {       // 构造函数,初始化节点
        memset(next, -1, sizeof(next));  // 子节点初始化为-1(表示无此子节点)
        fail = 0;                        // 失败指针默认指向根节点(根节点下标为0)
        end = 0;                         // 初始不是任何模式串的结尾
    }
} tree[MAXN];
int cnt = 0;  // 节点计数器,根节点为0,新节点从1开始创建
queue<int> q; // BFS队列,存节点下标,用于构建fail指针

核心代码:fail 指针的 BFS 构建(逐行注释 + 解析)

// 构建AC自动机的fail指针,根节点下标为0
void build_fail() {
    // 第一步:初始化根节点的直接子节点(第一层节点)的fail指针,并加入队列
    // 根节点的next[0~25]遍历,i对应a~z的偏移(0=a,1=b...25=z)
    for (int i = 0; i < 26; i++) {
        // 如果根节点有i这个子节点(下标不为-1)
        if (tree[0].next[i] != -1) {
            tree[tree[0].next[i]].fail = 0;  // 第一层节点的fail指针直接指向根节点
            q.push(tree[0].next[i]);         // 该节点入队,作为BFS的起始层
        } else {
            tree[0].next[i] = 0;  // 根节点无此子节点,直接指向自己(优化后续匹配)
        }
    }

    // 第二步:BFS遍历所有节点,逐层设置fail指针(核心循环)
    while (!q.empty()) {
        int u = q.front();  // 取出队首节点u(当前处理的父节点)
        q.pop();            // 队首出队

        // 遍历u的所有26个子节点,i对应a~z的偏移
        for (int i = 0; i < 26; i++) {
            // 情况1:u有i这个子节点(子节点下标为v)
            if (tree[u].next[i] != -1) {
                int v = tree[u].next[i];  // 记录u的i子节点的下标v
                // 关键:v的fail指针 = u的fail指针 所指向节点 的 i子节点
                // 理解:u的失败指针是f,那么v的失败指针就是f沿着i走的节点(最接近的后缀匹配)
                tree[v].fail = tree[tree[u].fail].next[i];
                q.push(v);  // 子节点v入队,后续处理它的子节点
            } 
            // 情况2:u无i这个子节点(优化:路径压缩,直接指向fail指针的i子节点)
            else {
                // 核心优化:u没有i子节点,就把u的next[i]指向 u->fail->next[i]
                // 匹配时无需回溯fail指针,直接通过next[i]跳转,提升匹配效率
                tree[u].next[i] = tree[tree[u].fail].next[i];
            }
        }
    }
}

核心步骤 2:Print 递归(匹配成功后递归输出所有匹配的模式串,逐行注释)

AC 自动机匹配到某个节点时,该节点可能是多个模式串的结尾(比如模式串abc和bc,匹配到abc的 c 节点时,也能匹配到bc的 c 节点),因此需要从当前节点递归向上遍历 fail 指针,输出所有匹配的模式串,这就是 Print 递归的核心作用。
前置:假设模式串存入数组string pattern[1005](pattern [1] 是第一个模式串,pattern [2] 第二个...),tree[node].end记录模式串的编号(如tree[v].end=3表示该节点是第 3 个模式串的结尾)。
核心代码:Print 递归函数(逐行注释 + 解析)

// 递归打印:从node节点开始,向上遍历fail指针,输出所有匹配的模式串
// pos:当前匹配到的文本串的位置(用于输出匹配的结束位置,可选)
void Print(int node, int pos) {
    // 递归终止条件:node为根节点(0),无更多后缀需要匹配
    if (node == 0) {
        return;
    }
    // 如果当前节点是某个模式串的结尾(end>0,end是模式串编号)
    if (tree[node].end != 0) {
        // 输出匹配信息:模式串编号、模式串内容、匹配的结束位置(pos)、匹配的起始位置(pos - 模式串长度 + 1)
        // pattern[tree[node].end].size()是当前匹配的模式串的长度
        cout << "匹配到模式串" << tree[node].end << ":" << pattern[tree[node].end] 
             << ",结束位置:" << pos 
             << ",起始位置:" << pos - pattern[tree[node].end].size() + 1 << endl;
    }
    // 递归核心:继续遍历当前节点的fail指针指向的节点,查找更长的后缀(或更短的)匹配
    // 因为fail指针指向的是当前节点的最长有效后缀,其可能也是某个模式串的结尾
    Print(tree[node].fail, pos);
}

递归调用场景(匹配时的调用)
在 AC 自动机的匹配函数中,遍历文本串的每个字符,跳转到对应节点now后,直接调用Print(now, i)(i是文本串当前字符的下标,从 1 开始更符合习惯),示例如下:

// 匹配函数:text是待匹配的文本串
void match(string text) {
    int now = 0;  // 从根节点开始匹配
    for (int i = 0; i < text.size(); i++) {
        // 跳转到对应节点:text[i]-'a'是字符偏移,now始终是当前匹配的节点
        now = tree[now].next[text[i] - 'a'];
        // 调用Print递归,输出所有匹配的模式串,i+1是文本串的实际位置(从1开始)
        Print(now, i + 1);
    }
}

循环版 Print(替代递归,避免栈溢出)

如果模式串很长,fail 指针链会很长,递归可能导致栈溢出,补充循环版的 Print(逐行注释,功能和递归完全一致),供你选择:

// 循环版Print:避免递归栈溢出,功能和递归版一致
void Print_Cycle(int node, int pos) {
    // 循环遍历fail指针链,直到根节点(0)
    while (node != 0) {
        // 如果是模式串结尾,输出匹配信息
        if (tree[node].end != 0) {
            cout << "匹配到模式串" << tree[node].end << ":" << pattern[tree[node].end] 
                 << ",结束位置:" << pos 
                 << ",起始位置:" << pos - pattern[tree[node].end].size() + 1 << endl;
        }
        // 跳转到fail指针指向的节点,继续遍历
        node = tree[node].fail;
    }
}

最简单的字符串匹配模拟书上那个 “模板覆盖” 逻辑

#include <iostream>
#include <map>
#include <string>
using namespace std;

int main() {
    map<string, int> pat;  // 模板串 → 编号

    // 插入第一个模板:abc → 编号 1
    pat["abc"] = 1;
    cout << "插入模板1后:abc -> " << pat["abc"] << endl;

    // 插入第二个**相同**模板:abc → 编号 2
    pat["abc"] = 2;
    cout << "插入模板2后:abc -> " << pat["abc"] << endl;

    // 匹配时,只会看到最后一次的值
    cout << "最终匹配到的模板编号:" << pat["abc"] << endl;
    return 0;
}
// 插入模板 s,编号 id
void insert(string &s, int id) {
    int u = 0;
    for (char ch : s) {
        int c = ch - 'a';
        if (!tr[u][c]) tr[u][c] = ++tot;
        u = tr[u][c];
    }
    // 关键:同一个节点,重复赋值!
    end[u] = id;
}
#include <iostream>
#include <cstring>
#include <queue>
#include <vector>
#include <algorithm>
using namespace std;

const int MAXN = 10005;   // 模板总长度
const int MAXC = 26;

int tr[MAXN][MAXC];
int fail[MAXN];
int end_id[MAXN];   // 每个节点对应的模板编号
int cnt[MAXN];      // 统计每个模板出现次数
int tot;

// 初始化
void init() {
    tot = 0;
    memset(tr, 0, sizeof(tr));
    memset(fail, 0, sizeof(fail));
    memset(end_id, -1, sizeof(end_id));
    memset(cnt, 0, sizeof(cnt));
}

// 插入模板 s,编号 id
void insert(string &s, int id) {
    int u = 0;
    for (char ch : s) {
        int c = ch - 'a';
        if (!tr[u][c]) tr[u][c] = ++tot;
        u = tr[u][c];
    }

    // ======================
    // 这里就是「覆盖」发生点!
    // 如果重复串,end_id[u] 会被覆盖
    // ======================
    end_id[u] = id;
}

// 构建 fail 指针
void build() {
    queue<int> q;
    for (int i = 0; i < MAXC; i++)
        if (tr[0][i]) q.push(tr[0][i]);

    while (!q.empty()) {
        int u = q.front(); q.pop();
        for (int i = 0; i < MAXC; i++) {
            int v = tr[u][i];
            if (v) {
                fail[v] = tr[fail[u]][i];
                q.push(v);
            } else {
                tr[u][i] = tr[fail[u]][i];
            }
        }
    }
}

// 匹配文本 t
void query(string &t) {
    int u = 0;
    for (char ch : t) {
        int c = ch - 'a';
        u = tr[u][c];
        // 沿 fail 跳,统计所有匹配
        for (int p = u; p; p = fail[p]) {
            if (end_id[p] != -1)
                cnt[end_id[p]]++;
        }
    }
}

// ==============================
// 去重:相同模板只保留第一个
// ==============================
vector<string> unique_patterns(vector<string> &pats) {
    sort(pats.begin(), pats.end());
    pats.erase(unique(pats.begin(), pats.end()), pats.end());
    return pats;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(0);

    int n;
    string text;

    // 示例输入
    cin >> n;
    vector<string> pats(n);
    for (int i = 0; i < n; i++) cin >> pats[i];
    cin >> text;

    init();

    // 关键:先去重,避免覆盖!
    auto uniq_pats = unique_patterns(pats);

    // 插入去重后的模板
    for (int i = 0; i < uniq_pats.size(); i++)
        insert(uniq_pats[i], i);

    build();
    query(text);

    // 输出每个模板出现次数
    for (int i = 0; i < uniq_pats.size(); i++)
        cout << uniq_pats[i] << ": " << cnt[i] << endl;

    return 0;
}
posted @ 2026-02-06 10:51  gdyyx  阅读(6)  评论(0)    收藏  举报