【HDU6035】 Colorful Tree

题目的意思是:给定一个点带颜色的树,两点之间的距离定义为路径上不同颜色的个数。求所有点对间的距离和。

做法有点分治,还有传说中的虚树DP,树上差分。

点分治法:

  考虑每个点的贡献,可以发现一个点的子树大小就是这个点的贡献。那么,对于同一个根的另一个子树的一个点x,去掉x到根结点对应颜色的贡献,再加上x到根结点上的颜色的种类数目,就是这个x点的答案。我们具体做的时候,是先不考虑根结点的,根结点对x点的贡献单独算。

  

#include <algorithm>
#include  <iterator>
#include  <iostream>
#include   <cstring>
#include   <cstdlib>
#include   <iomanip>
#include    <bitset>
#include    <cctype>
#include    <cstdio>
#include    <string>
#include    <vector>
#include     <stack>
#include     <cmath>
#include     <queue>
#include      <list>
#include       <map>
#include       <set>
#include   <cassert>

/*
 
 &#8834;_ヽ
   \\ Λ_Λ  来了老弟
    \('&#12613;')
     > ⌒ヽ
    /   へ\
    /  / \\
    &#65434; ノ   ヽ_つ
   / /
   / /|
  ( (ヽ
  | |、\
  | 丿 \ ⌒)
  | |  ) /
 'ノ )  L&#65417;
 
 */

using namespace std;
#define lson (l , mid , rt << 1)
#define rson (mid + 1 , r , rt << 1 | 1)
#define debug(x) cerr << #x << " = " << x << "\n";
#define pb push_back
#define pq priority_queue

typedef long long ll;
typedef unsigned long long ull;
//typedef __int128 bll;
typedef pair<ll ,ll > pll;
typedef pair<int ,int > pii;
typedef pair<int,pii> p3;

//priority_queue<int> q;//这是一个大根堆q
//priority_queue<int,vector<int>,greater<int> >q;//这是一个小根堆q
#define fi first
#define se second
//#define endl '\n'

#define boost ios::sync_with_stdio(false);cin.tie(0)
#define rep(a, b, c) for(int a = (b); a <= (c); ++ a)
#define max3(a,b,c) max(max(a,b), c);
#define min3(a,b,c) min(min(a,b), c);

const ll oo = 1ll<<17;
const ll mos = 0x7FFFFFFF;  //2147483647
const ll nmos = 0x80000000;  //-2147483648
const int inf = 0x3f3f3f3f;
const ll inff = 0x3f3f3f3f3f3f3f3f; //18
const int mod = 1e9+7;
const double esp = 1e-8;
const double PI=acos(-1.0);
const double PHI=0.61803399;    //黄金分割点
const double tPHI=0.38196601;

template<typename T>
inline T read(T&x){
    x=0;int f=0;char ch=getchar();
    while (ch<'0'||ch>'9') f|=(ch=='-'),ch=getchar();
    while (ch>='0'&&ch<='9') x=x*10+ch-'0',ch=getchar();
    return x=f?-x:x;
}

inline void cmax(int &x,int y){if(x<y)x=y;}
inline void cmax(ll &x,ll y){if(x<y)x=y;}
inline void cmin(int &x,int y){if(x>y)x=y;}
inline void cmin(ll &x,ll y){if(x>y)x=y;}

/*-----------------------showtime----------------------*/
const int maxn = 2e5+9;
int col[maxn];
vector<int>mp[maxn];
ll ans = 0, sumcol = 0;
int sz[maxn],wt[maxn], root, curn;
int vis[maxn];
void findRoot(int u, int fa) {
    sz[u] = 1;wt[u] = 0;
    for(int i=0; i<mp[u].size(); i++) {
        int v = mp[u][i];
        if(v == fa || vis[v]) continue;
        findRoot(v, u);
        sz[u] += sz[v];
        wt[u] = max(sz[v], wt[u]);
    }
    wt[u] = max(wt[u], curn - sz[u]);
    if(wt[u] <= wt[root]) root = u;
}
// map<int, int> pp;
ll pp[maxn];
int youmeiyou[maxn];
int ss;
void gao(int u, int fa, vector<pii>& vv, int cnt, ll sumfa, ll sum) {
    ll res = 0;
    if(youmeiyou[col[u]] == 0)
        vv.pb(pii(col[u], sz[u])), cnt++, res += pp[col[u]];
    
    youmeiyou[col[u]]++;
    ans += sumcol - sumfa - res + 1ll * cnt * sum;
    //sum-color[根的颜色]+size[root]
    if(youmeiyou[col[ss]] == 0) ans += sum - pp[col[ss]];
    for(int i=0; i<mp[u].size(); i++) {
        int v = mp[u][i];
        if(fa == v || vis[v]) continue;
        gao(v, u, vv, cnt, sumfa + res, sum);
    }
    youmeiyou[col[u]] --;
}

