AC自动机

AC 自动机用于解决多模式串匹配的问题,本质上是 trie + Kmp 思想优化匹配过程,例如:文本串aaasssddvbasd ,两个模式串:aass,ddssa。虽然可以设置多个指针同时匹配,但是时间也没快多少,我们形象化的感知时间复杂度:在每个位置的末尾试图往前匹配每个目标串,时间复杂度显然为O(mn)。

AC 自动机的构建过程:

  1. 建立模式串的字典树
int insert(string& t)
{
    int cur = 0;
    for (auto& ch : t)
    {
        int path = ch - 'a';
        if (!tr[cur][path]) tr[cur][path] = ++idx;
        cur = tr[cur][path];
    }
    return cur;
}
  1. 设置 fail 指针。fail 指针的含义为当前状态最长后缀在模式串中匹配的最长前缀。fail 指针与 kmp 里面的 next 数组有着异曲同工之妙,它们都是失配后跳转的指示,next 数组代表最长的 border,这个 border 在原串中的前缀对应这AC自动机中的模式串的前缀。

设置的过程为,遍历整个字典树(宽搜),在当前结点中去设置孩子结点的 fail 指针,fail 指针和 border 一样,不能是本身,这里及不能指向自己。

image

void bfs()
{
    queue<int> q;
    q.push(0);
    while (q.size())
    {
        int x = q.front(); q.pop();
        // 设置fail指针
        for (int i = 0; i < 26; i++)
        {
            if (tr[x][i])
            {
                q.push(tr[x][i]);
                int cur = fail[x]; // fail[fa]
                // 尝试从 cur 位置匹配 i 字符
                while (cur && !tr[cur][i]) cur = fail[cur];
                if (tr[cur][i] && tr[cur][i] != tr[x][i]) fail[tr[x][i]] = tr[cur][i];

            }
        }
    }
}

  1. 匹配模式串,从头到尾遍历文本串,cur 指针指向字典树中的最大匹配位置,当失配时,cur 指针跳转,留取当前匹配串的最长后缀,这个后缀在模式串中作为一个次长的前缀出现,然后接着往下匹配。然后在匹配途中记录答案,方法是让 j 从 cur 开始不断通过fail指针跳转,经过的结点词频加一,直到 j 跳到 0 为止。
    int cur = 0;
    for(int i = 0; i < s.size(); i++)
    {
        int path = s[i] - 'a';
        while(cur && !tr[cur][path]) cur = fail[cur];
        if(tr[cur][path]) cur = tr[cur][path];
        // cout << cur << endl;    
        int j = cur;
        while(j) 
        {
            cnt[j]++;
            j = fail[j];
        }
    }
  1. 统计答案,一个模式串出现的次数为 cnt[j] 数组中,j 为该模式串最后一个字符在字典树中对应的编号。
    for (int i = 1; i <= n; i++)
    {
        // 构建字典树
        cin >> t;
        ans[i] = insert(t);
    }
// ...
for(int i = 1; i <= n; i++) cout << cnt[ans[i]] << endl;

完整 code:

#include <iostream>
#include <queue>
using namespace std;
const int N = 2e5 + 10;

// 当然这个代码通过不了(至少不会wa),原因如下:
// 1. fail 指针在构建的时候转圈
// 2. fail 指针在遍历字符的时候充当 l,转圈
// 3. fail 指针在++路径的时候转圈

int tr[N][26], idx;
int fail[N], cnt[N], ans[N];
int insert(string& t)
{
    int cur = 0;
    for (auto& ch : t)
    {
        int path = ch - 'a';
        if (!tr[cur][path]) tr[cur][path] = ++idx;
        cur = tr[cur][path];
    }
    return cur;
}

void bfs()
{
    queue<int> q;
    q.push(0);
    while (q.size())
    {
        int x = q.front(); q.pop();
        // 设置fail指针
        for (int i = 0; i < 26; i++)
        {
            if (tr[x][i])
            {
                q.push(tr[x][i]);
                int cur = fail[x]; // fail[fa]
                // 尝试从 cur 位置匹配 i 字符
                while (cur && !tr[cur][i]) cur = fail[cur];
                if (tr[cur][i] && tr[cur][i] != tr[x][i]) fail[tr[x][i]] = tr[cur][i];

            }
        }
    }
}

int main()
{
    // n 个模式串,t,一个文本串 s。
    int n; cin >> n;
    string t;
    for (int i = 1; i <= n; i++)
    {
        // 构建字典树
        cin >> t;
        ans[i] = insert(t);
    }
    // 构建fail指针数组
    bfs();
    for(int i = 0; i <= 8; i++) cout << fail[i] << " ";
    cout << endl;
    string s; cin >> s;
    int cur = 0;
    for(int i = 0; i < s.size(); i++)
    {
        int path = s[i] - 'a';
        while(cur && !tr[cur][path]) cur = fail[cur];
        if(tr[cur][path]) cur = tr[cur][path];
        // cout << cur << endl;    
        int j = cur;
        while(j) 
        {
            cnt[j]++;
            j = fail[j];
        }
    }
    for(int i = 1; i <= n; i++) cout << cnt[ans[i]] << endl;
    return 0;
}

