P3346 [ZJOI2015]诸神眷顾的幻想乡

思路

注意到叶子节点(度数为1)只有20个,可以分别以这20个节点为根,把所有子串插入SAM中,统计最后的本质不同的子串个数

所以就是广义SAM了

然后注意要判断一下有无重复插入

代码

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>
#define int long long
using namespace std;
const int MAXN = 4000400;
int Nodecnt,trans[MAXN][10],maxlen[MAXN],suflink[MAXN],minlen[MAXN],cnt,v[MAXN<<1],fir[MAXN],nxt[MAXN<<1],d[MAXN],w_p[MAXN],root,n,c;
void addedge(int ui,int vi){
    ++cnt;
    v[cnt]=vi;
    d[vi]++;
    nxt[cnt]=fir[ui];
    fir[ui]=cnt;
}
int New_state(int _maxlen,int _minlen,int *_trans,int _suflink){
    ++Nodecnt;
    maxlen[Nodecnt]=_maxlen;
    minlen[Nodecnt]=_minlen;
    suflink[Nodecnt]=_suflink;
    if(_trans){
        for(int i=0;i<c;i++){
            trans[Nodecnt][i]=_trans[i];
        }
    }
    return Nodecnt;
}   
int add_len(int u,int c){
    if(trans[u][c]){
        int v=trans[u][c];
        if(maxlen[v]==maxlen[u]+1)
            return v;
        else{
            int y=New_state(maxlen[u]+1,0,trans[v],suflink[v]);
            suflink[v]=y;
            minlen[v]=maxlen[y]+1;
            while(u&&(trans[u][c]==v)){
                trans[u][c]=y;
                u=suflink[u];
            }
            minlen[y]=maxlen[suflink[y]]+1; 
            return y;
        }   
    }
    else{
        int z=New_state(maxlen[u]+1,0,NULL,0);
        while(u&&(trans[u][c]==0)){
            trans[u][c]=z;
            u=suflink[u];
        }
        if(!u){
            minlen[z]=1;
            suflink[z]=1;
            return z;
        }
        int v=trans[u][c];
        if(maxlen[v]==maxlen[u]+1){
            minlen[z]=maxlen[v]+1;
            suflink[z]=v;
            return z;
        }
        int y=New_state(maxlen[u]+1,0,trans[v],suflink[v]);
        suflink[v]=suflink[z]=y;
        minlen[v]=minlen[z]=maxlen[y]+1;
        while(u&&(trans[u][c]==v)){
            trans[u][c]=y;
            u=suflink[u];
        }
        minlen[y]=maxlen[suflink[y]]+1;
        return z;
    }
}
void dfs(int o,int fa,int last){
    last=add_len(last,w_p[o]);
    for(int i=fir[o];i;i=nxt[i]){
        if(v[i]==fa)
            continue;
        dfs(v[i],o,last);
    }
}
signed main(){
    scanf("%lld %lld",&n,&c);
    for(int i=1;i<=n;i++){
        scanf("%lld",&w_p[i]);
    }
    for(int i=1;i<n;i++){
        int a,b;
        scanf("%lld %lld",&a,&b);
        addedge(a,b);
        addedge(b,a);
    }
    Nodecnt=1;
    root=1;
    for(int i=1;i<=n;i++){
        if(d[i]==1){
            dfs(i,0,root);
        }
    }
    int ans=0;
    for(int i=2;i<=Nodecnt;i++){
        // printf("maxlen[%d]=%d minlen[%d]=%d\n",i,maxlen[i],i,minlen[i]);
        ans=ans+maxlen[i]-minlen[i]+1;
    }
    printf("%lld\n",ans);
    return 0;
}
posted @ 2019-04-01 15:06  dreagonm  阅读(122)  评论(0编辑  收藏  举报