[CTSC2010]珠宝商

https://www.luogu.org/problemnew/show/P4218

题解

树上路径问题可以想到用点分治。

考虑固定一分治点,然后算所有跨过这个点的所有路径的答案和。

假定我们分治重心为\(x\),那么一条合法的路径可以表示为\(u->x\ x->v\)

然后我们还要考虑两条路径拼接的情况。

那么考虑对于模板串\(m\)的一个分割点,\(u->x\)的路径一定是\(m\)的这个分割点前的串的后缀,\(x->v\)一定是后面的串的前缀。

后缀这个东西相当于是\(parent\)树上的子树,前缀的话得对反串建\(parent\)树。

所以我们每次的操作就是在当前串的前面加一个字母。

这个怎么做?

考虑这个操作其实就是在\(parent\)树上走,所以把\(parent\)树处理一下就好了。

然后在\(parent\)树上打完标记之后,最后要把标记全部下放到代表模板串前缀的点,然后算贡献。

但是这样每次后续都需要\(O(m)\)的处理,考虑到如果对于一个小的子树都要\(O(m)\)的做非常浪费,所以我们就设一个阈值,当块小的时候暴力做。

然后一朵菊花就卡掉了。。

所以还需要在处理儿子节点时候分大小讨论一下。。

懒得写。。

代码

