树形DP

定义

树形DP(树形动态规划),故名思义:在树结构上进行动态规划的算法。

基础思想

在讨论树形DP前,先了解树和动态规划两个知识点,这里不做讨论。

由于树的性质,树形DP一般是递归进行的。我们一般先处理子树的问题,然后再合并到父节点上,类似于树的后序遍历。所以,通常都是使用 \(dfs\) 的方式来遍历树,递归返回后根据子节点的信息进行更新。

一般代码形式:

void dfs(int u, int fa)
{
    初始化
    dp[u] = ... 
    遍历树
    for (int v : g[u])
    {
        if (v == fa) continue;
        dfs(v, u); 先递归子树
        
        根据子节点,更新信息
    }
}

经典例题

没有上司的舞会

题意

给定一课包含 \(n\) 个节点的树,以及树上 \(n - 1\) 条边。每个节点有一个权值 \(w_i\)。要求当前节点和子节点不能同时选择的情况下,最大权值之和是多少。

思路

分析题目可以发现,当前节点的选择取决于子节点的选择:

  • 如果当前的节点不选择,则其子节点可以选也可以不选,取一个最大权值即可。
  • 如果当前的节点要选择,则其子节点一定不能选,因为两者不能同时选择

由此可以得到状态定义:

  • \(dp[i][0]\) :节点 \(i\) 不选择的状态下最大权值之和
  • \(dp[i][1]\) :节点 \(i\) 选择的状态下最大权值之和

则状态转移为( \(j\)\(i\) 的子节点):

  • \(dp[i][0] = \sum_jmax(dp[j][1], dp[j][0])\)
  • \(dp[i][1] = \sum_jdp[j][0]\)

时间复杂度:\(O(n)\)

代码

#include <iostream>
#include <cstring>

using namespace std;

const int N = 6000 * 2 + 10;

int n;
int h[N], e[N], w[N], ne[N], idx; // 链式前向星
int dp[N][2];
bool st[N];

void add(int a, int b)
{
    e[idx] = b, ne[idx] = h[a], h[a] = idx ++;
}

void dfs(int u)
{
    dp[u][1] = w[u]; // 当前节点要选择,则初始权值为当前节点权值
    for (int i = h[u]; ~i; i = ne[i])
    {
        int v = e[i]; // 子节点
        dfs(v); // 递归子节点
        
        // 状态转移,根据子节点信息更新当前节点的权值
        dp[u][0] += max(dp[v][0], dp[v][1]);
        dp[u][1] += dp[v][0];
    }
}

int main()
{
    memset(h, -1, sizeof h); // 初始化
    
    cin >> n;
    
    for (int i = 1; i <= n; i ++) cin >> w[i]; // 每个节点的权值
    
    for (int i = 1; i < n; i ++)
    {
        int a, b;
        cin >> b >> a;
        add(a, b); // 建边,此题为有向图
        st[b] = 1; // 标记有入度的节点
    }
    
    int root = 1;
    while (st[root]) root ++; // 找根节点
    
    dfs(root); // 树形dp
    
    cout << max(dp[root][0], dp[root][1]); // 最终答案会合并到根节点,取一个最大权值
    
    return 0;
}

树的重心

题意

给定一颗树,树中包含 \(n\) 个结点(编号 \(1∼n\))和 \(n−1\) 条无向边。

请你找到树的重心,并输出将重心删除后,剩余各个连通块中点数的最大值。

重心定义:重心是指树中的一个结点,如果将这个点删除后,剩余各个连通块中点数的最大值最小,那么这个节点被称为树的重心。

思路

根据树的重心的定义,关键在某个节点作为根节点后最大子树最小。我们要从递归的角度出发,先递归当前节点的所有子节点,可以得到子节点为根的子树的大小,那么还有一个父节点向上的子树的大小,可以根据总节点数减去当前子树大小求得,取一个最大值最小的那个节点即为树的重心。

得到状态的定义:

  • \(dp[i]\):以 \(i\) 为根节点的子树的大小

状态转移:

  • \(dp[i] = \sum_jdp[j] + 1\)
  • \(ans = \min(ans, \max_j(dp[j], n - dp[i]))\)

时间复杂度:\(O(n)\)

代码

#include <iostream>
#include <cstring>

using namespace std;

const int N = 2e5 + 10;

int n, ans = 1e9;
int h[N], e[N], ne[N], idx;
int dp[N], c[2];

void add(int a, int b)
{
    e[idx] = b, ne[idx] = h[a], h[a] = idx ++;
}

void dfs(int u, int fa)
{
    dp[u] = 1; // 初始化,当前子树初始大小为1
    int res = 0;
    for (int i = h[u]; ~i; i = ne[i])
    {
        int v = e[i];
        if (v == fa) continue;
        
        dfs(v, u); // 递归子节点
        
        dp[u] += dp[v]; // 加上子节点为根的子树的大小
        res = max(res, dp[v]); // 取子节点为根的子树的大小的最大值
    }
    res = max(res, n - dp[u]); // 对比父节点为根的向上的子树的大小
    ans = min(ans, res); // 最大子树最小
    
    if (res <= n / 2) c[c[0] != 0] = u; // 根据重心定义,取重心的节点编号,附加获取重心的方法
}

int main()
{
    memset(h, -1, sizeof h);
    
    cin >> n;
    
    for (int i = 1; i < n; i ++)
    {
        int a, b;
        cin >> a >> b;
        add(a, b), add(b, a); // 建双向边
    }
    
    dfs(1, -1); // 默认以1为根节点
    
    cout << ans;
    
    return 0;
}

树的直径/树的最长路径

题意

洛谷上的题只考虑边权为1,这里讲第二题包含边权为负的情况。

给定一棵树,树中包含 \(n\) 个结点(编号 \(1∼n\) )和 \(n−1\) 条无向边,每条边都有一个权值。
现在请你找到树中的一条最长路径。

树上最长路径即是树的直径。

思路

假设当前节点是最长路径上的一点,那么最长路径长度一定是当前节点为根时向下延申的最长路径长度加上次长路径(与最长路径无公共边)长度,容易根据反证法证明。

在实现时,我们并不需要让每个节点都作为根节点进行计算,只需要在以 \(1\) 为根节点的树上,取每个节点作为子树根节点时最大路径和次大路径和的最大值即可。

时间复杂度:\(O(n)\)

代码

#include <iostream>
#include <cstring>

using namespace std;

const int N = 2e4 + 10;

int n;
int h[N], w[N], e[N], ne[N], idx;
int ans;

void add(int a, int b, int c)
{
    e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx ++;
}

int dfs(int u, int fa)
{
    int d1 = 0, d2 = 0;
    for (int i = h[u]; ~i; i = ne[i])
    {
        int j = e[i];
        if (j == fa) continue;
        int d = dfs(j, u) + w[i];

        if (d > d1) d2 = d1, d1 = d;
        else if (d > d2) d2 = d;
    }

    ans = max(ans, d1 + d2);
    return d1;
}

int main()
{
    memset(h, -1, sizeof h);
    cin >> n;

    for (int i = 1; i < n; i ++)
    {
        int a, b, c;
        cin >> a >> b >> c;

        add(a, b, c), add(b, a, c);
    }

    dfs(1, -1);

    cout << ans;

    return 0;
}

posted @ 2025-02-26 13:03  Natural_TLP  阅读(47)  评论(0)    收藏  举报