hdu5977 Garden of Eden

都不好意思写题解了
跑了4000多ms
纪念下自己A的第二题
(我还有一道freetour II wa20多发没A。。。呜呜呜

#include<bits/stdc++.h>
using namespace std;
#define sz(X) ((int)X.size())
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
#define index Index
typedef long long ll;
const int N = 5e4+5;
const int INF = 0x3f3f3f3f;
const double pi = acos(-1.0);

int n,k,K;
ll ans;
int ty[N];
struct Node{
    int to,nx;
}E[N<<1];
int head[N], tot, vis[N];
void add(int u,int v) {
    E[tot].to = v; E[tot].nx = head[u]; head[u] = tot++;
}
/***************WeightRoot************/
int all, num, center;
int pp[N], nodes[N]; 
void findRoot(int x,int pre) {
    nodes[x] = 1; pp[x] = 0;
    for(int i = head[x]; ~i; i = E[i].nx) {
        int y = E[i].to; if(y == pre || vis[y]) continue;
        findRoot(y,x);
        nodes[x] += nodes[y];
        pp[x] = max(pp[x], nodes[y]);
    }
    pp[x] = max(pp[x], all-nodes[x]);
    if(pp[x] < num) {
        num = pp[x]; center = x;
    }
}
int getRoot(int root,int sn) {
    num = INF; all = sn; center = root;
    findRoot(root, -1);
    return center;  
}
/****************treecdq**********/
ll has[1050];
ll dp[12][1050];
void getdp(int x, int pre, int num) {
    has[num] ++;
    for(int i = head[x]; ~i; i = E[i].nx) {
        int y = E[i].to; if(y == pre || vis[y]) continue;
        getdp(y,x, num|ty[y]);
    }
}
ll Cal(int x, int chu) {
    ll ret = 0;
    memset(has,0,sizeof(has));
    getdp(x,x,chu|ty[x]);
    for(int i = 0; i <= K; ++i) dp[0][i] = has[i];
    for(int i = 1; i <= k; ++i) {
        for(int j = 0; j <= K; ++j) {
            dp[i][j] = dp[i-1][j];
            if(!(j&(1<<(i-1)))) dp[i][j] += dp[i-1][j^(1<<(i-1))];
        }
    }
    for(int i = 0; i <= K; ++i) ret += has[i]* dp[k][i^K];
    return ret;
}
void work(int x) {
    vis[x] = 1;
    ans += Cal(x,0);
    for(int i = head[x]; ~i; i =  E[i].nx) {
        int y = E[i].to; if(vis[y]) continue;
        ans -= Cal(y,ty[x]);
        work(getRoot(y,nodes[y]));
    }
}

int main(){
    while(~scanf("%d %d",&n,&k)) {
        K = (1<<k)-1; 
        memset(vis,0,sizeof(vis));
        memset(head,-1,sizeof(head)); tot = 0;
        for(int i = 1; i <= n; ++i) {
            int a; scanf("%d",&a); a--;
            ty[i] = 1<<a;
        }
        for(int i = 1; i < n; ++i) {
            int a,b; scanf("%d %d",&a,&b);
            add(a, b); add(b, a);
        }
        ans = 0;
        work(getRoot(1,n));
        printf("%lld\n", ans);
    }
    return 0;
}
posted @ 2016-11-06 21:22  basasuya  阅读(104)  评论(0编辑  收藏  举报