【算法学习】AC自动机

P5357 【模板】AC 自动机

该题解极度敷衍,新手勿入。

正确的位置

字典树和 kmp 思想的结合。

先把模式串建成一个字符串,利于我们进行各种操作。

47101eb406112c372a3a6be82c5828c6

接下来就算 fail 指针了,fail 指针主要做的是在失配时智能回退,并不会推太多节点导致复杂度很大。

f5cd6313d2574635f8cc6397be70d920

构建时遵循:

  1. 逐层构建,且第二层节点的fail都指向根节点

  2. 这个节点的子结点的fail为该节点的fail指向的节点的同样字母的节点。

查询时要回跳,我们先找到了编号4这个点,编号4的fail连向编号7这个点,编号7的fail连向编号9这个点。那么我们要更新编号4这个点的值,同时也要更新编号7和编号9,这就是暴力跳fail的过程。

但是如果统计出现次数,每次查询多次跳复杂度很高,那么我们可不可以在找到的点打一个标记,最后再一次性将标记全部上传 来 更新其他点的ans。

拓扑排序!

我们使每一个点向它的fail指针连一条边,明显,每一个点的出度为1(fail只有一个),入度可能很多,所以我们就不需要像拓扑排序那样先建个图了,直接往fail指针跳就可以了。但入度数组in还是要存的。

最后我们根据fail指针建好图后(想象一下,程序里不用实现的),一定是一个DAG,具体原因不解释(很简单的),那么我们就直接在上面跑拓扑排序,然后更新ans就可以了。

#include<bits/stdc++.h>
#define ll long long
//#define int ll
#define ls t[p].l
#define rs t[p].r
#define re register 
#define pb push_back
#define pir pair<int,int> 
#define f(a,x,i) for(int i=a;i<=x;i++)
#define fr(a,x,i) for(int i=a;i>=x;i--)
#define lowbit(x) x&-x;
using namespace std;
const int N=2e5+5;
const int M=2e6+5;
const int mod=1e9+7;
//const int INF=1e17+7;
mt19937 rnd(251);

string str[N];

int sum[N];

struct tree{
	int fail;
	int node[28];
	vector<int> end;
}a[N];

int cnt;
int ans[N];
int in[N];

void build(string s,int id){
	int now=0;
	int l=s.size();
	for(int i=0;i<l;i++){
		if(a[now].node[s[i]-'a']==0){
			a[now].node[s[i]-'a']=++cnt;
		}
		now=a[now].node[s[i]-'a'];
	}
	a[now].end.push_back(id);
}

void get_fail(){
	queue<int> q;
	for(int i=0;i<26;i++){
		if(a[0].node[i]!=0){
			a[a[0].node[i]].fail=0;
			q.push(a[0].node[i]);
		}
	}
	while(!q.empty()){
		int x=q.front();
		q.pop();
		for(int i=0;i<26;i++){
			if(a[x].node[i]!=0){
				a[a[x].node[i]].fail=a[a[x].fail].node[i];//跳一次可以跳多次fail 
				q.push(a[x].node[i]);
			}
			else{
				a[x].node[i]=a[a[x].fail].node[i];//重复跳fail 
			}
		}	
	}
}

void query(string t){
	int n=t.size();
	int now=0;
	
	for(int i=0;i<n;i++){
		now=a[now].node[t[i]-'a'];
		ans[now]++;
	}
	
	for(int i=1;i<=cnt;i++){
		in[a[i].fail]++;
	}
	
	queue<int> q;
	for(int i=1;i<=cnt;i++){
		if(in[i]==0){
			q.push(i);
		}
	}
	
	while(!q.empty()){
		int x=q.front();
		q.pop();
		int y=a[x].fail;
		if(y!=0){
			ans[y]+=ans[x];
			if(--in[y]==0){
				q.push(y);
			}	
		}	
	}
	
	for(int i=0;i<=cnt;i++){
		for(int j=0;j<a[i].end.size();j++){
			sum[a[i].end[j]]+=ans[i];
		}
	}
}

void solve(){
	int n;
	cin>>n;
	for(int i=1;i<=n;i++){
		cin>>str[i];
		build(str[i],i);
	}
	
	a[0].fail=0;

	get_fail();
	
	string t;
	cin>>t;
	query(t);
	for(int i=1;i<=n;i++){
		cout<<sum[i]<<"\n";
	}
}
signed main(){
//    freopen("a.in","r",stdin);
//    freopen("a.out","w",stdout);
    ios::sync_with_stdio(0);
    cin.tie(nullptr);   
    int t=1;
//    cin>>t;
    while(t--){
    	solve();
	}
	    
    return 0;
}

