P4770 [NOI2018]你的名字

蒟蒻表示不会sam凉凉了,所以只能提高SA技巧?

题意:有一个串\(A\),每次选择一个\(A\)的子串\(A'\),以及串\(B\),问\(B\)的所有本质不同子串中不在\(A'\)中的串的数量。

(定义\(A_i\)表示以字符\(A_i\)开头的后缀,\(B_i\)同理)

\(B\)的本质不同字串显然是\(|B|*(|B|+1)/2\)了,然后要减去本质不同的在\(A'\)中的串

首先把所有串拼一起,把SA建出来,辣么就可以在SA中找到\(A'\)所对应的所有后缀,对于\(B\)对应的每个后缀\(B_i\),计算\(\max_{j=l}^r LCP(A_j,B_i)\),就是在\(A'\)中的前缀数量,加起来就是这个东西了,由于还要判重,所以计入的其实是\(\max(0,\max_{j=l}^r LCP(A_j,B_i)-LCP(B_i,next(B_{i})))\),其中\(next(B_i)\)意思是在\(B\)串的SA上\(B_i\)的后继

\(LCP(B_i,next(B_{i}))\)随便算是吧,现在要算的是\(\max_{j=l}^rLCP(A_j,B_i)\)

要算LCP的max值,显然只要在SA上求出前驱和后继计算就行了,那么要算区间前驱后继,就是二逼平衡树了

#include<bits/stdc++.h>
#define il inline
#define vd void
typedef long long ll;
#define Log(x) (31-__builtin_clz(x))
il ll gi(){
	ll x=0,f=1;
	char ch=getchar();
	while(!isdigit(ch)){
		if(ch=='-')f=-1;
		ch=getchar();
	}
	while(isdigit(ch))x=x*10+ch-'0',ch=getchar();
	return x*f;
}
char S[1600037];
int t[100010],lent[100010];
int n,N;
int ql[100010],qr[100010];
namespace SA{
	int x[1600037],y[1600037],_[1600037],SA[1600037],rk[1600037],ht[1600037],t[1600037];
	int st[21][1600037];
	il int LCP(int x,int y){
		if(!x||!y)return 0;
		if(x==y)return 1e9;
		int l=Log(y-x);
		return std::min(st[l][x],st[l][y-(1<<l)]);
	}
	il int gety(int x){return x<=N?y[x]:-1;}
	il vd getSA(){
		int set=128;
		for(int i=1;i<=N;++i)++t[x[i]=S[i]];
		for(int i=1;i<=set;++i)t[i]+=t[i-1];
		for(int i=N;i;--i)SA[t[x[i]]--]=i;
		for(int k=1;k<=N;k<<=1){
			int p=0;
			for(int i=N-k+1;i<=N;++i)y[++p]=i;
			for(int i=1;i<=N;++i)if(SA[i]>k)y[++p]=SA[i]-k;
			for(int i=0;i<=set;++i)t[i]=0;
			for(int i=1;i<=N;++i)++t[x[y[i]]];
			for(int i=1;i<=set;++i)t[i]+=t[i-1];
			for(int i=N;i;--i)SA[t[x[y[i]]]--]=y[i];
			memcpy(_,x,sizeof _);
			memcpy(x,y,sizeof _);
			memcpy(y,_,sizeof _);
			x[SA[1]]=p=1;
			for(int i=2;i<=N;++i){
				if(gety(SA[i])!=gety(SA[i-1])||gety(SA[i]+k)!=gety(SA[i-1]+k))++p;
				x[SA[i]]=p;
			}
			if(p>=N)break;set=p;
		}
		for(int i=1;i<=N;++i)rk[SA[i]]=i;
		for(int i=1,j,k=0;i<=N;++i){
			if(rk[i]==N)continue;
			if(k)--k;
			j=SA[rk[i]+1];
			while(S[i+k]==S[j+k])++k;
			ht[rk[i]]=k;
		}
		for(int i=1;i<=N;++i)st[0][i]=ht[i];
		for(int i=1;i<=Log(N);++i)
			for(int j=1;j+(1<<i)-1<=N;++j)
				st[i][j]=std::min(st[i-1][j],st[i-1][j+(1<<i-1)]);
		//for(int i=1;i<=N;++i)printf("%d %s\n",ht[i],S+SA[i]);
	}
}
#define mid ((l+r)>>1)
int rt[500010],ls[20000010],rs[20000010],sum[20000010],cnt;
il vd build(int&x,int l,int r){
	x=++cnt;if(l==r)return;
	build(ls[x],l,mid),build(rs[x],mid+1,r);
}
il vd update(int&x,int l,int r,const int&p){
	++cnt;ls[cnt]=ls[x],rs[cnt]=rs[x],sum[cnt]=sum[x];x=cnt;
	++sum[x];if(l==r)return;
	if(p<=mid)update(ls[x],l,mid,p);
	else update(rs[x],mid+1,r,p);
}
int pp;
il int query_nxt(int x,int y,int l,int r){
	if(!(sum[y]-sum[x]))return 0;
	if(l==r)return l;
	if(mid<pp)return query_nxt(rs[x],rs[y],mid+1,r);
	if(pp<l){
		if(sum[ls[y]]-sum[ls[x]])return query_nxt(ls[x],ls[y],l,mid);
		else return query_nxt(rs[x],rs[y],mid+1,r);
	}else{
		int t=query_nxt(ls[x],ls[y],l,mid);
		return t?t:query_nxt(rs[x],rs[y],mid+1,r);
	}
}
il int query_pre(int x,int y,int l,int r){
	if(!(sum[y]-sum[x]))return 0;
	if(l==r)return l;
	if(pp<=mid)return query_pre(ls[x],ls[y],l,mid);
	if(r<pp){
		if(sum[rs[y]]-sum[rs[x]])return query_pre(rs[x],rs[y],mid+1,r);
		else return query_pre(ls[x],ls[y],l,mid);
	}else{
		int t=query_pre(rs[x],rs[y],mid+1,r);
		return t?t:query_pre(ls[x],ls[y],l,mid);
	}
}
#undef mid
il bool check(int l,int r,int p,int k){
	if(l>r)return 0;
	pp=p;
	if(SA::LCP(query_pre(rt[l-1],rt[r],1,N),p)>=k)return 1;
	if(SA::LCP(p,query_nxt(rt[l-1],rt[r],1,N))>=k)return 1;
	return 0;
}
int lcp[1600037];
int main(){
#ifdef XZZSB
	freopen("in.in","r",stdin);
	freopen("out.out","w",stdout);
#endif
	scanf("%s",S+1);n=strlen(S+1);N=n+1;
	int Q=gi();t[1]=n+1;S[n+1]='~';
	int sumt=0;
	for(int i=1;i<=Q;++i){
		scanf("%s",S+t[i]+1);ql[i]=gi(),qr[i]=gi();
		lent[i]=strlen(S+t[i]+1);t[i+1]=t[i]+lent[i]+1;N+=lent[i]+1;
		sumt+=lent[i];
		S[t[i]+lent[i]+1]='|';
	}
	SA::getSA();
	build(rt[0],1,N);
	for(int i=1;i<=n;++i)rt[i]=rt[i-1],update(rt[i],1,N,SA::rk[i]);
	int l,r;
	for(int o=1;o<=Q;++o){
		l=ql[o],r=qr[o];
		std::vector<int>sufs;
		for(int i=t[o]+1;i<=t[o]+lent[o];++i)sufs.push_back(SA::rk[i]);
		std::sort(sufs.begin(),sufs.end());
		ll res=1ll*lent[o]*(lent[o]+1)/2;
		for(int i=1;i<sufs.size();++i)res-=(lcp[SA::SA[sufs[i-1]]]=SA::LCP(sufs[i-1],sufs[i]));
		for(int i=t[o]+1,j=0;i<=t[o]+lent[o];++i){
			if(j)--j;
			while(check(l,r-j,SA::rk[i],j+1))++j;
			res-=std::max(0,j-lcp[i]);
		}
		printf("%lld\n",res);
	}
	return 0;
}
posted @ 2019-04-03 22:14  菜狗xzz  阅读(226)  评论(0编辑  收藏  举报