loj 6401 字符串

    建出 SAM。 对于每个子串,处理出这个子串最长合法的后缀。

  

#include<bits/stdc++.h>
#define N 200007
using namespace std;
vector<int> v[N];
int last=1,tot=1,ch[N][26],mx[N],len[N],q,p,fa[N],nq;
char s[N],g[N];
void Sam(int x){
    q=++tot; len[q]=len[last]+1;
    for (;last&&!ch[last][x];last=fa[last]) ch[last][x]=q;
    if (!last) fa[q]=1; else{
        int p=ch[last][x];
        if (len[last]+1==len[p]) fa[q]=p;
        else {
            nq=++tot; len[nq]=len[last]+1;
            memcpy(ch[nq],ch[p],sizeof ch[p]);
            fa[nq]=fa[p]; fa[q]=fa[p]=nq;
            for (;last&&ch[last][x]==p;last=fa[last]) ch[last][x]=nq;
        }
    } last=q;
}
int Len,now,sum[N],l,k,ma[N];
void Lxf(int x){
    for (auto i:v[x]) 
        Lxf(i),mx[x]=max(mx[i],mx[x]);
}
long long nb;
signed main () {
    scanf("%s",s+1);
    scanf("%s",g+1);
    Len=strlen(s+1);
    for (int i=1;i<=Len;i++) {
        Sam(s[i]-'a'); 
        sum[i]=sum[i-1]+(g[i]=='0');
    } now=1;
    scanf("%d",&k); 
    for (int i=1;i<=Len;i++) {
        now=ch[now][s[i]-'a']; 
        while (sum[i]-sum[p]>k) p++;
        mx[now]=max(mx[now],i-p);
    }
    for (int i=1;i<=tot;i++) v[fa[i]].push_back(i);
    Lxf(1);
    for (int i=1;i<=tot;i++) 
      nb+=max(0,min(mx[i],len[i])-len[fa[i]]);
    printf("%lld\n",nb); 
}

 

posted @ 2018-09-04 21:00  泪寒之雪  阅读(286)  评论(0编辑  收藏  举报