树形dp

下边是从天涯空间里找出来的练习(转自notonlysuccess)


http://acm.pku.edu.cn/JudgeOnline/problem?id=3345
http://acm.zju.edu.cn/onlinejudge/showProblem.do?problemCode=3201  
http://acm.pku.edu.cn/JudgeOnline/problem?id=3107   √
http://acm.pku.edu.cn/JudgeOnline/problem?id=1655  √
http://acm.pku.edu.cn/JudgeOnline/problem?id=2378  √
http://acm.pku.edu.cn/JudgeOnline/problem?id=3140  √
http://acm.hdu.edu.cn/showproblem.php?pid=2242
http://acm.timus.ru/problem.aspx?space=1&num=1018 
http://acm.pku.edu.cn/JudgeOnline/problem?id=1947  
http://acm.pku.edu.cn/JudgeOnline/problem?id=2057
http://acm.pku.edu.cn/JudgeOnline/problem?id=2486  
http://acm.pku.edu.cn/JudgeOnline/problem?id=1848
http://acm.pku.edu.cn/JudgeOnline/problem?id=2152

zoj 3201 Tree of Tree

话说vector在用之前一点要clear啊亲!!sgementation fault了N次!!T_T

思路:就是普通的树形dp,个人不太习惯把树转成二叉树的做法。转移方程是:

f[r][j + k ] = max(f[r][j + k], f[r][j] + f[som[r]][k]);

实现也肯容易,f 初始化成0即可。ps:注意是无向图

MY CODE:

View Code
 1 #include <iostream>
2 #include <cstdio>
3 #include <cstring>
4 #include <vector>
5
6 using namespace std;
7
8 const int N = 110;
9
10 vector<int> g[N];
11
12 int f[N][N], v[N], m;
13
14 void dfs(int r, int pr) {
15 f[r][1] = v[r];
16 int len = g[r].size();
17 if(!len) return ;
18 int c, i, j, k;
19 for(i = 0; i < len; i++) {
20 c = g[r][i];
21 if(c == pr) continue;
22 dfs(c, r);
23 for(j = m; j >= 1; j--)
24 for(k = 1; k + j <= m; k++)
25 f[r][j + k] = max(f[r][j + k], f[r][j] + f[c][k]);
26 }
27 }
28
29 int main() {
30 //freopen("data.in", "r", stdin);
31
32 int n, i, x, y, ans;
33 while(cin >> n >> m) {
34 memset(f, 0, sizeof(f));
35 memset(v, 0, sizeof(v));
36 for(i = 0; i < n; i++) g[i].clear();
37 for(i = 0; i < n; i++) cin >> v[i];
38 for(i = 1; i < n; i++) {
39 cin >> x >> y;
40 g[x].push_back(y);
41 g[y].push_back(x);
42 }
43 dfs(0, -1);
44 ans = 0;
45 for(i = 0; i < n; i++) {
46 ans = max(ans, f[i][m]);
47 }
48 cout << ans << endl;
49 }
50 return 0;
51 }


poj 3107 Godfather

题意:是给一颗树,求删掉一个结点后使得得到的两个子树最小,求共有多少个这样的点。

思路:dfs搜出每个子树所包含的结点数。然后取 m = max(一个结点R所有子树的结点数num[r],总体删掉以R为根的子树所省的结点数n - num[r]),所有m的最小值就是所求的结果。

MY CODE:

View Code
 1 #include <iostream>
2 #include <cstdio>
3 #include <cstring>
4 #include <algorithm>
5
6 using namespace std;
7
8 const int N = 50010;
9 const int inf = ~0u>>2;
10
11 struct node {
12 int next;
13 int c;
14 } g[N<<1];
15
16 int ind, t, n, mins;
17 int ans[N], head[N], num[N];
18 bool vis[N];
19
20 void add(int u, int v) {
21 g[t].c = v;
22 g[t].next = head[u];
23 head[u] = t++;
24 }
25
26 void dfs(int r) {
27 vis[r] = true;
28 int c, i, m = -1;
29 num[r] = 1;
30 for(i = head[r]; i; i = g[i].next) {
31 c = g[i].c;
32 if(vis[c]) continue;
33 dfs(c);
34 num[r] += num[c];
35 m = max(m, num[c]);
36 }
37 m = max(m, n - num[r]);
38 ans[r] = m;
39 mins = min(mins, m);
40 }
41
42 int main() {
43 //freopen("data.in", "r", stdin);
44
45 int i, x, y;
46 while(~scanf("%d", &n)) {
47 memset(head, 0, sizeof(head));
48 memset(vis, false, sizeof(vis));
49 memset(g, 0, sizeof(g));
50 memset(ans, 0, sizeof(ans));
51 memset(num, 0, sizeof(num));
52 ind = 0, t = 1, mins = inf;
53
54 for(i = 1; i < n; i++) {
55 scanf("%d%d", &x, &y);
56 add(x, y);
57 add(y, x);
58 }
59 dfs(1);
60 for(i = 1; i <= n; i++) {
61 if(mins == ans[i]) printf("%d ", i);
62 }
63 cout << endl;
64 }
65 return 0;
66 }


