bzoj 3879: SvT

题目

首先\(SAM\)上两个节点的\(lca\)表示的子串就是这两个节点表示的前缀的最长公共后缀

而我们想求后缀的\(lcp\)只需要把\(SAM\)反过来建就好了

而这道题一次要求很多后缀的\(lcp\)显然可以考虑一个树形\(dp\),就是考虑每个节点作为\(lca\)的贡献

这个非常简单,一边\(dfs\)一边求子树和统计答案就好了

因为一次查多个后缀,所以我们需要建出一棵虚树

代码

#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdio>
#define re register
#define maxn 1000005
#define LL long long
#define max(a,b) ((a)>(b)?(a):(b))
#define min(a,b) ((a)<(b)?(a):(b))
struct E{int v,nxt;}e[maxn];
int n,num,m,lst=1,cnt=1,top,__,t,a[maxn];
LL ans=0;
char S[maxn>>1];
int head[maxn],dfn[maxn],sum[maxn],Top[maxn],Son[maxn],deep[maxn];
int len[maxn],fa[maxn],son[maxn][26],pos[maxn],st[maxn],vis[maxn];
inline int cmp(int A,int B) {return dfn[A]<dfn[B];}
inline void add(int x,int y) {e[++num].v=y;e[num].nxt=head[x];head[x]=num;}
inline int read() {
	char c=getchar();int x=0;
	while(c<'0'||c>'9') c=getchar();
	while(c>='0'&&c<='9') x=(x<<3)+(x<<1)+c-48,c=getchar();return x;
}
void dfs(int x) {
	for(re int i=head[x];i;i=e[i].nxt)
	 	dfs(e[i].v),ans+=len[x]*vis[x]*vis[e[i].v],vis[x]+=vis[e[i].v];
}
void clear(int x) {
	vis[x]=0;for(re int i=head[x];i;i=e[i].nxt) clear(e[i].v);head[x]=0; 
}
void dfs1(int x) {
	sum[x]=1;int maxx=-1;
	for(re int i=head[x];i;i=e[i].nxt)
	{
		deep[e[i].v]=deep[x]+1;
		dfs1(e[i].v);sum[x]+=sum[e[i].v];
		if(sum[e[i].v]>maxx) maxx=sum[e[i].v],Son[x]=e[i].v;
	}
}
void dfs2(int x,int topf) {
	dfn[x]=++__;Top[x]=topf;
	if(!Son[x]) return;
	dfs2(Son[x],topf);
	for(re int i=head[x];i;i=e[i].nxt) if(!Top[e[i].v]) dfs2(e[i].v,e[i].v);
}
inline int LCA(int x,int y) {
	while(Top[x]!=Top[y]){if(deep[Top[x]]<deep[Top[y]]) std::swap(x,y);x=fa[Top[x]];}
	if(deep[x]<deep[y]) return x;return y;
}
inline void ins(int c,int o) {
	int p=++cnt,f=lst; lst=p;
	len[p]=len[f]+1,pos[o]=p;
	while(f&&!son[f][c]) son[f][c]=p,f=fa[f];
	if(!f) {fa[p]=1;return;}
	int x=son[f][c];
	if(len[f]+1==len[x]) {fa[p]=x;return;}
	int y=++cnt;
	len[y]=len[f]+1,fa[y]=fa[x],fa[x]=fa[p]=y;
	for(re int i=0;i<26;i++) son[y][i]=son[x][i];
	while(f&&son[f][c]==x) son[f][c]=y,f=fa[f];
}
inline void insert(int x) {
	if(top<=1) {st[++top]=x;return;}
	int lca=LCA(x,st[top]);
	if(lca==st[top]) {st[++top]=x;return;}
	while(top>1&&dfn[st[top-1]]>=dfn[lca]) add(st[top-1],st[top]),top--;
	if(lca!=st[top]) add(lca,st[top]),st[top]=lca; 
	st[++top]=x;
}
int main()
{
	n=read(),m=read();
	scanf("%s",S+1);
	for(re int i=n;i;--i) ins(S[i]-'a',i);
	for(re int i=2;i<=cnt;i++) add(fa[i],i); dfs1(1),dfs2(1,1);
	num=0;memset(head,0,sizeof(head));
	while(m--)
	{
		t=read();top=0;ans=0;num=0;
		for(re int i=1;i<=t;i++) a[i]=read(),a[i]=pos[a[i]];
		if(t==1) {puts("0");continue;}
		std::sort(a+1,a+t+1,cmp);t=std::unique(a+1,a+t+1)-a-1;
		int root=LCA(a[1],a[2]);
		for(re int i=3;i<=t;i++) root=LCA(root,a[i]);insert(root);
		for(re int i=1;i<=t;i++) insert(a[i]),vis[a[i]]=1;
		while(top) add(st[top-1],st[top]),top--;
		dfs(1);clear(1);
		printf("%lld\n",ans);
	}
	return 0;
}
posted @ 2019-01-22 12:31  asuldb  阅读(191)  评论(0编辑  收藏  举报