ccz181078

  博客园 :: 首页 :: 博问 :: 闪存 :: 新随笔 :: 联系 :: :: 管理 ::
树包含N个点和N-1条边。树的边有2中颜色红色('r')和黑色('b')。给出这N-1条边的颜色,求有多少节点的三元组(a,b,c)满足:节点a到节点b、节点b到节点c、节点c到节点a的路径上,每条路径都至少有一条边是红色的。
注意(a,b,c), (b,a,c)以及所有其他排列被认为是相同的三元组。输出结果对1000000007取余的结果。
 
 
Input
第1行:1个数N(1 <= N <= 50000)
第2 - N行:每行2个数加一个颜色,表示边的起始点和结束的以及颜色。
Output
输出1个数,对应符合条件的3元组的数量。

预处理对每条边(a,b),从a出发经过b的路径有多少条是经过/不经过红边的

存在两种情况:

1.a,b,c两两间路径不经过第三点,这时三条路径有唯一公共点,公共点到a,b,c的路径至少有两条有红边,由此可以统计

2.a,b,c中有两点的路径经过第三点,这时枚举被经过的点进行统计

#include<cstdio>
typedef long long i64;
const int N=50007,R=N*40;
char buf[R+7],*ptr=buf-1;
int _(){
    int x=0,c=*++ptr;
    while(c<48)c=*++ptr;
    while(c>47)x=x*10+c-48,c=*++ptr;
    return x;
}
int _c(){
    int c=*++ptr;
    while(c<'a')c=*++ptr;
    return c=='r';
}
int n;
int es[N*2],enx[N*2],e0[N],ev[N*2],ep=2;
int f1[N],f2[N],f3[N],sz[N];
i64 ans=0;
void dfs1(int w,int pa){
    sz[w]=1;
    for(int i=e0[w];i;i=enx[i]){
        int u=es[i];
        if(u!=pa){
            dfs1(u,w);
            sz[w]+=sz[u];
            if(ev[i])f1[w]+=f3[u]=sz[u];
            else f1[w]+=f3[u]=f1[u];
        }
    }
}
void dfs2(int w,int pa){
    for(int i=e0[w];i;i=enx[i]){
        int u=es[i];
        if(u!=pa){
            if(ev[i])f2[u]=n-sz[u];
            else f2[u]=f2[w]+f1[w]-f1[u];
            dfs2(u,w);
        }
    }
}
void dfs3(int w,int pa){
    static i64 fs[N],fl[N],fr[N],ss[N],sl[N],sr[N];
    static int p;
    for(int i=e0[w];i;i=enx[i]){
        int u=es[i];
        if(u!=pa)dfs3(u,w);
    }
    p=0;
    for(int i=e0[w];i;i=enx[i]){
        int u=es[i];
        ++p;
        if(u!=pa){
            fs[p]=fl[p]=fr[p]=f3[u];
            ss[p]=sl[p]=sr[p]=sz[u]-fs[p];
        }else{
            fs[p]=fl[p]=fr[p]=f2[w];
            ss[p]=sl[p]=sr[p]=n-sz[w]-fs[p];
        }
    }
    i64 a0=ans;
    for(int i=2;i<=p;++i)fl[i]+=fl[i-1],sl[i]+=sl[i-1];
    for(int i=p-1;i;--i)fr[i]+=fr[i+1],sr[i]+=sr[i+1];
    for(int i=2;i<p;++i){
        ans+=fl[i-1]*(fs[i]*sr[i+1]+ss[i]*fr[i+1]);
        ans+=(sl[i-1]+fl[i-1])*fs[i]*fr[i+1];
    }
    for(int i=1;i<p;++i)ans+=fl[i]*fs[i+1];
}
int main(){
    fread(buf,1,R,stdin);
    n=_();
    for(int i=1;i<n;++i){
        int a=_(),b=_(),c=_c();
        es[ep]=b;enx[ep]=e0[a];ev[ep]=c;e0[a]=ep++;
        es[ep]=a;enx[ep]=e0[b];ev[ep]=c;e0[b]=ep++;
    }
    dfs1(1,0);
    dfs2(1,0);
    dfs3(1,0);
    printf("%lld",ans%1000000007);
    return 0;
}

 

posted on 2016-09-09 21:54  nul  阅读(217)  评论(0编辑  收藏  举报