Loading

「学习笔记」DP学习笔记 2

树形DP

树形 DP,即在树上进行的 DP。由于树固有的递归性质,树形 DP 一般都是 递归 进行的。

题目

CF1528A
多组数据 (\(t\) 组)
给你大小为 \(n\) 的一棵树,\(i\) 号节点有权值范围 \([l_i,r_i]\),让你对每个节点赋予一个权值 \(a_i\),使得每个节点权值都在规定的范围里并且对于每条边 \((u,v)\)\(\sum{|a_u-a_v|}\) 最大,并求出这个最大值。
\(2\le n\le 10^5,\sum n\le 2\times 10^5,1\le l_i\le r_i\le10^9\)

对于每一个节点,它的值只可能会是 \(l_i\)\(r_i\) 两种情况。
证明:
对于节点 \(i\),我们假设与它相连的其他点的权值我们都已经确定了,这些点的权值会有 \(\leq a_i\)\(> a_i\) 这两种情况。我们初设 \(a_i = l_i\),随后逐渐 \(+1\),如果权值 \(\leq a_i\) 的点的个数大于权值 \(> a_i\) 的个数,则答案会增加,转化成函数就是一个单调递增函数;如果权值 \(\leq a_i\) 的点的个数小于等于权值 \(> a_i\) 的个数,则答案会减小或不变,但是随着 \(a_i\) 的增加,权值 \(> a_i\) 的点的个数也在逐渐减少,所以函数图像是一个开口向上的单峰函数或单调递减函数。
单调函数和开口向上的单峰函数,极值都在左右两个端点上,所以每个点的取值只有 \(l_i\)\(r_i\) 两种。
状态:\(dp(i, 0/1)\)\(i\) 号节点取最小值或最大值的最大价值。
转移:\(dp(i, 0/1) = \max\{dp(i, 0/1), dp(v, 0/1) + |a_i - a_v| \}\)

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

inline ll read() {
	ll x = 0;
	int fg = 0;
	char ch = getchar();
	while (ch < '0' || ch > '9') {
		fg |= (ch == '-');
		ch = getchar();
	}
	while (ch >= '0' && ch <= '9') {
		x = (x << 3) + (x << 1) + (ch ^ 48);
		ch = getchar();
	}
	return fg ? ~x + 1 : x;
}

const int N = 2e5 + 5;

int T, n;
ll val[N][2], dp[N][2];
vector<int> e[N];

void dfs(int u, int fat) {
	for (int v : e[u]) {
		if (v == fat)	continue;
		dfs(v, u);
		for (int i = 0; i < 2; ++ i) {
			ll maxn = 0;
			for (int j = 0; j < 2; ++ j) {
				maxn = max(maxn, dp[v][j] + abs(val[u][i] - val[v][j]));
			}
			dp[u][i] += maxn;
		}
	}
}

void work() {
	n = read();
	for (int i = 1; i <= n; ++ i) {
		dp[i][1] = dp[i][0] = 0;
		val[i][0] = read(), val[i][1] = read();
		e[i].clear();
	}
	for (int i = 1, x, y; i < n; ++ i) {
		x = read(), y = read();
		e[x].emplace_back(y);
		e[y].emplace_back(x);
	}
	dfs(1, 0);
	printf("%lld\n", max(dp[1][0], dp[1][1]));
}

int main() {
	T = read();
	while (T --) {
		work();
	}
	return 0;
}

P3174 [HAOI2009] 毛毛虫
对于一棵树,我们可以将某条链和与该链相连的边抽出来,看上去就象成一个毛毛虫,点数越多,毛毛虫就越大。求最大的毛毛虫的大小。

状态:\(f_u\):以 \(u\) 为毛毛虫的头,最大的毛毛虫的大小,\(g_u\):以 \(u\) 为毛毛虫的头,次大的毛毛虫的大小
转移:\(f_u = \max(f_u, f_v), g_u = \max\{f_u, f_v\}\)
对于答案的判定,我们要把 \(i\) 点的 \(f_u\)\(g_u\) 加起来,加上 \(1\) (加上自己),同时还要处理与它相连的边,即他的所有孩子的数量(包括父节点)\(- 2\)

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

inline ll read() {
	ll x = 0;
	int fg = 0;
	char ch = getchar();
	while (ch < '0' || ch > '9') {
		fg |= (ch == '-');
		ch = getchar();
	}
	while (ch >= '0' && ch <= '9') {
		x = (x << 3) + (x << 1) + (ch ^ 48);
		ch = getchar();
	}
	return fg ? ~x + 1 : x;
}

const int N = 3e5 + 5;

int n, m, ans;
int f[N], g[N];
vector<int> e[N];

void dfs(int u, int fat) {
	for (int v : e[u]) {
		if (v == fat)	continue;
		dfs(v, u);
		if (f[v] >= f[u]) {
			g[u] = f[u];
			f[u] = f[v];
		}
		else if (f[v] > g[u]) {
			g[u] = f[v];
		}
	}
	int cnt = e[u].size() - (fat != 0);
	ans = max(ans, f[u] + g[u] + 1 + max(0, cnt - 1 - (fat == 0)));
	f[u] += (1 + max(0, cnt - 1));
}

int main() {
	n = read(), m = read();
	for (int i = 1, a, b; i <= m; ++ i) {
		a = read(), b = read();
		e[a].emplace_back(b);
		e[b].emplace_back(a);
	}
	dfs(1, 0);
	printf("%d\n", ans);
	return 0;
}

