BZOJ 3238: [Ahoi2013]差异((单调栈+后缀数组)/(后缀树))

[传送门[(https://www.lydsy.com/JudgeOnline/problem.php?id=3238)

解题思路

  首先原式可以把\(len\)那部分直接算出来,然后通过后缀数组求\(lcp\)。算\(\sum lcp\)的时候,刚开始傻了想要直接算贡献,结果越写越乱,后来想想只需要用单调栈把每个点的控制范围算出来即可,正着做一遍反着做一遍。注意还要考虑两个\(h[i]\)相邻并相等时的影响。还有一种比较自然的解法是后缀树,\(lcp\)其实就为两个点的\(lca\)的深度,所以建出后缀树后直接按拓扑序\(dp\)一下即可。

代码

后缀数组:

#include<iostream>
#include<cstdio>
#include<cstring>
#include<cstdlib>

using namespace std;
const int MAXN = 500005;
typedef long long LL; 

int n,m,height[MAXN],num,stk[MAXN],top,l[MAXN],r[MAXN];
int sa[MAXN],rk[MAXN],x[MAXN<<1],y[MAXN<<1],c[MAXN];
char s[MAXN];
LL ans;

inline void get_SA(){
    for(int i=1;i<=n;i++) x[i]=s[i],c[x[i]]++;
    for(int i=2;i<=m;i++) c[i]+=c[i-1];
    for(int i=n;i;i--) sa[c[x[i]]--]=i;
    for(int k=1;k<=n;k<<=1){num=0;
        for(int i=n-k+1;i<=n;i++) y[++num]=i;
        for(int i=1;i<=n;i++) if(sa[i]>k) y[++num]=sa[i]-k;
        memset(c,0,sizeof(c));
        for(int i=1;i<=n;i++) c[x[i]]++;
        for(int i=2;i<=m;i++) c[i]+=c[i-1];
        for(int i=n;i;i--) sa[c[x[y[i]]]--]=y[i],y[i]=0;
        swap(x,y);num=1;x[sa[1]]=1;
        for(int i=2;i<=n;i++)  
            x[sa[i]]=(y[sa[i]]==y[sa[i-1]] && y[sa[i]+k]==y[sa[i-1]+k])?num:++num;
        if(num==n) break;m=num;
    }
}

inline void get_height(){
    for(int i=1;i<=n;i++) rk[sa[i]]=i;int j,k=0;
    for(int i=1;i<=n;i++){
        if(rk[i]==1) continue;
        if(k) k--;j=sa[rk[i]-1];
        while(i+k<=n && j+k<=n && s[i+k]==s[j+k]) k++;
        height[rk[i]]=k;
    }
}

void solve(){
	for(int i=1;i<=n;i++){
		while(top && height[i]<=height[stk[top]]) l[stk[top]]=i-1,top--;
		if(height[i]) stk[++top]=i;
	}
	while(top) l[stk[top--]]=n;
	for(int i=n;i;i--){
		while(top && height[i]<height[stk[top]]) r[stk[top]]=i+1,top--;
		if(height[i]) stk[++top]=i;
	}	
	while(top) r[stk[top--]]=1;
	for(int i=1;i<=n;i++) ans-=(LL)height[i]*(l[i]-i+1)*(i-r[i]+1)*2;
}

int main(){
    scanf("%s",s+1);n=strlen(s+1);m='z';
    get_SA();get_height();
//    for(int i=1;i<=n;i++) cout<<sa[i]<<" ";cout<<endl;
//    for(int i=1;i<=n;i++) cout<<height[i]<<" ";cout<<endl;
    ans=(LL)n*(n-1)/2*(n+1);
    solve();printf("%lld\n",ans);
    return 0;
}

后缀树:

#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#include<cstdlib>
#define int long long

using namespace std;
const int MAXN = 500005<<1;
typedef long long LL;

char s[MAXN];
int n,siz[MAXN],lst,cnt;
int fa[MAXN],ch[MAXN][27],l[MAXN],a[MAXN],c[MAXN];
LL ans;

inline void Insert(int c){
    int p=lst,np=++cnt;lst=np;l[np]=l[p]+1;
    for(;p && !ch[p][c];p=fa[p]) ch[p][c]=np;
    if(!p) fa[np]=1;
    else{
        int q=ch[p][c];
        if(l[q]==l[p]+1) fa[np]=q;
        else {
            int nq=++cnt;l[nq]=l[p]+1;
            memcpy(ch[nq],ch[q],sizeof(ch[q]));
            fa[nq]=fa[q];fa[q]=fa[np]=nq;
            for(;ch[p][c]==q;p=fa[p]) ch[p][c]=nq;
        }
    }
    siz[np]=1;
}

signed main(){
    scanf("%s",s+1);n=strlen(s+1);lst=cnt=1;
    for(int i=n;i;i--) Insert(s[i]-'a'+1);
    ans=(LL)n*(n-1)/2*(n+1);
    for(int i=1;i<=cnt;i++) c[l[i]]++;
    for(int i=1;i<=cnt;i++) c[i]+=c[i-1];
    for(int i=1;i<=cnt;i++) a[c[l[i]]--]=i;
    for(int i=cnt;i;i--){
        ans-=(LL)siz[a[i]]*siz[fa[a[i]]]*l[fa[a[i]]]*2;
        siz[fa[a[i]]]+=siz[a[i]];
    }
    printf("%lld\n",ans);
    return 0;
}

posted @ 2018-12-12 14:24  Monster_Qi  阅读(136)  评论(0编辑  收藏  举报