【洛谷&ybtoj】【AC自动机】文本生成器

传送门

洛谷P4052 [JSOI2007]文本生成器

题解

正难则反。求不可读文本更加好求。
按照 AC 自动机先设计一个套路 DP:\(dp(i,j)\) 表示生成长度为 \(i\) ,在 Trie 图上结点 \(j\) 的不可读文本数。
BFS 求 \(fail\) 数组的过程中记录哪些节点不能走(因为走了就会产生可读文本),转移式:\(vis(trie_{p,i})|=vis_p\) 。其实这也是 AC 自动机的一种套路记法。

为什么总是说套路?因为做了几道类似的题,是分别做上面的套路操作,而这道题结合到一起了。(我甚至可以自己做出来了=w=)

这样转移就比较显然了,下面是转移的代码:

int find(int x)
{
	if(!x) return x;
	if(vis[x]) return x;
	return fail[x]=find(fail[x]);	
}
inl int DP()
{
	dp[0][1]=dp[0][0]=1;
	for(int i=1;i<=m;i++)	
		for(int p=1;p<=tot;p++)
			for(int j=0;j<26;j++)
			{	
				if(find(trie[p][j])) continue;
				dp[i][trie[p][j]]+=dp[i-1][p];
				dp[i][trie[p][j]]%=mod;
			}
	int ans=0;
	for(int i=1;i<=tot;i++)
		ans=(ans+dp[m][i])%mod;
	return ans;
}

一个小错误

但是转移的时候有个判断写的不对劲(还是理解不够深入),就是那个判断find()的地方,原来写成了下面这样:

	if(vis[trie[p][j]]||vis[p]) continue;

对比一下发现问题很明显吧,只判断当前位置和上一个转移来的位置并不全面,还要考虑到之前的所有位置。(其实这个也是套路,不过我由于不熟练没想到)

关于 find 函数

find函数相当于不断跳 \(fail\) ,看过程中有没有不能走的点,甚至可以像并查集一样路径压缩,这种思路很值得学习。

完整代码

#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define inl inline 
const int INF = 0x3f3f3f3f,N = 1e6+10,mod = 1e4+7;
inline ll read()
{
	ll ret=0;char ch=' ',c=getchar();
	while(!(c>='0'&&c<='9')) ch=c,c=getchar();
	while(c>='0'&&c<='9') ret=(ret<<1)+(ret<<3)+c-'0',c=getchar();
	return ch=='-'?-ret:ret;
}
int n,m;
struct AC
{
	int trie[N][27],fail[N],tot=1,dp[105][105*60];
	int cnt[N]; bool vis[N];
	queue<int> q;
	inl void insert(char s[])
	{
		int p=1,len=strlen(s+1);
		for(int i=1;i<=len;i++)
		{
			int ch=s[i]-'A';
			if(!trie[p][ch]) trie[p][ch]=++tot;
			p=trie[p][ch];
		}
		vis[p]=1;
	}
	inl void build()
	{
		for(int i=0;i<26;i++) trie[0][i]=1;
		q.push(1);
		while(!q.empty())
		{
			int p=q.front(); q.pop();
			for(int i=0;i<26;i++)
			{
				if(trie[p][i]) 
					fail[trie[p][i]]=trie[fail[p]][i],
					vis[trie[p][i]]|=vis[p],
					q.push(trie[p][i]);
				else trie[p][i]=trie[fail[p]][i];
			}
		}
	}
	inl int qpow(int x,int y)
	{
		int ret=1;
		while(y)
		{
			if(y&1) ret=ret*x%mod;
			x=x*x%mod;
			y>>=1;
		}
		return ret;
	}
	int find(int x)
	{
		if(!x) return x;
		if(vis[x]) return x;
		return fail[x]=find(fail[x]);	
	}
	inl int DP()
	{
		dp[0][1]=dp[0][0]=1;
		for(int i=1;i<=m;i++)	
			for(int p=1;p<=tot;p++)
				for(int j=0;j<26;j++)
				{	
					if(find(trie[p][j])) continue;
					dp[i][trie[p][j]]+=dp[i-1][p];
					dp[i][trie[p][j]]%=mod;
				}
		int ans=0;
		for(int i=1;i<=tot;i++)
			ans=(ans+dp[m][i])%mod;
		return ans;
	}
}ac;

char s[N];
int main()
{
	n=read(),m=read();
	for(int i=1;i<=n;i++) scanf("%s",s+1),ac.insert(s);
	ac.build();
	printf("%d\n",((ac.qpow(26,m)-ac.DP())%mod+mod)%mod);
	return 0;
}
posted @ 2021-10-09 00:11  conprour  阅读(135)  评论(0编辑  收藏  举报