【bzoj4566】[Haoi2016]找相同字符【后缀自动机】

题目传送门
题解:在文本串上建后缀自动机,用模式串在后缀自动机上跑。扫一遍模式串,在后缀自动机上走,走不了就跳fail再走。
走的过程中,维护模式串与文本串匹配的最大长度,并且统计答案。
怎么统计答案呢?
我们知道,状态x表示的字符串的长度为len[x],对应了len[x]len[fail[x]]个子串(每个子串都是当前状态的前缀)。因此,状态x表示的子串总数为(len[x]len[fail[x]])cnt[x]。同时,状态x的fail一定是x表示字符串的后缀。因此,如果当前在后缀自动机上走到了状态x,最大匹配长度为maxl,就把答案加上x沿fail到根节点上所有状态表示的子串总数。注意x这个状态是特殊的,它与模式串做能匹配的子串个数是(maxllen[fail[x]])cnt[x]。其他跳fail到根节点路径上的状态则直接加上其表示的子串总数,因为这些状态是与模式串完全匹配的。(难以描述,自行理解一下吧)
具体实现详见代码。

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
const int N=200005;
int l1,l2,a[N*2];
ll ans=0;
char s1[N],s2[N];
struct SAM{
    int last,tot,len[N*2],fail[N*2],cnt[N*2],ch[N*2][26],c[N*2];
    SAM(){
        last=tot=1;
    }
    void insert(int x){
        int p=last,np=++tot;
        len[np]=len[p]+1;
        last=np;
        for(;p&&!ch[p][x];p=fail[p]){
            ch[p][x]=np;
        }
        if(!p){
            fail[np]=1;
        }else{
            int q=ch[p][x];
            if(len[q]==len[p]+1){
                fail[np]=q;
            }else{
                int nq=++tot;
                len[nq]=len[p]+1;
                memcpy(ch[nq],ch[q],sizeof(ch[q]));
                fail[nq]=fail[q];
                fail[q]=fail[np]=nq;
                for(;p&&ch[p][x]==q;p=fail[p]){
                    ch[p][x]=nq;
                }
            }
        }
        cnt[np]=1;
    }
    void init(){
        for(int i=1;i<=tot;i++){
            c[len[i]]++;
        }
        for(int i=1;i<=l1;i++){
            c[i]+=c[i-1];
        }
        for(int i=1;i<=tot;i++){
            a[c[len[i]]--]=i;
        }
        for(int i=tot;i>=1;i--){
            cnt[fail[a[i]]]+=cnt[a[i]];
        }
    }
    void get(int now,int l){
        while(len[fail[now]]>l){
            now=fail[now];
        }
        while(fail[now]){
            ans+=1LL*cnt[now]*(l-len[fail[now]]);
            l=len[fail[now]];
            now=fail[now];
        }
    }
}sam;
int main(){
    scanf("%s%s",s1,s2);
    l1=strlen(s1);
    l2=strlen(s2);
    for(int i=0;i<l1;i++){
        sam.insert(s1[i]-'a');
    }
    sam.init();
    int now=1,len=0;
    for(int i=0;i<l2;i++){
        while(now&&!sam.ch[now][s2[i]-'a']){
            now=sam.fail[now];
        }
        if(!now){
            now=1;
            len=0;
        }else{
            len=min(len,sam.len[now])+1;
            now=sam.ch[now][s2[i]-'a'];
        }
        sam.get(now,len);
    }
    printf("%lld\n",ans);
    return 0;
}
posted @ 2018-03-24 11:22  ez_2016gdgzoi471  阅读(95)  评论(0编辑  收藏  举报