签到题 [换根dp]

签到题


\color{red}{正解部分}

这道题 每个子树看成一个子问题, 求出每个子树的答案, 然后往上合并得到总答案 .

设当前节点有 22 个子树, 权值和节点数量 分别是 sum1,size1,sum2,size2sum_1, size_1, sum_2, size_2,子树内的答案为 ans1,ans2ans_1, ans_2

则先往 11 儿子走对答案的贡献为:
ans1+sum1+ans2+(size1+1)×sum2ans_1+ sum_1+ ans_2 + (size_1+1)\times sum_2
22 儿子走对答案的贡献为:
ans2+sum2+ans1+(size2+1)×sum1ans_2+ sum_2+ ans_1 + (size_2+1)\times sum_1,

当 走11儿子 比 走22儿子 更优时,

ans1+sum1+ans2+(size1+1)×sum2<ans2+sum2+ans1+(size2+1)×sum1ans_1+ sum_1+ ans_2 + (size_1+1)\times sum_2 < ans_2+ sum_2+ ans_1 + (size_2+1)\times sum_1

化简得 size1×sum2<size2×sum1size_1 \times sum_2 < size_2\times sum_1 .

所以以 sizex×sumysize_x \times sum_y 从小到大排序后, 从小到大按顺序 dfsdfs 即可实现答案最优 .


现在已经解决了当根固定时的答案, 考虑如何计算 所有节点作为根的 最优值,

可以想到 先求出以 11 为根 的答案, 然后进行 换根,


现在已经计算出了 ansxans_x, 且要将 根的位置xyx \rightarrow y, 要求 yy 为根的答案,
首先观察 树的信息 哪里发生了变化,

  1. yy为根 的子树 从 xx 的子树中移除掉了, sizex=sizey,sumx=sumysize_x -=size_y,sum_x-=sum_y
  2. xx为根 的子树 成为了 yy 的新子树, sizey+=sizex,sumy+=sumxsize_y += size_x, sum_y += sum_x .

ansxans_x 的影响为 ansx=ansy+sizey×sumy+sizey×sumyans_x -= ans_y + size_{y前子树}\times sum_y + size_y \times sum_{y后子树},
其中 ansyans_y 在往下递归的时候使用子树信息计算即可 .


\color{red}{实现部分}

#include<bits/stdc++.h>
#define reg register
#define pb push_back
typedef long long ll;

int read(){
        char c;
        int s = 0, flag = 1;
        while((c=getchar()) && !isdigit(c))
                if(c == '-'){ flag = -1, c = getchar(); break ; }
        while(isdigit(c)) s = s*10 + c-'0', c = getchar();
        return s * flag;
}

const int maxn = 200005;

int N;
int num0;
int A[maxn];
int size[maxn];
int head[maxn];

ll tot;
ll Ans;
ll sum[maxn];
ll ans[maxn];

struct Edge{ int nxt, to; } edge[maxn << 1];

void Add(int from, int to){
        edge[++ num0] = (Edge){ head[from], to };
        head[from] = num0;
}

bool cmp(int a, int b){ return size[a]*sum[b] < size[b]*sum[a]; }

void DFS_1(int k, int fa){
        std::vector <int> B;
        sum[k] = A[k], size[k] = 1;
        for(reg int i = head[k]; i; i = edge[i].nxt){ 
                int to = edge[i].to; 
                if(to == fa) continue ; B.pb(to); 
                DFS_1(to, k);
                sum[k] += sum[to], size[k] += size[to];
        }
        std::sort(B.begin(), B.end(), cmp); 
        ans[k] = A[k]; ll last = 1;
        for(reg int i = 0; i < B.size(); i ++){
                int to = B[i]; 
                ans[k] += ans[to] + last * sum[to], last += size[to];
        }
}

void DFS_2(int k, int fa){
        std::vector <int> B;
        for(reg int i = head[k]; i; i = edge[i].nxt) B.pb(edge[i].to);
        std::sort(B.begin(), B.end(), cmp);
        ans[k] = A[k]; ll last = 1;
        for(reg int i = 0; i < B.size(); i ++){
                int to = B[i];
                ans[k] += ans[to] + last * sum[to];
                last += size[to];
        }
        Ans = std::min(Ans, ans[k]);
        last = 1; ll suf = tot - A[k];
        for(reg int i = 0; i < B.size(); i ++){
                int to = B[i]; suf -= sum[to];
                if(to != fa){
                        ll t1 = ans[k], t2 = ans[to];
                        ans[k] -= ans[to] + last*sum[to] + size[to]*suf;
                        size[k] -= size[to], sum[k] -= sum[to];
                        size[to] += size[k], sum[to] += sum[k];
                        DFS_2(to, k);
                        size[to] -= size[k], sum[to] -= sum[k];
                        size[k] += size[to], sum[k] += sum[to];
                        ans[k] = t1, ans[to] = t2;
                }
                last += size[to];
        }
}

int main(){
        N = read();
        for(reg int i = 1; i < N; i ++){ int u = read(), v = read(); Add(u, v), Add(v, u); }
        for(reg int i = 1; i <= N; i ++) A[i] = read(), tot += A[i];
        DFS_1(1, 1); 
        Ans = ans[1]; DFS_2(1, 1);
        printf("%lld\n", Ans);
        return 0;
}
posted @ 2019-09-21 21:37  XXX_Zbr  阅读(167)  评论(0编辑  收藏  举报