P2899 [USACO08JAN]Cell Phone Network G
\(n\) 个点的树,你要放置数量最少的信号塔,保证所有点要么有信号塔要么与信号塔相邻。求最少数量。

状态:\(dp(i, 0 / 1 / 2)\) : \(i\) 号节点被自己覆盖/被孩子覆盖/被父亲覆盖
转移:

\[dp(u, 0) = dp(u, 0) + \sum_{v \in son_u} \min\{dp(v, 0), dp(v, 1), dp(v, 2)\}\\ dp(u, 2) = dp(u, 2) + \sum_{v \in son_u} \min\{dp(v, 1), dp(v, 0)\}\\ dp(u, 1) = dp(u, 1) + \sum_{v \in son_u} \min\{dp(v, 1), dp(v, 0)\}\\ dp(u, 1) = dp(u, 1) + dp(id, 1) - dp(id, 0) \quad (当所有的儿子 dp(v, 0) < dp(v, 1) 时,id 为 dp(v, 1) 中最小的那个 v) \]

/*
  The code was written by yifan, and yifan is neutral!!!
 */

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

template<typename T>
inline T read() {
	T x = 0;
	bool fg = 0;
	char ch = getchar();
	while (ch < '0' || ch > '9') {
		fg |= (ch == '-');
		ch = getchar();
	}
	while (ch >= '0' && ch <= '9') {
		x = (x << 3) + (x << 1) + (ch ^ 48);
		ch = getchar();
	}
	return fg ? ~x + 1 : x;
}

const int N = 1e4 + 5;

int n;
int dp[N][3];
vector<int> e[N];

void dfs(int u, int fat) {
	dp[u][0] = 1;
	int p = 0, id = 0, fg = 0;
	for (int v : e[u]) {
		if (v == fat)	continue;
		dfs(v, u);
		fg = 1;
		dp[u][0] += min(dp[v][0], min(dp[v][1], dp[v][2]));
		dp[u][2] += min(dp[v][1], dp[v][0]);
		dp[u][1] += min(dp[v][1], dp[v][0]);
		if (dp[v][0] <= dp[v][1]) {
			p = v;
		}
		if (dp[v][0] < dp[id][0]) {
			id = v;
		}
	}
	if (fg && p == 0) {
		dp[u][1] += (dp[id][0] - dp[id][1]);
	}
	if (!fg) {
		dp[u][1] = 1e9;
	}
}

int main() {
	dp[0][0] = 1e9;
	n = read<int>();
	for (int i = 1, x, y; i < n; ++ i) {
		x = read<int>(), y = read<int>();
		e[x].emplace_back(y);
		e[y].emplace_back(x);
	}
	dfs(1, 0);
	printf("%d\n", min(dp[1][1], dp[1][0]));
	return 0;
}

P2986
\(n\) 个农场形成一棵树,边有长度。第 \(i\) 个农场里有 \(c_i\) 只奶牛。
集会会在一个农场举行,所有的奶牛会沿最近的道路到达集会农场。求所有奶牛路程之和的最小值。
\(1 \leq n \leq 10^5.\)

换根 DP
状态:\(dp_x\) 表示子树的奶牛到 \(x\) 的代价, \(f_x\) 表示所有奶牛到 \(x\) 的代价, \(siz_x\) 子树奶牛数。
转移:第一遍算出 \(dp_x\),很好算。第二遍根据 \(dp_x\) 自上而下算出 \(f_x\)
\(g_v = g_x − siz_v \cdot w(x, v) + (n − siz_v) \cdot w(x, v).\)

/*
  The code was written by yifan, and yifan is neutral!!!
 */

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int, ll> pil;

template<typename T>
inline T read() {
	T x = 0;
	bool fg = 0;
	char ch = getchar();
	while (ch < '0' || ch > '9') {
		fg |= (ch == '-');
		ch = getchar();
	}
	while (ch >= '0' && ch <= '9') {
		x = (x << 3) + (x << 1) + (ch ^ 48);
		ch = getchar();
	}
	return fg ? ~x + 1 : x;
}

const int N = 1e5 + 5;

int n;
ll sum, ans;
int c[N], siz[N];
ll dp[N], f[N];
vector<pil> e[N];

void dfs(int u, int fat) {
	siz[u] = c[u];
	for (pil it : e[u]) {
		int v = it.first;
		ll w = it.second;
		if (v == fat)	continue;
		dfs(v, u);
		siz[u] += siz[v];
		dp[u] = dp[u] + dp[v] + w * siz[v];
	}
}

void Dp(int u, int fat) {
	if (u == 1) {
		f[u] = dp[u];
	}
	bool fg = 1;
	for (pil it : e[u]) {
		int v = it.first;
		ll w = it.second;
		if (v == fat)	continue;
		fg = 0;
		f[v] = 1ll * f[u] - siz[v] * w + (sum - siz[v]) * w;
		ans = min(ans, f[v]);
		Dp(v, u);
	}
	if (fg) {
		f[u] = 1e18;
	}
}

int main() {
	n = read<int>();
	for (int i = 1; i <= n; ++ i) {
		c[i] = read<int>();
		sum += c[i];
	}
	for (int i = 1, x, y; i < n; ++ i) {
		x = read<int>(), y = read<int>();
		ll z = read<ll>();
		e[x].emplace_back(y, z);
		e[y].emplace_back(x, z);
	}
	dfs(1, 0);
	ans = dp[1];
	Dp(1, 0);
	printf("%lld\n", ans);
	return 0;
}
posted @ 2023-06-27 08:01  yi_fan0305  阅读(23)  评论(0编辑  收藏  举报