此代码为弱化版的 AC 自动机,原因是存在三个fail指针转圈的地方。

下面开始对 AC 自动机优化:

优化的核心思想是通过 dp 转移,fail 指针的含义无非就是失配后去往的下一个位置,当父节点 fa 去设置子节点 i的 fail 指针时,会不断重复判断 tr[fail[fa]][i] 是否存在,不存在让 fa 跳转到 fail[fa],现在设想如果想要直接跳到想要的位置上需要什么信息,不难想到我们需要 fali[fa] 跳转 i 的信息,注意这个跳转也是直接跳转不存在转圈的现象,如果直到了这个信息,显然 i 结点的 fail 信息可以直接设置,原因不难证明,如果存在一连串的跳转操作,直属的 fail 的信息一定是最晚的,匹配长度一定是最长的,如果后面还有更长的匹配长度可以为后续的结点的 fail 做贡献,一定会被转移的过程中截断,及后续的结点用到的是这个更长的信息。这样的思想有点类似于并查集中的按秩压缩。

总结我们需要的是一个结点的直去表,具体含义为:当一个位置匹配 i 失配后,下一个直接去往的位置为直去表中的信息,显然对于一个字典树中存在的位置,下一个去往位置就是该节点,否则去设置直去表,可以看出直去表和字典树是一个表。

下面来总结一下流程:

  • 如果字典树中存在路径,去设置路径到达结点的fail指针,fail[tr[x][i]] = tr[fail[x]][i]。
  • 否则设置该路径的直去表,tr[x][i] = tr[fail[x]][i]

优化后遍历文本串的策略需要改变一下:cur 直接跳转 tr[cur][i],然后统计词频 cnt[cur]++。

这样记录的词频为 i 位置的后缀匹配模式串的一个最长前缀,统计答案的时候需要建立反图,通过树形dp逐渐累计答案。

code:

#include <iostream>
#include <vector>
#include <unordered_map>
#include <map>
#include <unordered_set>
#include <set>
#include <algorithm>
#include <cmath>
#include <string>
#include <cstring>
#include <queue>
#include <cstring>


using namespace std;
#define endl '\n'
typedef long long LL;
typedef pair<int, int> PII;
#define lc p << 1
#define rc p << 1 | 1
#define lowbit(x) (x & -x)
const int N = 2e6 + 10;
const LL MOD = 1e9 + 7;
const double ln2 = log(2);
const double rec_ln2 = 1.0 / ln2;
int id[N], n, fail[N];
string s;
int tr[N][26], idx;
int cnt[N];
vector<int> edges[N];

int insert(string& s)
{
    int cur = 0;
    for(auto& ch : s)
    {
        int path = ch - 'a';
        if(!tr[cur][path]) tr[cur][path] = ++idx;
        cur = tr[cur][path];
    }
    return cur;
}

void bfs()
{
    queue<int> q;
    q.push(0);
    while(q.size())
    {
        int x = q.front(); q.pop();
        for(int i = 0; i < 26; i++)
        {
            // 如果有这条路,设置它的fail指针
            if(tr[x][i]) 
            {
                if(tr[fail[x]][i] != tr[x][i]) fail[tr[x][i]] = tr[fail[x]][i];
                q.push(tr[x][i]);
            }
            else tr[x][i] = tr[fail[x]][i]; // 如果没有就设置直去表
        }
    }
}

void dfs(int x, int fa)
{
    for(auto& y : edges[x])
    {
        if(y == fa) continue;
        dfs(y, x);
        cnt[x] += cnt[y];
    }
}

void solve()
{
    cin >> n;
    for(int i = 1; i <= n; i++)
    {
        cin >> s;
        id[i] = insert(s);
    }    
    cin >> s;
    // 建立直去表并且设置fail指针
    bfs();
    // for(int i = 0; i <= 8; i++) cout << fail[i] << ' ';
    // cout << endl;

    // 遍历字符
    for(int i = 0, cur = 0; i < s.size(); i++)
    {
        int path = s[i] - 'a';
        cur = tr[cur][path]; // 往下走
        // cout << cur << endl;
        cnt[cur]++; // 累加计数
    }
    // for(int i = 1; i <= idx; i++) cout << cnt[i] << " ";
    // cout << endl;
    // 建反图,统计最终计数
    for(int i = 1; i <= idx; i++)
    {
        edges[fail[i]].push_back(i);
    }
    dfs(0, 0);
    for(int i = 1; i <= n; i++) cout << cnt[id[i]] << endl;
}

int main()
{
    cin.tie(0), cout.tie(0), ios::sync_with_stdio(false);
    int T = 1;
    // cin >> T;
    while(T--)
    {
        solve();
    }
    return 0;
}

posted on 2026-06-29 10:44  我不爱吃汉堡  阅读(2)  评论(0)    收藏  举报

导航