【BZOJ3881】[Coci2015]Divljak fail树+树链的并

【BZOJ3881】[Coci2015]Divljak

Description

Alice有n个字符串S_1,S_2...S_n,Bob有一个字符串集合T,一开始集合是空的。
接下来会发生q个操作,操作有两种形式:
“1 P”,Bob往自己的集合里添加了一个字符串P。
“2 x”,Alice询问Bob,集合T中有多少个字符串包含串S_x。(我们称串A包含串B,当且仅当B是A的子串)
Bob遇到了困难,需要你的帮助。

Input

第1行,一个数n;
接下来n行,每行一个字符串表示S_i;
下一行,一个数q;
接下来q行,每行一个操作,格式见题目描述。

Output

对于每一个Alice的询问,帮Bob输出答案。

Sample Input

3
a
bc
abc
5
1 abca
2 1
1 bca
2 2
2 3

Sample Output

1
2
1

HINT

【数据范围】
1 <= n,q <= 100000;
Alice和Bob拥有的字符串长度之和各自都不会超过 2000000;
字符串都由小写英文字母组成。

题解:我们先求出fail树,很容易发现当我们在fail树上遍历T_i字符串时,所经过的节点和它的祖先都会被T_i包含,我们只需要将这些点的权值全部+1,但我们并不能只是将每个点到根的路径上的点权全都+1,因为这样可能导致重复计算

所以,我们要解决的问题是如何将这些点到根节点的路径的并集+1,就是求树链的并,具体方法:

将所有经过的点按照DFS序排序,然后求出相邻两点间的LCA

将所有点到根的路径上的点权+1,再讲所有LCA到根的路径上的点权-1

这个可以用树剖+树状数组搞定

#include <cstdio>
#include <cstring>
#include <iostream>
#include <queue>
#include <algorithm>
using namespace std;
int n,m,tot,cnt;
const int maxn=2000010;
struct node
{
	int ch[30],fail;
}p[maxn];
int P[maxn],Q[maxn],s[maxn],son[maxn],dep[maxn],to[maxn],head[maxn],next[maxn],top[maxn],siz[maxn];
int v[maxn],vis[maxn],pos[maxn];
queue<int> q;
char str[maxn];
bool cmp(int a,int b)
{
	return P[a]<P[b];
}
void add(int a,int b)
{
	to[cnt]=b,next[cnt]=head[a],head[a]=cnt++;
}
void build()
{
	q.push(1);
	int i,t,u;
	while(!q.empty())
	{
		u=q.front(),q.pop();
		for(i=0;i<26;i++)
		{
			if(!p[u].ch[i])	continue;
			q.push(p[u].ch[i]);
			if(u==1)
			{
				p[p[u].ch[i]].fail=1;
				continue;
			}
			int t=p[u].fail;
			while(!p[t].ch[i]&&t)	t=p[t].fail;
			if(t)	p[p[u].ch[i]].fail=p[t].ch[i];
			else	p[p[u].ch[i]].fail=1;
		}
	}
}
void dfs1(int x)
{
	siz[x]=1;
	for(int i=head[x];i!=-1;i=next[i])
	{
		dep[to[i]]=dep[x]+1;
		dfs1(to[i]);
		siz[x]+=siz[to[i]];
		if(siz[to[i]]>siz[son[x]])	son[x]=to[i];
	}
}
void dfs2(int x,int tp)
{
	top[x]=tp,P[x]=++P[0];
	if(son[x])	dfs2(son[x],tp);
	for(int i=head[x];i!=-1;i=next[i])
		if(to[i]!=son[x])
			dfs2(to[i],to[i]);
	Q[x]=P[0];
}
int lca(int x,int y)
{
	while(top[x]!=top[y])
	{
		if(dep[top[x]]<dep[top[y]])	swap(x,y);
		x=p[top[x]].fail;
	}
	if(dep[x]<dep[y])	return x;
	return y;
}
void updata(int x,int val)
{
	for(int i=x;i<=tot;i+=i&-i)	s[i]+=val;
}
int query(int x)
{
	int i,ret=0;
	for(i=x;i;i-=i&-i)	ret+=s[i];
	return ret;
}
int main()
{
	scanf("%d",&n);
	int i,j,a,b,c,u;
	tot=1;
	for(i=1;i<=n;i++)
	{
		scanf("%s",str);
		u=1,a=strlen(str);
		for(j=0;j<a;j++)
		{
			b=str[j]-'a';
			if(!p[u].ch[b])	p[u].ch[b]=++tot;
			u=p[u].ch[b];
		}
		pos[i]=u;
	}
	build();
	memset(head,-1,sizeof(head));
	for(i=2;i<=tot;i++)	add(p[i].fail,i);
	dep[1]=1,dfs1(1),dfs2(1,1);
	scanf("%d",&m);
	for(i=1;i<=m;i++)
	{
		scanf("%d",&c);
		if(c==1)
		{
			scanf("%s",str);
			u=1,a=strlen(str);
			vis[1]=i,v[v[0]=1]=1;
			for(j=0;j<a;j++)
			{
				b=str[j]-'a';
				while(!p[u].ch[b]&&u!=1)	u=p[u].fail;
				u=(p[u].ch[b]>0)?p[u].ch[b]:1;
				if(vis[u]!=i)	vis[u]=i,v[++v[0]]=u;
			}
			sort(v+1,v+v[0]+1,cmp);
			for(j=1;j<=v[0];j++)	updata(P[v[j]],1);
			for(j=1;j<v[0];j++)	updata(P[lca(v[j],v[j+1])],-1);
		}
		if(c==2)
		{
			scanf("%d",&a);
			printf("%d\n",query(Q[pos[a]])-query(P[pos[a]]-1));
		}
	}
	return 0;
}
posted @ 2017-04-24 12:45  CQzhangyu  阅读(417)  评论(0编辑  收藏  举报