poj 1655 Balancing Act

题意跟3107一样,不过要求输出的是最小m对应的结点 + 最小m的值;

MY CODE:

View Code
 1 #include <iostream>
2 #include <cstdio>
3 #include <cstring>
4 #include <algorithm>
5
6 using namespace std;
7
8 const int N = 20010;
9 const int inf = ~0u>>2;
10
11 struct node {
12 int next;
13 int c;
14 } g[N<<1];
15
16 int ind, t, n, mins;
17 int ans[N], head[N], num[N];
18 bool vis[N];
19
20 void add(int u, int v) {
21 g[t].c = v;
22 g[t].next = head[u];
23 head[u] = t++;
24 }
25
26 void dfs(int r) {
27 vis[r] = true;
28 int c, i, m = -1;
29 num[r] = 1;
30 for(i = head[r]; i; i = g[i].next) {
31 c = g[i].c;
32 if(vis[c]) continue;
33 dfs(c);
34 num[r] += num[c];
35 m = max(m, num[c]);
36 }
37 m = max(m, n - num[r]);
38 ans[r] = m;
39 mins = min(mins, m);
40 }
41
42 int main() {
43 //freopen("data.in", "r", stdin);
44
45 int i, x, y;
46 int z;
47 cin >> z;
48 while(z--) {
49 while(~scanf("%d", &n)) {
50 memset(head, 0, sizeof(head));
51 memset(vis, false, sizeof(vis));
52 memset(g, 0, sizeof(g));
53 memset(ans, 0, sizeof(ans));
54 memset(num, 0, sizeof(num));
55 ind = 0, t = 1, mins = inf;
56
57 for(i = 1; i < n; i++) {
58 scanf("%d%d", &x, &y);
59 add(x, y);
60 add(y, x);
61 }
62 dfs(1);
63 for(i = 1; i <= n; i++) {
64 if(mins == ans[i]) {printf("%d ", i); break;}
65 }
66 printf("%d\n", mins);
67 }
68 }
69 return 0;
70 }


poj 2378 Tree Cutting

题意同上边两个题,都是删掉树上的某个结点让剩下所有子树的结点数最少,这道题加了一个条件,所有子树的结点数必须小于N/2 否则输出NONE

MY CODE:

View Code
 1 #include <iostream>
2 #include <cstdio>
3 #include <cstring>
4 #include <vector>
5 #include <algorithm>
6
7 using namespace std;
8
9 const int N = 10010;
10
11 struct node {
12 int c;
13 int next;
14 } g[N<<1];
15
16 int ans[N], num[N], head[N];
17 int maxs, t, ind, n;
18 bool vis[N];
19
20 void add(int u, int v) {
21 g[t].c = v;
22 g[t].next = head[u];
23 head[u] = t++;
24 }
25
26 void dfs(int r) {
27 vis[r] = true;
28 num[r] = 1;
29 int c, i, m = 0;
30 for(i = head[r]; i; i = g[i].next) {
31 c = g[i].c;
32 if(vis[c]) continue;
33 dfs(c);
34 num[r] += num[c];
35 m = max(m, num[c]);
36 }
37 m = max(m, n - num[r]);
38 if(m < maxs) {
39 ind = 0;
40 ans[ind++] = r;
41 maxs = m;
42 } else if(m == maxs) {
43 ans[ind++] = r;
44 }
45 }
46
47 int main() {
48 //freopen("data.in", "r", stdin);
49
50 int x, y, i;
51 while(cin >> n) {
52 memset(g, 0, sizeof(g));
53 memset(ans, 0, sizeof(ans));
54 memset(num, 0, sizeof(num));
55 memset(head, 0, sizeof(head));
56 maxs = n; ind = 0; t = 1;
57
58 for(i = 1; i < n; i++) {
59 cin >> x >> y;
60 add(x, y);
61 add(y, x);
62 }
63 dfs(1);
64 if(ind == 0) {puts("NONE"); continue;}
65 sort(ans, ans + ind);
66 for(i = 0; i < ind; i++) {
67 printf("%d\n", ans[i]);
68 }
69 }
70 return 0;
71 }


