[提高组集训2021] 校游戏

一、题目

有两个字符串 \(A,B\),你需要对于所有 \(k\) 求出:从 \(A\) 中随机选一个长度为 \(k\) 的子串比从 \(B\) 中随机选一个长度为 \(k\) 的子串字典序小的概率、字典序相等的概率、字典序大的概率。

\(|A|,|B|\leq 2\cdot 10^5\)

二、解法

还是不要迷信后缀自动机,后缀数组解决字典序问题有些时候还是简单多了。

考虑计算 \(A<B\) 的概率(第一个子串更大的概率同理)

我们把两个串中间插入一个分隔符一起做一次后缀数组,对于某个属于 \(A\) 的后缀,我们找到排名比它大的 \(B\) 的后缀,考虑它们对 \(A\leq B\) 的贡献,是在 \([1,\min(|A|,|B|)]\) 都加上了 \(1\)

但是这样做要减去 \(A=B\) 的贡献,发现这就是 \(\tt lcp\) 问题,扫描的时候我们对 \(\tt height\) 数组做单调栈,单调栈中的每个元素都有一个负贡献,那么我们在栈顶打一个整体标记即可。

对于 \(A\leq B\) 贡献的计算可以使用权值线段树维护差分标记,也就是对所有已插入的属于 \(B\) 的后缀增加 \(1\) 的标记即可,时间复杂度 \(O(n\log n)\),然后再来一次即可!

三、总结

子串问题转化成后缀问题,用后缀数组来算贡献即可。

