回滚树形dp(按dfs序dp)——hdu6035

本题前面的操作别的博客里都有。难点在于颜色ci的贡献,如何一次dfs求出答案

先来考虑如何在一次dfs中单独对颜色i进行计算

  用遍历dfs序的方式,在深搜过程中,碰到带有颜色 i 的点 u,u每个颜色不为i的子节点v都会贡献一个联通块,

  v的贡献的联通块大小是size[v]-sum{v中层次最高的以颜色i的结点为根的子树大小}

  那我们要先求出v中层次最高的以颜色i的结点为根的子树大小,所以用sum表示目前为止颜色i的所有子树的大小,用last存下v进入dfs前的sum,即不算v下颜色i的子树时的sum

  到v中去dfs,然后把这些子树的大小加到sum中

  当遍历完v的所有子树后,我们会发现 sum-last 就是v下层次最高的以颜色i的结点为根的子树大小的总和

  那么v贡献出的联通块大小就是 k = size[v]-(sum-last)

  然后去u的下一棵子树v'进行同样的操作,直到遍历完u的所有子树,此时sum已经变成了到u为止(不包含u)的以颜色i为根的子树大小之和

  为了计算u的颜色为i的祖先,需要把u也并入sum里,那么只需要在sum里加入u自己,再加上所有v贡献的联通块即可

  现在已经维护完所有的信息,可以推出u的dfs,往上回滚求出u的祖先结点的信息

再来考虑如何在一次dfs中对所有出现的颜色进行计算,

  我们可以在上面的递归中发现,每中颜色在求贡献只用到了size[],还有每种颜色对应的sum,那么用sum[c]数组来维护颜色c代表的sum即可,就可以在一次dfs中维护多种颜色的贡献

本题和虚树有些类似的地方,首先把每种颜色当成是一个询问,就类似虚树的询问了

然后是回滚dfs的过程,自叶子往上求(其实是按照dfs序)的方式:在初次碰到u时记录进入dfs前的状态,然后dfs处理完其所有子节点的状态后再来计算u的状态

#include<bits/stdc++.h>
#include<vector>
using namespace std;
#define maxn 200005 
#define ll long long

ll ans,color[maxn],size[maxn],sum[maxn];
vector<int>G[maxn];

void dfs1(int u,int pre){
    size[u]=1;
    for(int i=0;i<G[u].size();i++){
        int v=G[u][i];
        if(v==pre)continue;
        dfs1(v,u);
        size[u]+=size[v];
    }
}

//这个树形dp最重要的是理解sum[]数组的含义,sum[x]的更新像虚树的加边一样是自叶子节点往上回滚的 
void dfs2(int u,int pre){
    ll other=0;//other表示为size[u]减去u下所有最高的以color[u]为根的大小 
    for(int i=0;i<G[u].size();i++){
        int v=G[u][i];
        if(v==pre)continue;
        ll last=sum[color[u]];//记录前前面子树里颜色u的子树(虚树)里的值
        dfs2(v,u);
        ll diff=sum[color[u]]-last;//v的子树里颜色为color[u]的个数
        //v树下不包含color[u]的联通块的大小 
        ans+=(size[v]-diff-1)*(size[v]-diff)/2;
        other+=size[v]-diff; 
    }
    sum[color[u]]+=other+1;//+1是因为u本身也是color[u] 
}
int f[maxn],tot;
int main(){
    ll n,t=0;
    while(cin>>n){
        ++t;
        for(int i=1;i<=n;i++)G[i].clear();
        memset(f,0,sizeof f);
        tot=0;
        
        for(int i=1;i<=n;i++){
            scanf("%d",&color[i]);
            f[color[i]]=1;
        }
        for(int i=1;i<=n;i++)tot+=f[i];
        for(int i=1;i<n;i++){
            int u,v;
            scanf("%d%d",&u,&v);
            G[u].push_back(v);
            G[v].push_back(u);
        }        
        
        if(tot==1){
            printf("Case #%d: %lld\n",t,n*(n-1)/2);
            continue;
        }
        
        memset(size,0,sizeof size);
        memset(sum,0,sizeof sum);
        
        ans=0;
        dfs1(1,1);dfs2(1,1);
        for(int i=1;i<=n;i++)
            if(f[i])
                ans+=(n-sum[i])*(n-sum[i]-1)/2;
        ll tmp=(n-1)*n/2*tot;
        printf("Case #%d: %lld\n",t,tmp-ans);
        
    }
}

 

posted on 2019-07-17 13:44  zsben  阅读(340)  评论(0编辑  收藏  举报

导航