ybtAu「字符串算法」第2章 后缀数组
这是 neatisaac 的金牌导航题解!
我不要写 SA
还没写完。感觉 SA 的板子还是没有学明白。
A. 【例题1】不可重叠串
看似不是板子,和题解不一样,实际上差分一下就好了。
二分判定是否存在长度为 \(mid\) 的相似子串,如果存在后缀 \(i\) \(j\),使得 \(\min_{i<k<=j}height_k\ge mid\),那么它们的 \(LCP\ge mid\);如果它们的差大于 \(mid\),那么这两个后缀就有长度为 \(mid\) 且不重叠的相同前缀。
#include <iostream>
#include <cstring>
#define N 20005
int n,m,p,a[N],sa[N<<1],rk[N<<1],buc[N],lrk[N],id[N],hi[N];
bool check(int x)
{
int mx=0,mn=0;
for(int i=1;i<=n;i++)
{
if(hi[i]<x) mx=mn=sa[i];
if(sa[i]<mn) mn=sa[i];
if(sa[i]>mx) mx=sa[i];
if(mx-mn>x) return 1;
}
return 0;
}
int main()
{
std::ios::sync_with_stdio(0);
std::cin.tie(0),std::cout.tie(0);
for(;std::cin>>n;)
{
m=256;
if(n==0) break;
memset(buc,0,sizeof buc);
for(int i=1,x=0,y=0;i<=n;i++) std::cin>>x,a[i]=x-y+100,y=x;
for(int i=1;i<=n;i++) buc[rk[i]=a[i]]++;
for(int i=1;i<=m;i++) buc[i]+=buc[i-1];
for(int i=n;i>=1;i--) sa[buc[rk[i]]--]=i;
for(int w=1;;w<<=1,m=p)
{
int cur=0;
for(int i=n-w+1;i<=n;i++) id[++cur]=i;
for(int i=1;i<=n;i++) if(sa[i]>w) id[++cur]=sa[i]-w;
memset(buc,0,sizeof buc);
for(int i=1;i<=n;i++) buc[rk[i]]++;
for(int i=1;i<=m;i++) buc[i]+=buc[i-1];
for(int i=n;i>=1;i--) sa[buc[rk[id[i]]]--]=id[i];
memcpy(lrk,rk,sizeof lrk),p=0;
for(int i=1;i<=n;i++)
{
if(lrk[sa[i]]==lrk[sa[i-1]]&&lrk[sa[i]+w]==lrk[sa[i-1]+w]) rk[sa[i]]=p;
else rk[sa[i]]=++p;
}
if(p==n) break;
}
for(int i=1,h=0;i<=n;i++)
{
if(!rk[i]) continue;
if(h) h--;
while(a[i+h]==a[sa[rk[i]-1]+h]) h++;
hi[rk[i]]=h;
}
int L=0,R=n>>1,ans=0;
while(L<=R)
{
int mid=L+R>>1;
if(check(mid)) L=mid+1,ans=mid;
else R=mid-1;
}
if(ans>=4) std::cout<<ans+1<<'\n';
else std::cout<<"0\n";
}
}
B. 【例题2】最长公共子串
把所有字符串正着反着都拼起来,拼成一个大串,二分长度 \(mid\),如果一段区间 \(\min height_i\ge x\) 且区间内的前缀有来自 \(1\) 到 \(n\) 的所有字符串的,那么就合法。
实际上我们不需要真的用一个 set 来存储字符串编号。
#include <iostream>
#include <algorithm>
#include <cstring>
#define N 200005
int n,len,sa[N<<1],rk[N<<1],id[N],lrk[N],buc[N],hi[N],str[N];
char S[N<<1];
struct Set
{
int vis[105],cur,pcnt;
Set() {cur=1;}
void ins(int x) {if(!x) return;if(vis[x]!=cur) pcnt++;vis[x]=cur;}
void clear() {pcnt=0,cur++;}
int siz() {return pcnt;}
} mp;
void SA()
{
int m=256,p;
memset(buc,0,sizeof buc);
for(int i=1;i<=len;i++) buc[rk[i]=S[i]]++;
for(int i=1;i<=m;i++) buc[i]+=buc[i-1];
for(int i=len;i>=1;i--) sa[buc[rk[i]]--]=i;
for(int w=1;;w<<=1,m=p)
{
int cur=0;
for(int i=len-w+1;i<=len;i++) id[++cur]=i;
for(int i=1;i<=len;i++) if(sa[i]>w) id[++cur]=sa[i]-w;
memset(buc,0,sizeof buc);
for(int i=1;i<=len;i++) buc[rk[i]]++;
for(int i=1;i<=m;i++) buc[i]+=buc[i-1];
for(int i=len;i>=1;i--) sa[buc[rk[id[i]]]--]=id[i];
memcpy(lrk,rk,sizeof lrk),p=0;
for(int i=1;i<=len;i++)
{
if(lrk[sa[i]]==lrk[sa[i-1]]&&lrk[sa[i]+w]==lrk[sa[i-1]+w]) rk[sa[i]]=p;
else rk[sa[i]]=++p;
}
if(p==len) break;
}
for(int i=1,h=0;i<=len;i++)
{
if(!rk[i]) continue;
if(h) h--;
while(S[i+h]==S[sa[rk[i]-1]+h]) h++;
hi[rk[i]]=h;
}
}
bool check(int x)
{
mp.clear();
for(int i=1;i<=len;i++)
{
if(hi[i]>=x) mp.ins(str[sa[i]]),mp.ins(str[sa[i-1]]);
else
{
if(mp.siz()==n) return 1;
mp.clear();
}
}
return mp.siz()==n;
}
int main()
{
std::ios::sync_with_stdio(0);
std::cin.tie(0),std::cout.tie(0);
int T;
std::cin>>T;
while(T--)
{
std::cin>>n,len=0;
int L=0,R=0,ans=0;
for(int i=1,j=0;i<=n;i++)
{
std::string s;
std::cin>>s;
R=std::max(R,(int)s.size());
for(int j=0;j<s.size();j++) S[++len]=s[j],str[len]=i;
S[++len]=++j,str[len]=i;
for(int j=s.size()-1;j>=0;j--) S[++len]=s[j],str[len]=i;
S[++len]=++j,str[len]=i;
}
SA();
while(L<=R)
{
int mid=L+R>>1;
if(check(mid)) L=mid+1,ans=mid;
else R=mid-1;
}
std::cout<<ans<<'\n';
}
}
C. 【例题3】连续重复子串
用 ST 表维护 \(height\) 来 \(O(1)\) 求任意两后缀 \(LCP\)。
枚举重复段长度 \(l\),枚举起始位置 \(j\),令 \(cnt\) 表示从 \(j\) 开始的连续重复段个数,前缀 \(j\) 和前缀 \(j+l\) 的 \(LCP\) 中含有的重复段个数即为 \(cnt-1\);如果还多出来一部分凑不满一个重复段,那么从 \(j\) 前面开始找,看看是否还能多找到一个重复段,最后处理答案。
时间复杂度 \(O(\sum\frac{n}{i})=O(n\log n)\)。
#include <iostream>
#include <cstring>
#define N 100005
std::string s;
int n,sa[N<<1],rk[N<<1],lrk[N],id[N],buc[N],hi[N],ans[N];
namespace ST
{
int f[N][20],l2[N];
void init()
{
for(int i=2;i<=n;i++) l2[i]=l2[i>>1]+1;
for(int i=1;i<=n;i++) f[i][0]=hi[i];
for(int i=1;(1<<i)<=n;i++) for(int j=1;j+(1<<i)-1<=n;j++)
f[j][i]=std::min(f[j][i-1],f[j+(1<<i-1)][i-1]);
// for(int i=0;(1<<i)<=n;i++)
// {
// for(int j=1;j+(1<<i)-1<=n;j++) printf("%d ",f[j][i]);
// printf("\n");
// }
}
int qr(int l,int r)
{
if(l>r) std::swap(l,r);
l++;
int d=l2[r-l+1];
return std::min(f[l][d],f[r-(1<<d)+1][d]);
}
};
void buildSA()
{
int m=256,p;
memset(buc,0,sizeof buc);
for(int i=1;i<=n;i++) buc[rk[i]=s[i-1]]++;
for(int i=1;i<=m;i++) buc[i]+=buc[i-1];
for(int i=n;i>=1;i--) sa[buc[rk[i]]--]=i;
for(int w=1;;w<<=1,m=p)
{
int cur=0;
for(int i=n-w+1;i<=n;i++) id[++cur]=i;
for(int i=1;i<=n;i++) if(sa[i]>w) id[++cur]=sa[i]-w;
memset(buc,0,sizeof buc);
for(int i=1;i<=n;i++) buc[rk[i]]++;
for(int i=1;i<=m;i++) buc[i]+=buc[i-1];
for(int i=n;i>=1;i--) sa[buc[rk[id[i]]]--]=id[i];
memcpy(lrk,rk,sizeof lrk),p=0;
for(int i=1;i<=n;i++)
{
if(lrk[sa[i]]==lrk[sa[i-1]]&&lrk[sa[i]+w]==lrk[sa[i-1]+w]) rk[sa[i]]=p;
else rk[sa[i]]=++p;
}
if(p>=n) break;
}
for(int i=1,h=0;i<=n;i++)
{
if(!rk[i]) continue;
if(h) h--;
while(s[i+h-1]==s[sa[rk[i]-1]+h-1]) h++;
hi[rk[i]]=h;
}
/*for(int i=1;i<=n;i++) printf("%d ",sa[i]);
printf("\n");
for(int i=1;i<=n;i++) printf("%d ",hi[i]);
printf("\n");*/
}
int main()
{
std::ios::sync_with_stdio(0);
std::cin.tie(0),std::cout.tie(0);
for(int sbSA=1;;sbSA++)
{
std::cin>>s;
if(s[0]=='#') return 0;
n=s.size(),buildSA(),ST::init();
int mx=0,lans=0;
for(int l=1;l<=n;l++) for(int j=1;j+l<=n;j+=l)
{
int k=ST::qr(rk[j],rk[j+l]),res=k/l+1,pos=j-(l-(k%l));
if(pos>0&&k%l&&ST::qr(rk[pos],rk[pos+l])) res++;
if(res>mx) mx=res,lans=0;
if(res==mx) ans[++lans]=l;
}
int mn=0,mp=0;
for(int i=1;i<=n&&!mn;i++) for(int j=1;j<=lans;j++)
if(ST::qr(i,rk[sa[i]+ans[j]])>=(mx-1)*ans[j])
{mn=ans[j],mp=sa[i];break;}
std::cout<<"Case "<<sbSA<<": ";
for(int i=mp;i<mp+mn*mx;i++) std::cout<<s[i-1];
std::cout<<'\n';
}
}
D. 后缀求和
直接求显然是不行的。
将所给的后缀按 \(rk\) 排序,求出相邻元素两两 \(LCP\),考虑每个 \(LCP\) 对答案的贡献,用单调栈可求出每个 \(LCP\) 贡献区间的左右端点,答案即为每个 \(LCP\) 的长度与其贡献区间长度的乘积。
#include <iostream>
#include <algorithm>
#include <vector>
#include <cstring>
#define N 1000005
#define int long long
int n,q,sa[N<<1],rk[N<<1],lrk[N],id[N],buc[N],hi[N],lcp[N<<3],L[N<<3],R[N<<3],st[N<<3];
std::string s;
std::vector<int> mp;
namespace ST
{
int f[N][20],l2[N];
void init()
{
for(int i=2;i<=n;i++) l2[i]=l2[i>>1]+1;
for(int i=1;i<=n;i++) f[i][0]=hi[i];
for(int i=1;(1<<i)<=n;i++) for(int j=1;j+(1<<i)-1<=n;j++)
f[j][i]=std::min(f[j][i-1],f[j+(1<<i-1)][i-1]);
}
int qr(int l,int r) {if(l>r) std::swap(l,r);int d=l2[r-l+1];return std::min(f[l][d],f[r-(1<<d)+1][d]);}
};
void buildSA()
{
int m=256,p;
for(int i=1;i<=n;i++) buc[rk[i]=s[i-1]]++;
for(int i=1;i<=m;i++) buc[i]+=buc[i-1];
for(int i=n;i>=1;i--) sa[buc[rk[i]]--]=i;
for(int w=1;;w<<=1,m=p)
{
int cur=0;
for(int i=n-w+1;i<=n;i++) id[++cur]=i;
for(int i=1;i<=n;i++) if(sa[i]>w) id[++cur]=sa[i]-w;
memset(buc,0,sizeof buc);
for(int i=1;i<=n;i++) buc[rk[i]]++;
for(int i=1;i<=m;i++) buc[i]+=buc[i-1];
for(int i=n;i>=1;i--) sa[buc[rk[id[i]]]--]=id[i];
memcpy(lrk,rk,sizeof lrk),p=0;
for(int i=1;i<=n;i++)
{
if(lrk[sa[i]]==lrk[sa[i-1]]&&lrk[sa[i]+w]==lrk[sa[i-1]+w]) rk[sa[i]]=p;
else rk[sa[i]]=++p;
}
if(p>=n) break;
}
for(int i=1,h=0;i<=n;i++)
{
if(!rk[i]) continue;
if(h) h--;
while(s[i+h-1]==s[sa[rk[i]-1]+h-1]) h++;
hi[rk[i]]=h;
}
}
signed main()
{
std::ios::sync_with_stdio(0);
std::cin.tie(0),std::cout.tie(0);
std::cin>>n>>q>>s;
buildSA();
ST::init();
for(int i=1,t;i<=q;i++)
{
std::cin>>t;
mp.clear();
for(int j=1,x;j<=t;j++) std::cin>>x,mp.push_back(rk[x]);
if(t==1) {std::cout<<"0\n";continue;}
std::sort(mp.begin(),mp.end());
mp.erase(std::unique(mp.begin(),mp.end()),mp.end());
int len=0,tp=0,ans=0;
for(int i=1;i<mp.size();i++) lcp[++len]=ST::qr(mp[i-1]+1,mp[i]);
for(int j=1;j<=len;j++)
{
L[j]=R[j]=j;
while(tp&&lcp[st[tp]]>lcp[j])
{
L[j]=L[st[tp]],R[st[tp-1]]=R[st[tp]];
tp--;
}
st[++tp]=j;
}
while(tp) R[st[tp-1]]=R[st[tp]],tp--;
for(int j=1;j<=len;j++) ans+=(j-L[j]+1)*(R[j]-j+1)*lcp[j];
std::cout<<ans<<'\n';
}
}
E. 可重叠子串
维护 \(height\) 数组上长度为 \(k\) 的滑动窗口最小值的最大值即可。
#include <iostream>
#include <cstring>
#define N 1000005
int n,m,K,a[N],sa[N<<1],rk[N<<1],id[N],lrk[N],buc[N],hi[N],q[N];
void buildSA()
{
for(int i=1;i<=n;i++) buc[rk[i]=a[i]]++;
for(int i=1;i<=m;i++) buc[i]+=buc[i-1];
for(int i=n;i>=1;i--) sa[buc[rk[i]]--]=i;
int p=0;
for(int w=1;;w<<=1,m=p)
{
int cur=0;
for(int i=n-w+1;i<=n;i++) id[++cur]=i;
for(int i=1;i<=n;i++) if(sa[i]>w) id[++cur]=sa[i]-w;
memset(buc,0,sizeof buc);
for(int i=1;i<=n;i++) buc[rk[id[i]]]++;
for(int i=1;i<=m;i++) buc[i]+=buc[i-1];
for(int i=n;i>=1;i--) sa[buc[rk[id[i]]]--]=id[i];
memcpy(lrk,rk,sizeof lrk),p=0;
for(int i=1;i<=n;i++)
{
if(lrk[sa[i]]==lrk[sa[i-1]]&&lrk[sa[i]+w]==lrk[sa[i-1]+w]) rk[sa[i]]=p;
else rk[sa[i]]=++p;
}
if(p>=n) break;
}
for(int i=1,h=0;i<=n;i++)
{
if(!rk[i]) continue;
if(h) h--;
while(a[i+h]==a[sa[rk[i]-1]+h]) h++;
hi[rk[i]]=h;
}
}
int main()
{
std::ios::sync_with_stdio(0);
std::cin.tie(0),std::cout.tie(0);
std::cin>>n>>K;
for(int i=1;i<=n;i++) std::cin>>a[i],m=std::max(m,a[i]);
buildSA();
int hd=1,tl=0,ans=0;
K--;
for(int i=1;i<=n;i++)
{
while(hd<=tl&&q[hd]<=i-K) hd++;
while(hd<=tl&&hi[q[tl]]>=hi[i]) tl--;
q[++tl]=i;
if(i>=K) ans=std::max(ans,hi[q[hd]]);
}
std::cout<<ans;
}
## F. 两串求交
把两串拼起来,求出那些与排前一个的后缀不来自同一个字符串的后缀的 $height$ 的最大值。
```cpp
#include <iostream>
#include <cstring>
#define N 200005
std::string s1,s2,S;
int n,m,p,sa[N<<1],rk[N<<1],id[N],buc[N],lrk[N],hi[N],str[N];
void buildSA()
{
n=S.size(),m=128;
for(int i=1;i<=n;i++) buc[rk[i]=S[i-1]]++;
for(int i=1;i<=m;i++) buc[i]+=buc[i-1];
for(int i=n;i>=1;i--) sa[buc[rk[i]]--]=i;
for(int w=1;;w<<=1,m=p)
{
int cur=0;
for(int i=n-w+1;i<=n;i++) id[++cur]=i;
for(int i=1;i<=n;i++) if(sa[i]>w) id[++cur]=sa[i]-w;
memset(buc,0,sizeof buc);
for(int i=1;i<=n;i++) buc[rk[id[i]]]++;
for(int i=1;i<=m;i++) buc[i]+=buc[i-1];
for(int i=n;i>=1;i--) sa[buc[rk[id[i]]]--]=id[i];
memcpy(lrk,rk,sizeof lrk),p=0;
for(int i=1;i<=n;i++)
{
if(lrk[sa[i]]==lrk[sa[i-1]]&&lrk[sa[i]+w]==lrk[sa[i-1]+w]) rk[sa[i]]=p;
else rk[sa[i]]=++p;
}
if(p>=n) break;
}
for(int i=1,h=0;i<=n;i++)
{
if(!rk[i]) continue;
if(h) h--;
while(S[i+h-1]==S[sa[rk[i]-1]+h-1]) h++;
hi[rk[i]]=h;
}
}
int main()
{
std::ios::sync_with_stdio(0);
std::cin.tie(0),std::cout.tie(0);
std::cin>>s1>>s2;
for(int i=1;i<=s1.size();i++) str[i]=1;
for(int i=s1.size()+2;i<=s1.size()+s2.size()+1;i++) str[i]=2;
S=s1+'#'+s2,buildSA();
int ans=0;
for(int i=2;i<=n;i++) if(str[sa[i]]+str[sa[i-1]]==3) ans=std::max(ans,hi[i]);
std::cout<<ans;
}
G. 公共串计数
未完待续……
H. 区间颠倒
未完待续……
I. 串的存在
未完待续……

浙公网安备 33010602011771号