void solve(int u) {
    vis[u] = 1;
    findRoot(u, -1);
    ll sum = 1;
    sumcol = 0;
    queue<int>needclear;
    needclear.push(col[u]);
    for(int i=0; i<mp[u].size(); i++) {
        int v = mp[u][i];
        if(vis[v]) continue;
        vector<pii>vv;
        ss = u;
        gao(v, -1, vv, 0, 0, sum);
        
        for(int j=0; j<vv.size(); j++){
            int c = vv[j].fi;
            if(pp[c])pp[c] += vv[j].se;
            else {
                pp[c] = vv[j].se;
                needclear.push(c);
            }
            sumcol += vv[j].se;
        }
        sum += sz[v];
    }
    
    while(!needclear.empty()) {
        pp[needclear.front()] = 0;
        needclear.pop();
    }
    for(int i=0; i<mp[u].size(); i++) {
        int v = mp[u][i];
        if(!vis[v]) {
            root = 0;   wt[0] = inf; curn = sz[v];
            findRoot(v, -1);
            solve(root);
        }
    }
}
int main(){
    int n, cas = 0;
    while(~scanf("%d", &n)) {
        memset(vis, 0, sizeof(vis));
        for(int i=1; i<=n; i++) scanf("%d", &col[i]);
        for(int i=1; i<=n; i++) mp[i].clear();
        for(int i=1; i<n; i++) {
            int u,v;
            scanf("%d%d", &u, &v);
            mp[u].pb(v);
            mp[v].pb(u);
        }
        
        ans = 0;
        root = 0; wt[0] = inf;
        curn = n;
        findRoot(1, -1);
        solve(root);
        printf("Case #%d: %lld\n", ++cas, ans);
    }
    return 0;
}
/*
 6
 1 2 3 1 2 3
 1 2
 1 3
 3 4
 3 5
 4 6
 */
View Code

 虚树 + 树上差分法:

  对于一种颜色,可以把树分割成许多联通块,同一个联通块内,这种颜色不会产生影响,所以某个点上,某个颜色的影响就是n - size,size是包含这个点的联通块的大小。

  由于有多种颜色,我们可以对每种颜色构建对应的虚树,选择这种颜色的点和这些点的直接儿子作为关键点。类似树上差分的思想,先把答案保存在每个联通块最上面的点。

  

#include <bits/stdc++.h>

using namespace std;
#define pb push_back
#define fi first
#define se second
#define debug(x) cerr<<#x << " := " << x << endl;
#define bug cerr<<"-----------------------"<<endl;
#define FOR(a, b, c) for(int a = b; a <= c; ++ a)

typedef long long ll;
typedef long double ld;
typedef pair<int, int> pii;
typedef pair<ll, ll> pll;
typedef pair<pii, int>PII;

template<class T> void _R(T &x) { cin >> x; }
void _R(int &x) { scanf("%d", &x); }
void _R(ll &x) { scanf("%lld", &x); }
void _R(double &x) { scanf("%lf", &x); }
void _R(char &x) { scanf(" %c", &x); }
void _R(char *x) { scanf("%s", x); }
void R() {}
template<class T, class... U> void R(T &head, U &... tail) { _R(head); R(tail...); }


template<typename T>
inline T read(T&x){
    x=0;int f=0;char ch=getchar();
    while (ch<'0'||ch>'9') f|=(ch=='-'),ch=getchar();
    while (ch>='0'&&ch<='9') x=x*10+ch-'0',ch=getchar();
    return x=f?-x:x;
}

const ll inf = 0x3f3f3f3f3f3f3f3f;

const int mod = 1e9+7;