poj 3140 Contestants Division

题意:是给一颗树,所有的结点带有相应的权值,要求把这颗树分成两个。求这两个子树总权值的差的最小值。

思路:大体就是先求出原树的总权值sum,然后dfs每一个结点为根的子树的总权值 ans = min(abs(sum - num[i] - num[i]));

MY CODE:313+MS

View Code
 1 #include <iostream>
2 #include <cstdio>
3 #include <cstring>
4 #include <vector>
5 #include <algorithm>
6
7 using namespace std;
8
9 const int N = 100010;
10
11 struct node {
12 int c;
13 int next;
14 } g[N<<1];
15
16 int head[N], val[N];
17 int maxs, t, ind;
18 long long sum, ans, num[N];
19 bool vis[N];
20
21 void add(int u, int v) {
22 g[t].c = v;
23 g[t].next = head[u];
24 head[u] = t++;
25 }
26
27 void dfs(int r) {
28 vis[r] = true;
29 num[r] = val[r];
30 int c, i;
31 long long tmp;
32
33 for(i = head[r]; i; i = g[i].next) {
34 c = g[i].c;
35 if(vis[c]) continue;
36 dfs(c);
37 num[r] += num[c];
38 }
39 tmp = sum - 2*num[r];
40 if(tmp < 0) tmp *= -1;
41 ans = ans < tmp ? ans : tmp;
42 }
43
44 int main() {
45 //freopen("data.in", "r", stdin);
46
47 int x, y, i, m, n, cas = 0;
48 while(scanf("%d%d", &n, &m), n||m) {
49 memset(g, 0, sizeof(g));
50 for(i = 0; i <= n; i++) {
51 head[i] = num[i] = 0;
52 vis[i] = false;
53 }
54 t = 1; sum = 0;
55 for(i = 1; i <= n; i++) {
56 scanf("%d", val + i);
57 sum += val[i];
58 }
59
60 for(i = 0; i < m; i++) {
61 scanf("%d%d", &x, &y);
62 add(x, y);
63 add(y, x);
64 }
65 ans = sum;
66 dfs(1);
67 printf("Case %d: %I64d\n", ++cas, ans);
68 }
69 return 0;
70 }


HDU_2242 考研路茫茫——空调教室

思路:状态转移很简单,主要是缩点建图上。将图中存在的回路缩成一个点,因为是无向图,所以可是看成是求强连通分量。。。

ps:先学习完tarjan又做这道题,各种坎坷。。。T^T

渣代码:

 

View Code
  1 #include <iostream>
2 #include <cstdio>
3 #include <cstring>
4 #include <vector>
5 #include <stack>
6
7 using namespace std;
8
9 const int N = 10006;
10 const int M = 20012;
11 const int inf = ~0u>>2;
12
13 struct node {
14 int to;
15 int next;
16 } g[M*2];
17
18 int head[N], blong[N];
19 int dfn[N], low[N];
20 int val[N], val1[N];
21 int x[M], y[M];
22 int ind, id, cnt;
23 int SUM, ans;
24
25 bool vis[N];
26 stack<int> s;
27
28 void init() {
29 memset(head, 0, sizeof(head));
30 memset(dfn, 0, sizeof(dfn));
31 memset(low, 0, sizeof(low));
32 memset(val, 0, sizeof(val));
33 memset(val1, 0, sizeof(val1));
34 memset(blong, 0, sizeof(blong));
35 memset(vis, 0, sizeof(vis));
36 id = 1; ind = cnt = SUM = 0;
37 while(!s.empty()) s.pop();
38 }
39
40 void add(int u, int v) {
41 g[id].to = v;
42 g[id].next = head[u];
43 head[u] = id++;
44 }
45
46 void tarjan(int u, int pre) {
47 int i, v, flag;
48 dfn[u] = low[u] = ++ind;
49 flag = 0;
50 vis[u] = 1;
51 s.push(u);
52 for(i = head[u]; i; i = g[i].next) {
53 v = g[i].to;
54 if(v == pre && !flag) {flag = 1; continue;}
55 if(!dfn[v]) {
56 tarjan(v, u);
57 low[u] = min(low[u], low[v]);
58 } else if(vis[v]) {
59 low[u] = min(low[u], dfn[v]);
60 }
61 }
62 if(dfn[u] == low[u]) {
63 cnt++;
64 do {
65 v = s.top();
66 s.pop();
67 blong[v] = cnt;
68 val1[cnt] += val[v];
69 } while(u != v);
70 }
71 }
72
73 int ABS(int x) {
74 return x < 0 ? -x : x;
75 }
76
77 int dfs(int u, int p) {
78 int sum, i, v;
79 sum = val1[u];
80 for(i = head[u]; i; i = g[i].next) {
81 v = g[i].to;
82 if(v == p) continue;
83 sum += dfs(v, u);
84 }
85 ans = min(ans, ABS(SUM - 2*sum));
86 return sum;
87 }
88
89 int main() {
90 //freopen("data.in", "r", stdin);
91
92 int i, n, m, u, v;
93 while(~scanf("%d%d", &n, &m)) {
94 init();
95 for(i = 0; i < n; i++) {
96 scanf("%d", val + i);
97 SUM += val[i];
98 }
99 for(i = 0; i < m; i++) {
100 scanf("%d%d", &x[i], &y[i]);
101 add(x[i], y[i]);
102 add(y[i], x[i]);
103 }
104 tarjan(0, -1);
105 if(cnt == 1) {cout << "impossible" << endl; continue;}
106 memset(head, 0, sizeof(head));
107 id = 1;
108 for(i = 0; i < m; i++) {
109 u = x[i];
110 v = y[i];
111 if(blong[u] == blong[v]) continue;
112 add(blong[u], blong[v]);
113 add(blong[v], blong[u]);
114 }
115 ans = inf;
116 dfs(1, 0);
117 printf("%d\n", ans);
118 }
119 return 0;
120 }


