2022牛客多校 补赛 C Cmostp(区间结尾本质不同子串)

多次询问求一个串的结尾在\([l,r]\)之间的本质不同子串个数。

此题是求一个区间的不同元素的问题,使用扫描线的方法解决,即每次加入一个元素就将这个位置\(+1\),这个元素上一次出现的位置\(-1\)

考虑使用\(SAM\)解决。

其实就是将所有结尾在\([l,r]\)的前缀代表的节点与parent树的根的路径上的点打上标记。
答案就是所有打上标记点\(u\)\(len[u]-len[fa[u]]\)之和。

当时认为这个东西不可做,但是却听别人说这是经典的链并问题。

考虑将所有询问离线。

将字符串从前向后处理,每次插入一个字符后将其对应节点到根的路径打通,同时维护所有路径的\(len[u]-len[fa[u]]\)之和\(sum[i]\)。答案就是一个[l,r]区间和。这其实就是在维护以i为结尾的最后出现的串的个数。

考虑怎么维护这个东西。

可以发现这个路径和与\(LCT\)\(access\)操作相匹配:本质是维护\(LCT\)上每一个\(splay\)\(len[u]-len[fa[u]]\)的和。

每次跳虚链时将修改对应两个\(splay\)对应\(sum[i]\)的大小即可。为了得到\(splay\)对应的\(sum[i]\)还要对每一个\(splay\)维护一个\(col\),用\(lazy\)下推。

最后需要对\(sum[i]\)单点修改区间查询,使用树状数组解决。

强制在线的话,树状数组改成可持久化线段树就行。

考虑如何求\([l,r]\)中本质不同子串个数。

同样考虑将所有询问离线。将字符串从前向后处理。

考虑以r为结尾的子串上一次出现的位置,发现就是\(r\)\(parent\)树上的所有祖先\(u\)的子树中除了\(r\)之外的最新节点\(pre\)。长度就是\(fa[len[u]]\~len[u]\).

所以可以直接维护每个位置为起点的最后出现子串个数\(sum[i]\)。维护时每次跳到一个祖先\(u\)就将\(sum[pre-len[u]+1,pre-fa[len[u]]+1]-1\),将\(sum[1,r]+1\)

显然这样会\(T\),考虑加速这个过程,发现多个祖先的\(pre\)是相同的,仔细想想这个过程其实就是打通这个点到根的路径,每一个祖先的\(pre\)就是\(splay\)\(col\),所以同样使用\(lct\)维护。

\(sum[i]\)需要区间修改区间查询,使用线段树解决。在线就可持久化。

#include<iostream>
#include<cstring>
#include<cstdio>
#include<cmath>
#include<algorithm>
#include<vector> 
using namespace std;
#define int long long
const int N=1e6+10;
int n,q;
char s[N];
struct SAM{
	int trans[N][30],fa[N],tot,u,len[N];
	void init(){tot=u=1;}
	void ins(int c){
		int x=++tot;
		len[x]=len[u]+1;
		for(;u&&trans[u][c]==0;u=fa[u])trans[u][c]=x;
		if(u==0)fa[x]=1;
		else{
			int v=trans[u][c];
			if(len[v]==len[u]+1)fa[x]=v;
			else{
				int w=++tot;
				len[w]=len[u]+1;
				fa[w]=fa[v];
				fa[v]=fa[x]=w;
				memcpy(trans[w],trans[v],sizeof(trans[v]));
				for(;u&&trans[u][c]==v;u=fa[u])trans[u][c]=w;
			}
		}
		u=x;
	}
	void clear(){
		for(int i=1;i<=tot;i++){
			for(int j=1;j<=26;j++)trans[i][j]=0;
			len[i]=0,fa[i]=0;
		}
		tot=u=0;
	}
}sam;
int sum[N],fa[N],ch[N][2],w[N],stack[N],col[N],lazy_col[N];
int ans[N],pos[N];
struct query{
	int id,l;
	query(int L,int ID){
		l=L,id=ID;
	}
};
vector<query> qu[N];
int read(){
	int sum=0,f=1;char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
	while(ch>='0'&&ch<='9'){sum=sum*10+ch-'0';ch=getchar();}
	return sum*f;
}
void update(int now){
	sum[now]=sum[ch[now][0]]+sum[ch[now][1]]+w[now];
}
bool son(int now){
	return ch[fa[now]][1]==now;
}
bool isroot(int now){
	return ch[fa[now]][0]!=now&&ch[fa[now]][1]!=now;
}
void rotate(int x){
	int y=fa[x],z=fa[y],a=son(x),b=son(y),s=ch[x][!a];
	if(!isroot(y))ch[z][b]=x;fa[x]=z;
	if(s)fa[s]=y;ch[y][a]=s;
	fa[y]=x;ch[x][!a]=y;
	update(y),update(x);
} 
void pushdown(int now){
	if(lazy_col[now]==0)return;
	if(ch[now][0]){
		col[ch[now][0]]=lazy_col[now];
		lazy_col[ch[now][0]]=lazy_col[now];
	}
	if(ch[now][1]){
		col[ch[now][1]]=lazy_col[now];
		lazy_col[ch[now][1]]=lazy_col[now];
	}
	lazy_col[now]=0;
}
void splay(int x){
	int now=x,top=0;
	while(!isroot(now))stack[++top]=now,now=fa[now];
	stack[++top]=now;
	while(top)pushdown(stack[top--]);
	while(!isroot(x)){
		int y=fa[x];
		if(isroot(y))rotate(x);
		else{
			rotate(son(x)==son(y)?y:x);
			rotate(x);
		}
	}
}
int tr[N];
int lowbit(int x){
	return x&-x;
} 
void add(int x,int w){
	for(int i=x;i<=n;i+=lowbit(i))tr[i]+=w; 
}
int get_sum(int x){
	int ans=0;
	for(int i=x;i;i-=lowbit(i)){
		ans+=tr[i];
	}
	return ans;
}
void access(int now,int new_col){
	for(int x=now,y=0;x;y=x,x=fa[x]){
		splay(x);
		if(col[x])add(col[x],-(sum[x]-sum[ch[x][1]]));
		add(new_col,(sum[x]-sum[ch[x][1]]));
		ch[x][1]=y;
		update(x);
		lazy_col[x]=new_col;
		col[x]=new_col;
	}
}
signed main(){
	n=read(),q=read();
	scanf("%s",s+1);
	sam.init();
	for(int i=1;i<=n;i++)sam.ins(s[i]-'a'+1);
	int now=1;
	for(int i=1;i<=n;i++){
		now=sam.trans[now][s[i]-'a'+1];
		pos[i]=now;
	} 
	for(int i=1;i<=sam.tot;i++){
		w[i]=sum[i]=sam.len[i]-sam.len[sam.fa[i]];
		if(sam.fa[i])fa[i]=sam.fa[i];
	}
	for(int i=1;i<=q;i++){
		int l=read(),r=read();
		qu[r].push_back(query(l,i));
	}
	for(int r=1;r<=n;r++){
		access(pos[r],r);
		if(qu[r].size())
			for(int i=0;i<qu[r].size();i++)ans[qu[r][i].id]=get_sum(r)-get_sum(qu[r][i].l-1);
	}
	for(int i=1;i<=q;i++)printf("%lld\n",ans[i]);
	return 0;
} 
posted @ 2022-08-21 20:37  Xu-daxia  阅读(43)  评论(0编辑  收藏  举报