AC自动机
AC 自动机用于解决多模式串匹配的问题,本质上是 trie + Kmp 思想优化匹配过程,例如:文本串aaasssddvbasd ,两个模式串:aass,ddssa。虽然可以设置多个指针同时匹配,但是时间也没快多少,我们形象化的感知时间复杂度:在每个位置的末尾试图往前匹配每个目标串,时间复杂度显然为O(mn)。
AC 自动机的构建过程:
- 建立模式串的字典树
int insert(string& t)
{
int cur = 0;
for (auto& ch : t)
{
int path = ch - 'a';
if (!tr[cur][path]) tr[cur][path] = ++idx;
cur = tr[cur][path];
}
return cur;
}
- 设置 fail 指针。fail 指针的含义为当前状态最长后缀在模式串中匹配的最长前缀。fail 指针与 kmp 里面的 next 数组有着异曲同工之妙,它们都是失配后跳转的指示,next 数组代表最长的 border,这个 border 在原串中的前缀对应这AC自动机中的模式串的前缀。
设置的过程为,遍历整个字典树(宽搜),在当前结点中去设置孩子结点的 fail 指针,fail 指针和 border 一样,不能是本身,这里及不能指向自己。

void bfs()
{
queue<int> q;
q.push(0);
while (q.size())
{
int x = q.front(); q.pop();
// 设置fail指针
for (int i = 0; i < 26; i++)
{
if (tr[x][i])
{
q.push(tr[x][i]);
int cur = fail[x]; // fail[fa]
// 尝试从 cur 位置匹配 i 字符
while (cur && !tr[cur][i]) cur = fail[cur];
if (tr[cur][i] && tr[cur][i] != tr[x][i]) fail[tr[x][i]] = tr[cur][i];
}
}
}
}
- 匹配模式串,从头到尾遍历文本串,cur 指针指向字典树中的最大匹配位置,当失配时,cur 指针跳转,留取当前匹配串的最长后缀,这个后缀在模式串中作为一个次长的前缀出现,然后接着往下匹配。然后在匹配途中记录答案,方法是让 j 从 cur 开始不断通过fail指针跳转,经过的结点词频加一,直到 j 跳到 0 为止。
int cur = 0;
for(int i = 0; i < s.size(); i++)
{
int path = s[i] - 'a';
while(cur && !tr[cur][path]) cur = fail[cur];
if(tr[cur][path]) cur = tr[cur][path];
// cout << cur << endl;
int j = cur;
while(j)
{
cnt[j]++;
j = fail[j];
}
}
- 统计答案,一个模式串出现的次数为 cnt[j] 数组中,j 为该模式串最后一个字符在字典树中对应的编号。
for (int i = 1; i <= n; i++)
{
// 构建字典树
cin >> t;
ans[i] = insert(t);
}
// ...
for(int i = 1; i <= n; i++) cout << cnt[ans[i]] << endl;
完整 code:
#include <iostream>
#include <queue>
using namespace std;
const int N = 2e5 + 10;
// 当然这个代码通过不了(至少不会wa),原因如下:
// 1. fail 指针在构建的时候转圈
// 2. fail 指针在遍历字符的时候充当 l,转圈
// 3. fail 指针在++路径的时候转圈
int tr[N][26], idx;
int fail[N], cnt[N], ans[N];
int insert(string& t)
{
int cur = 0;
for (auto& ch : t)
{
int path = ch - 'a';
if (!tr[cur][path]) tr[cur][path] = ++idx;
cur = tr[cur][path];
}
return cur;
}
void bfs()
{
queue<int> q;
q.push(0);
while (q.size())
{
int x = q.front(); q.pop();
// 设置fail指针
for (int i = 0; i < 26; i++)
{
if (tr[x][i])
{
q.push(tr[x][i]);
int cur = fail[x]; // fail[fa]
// 尝试从 cur 位置匹配 i 字符
while (cur && !tr[cur][i]) cur = fail[cur];
if (tr[cur][i] && tr[cur][i] != tr[x][i]) fail[tr[x][i]] = tr[cur][i];
}
}
}
}
int main()
{
// n 个模式串,t,一个文本串 s。
int n; cin >> n;
string t;
for (int i = 1; i <= n; i++)
{
// 构建字典树
cin >> t;
ans[i] = insert(t);
}
// 构建fail指针数组
bfs();
for(int i = 0; i <= 8; i++) cout << fail[i] << " ";
cout << endl;
string s; cin >> s;
int cur = 0;
for(int i = 0; i < s.size(); i++)
{
int path = s[i] - 'a';
while(cur && !tr[cur][path]) cur = fail[cur];
if(tr[cur][path]) cur = tr[cur][path];
// cout << cur << endl;
int j = cur;
while(j)
{
cnt[j]++;
j = fail[j];
}
}
for(int i = 1; i <= n; i++) cout << cnt[ans[i]] << endl;
return 0;
}
此代码为弱化版的 AC 自动机,原因是存在三个fail指针转圈的地方。
下面开始对 AC 自动机优化:
优化的核心思想是通过 dp 转移,fail 指针的含义无非就是失配后去往的下一个位置,当父节点 fa 去设置子节点 i的 fail 指针时,会不断重复判断 tr[fail[fa]][i] 是否存在,不存在让 fa 跳转到 fail[fa],现在设想如果想要直接跳到想要的位置上需要什么信息,不难想到我们需要 fali[fa] 跳转 i 的信息,注意这个跳转也是直接跳转不存在转圈的现象,如果直到了这个信息,显然 i 结点的 fail 信息可以直接设置,原因不难证明,如果存在一连串的跳转操作,直属的 fail 的信息一定是最晚的,匹配长度一定是最长的,如果后面还有更长的匹配长度可以为后续的结点的 fail 做贡献,一定会被转移的过程中截断,及后续的结点用到的是这个更长的信息。这样的思想有点类似于并查集中的按秩压缩。
总结我们需要的是一个结点的直去表,具体含义为:当一个位置匹配 i 失配后,下一个直接去往的位置为直去表中的信息,显然对于一个字典树中存在的位置,下一个去往位置就是该节点,否则去设置直去表,可以看出直去表和字典树是一个表。
下面来总结一下流程:
- 如果字典树中存在路径,去设置路径到达结点的fail指针,fail[tr[x][i]] = tr[fail[x]][i]。
- 否则设置该路径的直去表,tr[x][i] = tr[fail[x]][i]
优化后遍历文本串的策略需要改变一下:cur 直接跳转 tr[cur][i],然后统计词频 cnt[cur]++。
这样记录的词频为 i 位置的后缀匹配模式串的一个最长前缀,统计答案的时候需要建立反图,通过树形dp逐渐累计答案。
code:
#include <iostream>
#include <vector>
#include <unordered_map>
#include <map>
#include <unordered_set>
#include <set>
#include <algorithm>
#include <cmath>
#include <string>
#include <cstring>
#include <queue>
#include <cstring>
using namespace std;
#define endl '\n'
typedef long long LL;
typedef pair<int, int> PII;
#define lc p << 1
#define rc p << 1 | 1
#define lowbit(x) (x & -x)
const int N = 2e6 + 10;
const LL MOD = 1e9 + 7;
const double ln2 = log(2);
const double rec_ln2 = 1.0 / ln2;
int id[N], n, fail[N];
string s;
int tr[N][26], idx;
int cnt[N];
vector<int> edges[N];
int insert(string& s)
{
int cur = 0;
for(auto& ch : s)
{
int path = ch - 'a';
if(!tr[cur][path]) tr[cur][path] = ++idx;
cur = tr[cur][path];
}
return cur;
}
void bfs()
{
queue<int> q;
q.push(0);
while(q.size())
{
int x = q.front(); q.pop();
for(int i = 0; i < 26; i++)
{
// 如果有这条路,设置它的fail指针
if(tr[x][i])
{
if(tr[fail[x]][i] != tr[x][i]) fail[tr[x][i]] = tr[fail[x]][i];
q.push(tr[x][i]);
}
else tr[x][i] = tr[fail[x]][i]; // 如果没有就设置直去表
}
}
}
void dfs(int x, int fa)
{
for(auto& y : edges[x])
{
if(y == fa) continue;
dfs(y, x);
cnt[x] += cnt[y];
}
}
void solve()
{
cin >> n;
for(int i = 1; i <= n; i++)
{
cin >> s;
id[i] = insert(s);
}
cin >> s;
// 建立直去表并且设置fail指针
bfs();
// for(int i = 0; i <= 8; i++) cout << fail[i] << ' ';
// cout << endl;
// 遍历字符
for(int i = 0, cur = 0; i < s.size(); i++)
{
int path = s[i] - 'a';
cur = tr[cur][path]; // 往下走
// cout << cur << endl;
cnt[cur]++; // 累加计数
}
// for(int i = 1; i <= idx; i++) cout << cnt[i] << " ";
// cout << endl;
// 建反图,统计最终计数
for(int i = 1; i <= idx; i++)
{
edges[fail[i]].push_back(i);
}
dfs(0, 0);
for(int i = 1; i <= n; i++) cout << cnt[id[i]] << endl;
}
int main()
{
cin.tie(0), cout.tie(0), ios::sync_with_stdio(false);
int T = 1;
// cin >> T;
while(T--)
{
solve();
}
return 0;
}
浙公网安备 33010602011771号