树形DP入门及两道简单题( hdu1520 hdu2196 )
树形dp
树形DP是指在“树”这种数据结构上进行的DP,一般来说题目会暗示你去求一个最大值或最小值(比如最小代价,最大收益之类的)。而且一般来讲这种问题的规模比较大,没办法枚举,贪心也不能得到最优解,所以要用到动规。
而且,树实在是太适合做动规了......因为树本身就具有“子结构”的性质(子树),所以在写状态转移方程的时候比线性的dp更加直观。(但是更难写
一般来说都要用到dfs。(大概
来几个入门题帮助理解下
hdu1520——Anniversary Party
题意大概是给定一棵树,树上的每个节点都有对应的权值,要求不能同时选一个节点和他的父亲节点,求可以取到的最大权值。
显然每个节点对应的状态只有选和不选两种,所以我们可以定义两个状态:
- dp[i][0]为不选当前节点时的最优解
- dp[i][1]为选择当前节点时的最优解
同时有两个状态转移方程:
- 不选择当前节点,子节点可选可不选 : dp[i][0]+=max(dp[son][1],dp[son][0])
- 选择当前节点,子节点不能选: dp[i][1]+=dp[son][0]
本题可以使用stl的vector建立关系树。先找到一个根节点,然后向下dfs,在回溯时进行dp。下面是ac代码:
#include <bits/stdc++.h> using namespace std; const int N = 6005; int val[N], dp[N][2], fa[N], n; vector<int> tree[N]; void dfs(int u) { dp[u][0] = 0; dp[u][1] = val[u]; for (int i = 0; i < tree[u].size(); i++) { int son = tree[u][i]; dfs(son); dp[u][0] += max(dp[son][1], dp[son][0]); dp[u][1] += dp[son][0]; } } int main() { while (~scanf("%d", &n)) { for (int i = 1; i <= n; i++) { scanf("%d", &val[i]); tree[i].clear(); fa[i] = -1; } while (1) { int a, b; scanf("%d%d", &a, &b); if (a == 0 && b == 0) break; fa[a] = b; tree[b].push_back(a); } int t = 1; while (fa[t] != -1) t = fa[t]; dfs(t); printf("%d\n", max(dp[t][1], dp[t][0])); } return 0; }
hdu2196——computer
还是先说一下题意:一颗有根树,根节点的编号是1,对其中一个任意节点,求离他最远的节点距离。数据范围N<10000。
显然,如果对每个结点进行一次bfs,复杂度是O(n2),必然会tle。这时我们注意到题意要求求最大值,由此想到可以尝试一下dp。
先来分析一下状态,对于一个结点,距离它最远的距离有两种情况:
- 最远距离在以该结点为根的子树上,直接dfs整棵树就可以。
- 最远距离在除了该子树的部分到该结点的路径上,也就是该结点父节点的除了这个子树的最远距离(实际上我们要找出最远和次远距离)+根与该结点的距离。
然后来列出状态转移方程:dp[i][最远]=max(dp[i][最远],dp[i][向上最远]) , dp[i][向上最远]=max(dp[fa][向上最远],dp[fa][向下最远])+dis(i,fa).
具体操作要用到两遍dfs,第一遍用来获取向下的最远,第二遍时dp。下面直接上代码:
#include <bits/stdc++.h> using namespace std; const int N = 10010; struct node { int id, cost; }; vector<node> tree[N]; int dp[N][3]; int n; void dfs1(int fa) { int one = 0, two = 0; for (int i = 0; i < tree[fa].size(); i++) { node son = tree[fa][i]; dfs1(son.id); int cost = dp[son.id][0] + son.cost; if (cost >= one) { two = one; one = cost; } if (cost < one && cost > two) { two = cost; } } dp[fa][0] = one; dp[fa][1] = two; } void dfs2(int fa) { for (int i = 0; i < tree[fa].size(); i++) { node son = tree[fa][i]; if (dp[son.id][0] + son.cost == dp[fa][0]) //如果son在最长距离子树上 dp[son.id][2] = max(dp[fa][1], dp[fa][2]) + son.cost; else dp[son.id][2] = max(dp[fa][0], dp[fa][2]) + son.cost; dfs2(son.id); } } int main() { while (~scanf("%d", &n)) { for (int i = 1; i <= n; i++) tree[i].clear(); memset(dp, 0, sizeof(dp)); for (int i = 2; i <= n; i++) { int x, y; scanf("%d%d", &x, &y); node tmp; tmp.cost = y; tmp.id = i; tree[x].push_back(tmp); } dfs1(1); dp[1][2] = 0; dfs2(1); for (int i = 1; i <= n; i++) printf("%d\n", max(dp[i][0], dp[i][2])); } //system("pause"); return 0; }

浙公网安备 33010602011771号