学习笔记:换根 DP

换根 DP

引入

换根 DP 其实是一种树形 DP 的拓展与延伸。对于这样一类题可以使用换根 DP 来做:

  1. 给定一棵树,需要求出某一个符合条件的最优的点。
  2. 对于每一个点,都存在一个答案。
  3. 点与点之间可以通过某种方式进行转移。

对于这样的问题,通常来讲可以很快地想到一个暴力做法:对于每一个点都执行一次深度优先遍历,算出每一个点的答案,最后从中找出最优的。

然而,这样的做法时间复杂度达到了 $O(n^2)$,在一些数据范围较大的题目上妥妥 TLE。

这个时候便可以考虑用换根 DP 来做。首先在树上钦定任意一个节点为根节点,并以这个根节点为基准执行一次深度优先遍历。再依据转移方程执行一次深度优先遍历求出每一个点的答案。一共执行了两次深度优先遍历,总的时间复杂度由 $O(n^2)$ 降到了 $O(n)$。

例题

[POI2008] STA-Station

给定一个 $n$ 个点的树,请求出一个结点,使得以这个结点为根时,所有结点的深度之和最大。

一个结点的深度之定义为该节点到根的简单路径上边的数量。

思路

看到“求出一个节点”、“使得以这个节点为根时”之类的字样就应该理所当然地想到用换根 DP 来做。首先以节点 $1$ 为根执行一次深度优先遍历,后续主要问题在于如何考虑相邻节点之间的状态转移方程。

对于下图这样一种情况,不难发现,当从节点 $1$ 转移到节点 $2$ 时:

  1. 以节点 $2$ 为根且不包含节点 $1$ 的子树上的所有点的贡献都减少 $1$。
  2. 以节点 $1$ 为根且不包含节点 $2$ 的子树上的所有点的贡献都增加 $1$。

当从节点 $x$ 转移到节点 $y$ 时,则有:$$ f_y=f_x-s_x+(n-s_x) $$ 其中 $s$ 表示以节点 $1$ 为根时该点子树的大小,$n$ 为节点个数,将原式变形可得:$$ f_y=f_x-2\times s_x+n $$ 依据这个状态转移方程再执行一次深度优先遍历即可。

代码

#include <iostream>
#define int long long
#define MAXN 1000005
using namespace std;
int n, u, v;
struct edge{int to, nxt;}e[MAXN << 1];
int head[MAXN], cnt;
int siz[MAXN], dep[MAXN], f[MAXN];
int ans, ret;
int read(){
    int t = 1, x = 0;char ch = getchar();
    while(!isdigit(ch)){if(ch == '-')t = -1;ch = getchar();}
    while(isdigit(ch)){x = (x << 1) + (x << 3) + (ch ^ 48);ch = getchar();}
    return x * t;
}
void add(int u, int v){
    cnt++;e[cnt].to = v;e[cnt].nxt = head[u];head[u] = cnt;
    cnt++;e[cnt].to = u;e[cnt].nxt = head[v];head[v] = cnt;
}
void dfs1(int now, int fat){
    siz[now] = 1;dep[now] = dep[fat] + 1;
    for(int i = head[now] ; i != 0 ; i = e[i].nxt){
        if(e[i].to != fat){
            dfs1(e[i].to, now);
            siz[now] += siz[e[i].to];
        }
    }
}
void dfs2(int now, int fat){
    for(int i = head[now] ; i != 0 ; i = e[i].nxt){
        if(e[i].to != fat){
            f[e[i].to] = f[now] + n - siz[e[i].to] - siz[e[i].to];
            dfs2(e[i].to, now);
        }
    }
}
signed main(){
    n = read();
    for(int i = 1 ; i < n ; i ++){
        u = read();v = read();add(u, v);
    }
    dfs1(1, 0);
    for(int i = 1 ; i <= n ; i ++)f[1] += dep[i];
    dfs2(1, 0);
    for(int i = 1 ; i <= n ; i ++){
        if(ans < f[i]){
            ans = f[i];ret = i;
        }
    }
    cout << ret << endl;return 0;
}
posted @ 2023-09-20 08:48  tsqtsqtsq  阅读(34)  评论(0)    收藏  举报  来源