【BZOJ4861】[Beijing2017]魔法咒语 矩阵乘法+AC自动机+DP

【BZOJ4861】[Beijing2017]魔法咒语

题意:别看BZ的题面了,去看LOJ的题面吧~

题解:显然,数据范围明显的分成了两部分:一个是L很小,每个基本词汇长度未知;一个是L很大,每个基本词汇的长度是1或2。看来只能写两份代码了。

对于L很小的,我们先将禁忌串建成一个AC自动机,然后预处理出to[i][j]表示AC自动机中的第i个节点在加入基本词汇j后会到达的节点。然后设f[i][j]表示总长度为i,匹配到第j个节点的方案数。然后DP一下就好了。

对于L很大的,我们想到矩乘,设ans[i][j]表示总长度为i,匹配到第j个节点的方案数。但是ans[i]这个矩阵由ans[i-1]和ans[i-2]两个矩阵转移过来,所以我们直接用分块矩阵的乘法,即:

#include <cstdio>
#include <cstring>
#include <iostream>
#include <queue>
using namespace std;
typedef long long ll;
const ll mod=1000000007;
int n,m,N,M,L,mx,sum;
struct mat
{
	ll v[210][210];
	mat (){memset(v,0,sizeof(v));}
	ll* operator [](int a){return v[a];}
	mat operator * (mat a)
	{
		mat ret;
		int i,j,k;
		for(i=1;i<=2*M;i++)	for(j=1;j<=2*M;j++)	for(k=1;k<=2*M;k++)	(ret[i][j]+=v[i][k]*a[k][j])%=mod;
		return ret;
	}
}ans,x;
int l1[60],to[110][60];
ll f[110][110];
queue<int> q;
struct node
{
	int ch[26],fail,cnt;
}p[110];
char s1[60][110],s2[60][110];
void build()
{
	q.push(1);
	int i,j,k,a,u;
	while(!q.empty())
	{
		u=q.front(),q.pop();
		for(i=0;i<26;i++)
		{
			if(!p[u].ch[i])
			{
				if(u==1)	p[u].ch[i]=1;
				else	p[u].ch[i]=p[p[u].fail].ch[i];
				continue;
			}
			q.push(p[u].ch[i]);
			if(u==1)
			{
				p[p[u].ch[i]].fail=1;
				continue;
			}
			p[p[u].ch[i]].fail=p[p[u].fail].ch[i];
			p[p[u].ch[i]].cnt|=p[p[p[u].fail].ch[i]].cnt;
		}
	}
	for(i=1;i<=M;i++)	for(j=1;j<=n;j++)
	{
		u=i,a=strlen(s1[j]);
		if(p[u].cnt)	to[i][j]=-1;
		for(k=0;k<a;k++)
		{
			u=p[u].ch[s1[j][k]-'a'];
			if(p[u].cnt)	break;
		}
		if(k==a)	to[i][j]=u;
		else	to[i][j]=-1;
	}
}
void DP()
{
	int i,j,k,a;
	f[0][1]=1;
	for(i=0;i<L;i++)	for(j=1;j<=M;j++)	for(k=1;k<=n;k++)
	{
		if(to[j][k]==-1)	continue;
		a=strlen(s1[k]);
		if(a+i<=L)	(f[a+i][to[j][k]]+=f[i][j])%=mod;
	}
	for(i=1;i<=M;i++)	sum=(sum+f[L][i])%mod;
	printf("%d",sum);	
}
void pm(int y)
{
	while(y)
	{
		if(y&1)	ans=ans*x;
		x=x*x,y>>=1;
	}
}
void MM()
{
	int i,j;
	for(i=1;i<=M;i++)
	{
		for(j=1;j<=n;j++)
		{
			if(to[i][j]==-1)	continue;
			if(strlen(s1[j])==1)	x[i][to[i][j]]++;
			else	x[i+M][to[i][j]]++;
		}
		x[i][i+M]++;
	}
	ans[1][1]=1;
	pm(L);
	for(i=1;i<=M;i++)	sum=(sum+ans[1][i])%mod;
	printf("%d",sum);
}
int main()
{
	scanf("%d%d%d",&n,&m,&L);
	int i,j,a,b,u;
	N=1,M=1;
	for(i=1;i<=n;i++)	scanf("%s",s1[i]),a=strlen(s1[i]),mx=max(mx,a);
	for(i=1;i<=m;i++)
	{
		scanf("%s",s2[i]),a=strlen(s2[i]);
		for(u=1,j=0;j<a;j++)
		{
			b=s2[i][j]-'a';
			if(!p[u].ch[b])	p[u].ch[b]=++M;
			u=p[u].ch[b];
		}
		p[u].cnt=1;
	}
	build();
	if(mx<=2)	MM();
	else	DP();
	return 0;
}
posted @ 2017-07-11 14:55  CQzhangyu  阅读(368)  评论(0编辑  收藏  举报