Ural 1018. Binary Apple Tree

题意:给一个苹果树(二叉树)每个树枝上有一定的苹果。要求剪掉Q个树枝后苹果树上剩下的苹果最多。

思路:转移方程很明显,设f[t][i]表示以t为根减掉i个苹果枝后的最优解。

当t从一个孩子而来时:

f[t][i] = max(f[l][i - 1] + map[l][t], f[r][i-1] + map[r][t]);

当t从两个孩子转移而来时:

for(j = 0; j <= i-2; j++) f[t][i] = max(f[t][i], f[l][j] + f[r][i-j-2] + map[l][t] + map[r][t]);

这里需要先存图,再建二叉数。用vector写了半天没写对,还是老老实实建二叉数吧。

MY CODE:

View Code
 1 #include <iostream>
2 #include <cstring>
3 #include <cstdio>
4
5 using namespace std;
6
7 const int N = 110;
8
9 struct node {
10 int l;
11 int r;
12 } node[N];
13
14 int f[N][N], map[N][N];
15 int n, q;
16
17 void creat(int t) {
18 int flag, i;
19 for(flag = 0, i = 1; i <= n; i++) {
20 if(map[t][i] && !node[i].l) {
21 flag = 1;
22 if(!node[t].l) node[t].l = i;
23 else {node[t].r = i; break;}
24 }
25 }
26 if(!flag) return ;
27 creat(node[t].l);
28 creat(node[t].r);
29 }
30
31 void dfs(int t) {
32 if(!node[t].l) return ;
33 int l, r, i, j, tmp;
34 l = node[t].l; r = node[t].r;
35 dfs(l); dfs(r);
36
37 f[t][1] = max(map[l][t], map[r][t]);
38 tmp = map[l][t] + map[r][t];
39
40 for(i = 2; i <= q; i++) {
41 f[t][i] = max(f[l][i-1] + map[l][t], f[r][i-1] + map[r][t]);
42 for(j = 0; j <= i - 2; j++) {
43 f[t][i] = max(f[t][i], f[l][j] + f[r][i-j-2] + tmp);
44 }
45 }
46 }
47
48 int main() {
49 //freopen("data.in", "r", stdin);
50
51 int i, x, y, z;
52 while(cin >> n >> q) {
53 memset(node, 0, sizeof(node));
54 memset(f, 0, sizeof(f));
55
56 for(i = 1; i < n; i++) {
57 scanf("%d%d%d", &x, &y, &z);
58 map[x][y] = map[y][x] = z;
59 }
60 creat(1);
61 dfs(1);
62 printf("%d\n", f[1][q]);
63 }
64 return 0;
65 }


poj 1947 Rebuilding Roads

终于把这道题折腾出来了。不过还是看了大牛的代码。。。

思路:f[i][j]表示以i为根保留j个结点所需切掉的边数

f[i][j] = min(f[i][j], f[i][j-k] + f[son[i]][k]);

MY CODE:

View Code
 1 #include <iostream>
