AC自动机学习笔记
前置知识1:KMP
设 \(a\) 是长度为 \(n\) 的原串, \(b\) 是长度为 \(m\) 匹配串。
假设现在 \(a\) 为 ABCABCDDD \(b\) 为 ABCABB 。
匹配过程如下:
ABCABCDDD
ABCABB
前四位相同,但是第五位不同
ABCABCDDD
ABCABB
这样跳的规律如下(假设下一个要跳到k)
当a[1~k-1]=a[j-k~j-1]时,则可以跳
求 \(nxt\) 数组:
nxt[0]=nxt[1]=0;
for(register int i=2,j=0;i<=n;i++)
{
while(j&&b[i]!=b[j+1])j=nxt[j];
if(b[i]==b[j+1])j++;
nxt[i]=j;
}
前置知识2:Trie
其实好像比 KMP 还要简单一些。
直接弄一个 \(t\) 数组, \(t_{i,j}\) 表示编号为 \(i\) 的点的字母为 \(j\) 出边的编号。
第一个操作是插入操作。
直接沿着每一个字母往下走,遇到没有的点新建一个即可。
inline void insert(string a)
{
int p=0;
for(register int i=0;i<a.size();i++)
{
int c=a[i]-'a';
if(!t[p][c])t[p][c]=++cnt;
p=t[p][c];
}
v[p]=1;
}
现在是查询操作。
同样的向下走,如果当前的字母根本没有边或者走到的最后一个点不是某一个串的末尾,说明没有找到,反之则找到。
inline bool find(string a)
{
int p=0;
for(register int i=0;i<a.size();i++)
{
int c=a[i]-'a';
if(!t[p][c])return false;
p=t[p][c];
}
return v[p];
}
下面是一个具体的实现例子(P2580)。
#include<bits/stdc++.h>
#define WRONG 0
#define OK 1
#define REPEAT 2
using namespace std;
const int MAXN=4e5+5;
int n,m;
int t[MAXN][26],cnt;
bitset<MAXN>v;
bitset<MAXN>used;
inline void insert(string a)
{
int p=0;
for(register int i=0;i<a.size();i++)
{
int c=a[i]-'a';
if(!t[p][c])t[p][c]=++cnt;
p=t[p][c];
}
v[p]=1;
}
inline int find(string a)
{
int p=0;
for(register int i=0;i<a.size();i++)
{
int c=a[i]-'a';
if(!t[p][c])return WRONG;
p=t[p][c];
}
if(v[p])
{
if(used[p])return REPEAT;
else
{
used[p]=1;
return OK;
}
}
return WRONG;
}
int main()
{
scanf("%d",&n);
for(register int i=1;i<=n;i++)
{
string op;
cin>>op;
insert(op);
}
scanf("%d",&m);
for(register int i=1;i<=m;i++)
{
string op;
cin>>op;
int ans=find(op);
if(ans==WRONG)puts("WRONG");
if(ans==OK)puts("OK");
if(ans==REPEAT)puts("REPEAT");
}
return 0;
}
现在将两个结合起来,就是AC自动机。
AC自动机
现在 \(nxt\) 是每一个后缀和每一个前缀的最大长度。
#include<bits/stdc++.h>
using namespace std;
const int MAXN=500005;
int T,n;
int t[MAXN][26],cnt;
int num[MAXN];
int nxt[MAXN];
inline void insert(string a)
{
int p=0;
for(register int i=0;i<a.size();i++)
{
int c=a[i]-'a';
if(!t[p][c])t[p][c]=++cnt;
p=t[p][c];
}
num[p]++;
}
inline void build()
{
queue<int>q;
for(register int i=0;i<26;i++)
if(t[0][i])q.push(t[0][i]);
while(!q.empty())
{
int x=q.front();
q.pop();
for(register int i=0;i<26;i++)
{
int y=t[x][i];
if(!y)t[x][i]=t[nxt[x]][i];
else
{
nxt[y]=t[nxt[x]][i];
q.push(y);
}
}
}
}
int main()
{
scanf("%d",&T);
while(T--)
{
memset(t,0,sizeof t);
memset(num,0,sizeof num);
memset(nxt,0,sizeof nxt);
cnt=0;
scanf("%d",&n);
for(register int i=1;i<=n;i++)
{
string op;
cin>>op;
insert(op);
}
build();
string a;
cin>>a;
int ans=0;
for(register int i=0,j=0;i<a.size();i++)
{
int c=a[i]-'a';
j=t[j][c];
int p=j;
while(p)
{
ans+=num[p];
num[p]=0;
p=nxt[p];
}
}
printf("%d\n",ans);
}
return 0;
}

浙公网安备 33010602011771号