ccz181078

  博客园 :: 首页 :: 博问 :: 闪存 :: 新随笔 :: 联系 :: :: 管理 ::

Description

给出n个结点的树结构T,其中每一个结点上有一个字符,这里我们所说的字符只考虑大写字母A到Z,再给出长度为m
的模式串s,其中每一位仍然是A到z的大写字母。Alice希望知道,有多少对结点<u,v>满足T上从u到V的最短路径
形成的字符串可以由模式串S重复若干次得到?这里结点对<u,v>是有序的,也就是说<u,v>和<v,u>需要被区分.
所谓模式串的重复,是将若干个模式串S依次相接(不能重叠).例如当S=PLUS的时候,重复两次会得到PLUSPLUS,
重复三次会得到PLUSPLUSPLUS,同时要注恿,重复必须是整数次的。例如当S=XYXY时,因为必须重复整数次,所以X
YXYXY不能看作是S重复若干次得到的。

Input

每一个数据有多组测试,
第一行输入一个整数C,表示总的测试个数。
对于每一组测试来说:
第一行输入两个整数,分别表示树T的结点个数n与模式长度m。结点被依次编号为1到n,
之后一行,依次给出了n个大写字母(以一个长度为n的字符串的形式给出),依次对应树上每一个结点上的字符(
第i个字符对应了第i个结点).
之后n-1行,每行有两个整数u和v表示树上的一条无向边,之后一行给定一个长度为m的由大写字母组成的字符串,
为模式串S。
1<=C<=10,3<=N<=10000003<=M<=1000000

Output

给出C行,对应C组测试。每一行输出一个整数,表示有多少对节点<u,v>满足从u到v的路径形成的字符串恰好是模
式串的若干次重复.

点分治统计答案,hash判定当前点到中心的路径是否是由模式串重复得到的串的前/后缀,为保证复杂度,递归至子树大小不足m时结束,时间复杂度O(Tnlog(n/m))

#include<bits/stdc++.h>
int _(){
    int x=0,c=getchar();
    while(c<48)c=getchar();
    while(c>47)x=x*10+c-48,c=getchar();
    return x;
}
typedef unsigned long long u64;
const int N=1e6+7,p=2939;
int T,n,m;
char s1[N],s2[N];
int es[N*2],enx[N*2],e0[N],ep=2,sz[N],SZ,CG;
u64 h1[N],h2[N],ans;
bool ed[N];
void f2(int w,int pa){
    sz[w]=1;
    for(int i=e0[w],u;i;i=enx[i]){
        u=es[i];
        if(u==pa||ed[u])continue;
        f2(u,w);
        sz[w]+=sz[u];
    }
}
int md,dw;
void f3(int w,int pa){
    if(++dw>md)md=dw;
    bool is=sz[w]*2>=SZ;
    for(int i=e0[w],u;i;i=enx[i]){
        u=es[i];
        if(u==pa||ed[u])continue;
        f3(u,w);
        if(sz[u]*2>SZ)is=0;
    }
    if(is)CG=w;
    --dw;
}
int mod[N];
int t1[N],t2[N],a1[N],a2[N];
void f4(int w,int pa,u64 h,int dep){
    h=h*p+s1[w];
    if(h==h1[dep])++a1[dep[mod]];
    if(h==h2[dep])++a2[dep[mod]];
    for(int i=e0[w],u;i;i=enx[i]){
        u=es[i];
        if(u==pa||ed[u])continue;
        f4(u,w,h,dep+1);
    }
}
void f1(int w){
    f2(w,0);
    SZ=sz[w];
    md=dw=0;
    f3(w,0);
    w=CG;
    ed[w]=1;
    if(md*2<m||SZ<m)return;
    memset(t1,0,sizeof(int)*(m+1));
    memset(t2,0,sizeof(int)*(m+1));
    for(int i=e0[w];i;i=enx[i]){
        int u=es[i];
        if(ed[u])continue;
        memset(a1,0,sizeof(int)*(m+1));
        memset(a2,0,sizeof(int)*(m+1));
        f4(u,w,s1[w],2);
        a1[m]=a1[0];
        a2[m]=a2[0];
        a1[m+1]=a1[1];
        a2[m+1]=a2[1];
        ans+=a1[0]+a2[0];
        for(int x=0;x<m;++x){
            ans+=u64(a1[m+1-x])*t2[x];
            ans+=u64(a2[m+1-x])*t1[x];
        }
        for(int x=0;x<m;++x)t1[x]+=a1[x],t2[x]+=a2[x];
    }
    for(int i=e0[w];i;i=enx[i]){
        int u=es[i];
        if(!ed[u])f1(u);
    }
}
int main(){
    for(T=_();T;--T){
        ans=0;
        n=_();m=_();
        memset(e0,0,sizeof(int)*(n+1));
        memset(ed,0,n+1);
        scanf("%s",s1+1);
        for(int i=1,a,b;i<n;++i){
            a=_();b=_();
            es[ep]=b;enx[ep]=e0[a];e0[a]=ep++;
            es[ep]=a;enx[ep]=e0[b];e0[b]=ep++;
        }
        scanf("%s",s2+1);
        u64 pp=1;
        for(int i=1,j=1,k=m;i<=n;++i,pp*=p){
            mod[i]=i%m;
            h1[i]=s2[j]*pp+h1[i-1];
            h2[i]=s2[k]*pp+h2[i-1];
            if(m<++j)j=1;
            if(!--k)k=m;
        }
        if(m>1)f1(1);
        else for(int i=1;i<=n;++i)if(s1[i]==s2[1])++ans;
        printf("%llu\n",ans);
    }
    return 0;
}

 

posted on 2017-01-18 13:52  nul  阅读(682)  评论(0编辑  收藏  举报