【BZOJ2434】【NOI2011】阿狸的打字机(AC自动机,树状数组)

【BZOJ2434】阿狸的打字机(AC自动机,树状数组)

先写个暴力:
每次打印出字符串后,就插入到\(Trie\)树中
搞完后直接搭\(AC\)自动机
看一看匹配是怎么样的:
每次沿着\(AC\)自动机走,在每一个节点都跳\(fail\)指针
如果有\(x\)串的末节点,就给答案\(+1\)
这样的话没有必要存下每个串
只要给\(AC\)自动机存一个父亲节点
记录一下每个串的结束位置
倒着往上跳就可以了
这样能够拿到\(40\)

Update2018.1.25:这份代码对于重复串的处理会有问题,感谢 @超级大蒟蒻

#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<set>
#include<map>
#include<vector>
#include<queue>
using namespace std;
#define MAX 200000
inline int read()
{
	int x=0,t=1;char ch=getchar();
	while((ch<'0'||ch>'9')&&ch!='-')ch=getchar();
	if(ch=='-')t=-1,ch=getchar();
	while(ch<='9'&&ch>='0')x=x*10+ch-48,ch=getchar();
	return x*t;
}
char ss[MAX];
int nd[MAX],n,tot;
struct Node
{
	int vis[26];
	int fail,fa;
	int lt;
}t[MAX];
void GetFail()
{
	queue<int> Q;
	for(int i=0;i<26;++i)
		if(t[0].vis[i])Q.push(t[0].vis[i]);
	while(!Q.empty())
	{
		int u=Q.front();Q.pop();
		for(int i=0;i<26;++i)
			if(t[u].vis[i])
				t[t[u].vis[i]].fail=t[t[u].fail].vis[i],Q.push(t[u].vis[i]);
			else t[u].vis[i]=t[t[u].fail].vis[i];
	}
}
int Query(int x,int y)
{
	int ret=0;
	int now=nd[y];
	while(now)
	{
		for(int i=now;i;i=t[i].fail)
			if(t[i].lt==x){++ret;break;}
		now=t[now].fa;
	}
	return ret;
}
int main()
{
	scanf("%s",ss+1);
	int now=0;
	for(int i=1,l=strlen(ss+1);i<=l;++i)
	{
		if(ss[i]>='a'&&ss[i]<='z')
		{
			if(!t[now].vis[ss[i]-'a'])t[now].vis[ss[i]-'a']=++tot,t[tot].fa=now;
			now=t[now].vis[ss[i]-'a'];
		}
		if(ss[i]=='B')now=t[now].fa;
		if(ss[i]=='P'){nd[++n]=now;t[now].lt=n;}
	}
	int Q=read();
	GetFail();
	while(Q--)
	{
		int x=read(),y=read();
		printf("%d\n",Query(x,y));
	}
	return 0;
}


这样子对于每一个询问都会要暴跳
如果对于某个串有重复的多次询问
那么就会多很多次没有任何意义的计算
所以,可以离线把所有询问都按照\(y\)排序
每次跳的时候开个桶一起计算
这样的话可以拿到\(70\)

#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<set>
#include<map>
#include<vector>
#include<queue>
using namespace std;
#define MAX 200000
inline int read()
{
	int x=0,t=1;char ch=getchar();
	while((ch<'0'||ch>'9')&&ch!='-')ch=getchar();
	if(ch=='-')t=-1,ch=getchar();
	while(ch<='9'&&ch>='0')x=x*10+ch-48,ch=getchar();
	return x*t;
}
char ss[MAX];
int nd[MAX],n,tot;
int ans[MAX];
struct Node
{
	int vis[26];
	int fail,fa;
	int lt;
}t[MAX];
struct Question{int x,y,id,ans;}q[MAX];
bool operator<(Question a,Question b){return a.y<b.y;}
int sum[MAX];
void GetFail()
{
	queue<int> Q;
	for(int i=0;i<26;++i)
		if(t[0].vis[i])Q.push(t[0].vis[i]);
	while(!Q.empty())
	{
		int u=Q.front();Q.pop();
		for(int i=0;i<26;++i)
			if(t[u].vis[i])
				t[t[u].vis[i]].fail=t[t[u].fail].vis[i],Q.push(t[u].vis[i]);
			else t[u].vis[i]=t[t[u].fail].vis[i];
	}
}
int Query(int y)
{
	int ret=0;
	int now=nd[y];
	while(now)
	{
		for(int i=now;i;i=t[i].fail)
			if(t[i].lt)sum[t[i].lt]++;
		now=t[now].fa;
	}
	return ret;
}
int main()
{
	scanf("%s",ss+1);
	int now=0;
	for(int i=1,l=strlen(ss+1);i<=l;++i)
	{
		if(ss[i]>='a'&&ss[i]<='z')
		{
			if(!t[now].vis[ss[i]-'a'])t[now].vis[ss[i]-'a']=++tot,t[tot].fa=now;
			now=t[now].vis[ss[i]-'a'];
		}
		if(ss[i]=='B')now=t[now].fa;
		if(ss[i]=='P'){nd[++n]=now;t[now].lt=n;}
	}
	int Q=read();
	GetFail();
	for(int i=1;i<=Q;++i)
	{
		q[i].x=read(),q[i].y=read();
		q[i].id=i;
	}
	sort(&q[1],&q[Q+1]);
	for(int i=1,pos=1;i<=Q;i=pos)
	{
		Query(q[i].y);
		while(q[pos].y==q[i].y)q[pos].ans=sum[q[pos].x],pos++;
		memset(sum,0,sizeof(sum));
	}
	for(int i=1;i<=Q;++i)ans[q[i].id]=q[i].ans;
	for(int i=1;i<=Q;++i)
		printf("%d\n",ans[i]);
	return 0;
}


