BZOJ5304: [Haoi2018]字串覆盖

BZOJ5304: [Haoi2018]字串覆盖

https://lydsy.com/JudgeOnline/problem.php?id=5304

分析:

  • \(L=r-l+1\)
  • 建出\(sam\),倍增+线段树合并求出每个询问对应原串的\(right\)集合。
  • 可以知道
  • 如果\(L>50\),则每次在线段树上二分找到第一个\(1\),最多找\(\frac{n}{L}\)次。
  • 否则就比较麻烦了,我们对于每个位置,维护\(F[L][i]\)表示长度为\(L\),最后一个字符是\(i\)的子串向后找能找到谁,然后倍增求这个就完事了。
  • 由于空间可能开不下,我一开始的做法是对每个状态开\(vector\)来减少存储状态,不过还是会\(mle\),改成将询问离线,那么\(L\)这一维就可以一起处理了。
  • 其中求\(F\)我使用了字符串哈希和\(map\)
  • 如果哈希对\(unsigned long long\)自然溢出并且\(base=998244353\)会只有\(40\)分,别问我是怎么知道的。

代码:

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cstdlib>
#include <iostream>
#include <vector>
#include <tr1/unordered_map>
using namespace std;
using namespace std::tr1;
#define N 200050
#define M 4000050
#define base 131
#define db(x) cerr<<#x<<" = "<<x<<endl
typedef long long ll;
typedef unsigned long long ull;
char ss[N],tt[N];
int n,K,ch[N][26],fa[N],lst=1,cnt=1,len[N];
int ls[M],rs[M],tot,siz[M],f[20][N],ke[N],ro[N],root[N];
int tl[N],tq[N],Lg[N];
ull h[N],mi[N];
vector<int>F[N>>1],G[N>>1];
vector<ll>H[N>>1];
int pl[N];
ll ans[N];
ull gh(int l,int r) {return h[r]-h[l-1]*mi[r-l+1];}
unordered_map<ull,int>mp;
void update(int l,int r,int x,int &p) {
	if(!p) p=++tot;
	siz[p]++;
	if(l==r) return ;
	int mid=(l+r)>>1;
	if(x<=mid) update(l,mid,x,ls[p]);
	else update(mid+1,r,x,rs[p]);
}
int merge(int x,int y) {
	if(!x||!y) return x+y;
	int p=++tot;
	ls[p]=merge(ls[x],ls[y]);
	rs[p]=merge(rs[x],rs[y]);
	siz[p]=siz[ls[p]]+siz[rs[p]];
	return p;
}
void insert(int x) {
	int p=lst,np=++cnt,q,nq;
	lst=np; len[np]=len[p]+1;
	for(;p&&!ch[p][x];p=fa[p]) ch[p][x]=np;
	if(!p) fa[np]=1;
	else {
		q=ch[p][x];
		if(len[q]==len[p]+1) fa[np]=q;
		else {
			nq=++cnt;
			len[nq]=len[p]+1;
			fa[nq]=fa[q];
			fa[q]=fa[np]=nq;
			memcpy(ch[nq],ch[q],sizeof(ch[q]));
			for(;p&&ch[p][x]==q;p=fa[p]) ch[p][x]=nq;
		}
	}
}
int Ql(int l,int r,int p) {
	if(l==r) return l;
	int mid=(l+r)>>1;
	if(siz[ls[p]]) return Ql(l,mid,ls[p]);
	else return Ql(mid+1,r,rs[p]);
}
int query(int l,int r,int x,int y,int p) {
	if(!p||!siz[p]) return -1;
	if(x<=l&&y>=r) {
		return Ql(l,r,p);
	}
	int mid=(l+r)>>1;
	if(y<=mid) return query(l,mid,x,y,ls[p]);
	else if(x>mid) return query(mid+1,r,x,y,rs[p]);
	else {
		int q=query(l,mid,x,y,ls[p]);
		if(q!=-1) return q;
		return query(mid+1,r,x,y,rs[p]);
	}
}
ll solve(int l,int r,int p,int L) {
	ll c1=0,c2=0;
	if(L<=50) {
		int x=query(1,n,l,r,root[p]);
		if(x==-1) return 0;
		int i;
		for(i=18;i>=0;i--) {
			int lim=F[x].size();
			if(i>=lim) continue;
			if(F[x][i]&&F[x][i]<=r) {
				c1+=G[x][i];
				c2+=H[x][i];
				x=F[x][i];
			}
		}
		c1++; c2+=x;
		return c1*K-c2+(L-1)*c1;
	}
	while(l<=r) {
		int x=query(1,n,l,r,root[p]);
		if(x==-1) break;
		c1++; c2+=x;
		l=x+L;
	}
	return c1*K-c2+(L-1)*c1;
}
struct A {
	int x,y,l,r,L,id;
	bool operator < (const A &u) const {return L<u.L;}
}qq[N];
int main() {
	scanf("%d%d%s%s",&n,&K,ss+1,tt+1);
	int i,j,p;
	for(i=1;i<=n;i++) insert(ss[i]-'a'),update(1,n,i,root[lst]);
	for(i=1;i<=cnt;i++) f[0][i]=fa[i];
	for(i=1;(1<<i)<=cnt;i++) for(j=1;j<=cnt;j++) f[i][j]=f[i-1][f[i-1][j]];
	for(i=1;i<=cnt;i++) ke[len[i]]++;
	for(i=1;i<=cnt;i++) ke[i]+=ke[i-1];
	for(i=cnt;i;i--) ro[ke[len[i]]--]=i;
	for(i=cnt;i>1;i--) {
		p=ro[i]; root[fa[p]]=merge(root[fa[p]],root[p]);
	}
	p=1; int now=0;
	for(i=1;i<=n;i++) {
		int x=tt[i]-'a';
		if(ch[p][x]) {
			now++; p=ch[p][x];
		}else {
			for(;p&&!ch[p][x];p=fa[p]) ;
			if(!p) now=0,p=1;
			else now=len[p]+1,p=ch[p][x];
		}
		tl[i]=now; tq[i]=p;
	}
	int L;
	for(mi[0]=i=1;i<=n;i++) mi[i]=mi[i-1]*base,h[i]=h[i-1]*base+ss[i];
	for(Lg[0]=-1,i=1;i<=n;i++) Lg[i]=Lg[i>>1]+1;
	int cas;scanf("%d",&cas);
	int x,y,l,r;
	for(i=1;i<=cas;i++) scanf("%d%d%d%d",&qq[i].x,&qq[i].y,&qq[i].l,&qq[i].r),qq[i].id=i,qq[i].L=qq[i].r-qq[i].l+1;
	sort(qq+1,qq+cas+1);
	int lf=1;
	for(L=1;L<=50;L++) {
		if(qq[lf].L!=L) continue;
		for(i=n;i>=L;i--) {
			ull tmp=gh(i-L+1,i);
			if(mp[tmp]) {
				int x=mp[tmp];
				int sz=Lg[pl[x]+1];
				F[i].resize(sz+1);
				G[i].resize(sz+1);
				H[i].resize(sz+1);
				F[i][0]=x;
				G[i][0]=1;
				H[i][0]=i;
				for(j=1;j<=sz;j++) {
					int t=F[i][j-1];
					F[i][j]=F[t][j-1];
					G[i][j]=G[i][j-1]+G[t][j-1];
					H[i][j]=H[i][j-1]+H[t][j-1];
				}
				pl[i]=pl[x];
			}else {
				pl[i]=-1;
			}
			pl[i]++;
			if(i+L-1<=n) {
				mp[gh(i,i+L-1)]=i+L-1;
			}	
		}
		mp.clear();
		memset(pl,0,sizeof(pl));
		for(;lf<=n&&qq[lf].L==L;lf++) {
			x=qq[lf].x;
			y=qq[lf].y;
			l=qq[lf].l;
			r=qq[lf].r;
			int id=qq[lf].id;
			if(tl[r]<L) {ans[id]=0; continue;}
			p=tq[r];
			for(j=18;j>=0;j--) {
				if(f[j][p]&&len[f[j][p]]>=L) p=f[j][p];
			}
			ans[id]=solve(x+L-1,y,p,L);
		}
		for(i=1;i<=n;i++) F[i].clear(),G[i].clear(),H[i].clear();
	}
	for(j=1;j<=cas;j++) if(qq[j].L>50) {
		x=qq[j].x;
		y=qq[j].y;
		l=qq[j].l;
		r=qq[j].r;
		int id=qq[j].id;
		L=r-l+1;
		if(tl[r]<L) {
			ans[id]=0; continue;
		}
		p=tq[r];
		for(i=18;i>=0;i--) {
			if(f[i][p]&&len[f[i][p]]>=L) p=f[i][p];
		}
		ans[id]=solve(x+L-1,y,p,L);
	}
	for(i=1;i<=cas;i++) printf("%lld\n",ans[i]);
}
posted @ 2019-01-06 20:08  fcwww  阅读(481)  评论(0编辑  收藏  举报