BZOJ 3530 [Sdoi2014]数数

题解:建立AC自动机,然后Dp

考虑长度与n相等时

f[i][j][2]表示第i位匹配到AC自动机第j号节点,是否顶着上界的方案数

转移枚举这一位填什么

注意,如果当前节点沿Fail树能走到单词节点就不能转移到他

长度<lenn不用考虑顶上界

问题:不明白最后统计答案的方式

#include<iostream>
#include<cstdio>
#include<cstring>
#include<queue>
using namespace std;
const int maxn=2009;
const int mm=1000000007;

int n,m;
char s[maxn];
int a[maxn];

int nn;
int ch[maxn][10];
int val[maxn];
void Ins() {
	int n=strlen(s);
	int u=0;
	for(int i=0; i<n; ++i) {
		int c=s[i]-'0';
		if(!ch[u][c])ch[u][c]=++nn;
		u=ch[u][c];
	}
	val[u]=1;
}

queue<int>q;
int fai[maxn];
void Getfail(){
	for(int c=0;c<=9;++c){
		int v=ch[0][c];
		if(v)q.push(v);
	}
	while(!q.empty()){
		int u=q.front();q.pop();
		for(int c=0;c<=9;++c){
			int v=ch[u][c];
			if(!v)continue;
			q.push(v);
			val[v]|=val[u];
			int j=fai[u];
			while((j)&&(!ch[j][c]))j=fai[j];
			fai[v]=ch[j][c];
			val[v]|=val[fai[v]];
		}
	}
}

long long f[maxn][maxn][2];
long long ans;

int main() {
	scanf("%s",s);
	n=strlen(s);
	for(int i=1; i<=n; ++i)a[i]=s[i-1]-'0';
	scanf("%d",&m);
	while(m--) {
		scanf("%s",s);
		Ins();
	}
	Getfail();
	
	memset(f,0,sizeof(f));
	for(int c=1;c<=9;++c){
		int v=ch[0][c];
		if(val[v])continue;
		if(c<a[1])f[1][v][0]++;
		if(c==a[1])f[1][v][1]++;
	}
	for(int i=1;i<n;++i){
		for(int j=0;j<=nn;++j){
			for(int c=0;c<=9;++c){
				int u=j;
				while((u)&&(!ch[u][c]))u=fai[u];
				int v=ch[u][c];
				if(val[v])continue;
				if(c<a[i+1]){
					f[i+1][v][0]=(f[i+1][v][0]+f[i][j][0]+f[i][j][1])%mm;
				}
				if(c==a[i+1]){
					f[i+1][v][0]=(f[i+1][v][0]+f[i][j][0])%mm;
					f[i+1][v][1]=(f[i+1][v][1]+f[i][j][1])%mm;
				}
				if(c>a[i+1]){
					f[i+1][v][0]=(f[i+1][v][0]+f[i][j][0])%mm;
				}
			}
		}
	}
	for(int j=0;j<=nn;++j){
		ans=(ans+f[n][j][0]+f[n][j][1])%mm;
	}
//	ans=(f[n][0][0]+f[n][0][1])%mm;
//	for(int c=0;c<=9;++c){
//		int v=ch[0][c];
//		if(v){
//			ans=(ans+f[n][v][0]+f[n][v][1])%mm;
//		}
//	}
//	cout<<ans<<endl;
	--n;
	memset(f,0,sizeof(f));
	for(int c=1;c<=9;++c){
		int v=ch[0][c];
		if(val[v])continue;
		f[1][v][0]++;
	}
	for(int i=1;i<n;++i){
		for(int j=0;j<=nn;++j){
			for(int c=0;c<=9;++c){
				int u=j;
				while((u)&&(!ch[u][c]))u=fai[u];
				int v=ch[u][c];
				if(val[v])continue;
				f[i+1][v][0]=(f[i+1][v][0]+f[i][j][0])%mm;
			}
		}
//		ans=(ans+f[i+1][0][0])%mm;
//		for(int c=0;c<=9;++c){
//			int v=ch[0][c];
//			if(!v)continue;
//			ans=(ans+f[i+1][v][0])%mm;
//		}
	}
	for(int i=1;i<=n;++i){
		for(int j=0;j<=nn;++j){
			ans=(ans+f[i][j][0])%mm;
		}
	}
	cout<<ans<<endl;
	return 0;
}

  

posted @ 2018-02-20 18:19  ws_zzy  阅读(137)  评论(0编辑  收藏  举报