【BZOJ4566】[Haoi2016]找相同字符 后缀数组+单调栈

【BZOJ4566】[Haoi2016]找相同字符

Description

给定两个字符串,求出在两个字符串中各取出一个子串使得这两个子串相同的方案数。两个方案不同当且仅当这两
个子串中有一个位置不同。

Input

两行,两个字符串s1,s2,长度分别为n1,n2。1 <=n1, n2<= 200000,字符串中只有小写字母

Output

输出一个整数表示答案

Sample Input

aabb
bbaa

Sample Output

10

题解:本题跟差异那道题很相似,理论上可以直接一遍sa搞定,但是我比较懒,直接求了3遍sa。

子串相同的方案数=后缀的相同前缀长度总和,两个串的相同后缀长度总和=两个串连一起的总和-两个串内部的长度和

具体做法请见差异

 

#include <cstdio>
#include <cstring>
#include <iostream>
using namespace std;
const int maxn=400010;
int n,m,len1,len2;
int r[maxn],ra[maxn],rb[maxn],st[maxn],sa[maxn],h[maxn],rank[maxn];
int q[maxn],t,ls[maxn],rs[maxn];
long long ans,sum;
char s1[maxn],s2[maxn];
void work()
{
    int i,j,k,*x=ra,*y=rb,p;
    for(i=0;i<m;i++) st[i]=0;
    for(i=0;i<n;i++) st[x[i]=r[i]]++;
    for(i=1;i<m;i++) st[i]+=st[i-1];
    for(i=n-1;i>=0;i--)  sa[--st[x[i]]]=i;
    for(p=j=1;p<n;j<<=1,m=p)
    {
        for(p=0,i=n-j;i<n;i++)   y[p++]=i;
        for(i=0;i<n;i++) if(sa[i]>=j) y[p++]=sa[i]-j;
        for(i=0;i<m;i++) st[i]=0;
        for(i=0;i<n;i++) st[x[y[i]]]++;
        for(i=1;i<m;i++) st[i]+=st[i-1];
        for(i=n-1;i>=0;i--)  sa[--st[x[y[i]]]]=y[i];
        for(swap(x,y),x[sa[0]]=0,i=p=1;i<n;i++)
            x[sa[i]]=(y[sa[i-1]]==y[sa[i]]&&y[sa[i-1]+j]==y[sa[i]+j])?p-1:p++;
    }
    for(i=1;i<n;i++) rank[sa[i]]=i;
    for(i=k=0;i<n-1;h[rank[i++]]=k)
        for(k?k--:0,j=sa[rank[i]-1];r[i+k]==r[j+k];k++);
    sum=0,t=0,h[0]=h[n]=-1;
    for(i=1;i<=n;i++)
    {
        while(t&&h[q[t]]>=h[i])  rs[q[t--]]=i;
        q[++t]=i;
    }
    t=0;
    for(i=n-1;i>=0;i--)
    {
        while(t&&h[q[t]]>h[i])   ls[q[t--]]=i;
        q[++t]=i;
    }
    for(i=1;i<n;i++) sum+=(long long)(i-ls[i])*(rs[i]-i)*h[i];
}
int main()
{
    scanf("%s%s",s1,s2);
    len1=strlen(s1),len2=strlen(s2);
    int i;
    for(i=0;i<len1;i++)  r[i]=s1[i]-'a'+2;
    r[len1]=0,n=len1+1,m=27;
    work(),ans-=sum;
    for(i=0;i<len2;i++)  r[i+len1+1]=s2[i]-'a'+2;
    r[len1+len2+1]=1;
    n=len1+len2+2,m=28;
    work(),ans+=sum;
    for(i=0;i<len2;i++)  r[i]=s2[i]-'a'+2;
    r[len2]=0,n=len2+1,m=27;
    work(),ans-=sum;
    printf("%lld",ans);
    return 0;
}
posted @ 2017-05-09 10:49  CQzhangyu  阅读(356)  评论(0编辑  收藏  举报