Codeforces 1090J $kmp+hash+$二分

题意

给出两个字符串\(s\)\(t\),设\(S\)\(s\)的任意一个非空前缀,\(T\)\(t\)的任意一个非空前缀,问\(S+T\)有多少种不同的可能。

Solution

看了一圈,感觉好像就我一个人写的\(kmp+hash+\)二分。

直接算好像不是很好算?先容斥一下,不同\(=\)总方案\(-\)相同。

显然总方案为两个字符串的长度的乘积,考虑相同的情况怎么算。

相同即两组\(S\)\(T\)不同,但\(S+T\)本质相同的情况.

这个东西怎么算呢。。。。

(感觉看图会好理解一点

不难想到当上图框出来的地方相同,则两者同质。

先来看右边那个框,显然这个东西就是一个字符串里两个子串\([1,i],[j,k]\)相同。

左边这个框就是\(s\)的某个子串和\(t\)的前缀相同。

具体怎么算?

根据上图,设\(a_i\)\(t\)的前缀\([1,i]\)\(s\)里出现了几次,这个可以\(hash+\)二分算。

\(b_i\)为符合\([1,j]=[i-j+1,i]\)\(j\)的最大值,这个可以\(kmp\)一波。

那么最终同质的个数就是\(\sum_{i=2}^{|t|}a_{i-b_i}\)

#include<bits/stdc++.h>
#define For(i,x,y) for (register int i=(x);i<=(y);i++)
#define Dow(i,x,y) for (register int i=(x);i>=(y);i--)
#define cross(i,u) for (register int i=first[u];i;i=last[i])
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
inline ll read(){
    ll x=0;int ch=getchar(),f=1;
    while (!isdigit(ch)&&(ch!='-')&&(ch!=EOF)) ch=getchar();
    if (ch=='-'){f=-1;ch=getchar();}
    while (isdigit(ch)){x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}
    return x*f;
}
const int N = 1e5+10;
int n,m;
char a[N],b[N];
const ull base = 233;
ull pre[N],Pre[N],p[N];
const ll Base = 23, mod = 1e9+7;
ll pre2[N],Pre2[N],p2[N];
inline void GetPre(){
	p[0]=1;For(i,1,n) p[i]=p[i-1]*base;
	For(i,1,n) pre[i]=pre[i-1]*base+a[i];
	For(i,1,m) Pre[i]=Pre[i-1]*base+b[i];
	p2[0]=1;For(i,1,n) p2[i]=p2[i-1]*Base%mod;
	For(i,1,n) (pre2[i]=pre2[i-1]*Base%mod+a[i])%=mod;
	For(i,1,m) (Pre2[i]=Pre2[i-1]*Base%mod+b[i])%=mod;
} 
inline ull query(int l,int r){return pre[r]-pre[l-1]*p[r-l+1];}
inline ll query2(int l,int r){return (pre2[r]-pre2[l-1]*p2[r-l+1]%mod+mod)%mod;}
int now,fail[N];
inline void GetKmp(){
	now=0;
	For(i,2,m){
		while (now&&b[now+1]!=b[i]) now=fail[now];
		fail[i]=(b[now+1]==b[i]?++now:now);
	}
}
int sum[N];
inline void Get(){
	For(i,2,n){
		int l=1,r=min(m,n-i+1),mid,ans=0;
		while (l<=r){
			mid=l+r>>1;
			if (query(i,i+mid-1)==Pre[mid]&&query2(i,i+mid-1)==Pre2[mid]) l=mid+1,ans=mid;
				else r=mid-1;
		}
		sum[ans]++;
	}
	sum[0]=0;
	Dow(i,m,1) sum[i]+=sum[i+1];
}
inline void calc(){
	ll ans=1ll*n*m;
	For(i,2,m) if (fail[i]) ans-=sum[i-fail[i]];
	printf("%lld\n",ans);
}
int main(){
	scanf("%s",a+1),scanf("%s",b+1),n=strlen(a+1),m=strlen(b+1);
	GetPre(),GetKmp(),Get(),calc();
}
posted @ 2018-12-21 18:08  zykykyk  阅读(331)  评论(0编辑  收藏  举报