学习笔记《AC 自动机》
作者是个 fw,有些话不是很标准,还请见谅。
为了方便,接下来的 acam,没有特殊表明,均表示 AC自动机。
我们直接引入一道题目 P5357。
这题就是标准的模板,从中,我们可以得到 AC 的作用:统计文本串内各个模式串的个数。
我们回忆一下 trie 的作用:判断一个字符串在不在一堆字符串里。
这和题目很像,我们想想怎么建这棵 trie。
我们要找模式串的个数,那么这个 trie 肯定就是模式串建起来的了。
接下来,我们会遇到问题:文本串很长,我们不可能把所有子串遍历一遍。
在 acam 中,我么只需要遍历文本串一次。可只遍历一遍,那么到一个节点没路了,怎么办?
那么我们需要一个新的路,这就是失配指针。
既然接下来没有路了,那么已经匹配好的所有前缀已经没用了。
所以,每个节点失配指针会指向该节点字符串的最长后缀。
为了代码好写,如果节点没有某个儿子,我们会把这个儿子设成它失配指针的节点的该儿子。
说一下怎么算答案。
我们遍历文本串,设一个变量像遍历 trie 一样一直走。
然后,每到一个节点,说明这个节点的字符串出现了一次,相应的,它的所有后缀也出现了一次。
所以我们从这个遍历失配指针增加答案,直到到根节点。
但是这样每次最少减一次深度,复杂度错误。
如果我们将 \(fail_x \rightarrow x\) 连有向边,显然最后是一颗树(\(0\) 节点处有自环)
我们可以离线,最后拓扑一下即可。
这里给出例题代码。
代码
#include<bits/stdc++.h>
using namespace std;
#define endl '\n'
#define FL(a,b,c) for(int a=(b),a##end=(c);a<=a##end;++a)
#define FR(a,b,c) for(int a=(b),a##end=(c);a>=a##end;--a)
#define lowbit(x) ((x)&-(x))
#define eb emplace_back
#define SZ(x) (int)((x).size())
#define ll long long
#define vt vector
#define ar(x) array<int,x>
#define PII pair<int, int>
#define max(a, b)({auto f7r=(a);auto j3h=(b);f7r<j3h?j3h:f7r;})
#define cmax(a, b)({auto j3h=(b);(j3h>a)&&(a=j3h);})
#define min(a, b)({auto f7r=(a);auto j3h=(b);f7r>j3h?j3h:f7r;})
#define cmin(a, b)({auto j3h=(b);(j3h<a)&&(a=j3h);})
constexpr int N = 2e6 + 10;
struct node{//fail 失配指针,end 是哪个模式串的终点,ans 表示文本串里该节点字符串的个数。
int fail, end, son[27], ans;
void clear(){fail = end = ans = 0, memset(son, 0, sizeof son);}
}AC[N];
char a[N];
int tot, Map[N], ans[N];//模式串可能重复。
void insert(char*a, int num){
int u = 0;
FL(i, 1, strlen(a + 1)){
int &z = AC[u].son[a[i] - 'a'];
if(!z)z = ++tot, AC[tot].clear();
u = z;
}
if(!AC[u].end)AC[u].end = num;
Map[num] = AC[u].end;
}
void get_fail(){//建立失配指针,特别妙。
queue<int>q;
FL(i, 0, 25)if(AC[0].son[i])q.push(AC[0].son[i]);
while(!q.empty()){
int u = q.front(), fail = AC[u].fail, *z;
q.pop();
FL(i, 0, 25)
if(!(z = &AC[u].son[i], *z))*z = AC[fail].son[i];
else AC[*z].fail = AC[fail].son[i], q.push(*z), in[AC[*z].fail]++;
}
}
void query(char*a){//遍历 trie
int u = 0;
FL(i, 1, strlen(a + 1))
u = AC[u].son[a[i] - 'a'], AC[u].ans++;
}
void bfs(){//拓扑
queue<int>q;
FL(i, 1, tot)if(!in[i])q.push(i);
while(!q.empty()){
int u = q.front(), v;
q.pop(), ans[AC[u].end] = AC[u].ans;
AC[v = AC[u].fail].ans += AC[u].ans;
if(!--in[v])q.push(v);
}
}
int32_t main(){
cin.tie(0)->sync_with_stdio(0);
int n;
cin >> n;
FL(i, 1, n)cin >> a + 1, insert(a, i);
get_fail(), cin >> a + 1, query(a), bfs();
FL(i, 1, n)cout << ans[Map[i]] << endl;
return 0;
}
例 1
2025CSP-S 谐音替换
将一个变换 \(S_1 \rightarrow S_2\) 表示为 \(ABC\rightarrow ADC\)。
其中 \(A,C\) 是最长公共前后缀,那么将其转换为 \(A?BD?C\),其中 \(?\) 为特殊字符。
\(T_1 \rightarrow T_2\) 同理, 直接用 AC 自动机解决。
由于只要知道匹配到的总数,将 fail 树每个节点都记录其到根路径上有多少个模式串的结尾即可。
代码
#include<bits/stdc++.h>
using namespace std;
#define endl '\n'
#define FL(a,b,c) for(int a=(b),a##end=(c);a<=a##end;++a)
#define FR(a,b,c) for(int a=(b),a##end=(c);a>=a##end;--a)
#define lowbit(x) ((x)&-(x))
#define eb emplace_back
#define sz(x) (int)((x).size())
#define vt vector
#define fr first
#define se second
bool IOS=(cin.tie(0)->sync_with_stdio(0),0);
// #define LOCAL
#ifdef LOCAL
bool IOS1=(freopen(LOCAL".in", "r", stdin),1);
bool IOS2=(freopen(LOCAL".out", "w", stdout),1);
#endif
#define mmt(x, y) memset(x, y, sizeof x)
#define PII pair<int, int>
#define max(a, b)({auto f7r=(a);auto j3h=(b);f7r<j3h?j3h:f7r;})
#define cmax(a, b)({auto j3h=(b);(j3h>a)&&(a=j3h);})
#define min(a, b)({auto f7r=(a);auto j3h=(b);f7r>j3h?j3h:f7r;})
#define cmin(a, b)({auto j3h=(b);(j3h<a)&&(a=j3h);})
constexpr int N = 1e6 + 10, M = 5e6 + 10;
char a[M], b[M];
int len, cnt[M], ch[M][27], tot, g[M], fail[M];
inline void insert(){
int p = 0;
FL(i, 1, len)
(!ch[p][g[i]]) && (ch[p][g[i]] = ++tot), p = ch[p][g[i]];
cnt[p]++;
}
inline void build(){
queue<int>q;
FL(i, 0, 26)if(ch[0][i])q.emplace(ch[0][i]);
while(!q.empty()){
int x = q.front(), z;q.pop();
FL(i, 0, 26)
if(!(z = ch[x][i]))ch[x][i] = ch[fail[x]][i];
else fail[z] = ch[fail[x]][i], q.emplace(z), cnt[z] += cnt[fail[z]];//*****
}
}
int query(){
int p = 0, ans = 0;
FL(i, 1, len)p = ch[p][g[i]], ans += cnt[p];
return ans;
}
int32_t main(){
int n, q;
cin >> n >> q;
FL(i, 1, n){
cin >> a + 1 >> b + 1, len = 0;
int n = strlen(a + 1), l = n + 1, r;
FL(i, 1, n)if(a[i] != b[i])r = i, cmin(l, i);
if(len = 0, l > n)continue;
FL(i, 1, l - 1)g[++len] = a[i] - 'a';g[++len] = 26;
FL(i, l, r)g[++len] = a[i] - 'a', g[++len] = b[i] - 'a';g[++len] = 26;
FL(i, r + 1, n)g[++len] = a[i] - 'a';insert();
}
build();
while(q--){
cin >> a + 1 >> b + 1, len = 0;
if(strlen(a + 1) != strlen(b + 1)){cout << 0 << endl;continue;}
int n = strlen(a + 1), l = n + 1, r;
FL(i, 1, n)if(a[i] != b[i])r = i, cmin(l, i);
FL(i, 1, l - 1)g[++len] = a[i] - 'a';g[++len] = 26;
FL(i, l, r)g[++len] = a[i] - 'a', g[++len] = b[i] - 'a';g[++len] = 26;
FL(i, r + 1, n)g[++len] = a[i] - 'a';
cout << query() << endl;
}
return 0;
}
例 2
阿狸的打字机。
求字符串 \(x\) 在 \(y\) 中的子串数量,相当于在 fail 树上将遍历 \(y\) 经过的节点标记,求 \(x\) 末尾的子树内有多少个标记点。
将询问离线,然后在 trie 上做 dfs 即可。
代码
#include<bits/stdc++.h>
using namespace std;
#define endl '\n'
#define FL(a,b,c) for(int a=(b),a##end=(c);a<=a##end;++a)
#define FR(a,b,c) for(int a=(b),a##end=(c);a>=a##end;--a)
#define lowbit(x) ((x)&-(x))
#define eb emplace_back
#define sz(x) (int)((x).size())
#define int long long
#define vt vector
#define fr first
#define se second
bool IOS=(cin.tie(0)->sync_with_stdio(0),0);
// #define LOCAL
#ifdef LOCAL
bool IOS1=(freopen(LOCAL".in", "r", stdin),1);
bool IOS2=(freopen(LOCAL".out", "w", stdout),1);
#endif
#define mmt(x, y) memset(x, y, sizeof x)
#define PII pair<int, int>
#define max(a, b)({auto f7r=(a);auto j3h=(b);f7r<j3h?j3h:f7r;})
#define cmax(a, b)({auto j3h=(b);(j3h>a)&&(a=j3h);})
#define min(a, b)({auto f7r=(a);auto j3h=(b);f7r>j3h?j3h:f7r;})
#define cmin(a, b)({auto j3h=(b);(j3h<a)&&(a=j3h);})
constexpr int N = 1e6 + 10;
char a[N];
int ch[N][26], tot, fail[N], fa[N], T[N], cnt;
int m, ans[N], w[N], L[N], R[N], n;
vt<int>e[N];vt<PII>Q[N];
void insert(){
n = strlen(a + 1);int p = 0;
FL(i, 1, n){
if(a[i] >= 'a'){
if(!ch[p][a[i] - 'a'])ch[p][a[i] - 'a'] = ++tot, fa[tot] = p;
p = ch[p][a[i] - 'a'];
}
else if(a[i] == 'B')p = fa[p];
else T[++cnt] = p;
}
}
void dfs(int x){L[x] = ++tot;for(auto v : e[x])dfs(v);R[x] = tot;}
void build(){
queue<int>q;FL(i, 0, 25)if(ch[0][i])q.emplace(ch[0][i]), e[0].eb(ch[0][i]);
while(!q.empty()){
int u = q.front(), z;q.pop();
FL(i, 0, 25)
if(z = ch[fail[u]][i], !ch[u][i])ch[u][i] = z;
else fail[ch[u][i]] = z, q.emplace(ch[u][i]), e[z].eb(ch[u][i]);
}
tot = 0, dfs(0);
}
void add(int x, int v){while(x <= tot)w[x] += v, x += lowbit(x);}
int query(int x, int s = 0){while(x)s += w[x], x -= lowbit(x);return s;}
void get(){
int p = 0;
FL(i, 1, n){
if(a[i] >= 'a'){
p = ch[p][a[i] - 'a'], add(L[p], 1);
if(!(T[p]++))for(auto z : Q[p])ans[z.se] = query(R[z.fr]) - query(L[z.fr] - 1);
}
else if(a[i] == 'B')add(L[p], -1), p = fa[p];
}
}
int32_t main(){
cin >> a + 1, insert(), build(), cin >> m;
int x, y;
FL(i, 1, m)cin >> x >> y, Q[T[y]].eb(T[x], i);
mmt(T, 0), get();
FL(i, 1, m)cout << ans[i] << endl;
return 0;
}

浙公网安备 33010602011771号