P3966 [TJOI2013] 单词

每个单词用特殊词隔开如{,然后直接AC自动机统计次数就行了。

SP413 WPUZZLES - Word Puzzles

AC自动机。

暴力的话枚举每个位置再向8个方向查询字符串是否出现复杂度太高,考虑从方向上下手,例如正北方向:从最下面一行开始,向上走,形成一列字符串,用 AC 自动机匹配。

P3121 [USACO15FEB] Censoring G

用AC自动机匹配串,用上来回溯,维护两个栈,一个栈记录遍历到当前字符所在的节点编号,第二个栈维护需要输出的字符编号,这两个公用一个top,回溯时直接退top-长度就行。

P4052 [JSOI2007] 文本生成器

AC自动机上dp,看到至少一个模式串,正难则反,总数量-完全没有模式串的数量,设 \(f_{i,j}\) 为长度i,在AC自动机节点编号为j的没有模式串的数量,转移肯定根据没有以这个串及字串结尾,转移过来。

f[0][0]=1;
for(int i=1;i<=m;i++){
	for(int j=0;j<=cnt;j++){
		for(int k=0;k<26;k++){
			if(!a[a[j].node[k]].end){
				f[i][a[j].node[k]]=(f[i][a[j].node[k]]+f[i-1][j])%mod;
			}
		}
	}
}

当然ap为模式串,map也不能有,所以有:

a[a[x].node[i]].end|=a[a[a[x].node[i]].fail].end;

最后计算就行:

#include<bits/stdc++.h>
#define ll long long
//#define int ll
#define ls t[p].l
#define rs t[p].r
#define re register 
#define pb push_back
#define pir pair<int,int> 
#define f(a,x,i) for(int i=a;i<=x;i++)
#define fr(a,x,i) for(int i=a;i>=x;i--)
#define lowbit(x) x&-x;
using namespace std;
const int N=1e4+5;
const int M=2e6+5;
const int mod=1e4+7;
//const int INF=1e17+7;
mt19937 rnd(251);

struct ss{
	int fail;
	int node[28];
	int end;
}a[N];

string str[N];

int cnt=0;

void build(string s,int id){
	int now=0;
	int l=s.size();
	for(int i=0;i<l;i++){
		if(a[now].node[s[i]-'A']==0){
			a[now].node[s[i]-'A']=++cnt;
		}
		now=a[now].node[s[i]-'A'];
	}
	a[now].end=1;
}

void get_fail(){
	int now=0;
	
	queue<int> q;
	for(int i=0;i<26;i++){
		if(a[0].node[i]!=0){
			a[a[0].node[i]].fail=0;
			q.push(a[0].node[i]);
		}
	}
	
	while(!q.empty()){
		int x=q.front();
		q.pop();
		
		for(int i=0;i<26;i++){
			if(a[x].node[i]!=0){
				a[a[x].node[i]].fail=a[a[x].fail].node[i];
				a[a[x].node[i]].end|=a[a[a[x].node[i]].fail].end;
				q.push(a[x].node[i]);
			}
			else{
				a[x].node[i]=a[a[x].fail].node[i];
			}
		}	
	}
}


int n,m;
int f[105][N];



void query(){
	f[0][0]=1;
	for(int i=1;i<=m;i++){
		for(int j=0;j<=cnt;j++){
			for(int k=0;k<26;k++){
				if(!a[a[j].node[k]].end){
					f[i][a[j].node[k]]=(f[i][a[j].node[k]]+f[i-1][j])%mod;
				}
			}
		}
	}
	int ans=1;
	for(int i=1;i<=m;i++){
		ans*=26;
		ans%=mod;
	}
	for(int i=0;i<=cnt;i++){
		ans=(ans-f[m][i]+mod)%mod;
	}
	cout<<ans;
}

void solve(){
	cin>>n>>m;
	
	for(int i=1;i<=n;i++){
		cin>>str[i];
		build(str[i],i);
	}
	
	get_fail();
	
	query();	
}
signed main(){
//    freopen("a.in","r",stdin);
//    freopen("a.out","w",stdout);
    ios::sync_with_stdio(0);
    cin.tie(nullptr);   
    int t=1;
//    cin>>t;
    while(t--){
    	solve();
//    	cout<<"\n";
	}
	    
    return 0;
}

P3041 [USACO12JAN] Video Game G

上一题是不可转移,这题要转移,结尾时end++,以map结尾ap也算一个,所以有a[x].end+=a[a[x].fail].end;

然后转移就行,注意初始化,因为有些不可能的状态会导致他错误转移。

#include<bits/stdc++.h>
#define ll long long
//#define int ll
#define ls t[p].l
#define rs t[p].r
#define re register 
#define pb push_back
#define pir pair<int,int> 
#define f(a,x,i) for(int i=a;i<=x;i++)
#define fr(a,x,i) for(int i=a;i>=x;i--)
#define lowbit(x) x&-x;
using namespace std;
const int N=1e4+5;
const int M=2e6+5;
const int mod=1e4+7;
//const int INF=1e17+7;
mt19937 rnd(251);

struct ss{
	int fail;
	int node[5];
	int end;
}a[N];

string str[N];

int cnt=0;

void build(string s,int id){
	int now=0;
	int l=s.size();
	for(int i=0;i<l;i++){		
		if(a[now].node[s[i]-'A']==0){
			a[now].node[s[i]-'A']=++cnt;
		}
		now=a[now].node[s[i]-'A'];
	}
	a[now].end++;
}

void get_fail(){
	int now=0;
	
	queue<int> q;
	for(int i=0;i<3;i++){
		if(a[0].node[i]!=0){
			a[a[0].node[i]].fail=0;
			q.push(a[0].node[i]);
		}
	}
	
	while(!q.empty()){
		int x=q.front();
		q.pop();
		
		a[x].end+=a[a[x].fail].end; 
		
		for(int i=0;i<3;i++){
			if(a[x].node[i]!=0){
				a[a[x].node[i]].fail=a[a[x].fail].node[i];
				q.push(a[x].node[i]);
			}
			else{
				a[x].node[i]=a[a[x].fail].node[i];
			}
		}	
	}
}


int n,m;
int f[1005][1005];



void query(){
	//其他状态不可达!!! 
	memset(f,-0x3f,sizeof f);
	f[0][0]=0; 
	for(int i=1;i<=m;i++){
		for(int j=0;j<=cnt;j++){
			for(int k=0;k<3;k++){
				int y=a[j].node[k];
				f[i][y]=max(f[i][y],f[i-1][j]+a[y].end);
			}
		}
	}
	int ans=0; 
	for(int i=0;i<=cnt;i++){
		ans=max(ans,f[m][i]);
	}
	cout<<ans;
}

void solve(){
	cin>>n>>m;
	
	for(int i=1;i<=n;i++){
		cin>>str[i];
		build(str[i],i);
	}
	
	get_fail();
	
	query();	
}
signed main(){
//    freopen("a.in","r",stdin);
//    freopen("a.out","w",stdout);
    ios::sync_with_stdio(0);
    cin.tie(nullptr);   
    int t=1;
//    cin>>t;
    while(t--){
    	solve();
//    	cout<<"\n";
	}
    return 0;
}

P3311 [SDOI2014] 数数

类似于文本编辑器,有了限制不能超过该数字,所以在AC自动机上数位dp,我这会知道前导0是干嘛的了,他是要解决不定长的数位dp,所以加上,统计答案时只记录实际数字就行。

#include<bits/stdc++.h>
#define ll long long
//#define int ll
#define ls t[p].l
#define rs t[p].r
#define re register 
#define pb push_back
#define pir pair<int,int> 
#define f(a,x,i) for(int i=a;i<=x;i++)
#define fr(a,x,i) for(int i=a;i>=x;i--)
#define lowbit(x) x&-x;
using namespace std;
const int N=1e4+5;
const int M=2e6+5;
const int mod=1e9+7;
//const int INF=1e17+7;
mt19937 rnd(251);

string str[N];
int cnt;

struct ss{
	int fail;
	int node[27];
	int end;
	int op;
}a[N]; 

void build(string s){
	int now=0;
	for(int i=0;i<s.size();i++){
		int c=s[i]-'0';
		if(a[now].node[c]==0){
			a[now].node[c]=++cnt;
			a[a[now].node[c]].op=c;
		}
		now=a[now].node[c];
	}
	a[now].end=1;
}

void get_fail(){
	
	queue<int> q;
	for(int i=0;i<10;i++){
		if(a[0].node[i]!=0){
			a[a[0].node[i]].fail=0;
			q.push(a[0].node[i]);
		}
	}
	
	while(!q.empty()){
		int x=q.front();
		q.pop();
		
		for(int i=0;i<10;i++){
			if(a[x].node[i]){
				a[a[x].node[i]].fail=a[a[x].fail].node[i];
				q.push(a[x].node[i]); 
			}	
			else{
				a[x].node[i]=a[a[x].fail].node[i]; 
			}
		}
	}
	
	for(int i=0;i<=cnt;i++){
		a[i].end|=a[a[i].fail].end;
	}	
}

int f[1250][1250][3][3];
//第几位,这个数字,上界 
void query(string t){
	
	f[0][0][1][1]=1;
	
	for(int i=0;i<t.size();i++){//第几位 
		for(int p=0;p<=cnt;p++){
			for(int lead=0;lead<=1;lead++){
				for(int lim=0;lim<=1;lim++){
					int up=lim?(t[i]-'0'):9;
					for(int k=0;k<=up;k++){
						int lim_=lim&&(k==up);
						int lead_=lead&&(k==0);
						if(lead_){
							f[i+1][0][lim_][1]=(f[i+1][0][lim_][1]+f[i][p][lim][lead])%mod;
						}	
						else{
							int y=a[p].node[k];
							if(a[y].end) continue;
							f[i+1][y][lim_][0]=(f[i+1][y][lim_][0]+f[i][p][lim][lead])%mod;
						}
					}
				}	
			}
		}
	}
	
	int ans=0;
	for(int i=0;i<=cnt;i++){
		for(int lim=0;lim<=1;lim++){
			ans=(ans+f[t.size()][i][lim][0])%mod;
		}
	}
	cout<<ans;	
}

string t;

void solve(){
	a[0].op=10;
	cin>>t;
	int n;
	cin>>n;
	for(int i=1;i<=n;i++){
		cin>>str[i];
		build(str[i]);
	}
	
	get_fail();
	
	query(t);
}


signed main(){
//    freopen("a.in","r",stdin);
//    freopen("a.out","w",stdout);
    ios::sync_with_stdio(0);
    cin.tie(nullptr);   
    int t=1;
//    cin>>t;
    while(t--){
    	solve();
//    	cout<<"\n";
	}
    return 0;
}

P2444 [POI 2000] 病毒

说白了还是文本编辑器,只不过要一直没有模式串,那可以想到在不经过危险节点的情况下还能回到根节点就肯定可以无限长了,那又等于有环,直接dfs判环就行了。

get_fail就相当于建了完全图了。

#include<bits/stdc++.h>
#define ll long long
//#define int ll
#define ls t[p].l
#define rs t[p].r
#define re register 
#define pb push_back
#define pir pair<int,int> 
#define f(a,x,i) for(int i=a;i<=x;i++)
#define fr(a,x,i) for(int i=a;i>=x;i--)
#define lowbit(x) x&-x;
using namespace std;
const int N=1e6+5;
const int M=2e6+5;
const int mod=1e9+7;
//const int INF=1e17+7;
mt19937 rnd(251);

string str[N];
int cnt;

struct ss{
	int fail;
	int node[3];
	int end;
}a[N]; 

void build(string s){
	int now=0;
	for(int i=0;i<s.size();i++){
		int c=s[i]-'0';
		if(a[now].node[c]==0){
			a[now].node[c]=++cnt;
		}
		now=a[now].node[c];
	}
	a[now].end=1;
}

void get_fail(){
	
	queue<int> q;
	for(int i=0;i<=1;i++){
		if(a[0].node[i]!=0){
			a[a[0].node[i]].fail=0;
			q.push(a[0].node[i]);
		}
	}
	
	while(!q.empty()){
		int x=q.front();
		q.pop();
		a[x].end|=a[a[x].fail].end;
		for(int i=0;i<=1;i++){
			if(a[x].node[i]){
				a[a[x].node[i]].fail=a[a[x].fail].node[i];
				q.push(a[x].node[i]); 
			}	
			else{
				a[x].node[i]=a[a[x].fail].node[i]; 
			}
		}
	}

}

bool vis[N],f[N]; 

void query(int x){
	vis[x]=1;
	for(int i=0;i<=1;i++){
		int y=a[x].node[i];
		if(vis[y]){
			cout<<"TAK";
			exit(0);	
		}
		else if(!a[y].end&&!f[y]){
			f[y]=1;
			query(y);
		}
	}
	vis[x]=0;
}

string t;

void solve(){
	int n;
	cin>>n;
	for(int i=1;i<=n;i++){
		cin>>str[i];
		build(str[i]);
	}
	
	get_fail();
	
	query(0);
	
	cout<<"NIE";
}


signed main(){
//    freopen("a.in","r",stdin);
//    freopen("a.out","w",stdout);
    ios::sync_with_stdio(0);
    cin.tie(nullptr);   
    int t=1;
//    cin>>t;
    while(t--){
    	solve();
//    	cout<<"\n";
	}
	    
    return 0;
}

P2292 [HNOI2004] L 语言

思路好想,类似于上面用栈删除的那题,这个合法那么退这个模式串长度那个位置是否合法,优化。

  1. 发现模式串只有20,我们只会在当前位置退最多20步然后找那个位置的值是否合法,其他位置就没什么用了,用bitset存状态。

  2. i-ans<=25,不合法时退出,我们串最长只有25,差的太多了就不可能转移过来了。

#include<bits/stdc++.h>
#define ll long long
//#define int ll
#define ls t[p].l
#define rs t[p].r
#define re register 
#define pb push_back
#define pir pair<int,int> 
#define f(a,x,i) for(int i=a;i<=x;i++)
#define fr(a,x,i) for(int i=a;i>=x;i--)
#define lowbit(x) x&-x;
using namespace std;
const int N=1e4+10;
const int M=2e6+5;
const int mod=1e9+7;
//const int INF=1e17+7;
mt19937 rnd(251);

char str[30];
int cnt;

struct ss{
	int fail;
	int node[28];
	vector<short> end;
}a[N]; 

void build(char *s,int id){
	int now=0;
	for(int i=0;i<strlen(s);i++){
		int c=s[i]-'a';
		if(a[now].node[c]==0){
			a[now].node[c]=++cnt;
		}
		now=a[now].node[c];
	}
	a[now].end.push_back(strlen(s));
}

void get_fail(){
	queue<int> q;
	for(int i=0;i<26;i++){
		if(a[0].node[i]!=0){
			a[a[0].node[i]].fail=0;
			q.push(a[0].node[i]);
		}
	}
	
	while(!q.empty()){
		int x=q.front();
		q.pop();
		for(int i=0;i<a[a[x].fail].end.size();i++){
			int len=a[a[x].fail].end[i];
			a[x].end.push_back(len);
        }
		for(int i=0;i<26;i++){
			if(a[x].node[i]){
				a[a[x].node[i]].fail=a[a[x].fail].node[i];
				q.push(a[x].node[i]); 
			}	
			else{
				a[x].node[i]=a[a[x].fail].node[i]; 
			}
		}
	}
}

void query(string s){
	bitset<20> f;
	
	int ans=0,now=0;
	for(int i=0;i<s.size()&&i-ans<=25;i++){
//		cout<<f<<"\n";
		int c=s[i]-'a';
		now=a[now].node[c];
		if(!a[now].end.empty()){
			for(int j=0;j<a[now].end.size();j++){
				if(i-a[now].end[j]<0){
					f[0]=1;
				}
				else{
					f[0]=f[a[now].end[j]];
				}
				if(f[0]){
					ans=i+1;
					break;
				}
			}
		}
		f<<=1;
	}
	cout<<ans<<"\n";
}

void solve(){
	int n,m;
	cin>>n>>m;
	for(int i=1;i<=n;i++){
		cin>>str;
		build(str,i);
	}
	
	get_fail();
	
	string t;
	for(int i=1;i<=m;i++){
		cin>>t;
		query(t);
	}
}


signed main(){
//   freopen("a.in","r",stdin);
//   freopen("a.out","w",stdout);
    ios::sync_with_stdio(0);
    cin.tie(nullptr);   
//    int t=1;
//    cin>>t;
//    while(t--){
	solve();
//    	cout<<"\n";
//	}
    return 0;
}
posted @ 2025-11-10 21:10  sad_lin  阅读(5)  评论(0)    收藏  举报