POJ 1848 (一道不错的树形dp)

题意:N个点的一颗树。问最少添加多少条边可以让每个点都在一个(且仅一个)环中。

不得不佩服,这题dp设计出来的人。。。偶是弱菜,只能膜拜了。

这位大牛的解说,很详细:http://hi.baidu.com/19930705cxjff/blog/item/1df66e4a4ff3022e08f7ef5d.html
 

 

首先明确一点,题中的环至少需要3个顶点。因此,对于树中的每个顶点,有3种状态。
f[x][0]表示以x为根的树,变成每个顶点恰好在一个环中的图,需要连的最少边数。
f[x][1]表示以x为根的树,除了根x以外,其余顶点变成每个顶点恰好在一个环中的图,需要连的最少边数。
f[x][2]表示以x为根的树,除了根x以及和根相连的一条链(算上根一共至少2个顶点)以外,其余顶点变成每个顶点恰好在一个环中的图,需要连的最少边数。
有四种状态转移(假设正在考虑的顶点是R,有k个儿子):
A.根R的所有子树自己解决(取状态0),转移到R的状态1。即R所有的儿子都变成每个顶点恰好在一个环中的图,R自己不变。



B.根R的k-1个棵树自己解决,剩下一棵子树取状态1和状态2的最小值,转移到R的状态2。剩下的那棵子树和根R就构成了长度至少为2的一条链。


C.根R的k-2棵子树自己解决,剩下两棵子树取状态1和状态2的最小值,在这两棵子树之间连一条边,转移到R的状态0。


D.根R的k-1棵子树自己解决,剩下一棵子树取状态2(子树里还剩下长度至少为2的一条链),在这棵子树和根之间连一条边,构成一个环,转移到R的状态0。

 
 
ps:写代码要有胆量,不要怕TLE。。。直接暴力枚举这些值就行。

附代码:

#include <iostream>
#include <cstdio>
#include <cmath>
#include <vector>
#include <cstring>
#include <algorithm>
#include <string>
#include <set>
#include <ctime>
#include <queue>
#include <map>
#include <sstream>

#define CL(arr, val)    memset(arr, (val), sizeof(arr))
#define REP(i, n)       for((i) = 0; (i) < (n); ++(i))
#define FOR(i, l, h)    for((i) = (l); (i) <= (h); ++(i))
#define FORD(i, h, l)   for((i) = (h); (i) >= (l); --(i))
#define L(x)    (x) << 1
#define R(x)    (x) << 1 | 1
#define MID(l, r)   ((l) + (r)) >> 1
#define Min(x, y)   (x) < (y) ? (x) : (y)
#define Max(x, y)   (x) < (y) ? (y) : (x)
#define E(x)    (1 << (x))
#define iabs(x)  ((x) > 0 ? (x) : -(x))

typedef long long LL;
const double eps = 1e-12;
const int inf = 10000;  

using namespace std;

const int N = 110;

struct node {
    int to;
    int next;
} g[N<<2];    

int head[N], t;

bool vis[N];
int f[N][3];

void init() {
    CL(head, -1); t = 0;
}

void add(int u, int v) {
    g[t].to = v; g[t].next = head[u]; head[u] = t++;
}

void dfs(int t) {
    int i, j, k, v, sum = 0, tmp, n;
    vector<int> ch;
    for(i = head[t]; i != -1; i = g[i].next) {
        v = g[i].to;
        if(!vis[v]) {
            vis[v] = true;
            dfs(v);
            sum += f[v][0];
            ch.push_back(v);
        }
    }
    n = ch.size();
    if(n == 0) {
        //printf("%d here!\n", t);
        f[t][0] = inf;
        f[t][1] = 0;
        f[t][2] = inf;
        return ;
    }


    f[t][1] = min(inf, sum);
    f[t][0] = f[t][2] = inf;

    for(i = 0; i < n; ++i) {
        v = ch[i];
        f[t][2] = min(f[t][2], sum - f[v][0] + min(f[v][1], f[v][2]));
        f[t][0] = min(f[t][0], sum - f[v][0] + f[v][2] + 1);
    }
    for(i = 0; i < n; ++i) {
        v = ch[i];
        for(j = 0; j < n; ++j) {
            if(i == j)  continue;
            k = ch[j];
            tmp = sum - f[v][0] - f[k][0] + min(f[v][1], f[v][2]) + min(f[k][1], f[k][2]);
            f[t][0] = min(f[t][0], tmp + 1);
        }
    }
}

int main() {
    //freopen("data.in", "r", stdin);

    int n, x, y, i;
    while(~scanf("%d", &n)) {
        init();
        for(i = 1; i < n; ++i) {
            scanf("%d%d", &x, &y);
            add(x, y); add(y, x);
        }
        CL(vis, false); vis[1] = true;
        dfs(1);
        //for(i = 1; i <= n; ++i) printf("%d | %-11d %-11d %-11d\n", i, f[i][0], f[i][1], f[i][2]);
        if(f[1][0] >= inf)  puts("-1");
        else    printf("%d\n", f[1][0]);
    }
    return 0;
}

 

 
 
 

 

 

posted @ 2012-08-12 16:23  AC_Von  阅读(973)  评论(0编辑  收藏  举报