再来想想我们每次在干什么??
\(fail\)
显然每个节点有且仅有一个\(fail\)指针
所以,这就是一棵树??
把这个\(fail\)反过来看
现在的问题是什么?
原来是\(y\)的某个节点往上跳能不能到达\(x\)
现在反过来:
\(x\)往下跳能够到达几个\(y\)的节点
那,不就是求子树和???
如果把所有\(y\)的节点全部打上一个\(1\)的标记
那么,每次就变成了求\(x\)末节点的子树和
而一个点的子树在\(dfs\)序上一定是连续的一段
这样还是可以拿到\(70\)

#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<set>
#include<map>
#include<vector>
#include<queue>
using namespace std;
#define MAX 200000
inline int read()
{
	int x=0,t=1;char ch=getchar();
	while((ch<'0'||ch>'9')&&ch!='-')ch=getchar();
	if(ch=='-')t=-1,ch=getchar();
	while(ch<='9'&&ch>='0')x=x*10+ch-48,ch=getchar();
	return x*t;
}
char ss[MAX];
int nd[MAX],n,tot;
int ans[MAX];
int c[MAX];
int dfn[MAX],low[MAX],tim;
inline int lowbit(int x){return x&(-x);}
void Modify(int x,int w){while(x<=tim)c[x]+=w,x+=lowbit(x);}
int getsum(int x){int ret=0;while(x)ret+=c[x],x-=lowbit(x);return ret;}
struct Node
{
	int vis[26];
	int fail,fa;
	int lt;
}t[MAX];
struct Question{int x,y,id,ans;}q[MAX];
bool operator<(Question a,Question b){return a.y<b.y;}
void GetFail()
{
	queue<int> Q;
	for(int i=0;i<26;++i)
		if(t[0].vis[i])Q.push(t[0].vis[i]);
	while(!Q.empty())
	{
		int u=Q.front();Q.pop();
		for(int i=0;i<26;++i)
			if(t[u].vis[i])
				t[t[u].vis[i]].fail=t[t[u].fail].vis[i],Q.push(t[u].vis[i]);
			else t[u].vis[i]=t[t[u].fail].vis[i];
	}
}
struct Line{int v,next;}e[MAX<<1];
int h[MAX],cnt=1;
inline void Add(int u,int v){e[cnt]=(Line){v,h[u]};h[u]=cnt++;}
void dfs(int u)
{
	dfn[u]=++tim;
	for(int i=h[u];i;i=e[i].next)dfs(e[i].v);
	low[u]=tim;
}
int main()
{
	scanf("%s",ss+1);
	int now=0;
	for(int i=1,l=strlen(ss+1);i<=l;++i)
	{
		if(ss[i]>='a'&&ss[i]<='z')
		{
			if(!t[now].vis[ss[i]-'a'])t[now].vis[ss[i]-'a']=++tot,t[tot].fa=now;
			now=t[now].vis[ss[i]-'a'];
		}
		if(ss[i]=='B')now=t[now].fa;
		if(ss[i]=='P'){nd[++n]=now;t[now].lt=n;}
	}
	int Q=read();
	GetFail();
	for(int i=1;i<=tot;++i)Add(t[i].fail,i);
	dfs(0);
	for(int i=1;i<=Q;++i)
	{
		q[i].x=read(),q[i].y=read();
		q[i].id=i;
	}
	sort(&q[1],&q[Q+1]);
	for(int i=1,pos=1;i<=Q;i=pos)
	{
		for(int now=nd[q[i].y];now;now=t[now].fa)
			Modify(dfn[now],1);
		while(q[pos].y==q[i].y)
		{
			int v=nd[q[pos].x];
			q[pos].ans=getsum(low[v])-getsum(dfn[v]-1);
			pos++;
		}
		memset(c,0,sizeof(c));
	}
	for(int i=1;i<=Q;++i)ans[q[i].id]=q[i].ans;
	for(int i=1;i<=Q;++i)
		printf("%d\n",ans[i]);
	return 0;
}