2 #include <cstdio>
3 #include <cstring>
4 #include <vector>
5
6 using namespace std;
7
8 const int N = 160;
9 const int inf = ~0u>>2;
10
11 struct node {
12 int c;
13 int next;
14 } g[N];
15
16 int f[N][N], head[N];
17 int n, m, ans, ind;
18 bool in[N];
19
20 void add(int u, int v) {
21 g[ind].c = v;
22 g[ind].next = head[u];
23 head[u] = ind++;
24 }
25
26 void dfs(int r) {
27 int i, j, k, c;
28 for(i = 1; i <= m; i++) {
29 f[r][i] = inf;
30 }
31 f[r][1] = 0;
32 for(i = head[r]; i; i = g[i].next) {
33 c = g[i].c;
34 //printf("%d %d\n", r, c);
35 dfs(c);
36 for(j = m; j >= 1; j--) {
37 f[r][j]++;
38 for(k = 1; k < j; k++) {
39 f[r][j] = min(f[r][j], f[r][j-k] + f[c][k]);
40 }
41 }//printf("%d %d\n", r, f[r][m]);
42 }
43
44 }
45
46 int main() {
47 //freopen("data.in", "r", stdin);
48
49 int i, x, y;
50 while(cin >> n >> m) {
51 memset(head, 0, sizeof(head));
52 memset(in, 0, sizeof(in));
53 ind = 1;
54 for(i = 1; i < n; i++) {
55 scanf("%d%d", &x, &y);
56 add(x, y);
57 //printf("%d %d\n", x, y);
58 in[y] = true;
59 }
60 for(i = 1; i <= n; i++) {
61 if(!in[i]) {
62 dfs(i);
63 ans = f[i][m];
64 break;
65 }
66 }
67 for(i = 1; i <= n; i++) {
68 ans = min(ans, f[i][m] + 1);
69 }
70 printf("%d\n", ans);
71
72 }
73 return 0;
74 }

 

poj 2486 Apple Tree

思路:定义go[i][j]表示以i为根,用掉j步不回到i可以吃到的苹果数,gb[i][j]表示以i为根,用掉j不且回到i可以吃到的苹果数

设当前结点为r,它的一个子结点为v,则可以分为三种情况:

1、回到r,且回到v (r->v , v->r多出两步)

2、不回r,回到v(先走到v且回v gb[v][j-k], 然后回到r,再走出去不回来go[v][k])(r->v , v->r多出两步)

3、不回r,回到v (既先走除v以外的其他子树并且回来gb[r][k],然后走到v,不会来go[v][j-k])(r->v 多出一步)

渣代码:

View Code
 1 #include <iostream>
2 #include <cstdio>
3 #include <cstring>
4
5 using namespace std;
6
7 const int N = 110;
8 const int M = 210;
9
10 struct node {
11 int to;
12 int next;
13 } g[M<<1];
14
15 int head[N], val[N];
16 int go[N][M], gb[N][M];
17 bool vis[N];
18 int t, n, k;;
19
20 void add(int u, int v) {
21 g[t].to = v;
22 g[t].next = head[u];
23 head[u] = t++;
24 }
25
26 void init() {
27 memset(vis, 0, sizeof(vis));
28 memset(val, 0, sizeof(val));
29 memset(go, 0, sizeof(go));
30 memset(gb, 0, sizeof(gb));
31 memset(head, 0, sizeof(head));
32 t = 1;
33 }
34
35 void dfs(int r) {
36 int v, i, j, l;
37 vis[r] = true;
38 for(i = 0; i <= k; i++) {
39 go[r][i] = gb[r][i] = val[r];
40 }
41 for(i = head[r]; i; i = g[i].next) {
42 v = g[i].to;
43 if(vis[v]) continue;
44 dfs(v);
45 for(j = k; j >= 0; j--) {
46 for(l = 0; l <= j; l++) {
47 gb[r][j+2] = max(gb[r][j+2], gb[v][l] + gb[r][j-l]);
48 go[r][j+2] = max(go[r][j+2], gb[v][l] + go[r][j-l]);
49 go[r][j+1] = max(go[r][j+1], go[v][l] + gb[r][j-l]);
50 }
51 }
52 }
53 }
54
55 int main() {
56 //freopen("data.in", "r", stdin);
57
58 int i, u, v;
59 while(~scanf("%d%d", &n, &k)) {
60 init();
61 for(i = 1; i <= n; i++) {
62 scanf("%d", val + i);
63 }
64 for(i = 1; i < n; i++) {
65 scanf("%d%d", &u, &v);
66 add(u, v);
67 add(v, u);
68 }
69 dfs(1);
70 printf("%d\n", max(go[1][k], gb[1][k]));
71 }
72 return 0;
73 }




posted @ 2012-01-10 11:07  AC_Von  阅读(307)  评论(0编辑  收藏  举报