NOI 2016 优秀的拆分
这题\(95\)分的算法很好想。但难点在于剩下\(5\)分。
我们将AABB分成前后两个部分,AA和BB。那么我们只需计算\(f1(i),f2(i)\),分别表示以\(i\)结束&开头的形如AA的子串个数。答案就是\(\sum_{i=1}^{n-1} f1_{i}\times f2_{i+1}\)。
计算\(f1,f2\)可以枚举+hash,具体不多说了。时间复杂度为\(O(n^2)\)。
那么这题就转化成如何快速求得\(f1(i)\)和\(f2(i)\)。
我们可以枚举最终AA子串的A的长度\(L\),然后在\(L,2L,3L,...\)处打上标记。
可以发现一个有趣的性质——长为\(2L\)的AA子串,一定经过恰好\(2\)个相邻的打上标记的下标。(抽屉原理)
(至于为什么要往这方面思考,我不知道:(
那么,我们枚举后一个下标\(j\),前一个下标就是\(j-L\),我们求它们的longest common prefix和longest common suffix。设其长度为\(lcp,lcs\)。
如果\(j-L+lcp<j-lcs\),那么两个区间不相交,不对任何\(f1,f2\)造成贡献。
否则,我们可以得到一段区间,可以作为AA的开头;还可以得到长度相同的区间作为结尾,那么我们对这两个区间\(+1\),这可以使用差分维护。
时间复杂度为\(O(Tn\log n)\)或\(O(Tn\log ^2 n)\),看实现lcp和lcs时用的是SA还是二分+hash。
下面是二分+hash的\(O(Tn\log ^2 n)\)的解法:
#include<bits/stdc++.h>
#define debug(...) std::cerr<<#__VA_ARGS__<<" : "<<__VA_ARGS__<<std::endl
using ll=long long;
const int maxn=30005;
int n;
const ll mod1=998244353,mod2=1000000007;
const ll b1=200061018,b2=20061009;
struct stringhash {
ll h1[maxn],h2[maxn];
ll p1[maxn],p2[maxn];
void init(char *S) {
int len=strlen(S+1);
p1[0]=p2[0]=1; h1[0]=h2[0]=0;
for(int i=1;i<=len;i++) p1[i]=p1[i-1]*b1%mod1;
for(int i=1;i<=len;i++) p2[i]=p2[i-1]*b2%mod2;
for(int i=1;i<=len;i++) h1[i]=(h1[i-1]*b1+S[i])%mod1;
for(int i=1;i<=len;i++) h2[i]=(h2[i-1]*b2+S[i])%mod2;
}
std::pair<ll,ll> gethash(int l,int r) {
ll ret1=h1[r]-h1[l-1]*p1[r-l+1]%mod1+mod1;
ll ret2=h2[r]-h2[l-1]*p2[r-l+1]%mod2+mod2;
return std::make_pair(ret1%mod1,ret2%mod2);
}
void clearall() {
memset(h1,0,sizeof h1); memset(h2,0,sizeof h2);
memset(p1,0,sizeof p1); memset(p2,0,sizeof p2);
}
}solver;
int Q(int l,int r,char ch) {
int lef=1,rig=n,ret=0;
while(lef<=rig) {
int mid=lef+rig>>1; bool judge;
if(ch=='p') judge=solver.gethash(l,l+mid-1)==solver.gethash(r,r+mid-1);
else judge=solver.gethash(l-mid+1,l)==solver.gethash(r-mid+1,r);
if(judge) ret=mid,lef=mid+1;
else rig=mid-1;
}
return ret;
}
long long f1[maxn],f2[maxn];
char buf[maxn];
void solve() {
memset(f1,0,sizeof f1);
memset(f2,0,sizeof f2);
solver.clearall();
scanf("%s",buf+1);
n=strlen(buf+1);
solver.init(buf);
for(int len=1;len<=n;len++) {
for(int j=len+len;j<=n;j+=len) {
int lcs=Q(j-len,j,'s'),lcp=Q(j-len,j,'p'),l,r;
l=std::max(j,j-len-lcs+len+len),r=std::min({n,j+len-1,j+lcp-1});//特别要注意,这里合法的区间不能包含2个以上标记的点,也不能超过n
if(l>r) continue;//无解
f1[l]++; f1[r+1]--;//区间[l,r]加上1
int delta=r-l; l=std::max(j-len-len+1,j-len-lcs+1),r=l+delta;//这里直接加上一次区间的长度即可,省去一次麻烦的计算
f2[l]++; f2[r+1]--;
}
}
long long ans=0;
for(int i=1;i<=n;i++) f1[i]+=f1[i-1],f2[i]+=f2[i-1];
for(int i=1;i<n;i++) ans+=f1[i]*f2[i+1];
printf("%lld\n",ans);
}
int main() {
int T; scanf("%d",&T);
while(T--) solve();
return 0;
}
浙公网安备 33010602011771号