现在大致的方向已经没有问题了
看看我们重复算在哪里?
每次把串插入进树状数组!
因为很多的串会有重复
所以会反反复复把很多东西给重复插进去
这样就很慢了

于是,我们把\(Trie\)\(dfs\)遍历一遍
我搞Fail指针的时候会把原来的Trie数给搞掉,还要备份。。
访问到的时候打一个\(+1\)
结束的时候打一个\(-1\)
每次访问到一个结束节点的时候,
一定是有且仅有这个串的节点被打了标记
这样就可以直接回答这个串的相关询问了

#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<set>
#include<map>
#include<vector>
#include<queue>
using namespace std;
#define MAX 200000
inline int read()
{
	int x=0,t=1;char ch=getchar();
	while((ch<'0'||ch>'9')&&ch!='-')ch=getchar();
	if(ch=='-')t=-1,ch=getchar();
	while(ch<='9'&&ch>='0')x=x*10+ch-48,ch=getchar();
	return x*t;
}
char ss[MAX];
int nd[MAX],n,tot;
int ans[MAX];
int c[MAX];
int dfn[MAX],low[MAX],tim;
int ql[MAX],qr[MAX];
inline int lowbit(int x){return x&(-x);}
void Modify(int x,int w){while(x<=tim)c[x]+=w,x+=lowbit(x);}
int getsum(int x){int ret=0;while(x)ret+=c[x],x-=lowbit(x);return ret;}
struct Node
{
	int vis[26];
	int Vis[26];
	int fail,fa;
	int lt;
}t[MAX];
struct Question{int x,y,id,ans;}q[MAX];
bool operator<(Question a,Question b){return a.y<b.y;}
void GetFail()
{
	queue<int> Q;
	for(int i=0;i<26;++i)
		if(t[0].vis[i])Q.push(t[0].vis[i]);
	while(!Q.empty())
	{
		int u=Q.front();Q.pop();
		for(int i=0;i<26;++i)
			if(t[u].vis[i])
				t[t[u].vis[i]].fail=t[t[u].fail].vis[i],Q.push(t[u].vis[i]);
			else t[u].vis[i]=t[t[u].fail].vis[i];
	}
}
struct Line{int v,next;}e[MAX<<1];
int h[MAX],cnt=1;
inline void Add(int u,int v){e[cnt]=(Line){v,h[u]};h[u]=cnt++;}
void dfs(int u)
{
	dfn[u]=++tim;
	for(int i=h[u];i;i=e[i].next)dfs(e[i].v);
	low[u]=tim;
}
void DFS(int u)
{
	Modify(dfn[u],1);
	if(t[u].lt)
		for(int i=ql[t[u].lt];i<=qr[t[u].lt];++i)
			q[i].ans=getsum(low[nd[q[i].x]])-getsum(dfn[nd[q[i].x]]-1);
	for(int i=0;i<26;++i)
		if(t[u].Vis[i])
			DFS(t[u].Vis[i]);
	Modify(dfn[u],-1);
}
int main()
{
	scanf("%s",ss+1);
	int now=0;
	for(int i=1,l=strlen(ss+1);i<=l;++i)
	{
		if(ss[i]>='a'&&ss[i]<='z')
		{
			if(!t[now].vis[ss[i]-'a'])t[now].vis[ss[i]-'a']=++tot,t[tot].fa=now;
			now=t[now].vis[ss[i]-'a'];
		}
		if(ss[i]=='B')now=t[now].fa;
		if(ss[i]=='P'){nd[++n]=now;t[now].lt=n;}
	}
	for(int i=0;i<=tot;++i)
		for(int j=0;j<26;++j)
			t[i].Vis[j]=t[i].vis[j];
	int Q=read();
	GetFail();
	for(int i=1;i<=tot;++i)Add(t[i].fail,i);
	dfs(0);
	for(int i=1;i<=Q;++i)
	{
		q[i].x=read(),q[i].y=read();
		q[i].id=i;
	}
	sort(&q[1],&q[Q+1]);
	for(int i=1,pos=1;i<=Q;i=pos)
	{
		ql[q[i].y]=i;
		while(q[pos].y==q[i].y)pos++;
		qr[q[i].y]=pos-1;
	}
	DFS(0);
	for(int i=1;i<=Q;++i)ans[q[i].id]=q[i].ans;
	for(int i=1;i<=Q;++i)
		printf("%d\n",ans[i]);
	return 0;
}

posted @ 2018-01-20 15:49  小蒟蒻yyb  阅读(384)  评论(8编辑  收藏  举报