#include<bits/stdc++.h>
#define N 100009
#define P pair<int,int>
#define mm make_pair
#define fi first
#define se second
using namespace std;
typedef long long ll;
ll ans;
int head[N],tot,tong1[N],tong2[N],size[N],dp[N],root,sum,n,m,n1,to[N];
char s[N],s1[N];
bool vis[N];
inline ll rd(){
    ll x=0;char c=getchar();bool f=0;
    while(!isdigit(c)){if(c=='-')f=1;c=getchar();}
    while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
    return f?-x:x;
}
struct edge{int n,to;}e[N<<1];
inline void add(int u,int v){e[++tot].n=head[u];e[tot].to=v;head[u]=tot;}
struct SAM{
    int ch[N][26],fa[N],l[N],cnt,last,id[N],rig[N],ct[N],tong[N],pr[N][26],siz[N],rk[N];
    SAM(){cnt=last=1;}
    inline void ins(int x,int now){
        int p=last,np=++cnt;last=np;l[np]=l[p]+1;//cout<<np<<"   WA "; 
        id[np]=now;rig[np]=now;
        for(;!ch[p][x];p=fa[p])ch[p][x]=np;siz[np]=1;
        if(!p)fa[np]=1;
        else{
            int q=ch[p][x];
            if(l[p]+1==l[q])fa[np]=q;
            else{
              int nq=++cnt;l[nq]=l[p]+1;
              memcpy(ch[nq],ch[q],sizeof(ch[nq]));
              fa[nq]=fa[q];fa[q]=fa[np]=nq;
              for(;ch[p][x]==q;p=fa[p])ch[p][x]=nq;
            }
        }
    }
    inline void build(int tag){
        for(int i=0;i<=m;++i)to[i]=0;
        for(int i=1;i<=cnt;++i)to[l[i]]++;
        for(int i=1;i<=m;++i)to[i]+=to[i-1];
        for(int i=1;i<=cnt;++i)rk[to[l[i]]--]=i;
        for(int i=cnt;i>=1;--i){
            int x=rk[i];
            siz[fa[x]]+=siz[x];
            if(!rig[fa[x]])rig[fa[x]]=rig[x];
            if(!tag)pr[fa[x]][s1[rig[x]-l[fa[x]]]-'a']=x;
            else pr[fa[x]][s1[rig[x]+l[fa[x]]]-'a']=x;
        }
    }
    inline void push(){
        for(int i=1;i<=cnt;++i){
            int x=rk[i];
            ct[x]+=ct[fa[x]];
            if(id[x])tong[id[x]]+=ct[x];
        }
        for(int i=0;i<=cnt;++i)ct[i]=0;
    }
    inline P add(P now,int x,int tag){
        int xx=now.fi,yy=now.se;
        if(yy==l[xx]){
           yy++;xx=pr[xx][x];
           return mm(xx,yy);	
        }
        if(!tag){
          if(s1[rig[xx]-yy]-'a'!=x)return mm(0,yy);
          return mm(xx,yy+1);
        }
        if(tag){
          if(s1[rig[xx]+yy]-'a'!=x)return mm(0,yy);
          return mm(xx,yy+1);
        }
    }
}T1,T2;
void getroot(int u,int fa){
    dp[u]=0;size[u]=1;
    for(int i=head[u];i;i=e[i].n)if(e[i].to!=fa&&!vis[e[i].to]){
        int v=e[i].to;
        getroot(v,u);
        size[u]+=size[v];
        dp[u]=max(dp[u],size[v]);
    }
    dp[u]=max(dp[u],sum-size[u]);
    if(dp[u]<dp[root])root=u;
}
void getsize(int u,int fa){
    size[u]=1;
    for(int i=head[u];i;i=e[i].n)if(e[i].to!=fa&&!vis[e[i].to]){
        int v=e[i].to;
        getsize(v,u);
        size[u]+=size[v];
    }
}
void dfs1(int u,int now,int fa,int top,int tag){
    now=T1.ch[now][s[u]-'a'];
    if(u==top)tag=1;
    if(!now)return;
    if(tag)ans+=T1.siz[now];
    for(int i=head[u];i;i=e[i].n)if(e[i].to!=fa&&!vis[e[i].to]){
        int v=e[i].to;
        dfs1(v,now,u,top,tag);
    }
}
inline void work1(int u,int fa,int top){
    dfs1(u,1,0,top,0);
    for(int i=head[u];i;i=e[i].n)if(e[i].to!=fa&&!vis[e[i].to]){
        int v=e[i].to;
        work1(v,u,top);
    }
}
void getdeep(P now1,P now2,int u,int fa){
    if(now1.fi)now1=T1.add(now1,s[u]-'a',0);
    if(now2.fi)now2=T2.add(now2,s[u]-'a',1);
    if(now1.fi)T1.ct[now1.fi]++;
    if(now2.fi)T2.ct[now2.fi]++;
    for(int i=head[u];i;i=e[i].n)if(!vis[e[i].to]&&e[i].to!=fa){
        int v=e[i].to;
        getdeep(now1,now2,v,u);
    }
}
inline void work2(int u){
    P now1=T1.add(mm(1,0),s[u]-'a',0),now2=T2.add(mm(1,0),s[u]-'a',1);
    if(now1.fi)T1.ct[now1.fi]++;
    else return;
    if(now2.fi)T2.ct[now2.fi]++;
    T1.push();T2.push();
    for(int i=1;i<=m;++i)if(T1.tong[i]){
        ans++;tong1[i]++;tong2[i]++;
        T1.tong[i]--;T2.tong[i]--;
    }
    for(int i=head[u];i;i=e[i].n)if(!vis[e[i].to]){
        int v=e[i].to;
        getdeep(now1,now2,v,u);
        T1.push();T2.push(); 
        for(int j=1;j<=m;++j){
            ans+=1ll*T1.tong[j]*tong2[j]+1ll*T2.tong[j]*tong1[j];
            tong1[j]+=T1.tong[j];
            tong2[j]+=T2.tong[j]; 
            T1.tong[j]=T2.tong[j]=0;
        }
    }
    for(int i=1;i<=m;++i)tong1[i]=tong2[i]=0;
}
inline void calc(int u){
    if(size[u]<=n1)work1(u,0,u);
    else work2(u);
    vis[u]=1;
}
void solve(int u){
    calc(u);
    for(int i=head[u];i;i=e[i].n)if(!vis[e[i].to]){
        int v=e[i].to;
        root=n+1;sum=size[v];
        getroot(v,u);getsize(root,0);
        solve(v);
    }
}
int main(){
    n=rd();m=rd();n1=sqrt(n);
    int x,y;
    for(int i=1;i<n;++i){
        x=rd();y=rd();
        add(x,y);add(y,x);
    }
    scanf("%s",s+1);
    scanf("%s",s1+1);
    for(int i=1;i<=m;++i)T1.ins(s1[i]-'a',i);
    for(int i=m;i>=1;--i)T2.ins(s1[i]-'a',i);
    T1.build(0);T2.build(1);	
    root=n+1;dp[root]=n+1;sum=n;
    getroot(1,0);getsize(root,0);
    solve(root);
    cout<<ans;
    return 0;
}
posted @ 2019-05-07 10:52  comld  阅读(278)  评论(0编辑  收藏  举报