luoguP2664 树上游戏

https://www.luogu.org/problemnew/show/P2664

考虑对于每种颜色包含的点和这些点的子节点建出虚树,发现只要将一个联通块中的东西 Dp + 差分一下就行了

当然要考虑哪些东西要被加进去

如果把不是一个颜色的联通块放在一起加,里面就要算上 n - 联通块大小的贡献(画个图就行了

然后输出的时候每个点的贡献要 + n (因为自己对任何一个点的连边肯定包含自己这种颜色

博主差分的时候写挂了导致要 #define int long long,而且常数巨大

#include <bits/stdc++.h>
#define int long long
using namespace std;

const int N = 1e5 + 5, LG = 17;

vector <int> col[N], G[N], mdf[N], G2[N];
int pre[N][LG + 1], dep[N], sta[N], a[N], s[N], id[N], siz[N], book[N], f[N], h[N * 2];
int n, len, dfn, maxn, k;

void init(int u, int fa) {
    pre[u][0] = fa; dep[u] = dep[fa] + 1; id[u] = ++dfn; siz[u] = 1;
    for(int i = 1; i <= LG; i++) pre[u][i] = pre[pre[u][i - 1]][i - 1];
    for(vector <int> :: iterator it = G[u].begin(); it != G[u].end(); it++)
        if(*it != fa) init(*it, u), siz[u] += siz[*it];
}

int jump(int x, int k) {
    for(int i = LG; i >= 0; i--)
        if(k & (1 << i))
            x = pre[x][i];
    return x;
}

int LCA(int x, int y) {
    if(dep[x] > dep[y]) swap(x, y);
    y = jump(y, dep[y] - dep[x]);
    if(x == y) return x;
    for(int i = LG; i >= 0; i--)
        if(pre[x][i] != pre[y][i])
            x = pre[x][i], y = pre[y][i];
    return pre[x][0];
}

bool cmp(int x, int y) {return id[x] < id[y];}

void dfs1(int u, int top) {
    if(book[u] == 0 && G[u].size() == 0) {
        f[u] = siz[u]; return;
    }
    if(book[u] == 1) {
        f[u] = 0;
    	for(vector <int> :: iterator it = G[u].begin(); it != G[u].end(); it++) {
    		int len = (dep[*it] - dep[u] - 1);
    		int son = jump(*it, len);
    		if(book[*it] == 1) {
    			dfs1(*it, top);
    			int sz = siz[son] - siz[*it];
    			s[son] += (n - sz); s[*it] -= (n - sz);
            } else {
                mdf[son].clear(); dfs1(*it, son);
                int sz = siz[son] - siz[*it];
                int allsz = sz + f[*it];
                s[son] += (n - allsz);
                for(vector <int> :: iterator itt = mdf[son].begin(); itt != mdf[son].end(); itt++) s[*itt] -= (n - allsz);
            }
        }
    } else {
        f[u] = siz[u];
        for(vector <int> :: iterator it = G[u].begin(); it != G[u].end(); it++) {
            int len = (dep[*it] - dep[u] - 1);
            int son = jump(*it, len);
            dfs1(*it, top);
            if(book[*it] == 1) {
                mdf[top].push_back(*it);
                f[u] -= siz[*it];
            } else {
                f[u] -= (siz[*it] - f[*it]);
            }
        }
    }
}

void dfs2(int u, int fa) {
    s[u] += s[fa];
    for(vector <int> :: iterator it = G2[u].begin(); it != G2[u].end(); it++) {
        if(*it != fa) dfs2(*it, u);
    }
}

signed main() {
    cin >> n;
    for(int i = 1; i <= n; i++) {
        scanf("%lld", &a[i]);
        col[a[i]].push_back(i);
        maxn = max(maxn, a[i]);
    }
    for(int i = 1; i < n; i++) {
        int a, b;
        scanf("%lld %lld", &a, &b);
        G[a].push_back(b);
        G[b].push_back(a);
        G2[a].push_back(b);
        G2[b].push_back(a);
    }
    init(1, 0);
    G[n + 1].push_back(1); book[n + 1] = 1; 
    for(int i = 1; i <= maxn; i++) {
        k = col[i].size();
        for(int j = 0; j < k; j++) h[j + 1] = col[i][j], book[col[i][j]] = 1;
        int tmp = k; bool have = 0;
        for(int j = 1; j <= tmp; j++) {
        	int u = h[j];
        	if(u == 1) have = 1;
        	for(vector <int> :: iterator it = G2[u].begin(); it != G2[u].end(); it++) {
        		if(*it != pre[u][0]) h[++k] = *it;
            }
        }
        if(!have) {
            for(vector <int> :: iterator it = G2[1].begin(); it != G2[1].end(); it++) {
                h[++k] = *it;
            }
        }
        sort(h + 1, h + k + 1, cmp);
        k = unique(h + 1, h + k + 1) - h - 1;
        sort(h + 1, h + k + 1, cmp);
        sta[len = 1] = 1; G[1].clear();
        for(int j = 1; j <= k; j++) {
            if(h[j] == 1) continue;
            int lca = LCA(h[j], sta[len]);
            if(lca != sta[len]) {
                while(id[lca] < id[sta[len - 1]]) {
                    G[sta[len - 1]].push_back(sta[len]);
                    len--;
                }
                if(id[lca] > id[sta[len - 1]]) {
                    G[lca].clear();
                    G[lca].push_back(sta[len]);
                    sta[len] = lca;
                } else G[lca].push_back(sta[len]), len--;
            }
            G[h[j]].clear(); sta[++len] = h[j];
        }
        for(int j = 1; j < len; j++) G[sta[j]].push_back(sta[j + 1]);
        dfs1(n + 1, 0); for(int j = 0; j < tmp; j++) book[col[i][j]] = 0;
    }
    dfs2(1, 0);
    for(int i = 1; i <= n; i++) printf("%lld\n", s[i] + n);
    return 0;
}
posted @ 2018-09-12 17:39  LJC00118  阅读(179)  评论(1编辑  收藏  举报
/*
*/