bzoj 4033 树上染色 - 树形动态规划

  有一棵点数为N的树,树边有边权。给你一个在0~N之内的正整数K,你要在这棵树中选择K个点,将其染成黑
色,并将其他的N-K个点染成白色。将所有点染色后,你会获得黑点两两之间的距离加上白点两两之间距离的和的
收益。问收益最大值是多少。

Input

第一行两个整数N,K。
接下来N-1行每行三个正整数fr,to,dis,表示该树中存在一条长度为dis的边(fr,to)。
输入保证所有点之间是联通的。
N<=2000,0<=K<=N

Output

输出一个正整数,表示收益的最大值。

Sample Input

5 2
1 2 3
1 5 1
2 3 1
2 4 2

Sample Output

17
【样例解释】
将点1,2染黑就能获得最大收益。

  动态规划的第一步——设计状态,f[i][j]表示以i节点为根的子树中染了j个黑点的"收益"。

  不过这样没有黑点的位置,这么多个点,总不可能用N进制来表示点的位置。所以只能换个思路。

  对于当前考虑的这棵子树,我知道染了j个节点,那么我知道在这棵子树内的白点数和子树外的白点数和黑点数。因此我可以计算出节点i到它的父节点的那条边的对答案的贡献,对于子节点转移到父节点就是一个用dp合并的过程,因此解决了状态转移的问题,时间复杂度为O(nk)。

  注意dp时不合法的状态一定不能转移(看代码吧,或者自己想想也可以,状态转移前有个if)

  (现在觉得以前的树归写得好丑)

Code

  1 /**
  2  * bzoj
  3  * Problem#4033
  4  * Accepted
  5  * Time:630ms
  6  * Memory:17092k
  7  */
  8 #include<iostream>
  9 #include<fstream> 
 10 #include<sstream>
 11 #include<algorithm>
 12 #include<cstdio>
 13 #include<cstring>
 14 #include<cstdlib>
 15 #include<cctype>
 16 #include<cmath>
 17 #include<ctime>
 18 #include<map>
 19 #include<stack>
 20 #include<set>
 21 #include<queue>
 22 #include<vector>
 23 #ifndef WIN32
 24 #define AUTO "%lld"
 25 #else
 26 #define AUTO "%I64d"
 27 #endif
 28 using namespace std;
 29 typedef bool boolean;
 30 #define inf 0xfffffff
 31 #define smin(a, b) (a) = min((a), (b))
 32 #define smax(a, b) (a) = max((a), (b))
 33 template<typename T>
 34 inline boolean readInteger(T& u) {
 35     char x;
 36     int aFlag = 1;
 37     while(!isdigit((x = getchar())) && x != '-' && x != -1);
 38     if(x == -1)    {
 39         ungetc(x, stdin);
 40         return false;
 41     }
 42     if(x == '-') {
 43         aFlag = -1;
 44         x = getchar();
 45     }
 46     for(u = x - '0'; isdigit((x = getchar())); u = u * 10 + x - '0');
 47     u *= aFlag;
 48     ungetc(x, stdin);
 49     return true;
 50 }
 51 
 52 ///map template starts
 53 typedef class Edge{
 54     public:
 55         int end;
 56         int next;
 57         int w;
 58         Edge(const int end = 0, const int next = 0, const int w = 0):end(end), next(next), w(w){}
 59 }Edge;
 60 
 61 typedef class MapManager{
 62     public:
 63         int ce;
 64         int *h;
 65         Edge *edge;
 66         MapManager(){}
 67         MapManager(int points, int limit):ce(0){
 68             h = new int[(const int)(points + 1)];
 69             edge = new Edge[(const int)(limit + 1)];
 70             memset(h, 0, sizeof(int) * (points + 1));
 71         }
 72         inline void addEdge(int from, int end, int w){
 73             edge[++ce] = Edge(end, h[from], w);
 74             h[from] = ce;
 75         }
 76         inline void addDoubleEdge(int from, int end, int w){
 77             addEdge(from, end, w);
 78             addEdge(end, from, w);
 79         }
 80         Edge& operator [] (int pos) {
 81             return edge[pos];
 82         }
 83 }MapManager;
 84 #define m_begin(g, i) (g).h[(i)]
 85 ///map template ends
 86 
 87 template<typename T>class Matrix{
 88     public:
 89         T *p;
 90         int lines;
 91         int rows;
 92         Matrix():p(NULL){    }
 93         Matrix(int rows, int lines):lines(lines), rows(rows){
 94             p = new T[(lines * rows)];
 95         }
 96         T* operator [](int pos){
 97             return (p + pos * lines);
 98         }
 99 };
100 #define matset(m, i, s) memset((m).p, (i), (s) * (m).lines * (m).rows)
101 
102 int n, k;
103 MapManager g;
104 Matrix<long long> f;
105 int* size;
106 
107 inline void init() {
108     readInteger(n);
109     readInteger(k);
110     g = MapManager(n, 2 * n);
111     f = Matrix<long long>(n + 1, k + 1);
112     size = new int[(const int)(n + 1)];
113     matset(f, 0, sizeof(long long));
114     for(int i = 1, a, b, c; i < n; i++) {
115         readInteger(a);
116         readInteger(b);
117         readInteger(c);
118         g.addDoubleEdge(a, b, c);
119     }
120 }
121 
122 void treedp(int node, int fa, int len) {
123     size[node] = 1;
124     for(int i = m_begin(g, node); i != 0; i = g[i].next) {
125         int& e = g[i].end;
126         if(e == fa)     continue;
127         treedp(e, node, g[i].w);
128         size[node] += size[e];
129         for(int j = min(size[node], k); j >= 0; j--) {
130             for(int s = 0; s <= size[e] && s <= j; s++) {
131                 if(j - s <= size[node] - size[e])
132                     smax(f[node][j], f[node][j - s] + f[e][s]);
133             }
134         }
135     }
136     for(int i = 0; i <= min(size[node], k); i++)
137             f[node][i] += (i * 1LL * (k - i) + (size[node] - i) * 1LL * (n - k - size[node] + i)) * len;
138 }
139 
140 inline void solve() {
141     treedp(1, 0, 0);
142     printf(AUTO"\n", f[1][k]);
143 }
144 
145 int main() {
146     init();
147     solve();
148     return 0;
149 }
posted @ 2017-03-22 21:56  阿波罗2003  阅读(480)  评论(0编辑  收藏  举报