/**********showtime************/
            const int maxn = 2e5+9;
            int col[maxn],vis[maxn];
            vector<int>mp[maxn],xu_mp[maxn];
            vector<int>node[maxn],xu;
            int sz[maxn], dfn[maxn], dp[maxn], tim;
            int fa[maxn][20];
            ll  fen[maxn],ans;

            void dfs(int u, int o) {
                sz[u] = 1;  dfn[u] = ++tim;
                fa[u][0] = o;
                dp[u] = dp[o]  + 1;
                for(int i=1; i<20; i++)
                    fa[u][i] = fa[fa[u][i-1]][i-1];
                for(int v : mp[u]) {
                    if(v == o) continue;
                    dfs(v, u);
                    sz[u] += sz[v];
                }
            }

            int lca(int u, int v) {
                if(dp[u] < dp[v]) swap(u, v);

                for(int i=19; i>=0; i--) {
                    if(dp[fa[u][i]] >= dp[v])
                        u = fa[u][i];
                }
                if(u == v) return u;

                for(int i=19; i>=0; i--) {
                    if(fa[u][i] != fa[v][i])
                        u = fa[u][i], v = fa[v][i];
                }
                return fa[u][0];
            }

            bool cmp(int x, int y) {
                return dfn[x] < dfn[y];
            }
            int used[maxn];
            int nsz[maxn];
            int curcol;
            int n;
            int cdp[maxn];
            //求虚树上每个联通块的大小
            void gaoNewSz(int u, int o) {
                ll s = 0;
                cdp[u] = 0;
                for(int v : xu_mp[u]) {
                    if(v == o) continue;
                    gaoNewSz(v, u);
                    if(col[v] == curcol)
                        cdp[u] += sz[v];
                    else cdp[u] += cdp[v];
                }
                nsz[u] = n - (sz[u] - cdp[u]);
            }
            //建立树上的差分
            void gaoSub(int u, int fa, int val) {
                int w = val;
                if(col[u] == curcol) {
                    fen[u] -= val;
                }
                else if(col[fa] == curcol || u == 1)
                {
                    fen[u] += nsz[u];
                    w = nsz[u];
                }

                for(int v : xu_mp[u]) {
                    if(v == fa) continue;
                    if(col[u] == curcol)gaoSub(v, u, 0);
                    else gaoSub(v, u, w);
                }
            }

            //建立虚树
            void build(vector <int> & xu) {
                sort(xu.begin(), xu.end(), cmp);
                stack<int>st;
                queue<int>que;

                for(int i=0; i<xu.size(); i++) {
                    int u = xu[i];
                    if(st.size() <= 1) st.push(u);
                    else {
                        int x = st.top(); st.pop();
                        int o = lca(x, u);
                        if(o == x) {
                            st.push(x);
                            st.push(u);
                            continue;
                        }
                        while(!st.empty()) {
                            int y = st.top(); st.pop();

                            if(dfn[y] > dfn[o]) {
                                xu_mp[y].pb(x);
                                if(used[y] == 0) used[y] = 1, que.push(y);
                                x = y;
                            }
                            else if(dfn[y] == dfn[o]) {
                                xu_mp[y].pb(x);
                                st.push(y);
                                if(used[y] == 0) used[y] = 1, que.push(y);
                                break;
                            }
                            else {
                                xu_mp[o].pb(x);
                                st.push(y);
                                st.push(o);
                                if(used[o] == 0) used[o] = 1, que.push(o);
                                break;
                            }
                        }
                        st.push(u);
                    }
                }
                while(st.size() > 1) {
                    int u = st.top(); st.pop();
                    int v = st.top();
                    xu_mp[v].pb(u);
                 //xu_mp[u].pb(v);
                 //   if(used[u] == 0) used[u] = 1, que.push(u);
                    if(used[v] == 0) used[v] = 1, que.push(v);
                }
                while(!st.empty())st.pop();

                gaoNewSz(1, 1);
                gaoSub(1, 1, 0);

                while(!que.empty()) {
                    int u = que.front();
                    xu_mp[u].clear();
                    used[u] = 0;
                    que.pop();
                }
            }

            //树上差分,最后的更新
            void pushdown(int u, int fa, ll val) {
                ans  += fen[u] + val + n;
                val += fen[u];
                for(int v : mp[u]) {
                    if(v == fa) continue;
                    pushdown(v, u, val);
                }
            }

int main(){
            int cas = 0;
            while(~scanf("%d", &n)){
                ans = 0;tim = 0;
                for(int i=1; i<=n; i++){
                    mp[i].clear();
                    fen[i] = 0;
                    vis[i] = 0;
                    dp[i] = 0;
                    node[i].clear();
                }
                for(int i=1; i<=n; i++) {
                    read(col[i]);
                    vis[col[i]] = 1;
                    node[col[i]].pb(i);
                }
                for(int i=1; i<n; i++) {
                    int u,v;
                    read(u); read(v);
                    mp[u].pb(v);
                    mp[v].pb(u);
                }

                dfs(1, 1);

                for(int i=1; i<maxn; i++) {
                    if(vis[i]) {
                        xu.clear();
                        if(col[1] != i) xu.pb(1);
                        for(int v : node[i]) {
                            xu.pb(v);
                            for(int k : mp[v]) {
                                if(col[k] != i && dp[k] > dp[v])
                                    xu.pb(k);
                            }
                        }
                        curcol = i;
                        build(xu);
                    }
                }
                pushdown(1, 1, 0);
                printf("Case #%d: %lld\n", ++cas, (ans - n )/ 2);
            }
            return 0;
}
View Code

 

附上虚树建立的网上流行模板

void insert(int x) {
                if(top == 1) {s[++top] = x; return ;}
                int lca = LCA(x, s[top]);
                if(lca == s[top]){ s[++top] = x;return ;}
                while(top > 1 && dfn[s[top - 1]] >= dfn[lca]) add_edge(s[top - 1], s[top]), top--;
                if(lca != s[top]) add_edge(lca, s[top]), s[top] = lca;//
                s[++top] = x;
            }

 

posted @ 2019-07-03 22:08  ckxkexing  阅读(396)  评论(0编辑  收藏  举报