//Cause sum gonna arrive
#include <cstdio>
#include <cassert>
#include <cstring>
#include <iostream>
using namespace std;
const int M = 400005;
#define int long long
int read()
{
	int x=0,f=1;char c;
	while((c=getchar())<'0' || c>'9') {if(c=='-') f=-1;}
	while(c>='0' && c<='9') {x=(x<<3)+(x<<1)+(c^48);c=getchar();}
	return x*f;
}
int n,la,lb,p[M],z[M],tmp[M],pre[M],xi[M],da[M];
int suf[M],c[M],x[M],y[M],sa[M],rk[M],h[M];
int tag[2*M],num[2*M];char a[M],b[M],s[M];
//suffix array
void init()
{
	int m=256;
	for(int i=0;i<=m;i++) c[i]=0;
	for(int i=1;i<=n;i++) c[x[i]=s[i]]++;
	for(int i=1;i<=m;i++) c[i]+=c[i-1];
	for(int i=n;i>=1;i--) sa[c[x[i]]--]=i;
	for(int k=1;k<=n;k<<=1)
	{
		int num=0;
		for(int i=n-k+1;i<=n;i++) y[++num]=i;
		for(int i=1;i<=n;i++)
			if(sa[i]>k) y[++num]=sa[i]-k;
		for(int i=0;i<=m;i++) c[i]=0;
		for(int i=1;i<=n;i++) c[x[i]]++;
		for(int i=1;i<=m;i++) c[i]+=c[i-1];
		for(int i=n;i>=1;i--) sa[c[x[y[i]]]--]=y[i];
		swap(x,y);
		x[sa[1]]=num=1;
		for(int i=2;i<=n;i++)
			x[sa[i]]=(y[sa[i]]==y[sa[i-1]]
			&&y[sa[i]+k]==y[sa[i-1]+k])?num:++num;
		if(n==num) break;
		m=num;
	}
	int k=0;
	for(int i=1;i<=n;i++) rk[sa[i]]=i;
	for(int i=1;i<=n;i++)
	{
		if(rk[i]==1) continue;
		if(k) k--;
		int x=i,y=sa[rk[i]-1];
		while(x+k<=n && y+k<=n && s[x+k]==s[y+k]) k++;
		h[rk[i]]=k;
	}
}
//segment tree maintaining tags
void down(int i)
{
	if(!tag[i]) return ;
	tag[i<<1]+=tag[i];
	tag[i<<1|1]+=tag[i];
	tag[i]=0;
}
void work(int i,int l,int r,int L,int R)
{
	if(l>R || L>r) return ;
	if(L<=l && r<=R)
	{
		tag[i]++;
		return ;
	}
	int mid=(l+r)>>1;down(i);
	work(i<<1,l,mid,L,R);
	work(i<<1|1,mid+1,r,L,R);
}
int ask(int i,int l,int r,int L,int R)
{
	if(L>r || l>R) return 0;
	if(L<=l && r<=R) return num[i];
	int mid=(l+r)>>1;
	return ask(i<<1,l,mid,L,R)
	+ask(i<<1|1,mid+1,r,L,R);
}
void ins(int i,int l,int r,int id)
{
	if(l==r)
	{
		tag[i]=0;
		num[i]=1;
		return ;
	}
	int mid=(l+r)>>1;down(i);
	if(mid>=id) ins(i<<1,l,mid,id);
	else ins(i<<1|1,mid+1,r,id);
	num[i]=num[i<<1]+num[i<<1|1];
}
void dfs(int i,int l,int r)
{
	if(l==r)
	{
		assert(num[i]);
		pre[l+1]-=tag[i];
		return ;
	}
	int mid=(l+r)>>1;down(i);
	dfs(i<<1,l,mid);
	dfs(i<<1|1,mid+1,r);
}
void work()
{
	init();int t=0;
	memset(tag,0,sizeof tag);
	memset(num,0,sizeof num);
	for(int i=1;i<=n;i++)
		suf[i]=pre[i]=z[i]=tmp[i]=0;
	for(int i=n;i>=1;i--)
	{
		int x=sa[i];suf[i]=suf[i+1];
		if(x<=la)//A
		{
			int len=la-x+1;
			int num=ask(1,1,lb,len+1,lb);//get number
			pre[1]+=suf[i];pre[len+1]-=num;
			work(1,1,lb,1,len);//give tag(-)
			z[t]++;
		}
		if(x>la+1)//B
		{
			suf[i]++;
			ins(1,1,lb,n-x+1); 
		}
		//upd the stack
		while(t && h[i]<=h[p[t]])
		{
			tmp[1]+=z[t]*(suf[p[t]]-suf[p[t-1]]);
			tmp[h[p[t]]+1]-=z[t]*(suf[p[t]]-suf[p[t-1]]);
			z[t-1]+=z[t];t--;
		}
		p[++t]=i,z[t]=0;
	}
	dfs(1,1,lb);
	for(int i=t;i>=1;i--)
	{
		tmp[1]+=z[i]*(suf[p[i]]-suf[p[i-1]]);
		tmp[h[p[i]]+1]-=z[i]*(suf[p[i]]-suf[p[i-1]]);
		z[i-1]+=z[i];
	}
	for(int i=1;i<=lb;i++)
		tmp[i]+=tmp[i-1];
	for(int i=1;i<=lb;i++)
	{
		pre[i]+=pre[i-1];
		tmp[i]=pre[i]-tmp[i];
	}
}
int gcd(int a,int b)
{
	return !b?a:gcd(b,a%b);
}
void write(int x)
{
	if(x<=9)
	{
		putchar(x+'0');
		return ;
	}
	write(x/10);
	putchar(x%10+'0');
}
void print(int x,int y)
{
	int t=gcd(x,y);
	x/=t;y/=t;
	write(x);
	putchar('/');
	write(y);
	putchar(' ');
}
signed main()
{
	freopen("game.in","r",stdin);
	freopen("game.out","w",stdout);
	scanf("%s",a+1),la=strlen(a+1);
	scanf("%s",b+1),lb=strlen(b+1);
	for(int i=1;i<=la;i++) s[++n]=a[i];
	s[++n]='z'+1;
	for(int i=1;i<=lb;i++) s[++n]=b[i];
	work();
	for(int i=1;i<=lb;i++) xi[i]=tmp[i];
	//once again
	n=0;
	for(int i=1;i<=lb;i++) s[++n]=b[i];
	s[++n]='z'+1;
	for(int i=1;i<=la;i++) s[++n]=a[i];
	swap(la,lb);
	work();
	for(int i=1;i<=lb;i++) da[i]=tmp[i];
	for(int i=1;i<=min(la,lb);i++)
	{
		int sum=(la-i+1)*(lb-i+1);
		print(xi[i],sum);
		print(sum-xi[i]-da[i],sum);
		print(da[i],sum);
		puts("");
	}
}
posted @ 2021-10-06 20:34  C202044zxy  阅读(73)  评论(0编辑  收藏  举报