luogu P3181 [HAOI2016]找相同字符

首先把两个字符串拼在一起,中间夹一个不可能出现的字符。

然后就是一个简单容斥,我们假设给的字符串为 \(S_1\)\(S_2\),新拼成的字符串为 \(S\),那么答案就是求 \(same(S)-same(S_1)-same(S_2)\),其中 \(same(s)\) 表示 \(s\) 这个字符串中位置不同大小相同的子串的个数。因为容易看出 \(same(S)\) 中统计的要么就是两个字符串都在 \(S_1\),要么就是两个字符串都在 \(S_2\),要么分别在 \(S_1\)\(S_2\) 中(而后者就是本题的答案)。

现在就是考虑给你一个字符串,如何求出它的 \(same(s)\)

\[same(s)=\sum\limits_{i=1}^{n}\sum\limits_{j=i+1}^{n}lcp(Suf_i,Suf_j) \]

原理还是那句话:后缀的前缀可以不重不漏的表示每一个子串。继续:

\[\begin{aligned} same(s)&=\sum\limits_{i=1}^{n}\sum\limits_{j=i+1}^{n}lcp(Suf_i,Suf_j) \\ &=\sum\limits_{i=1}^{n}\sum\limits_{j=i+1}^{n}lcp(Suf_{sa[i]},Suf_{sa[j]}) \\ &=\sum\limits_{i=2}^{n}\sum\limits_{j=i}^{n}\min\limits_{k=i}^{j}height_k \end{aligned} \]

第二行相当于换了个顺序,没什么好解释的,第三行是根据 \(height\) 数组的性质决定的:

\[lcp(i,j)\leq lcp(i,k) \]

其中 \(x_i<x_k<x_j\)\(x\) 数组表示后缀排名)

这也是我们求 \(lcp\) 是可以用 \(ST\) 表的原理。

回到最后那个式子:可以用单调栈维护左边第一个比当前 \(height_j\) 小的 \(height_k\),然后直接转移就好(具体看代码非常好理解)。

代码:

#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>

using namespace std;

typedef long long LL;
const int N=1000000;
char s1[N],s2[N];
int stk[N];
struct Suffix_Array
{
	int n,m,h[N],height[N],c[N],x[N],y[N],sa[N];
	LL Ans[N];
	char s[N];
	
	void clear()
	{
		memset(h,0,sizeof(h)),memset(c,0,sizeof(c)),memset(x,0,sizeof(x));
		memset(y,0,sizeof(y)),memset(sa,0,sizeof(sa)),memset(stk,0,sizeof(stk));
	}
	
	void Rsort()
	{
		for (int i=1;i<=m;i++) c[i]=0;
		for (int i=1;i<=n;i++) c[x[y[i]]]++;
		for (int i=1;i<=m;i++) c[i]+=c[i-1];
		for (int i=n;i>=1;i--) sa[c[x[y[i]]]--]=y[i];
	}
	
	void Get_SA()
	{
		clear();
		m=122;
		for (int i=1;i<=n;i++)
			x[i]=s[i],y[i]=i;
		Rsort();
		for (int k=1;k<=n;k<<=1)
		{
			int num=0;
			for (int i=n-k+1;i<=n;i++)
				y[++num]=i;
			for (int i=1;i<=n;i++)
				if(sa[i]>k)
					y[++num]=sa[i]-k;
			Rsort(),swap(x,y);
			x[sa[1]]=num=1;
			for (int i=2;i<=n;i++)
				x[sa[i]]=(y[sa[i]]==y[sa[i-1]]&&y[sa[i]+k]==y[sa[i-1]+k])?num:++num;
			if(num==n) break;
			m=num;
		}
	}
	
	void Get_Height()
	{
		for (int i=1;i<=n;i++)
		{
			int tmp=max(0,h[i-1]-1),j,k;
			for (j=i+tmp,k=sa[x[i]-1]+tmp;j<=n&&k<=n&&s[j]==s[k];j++,k++);
			h[i]=j-i;
		}
		for (int i=1;i<=n;i++)
			height[i]=h[sa[i]];
	}
	
	void init()
	{
		scanf("%s",s+1);
		n=strlen(s+1);
	}
	
	void work()
	{
		Get_SA();
		Get_Height();
	}
	
	LL Get_Ans()
	{
		memset(Ans,0,sizeof(Ans));
		LL ans=0;
		int l=1,r=0;
		for (int i=2;i<=n;i++)
		{
			while(l<=r&&height[i]<=height[stk[r]])
				r--;
			if(r==0) Ans[i]=height[i]*(i-1);
			else Ans[i]=height[i]*(i-stk[r])+Ans[stk[r]];
			ans+=Ans[i];
			stk[++r]=i;
		}
		return ans;
	}
}A,B,C;

void init()
{
	A.init();
	B.init();
	C.n=A.n+B.n+1;
	for (int i=1;i<=A.n;i++)
		C.s[i]=A.s[i];
	C.s[A.n+1]='?';
	for (int i=1,j=A.n+2;i<=B.n;i++,j++)
		C.s[j]=B.s[i];
}

void work()
{
	A.work(),B.work(),C.work();
//	printf("%lld %lld %lld\n",A.Get_Ans(),B.Get_Ans(),C.Get_Ans());
	printf("%lld\n",C.Get_Ans()-A.Get_Ans()-B.Get_Ans());
}

int main()
{
	init();
	work();
	return 0;
}
posted @ 2020-06-25 11:43  With_penguin  阅读(151)  评论(0编辑  收藏  举报