HDU 4616(树形DP)

这道题目麻烦的地方是陷阱的处理,用dp[ u ][ j ][ 0/1 ]表示以u为根的某一子节点经过j个陷阱后到达u的最大权值和,0/1表示起点是否有陷阱。

在dfs的过程中,当处理到u的儿子v的时候,先去用dp[u]和dp[v]的和去更新ans。然后再用dp[v]更新dp[u]。这样相当于子链的连接。

#include <algorithm>
#include <cstdio>
#include <cstring>
#include <iostream>

#define Max 1000001
#define MAXN 1001009
#define MOD 1000000007
#define rin freopen("in.txt","r",stdin)
#define rout freopen("1.out","w",stdout)
#define Del(a,b) memset(a,b,sizeof(a))
typedef long long LL;
using namespace std;
const int N = 100005;
int T;
struct Node {
    int next;
    int to;
} e[MAXN];
int tot, ans, c;
int head[N], vis[N];
int dp[N][4][3];
int trap[N];
int gift[N];
void Init(int n) {
    Del(head, -1);
    ans = 0;
    tot = 0;
    Del(dp,0);
}
void addedge(int from, int to) {
    e[tot].to = to;
    e[tot].next = head[from];
    head[from] = tot++;
}
void dfs(int u, int father) {
    dp[u][trap[u]][0]=gift[u];
    dp[u][trap[u]][1]=gift[u];
    for (int i = head[u]; i != -1; i = e[i].next) {
        int v = e[i].to;
        if (v == father)
            continue;
        dfs(v, u);
        for (int j = 0; j <= c; j++) {
            for (int k = 0; j + k <= c; k++) {
                if (j != c)
                    ans = max(ans, dp[u][j][0] + dp[v][k][1]);
                if (k != c)
                    ans = max(ans, dp[u][j][1] + dp[v][k][0]);
                if (j + k < c)
                    ans = max(ans, dp[u][j][0] + dp[v][k][0]);  //起点和终点都可以为非陷阱
                if (k + j <= c)
                    ans = max(ans, dp[u][j][1] + dp[v][k][1]);  //起点和终点都可以为陷阱
            }
        }

        for (int j = 0; j + trap[u] <= c; j++) {
            dp[u][j + trap[u]][0] = max(dp[u][j + trap[u]][0],
                    dp[v][j][0] + gift[u]);
            if (j != 0) {
                dp[u][j + trap[u]][1] = max(dp[u][j + trap[u]][1],
                        dp[v][j][1] + gift[u]);
            }
        }
    }

}
int main() {
    //rin;
    int n, T;
    while (cin >> T) {
        while (T--) {
            scanf("%d%d", &n, &c);
            Init(n);
            for (int i = 0; i < n; i++) {
                scanf("%d%d", &gift[i], &trap[i]);
            }
//            for (int i = 0; i < n; i++) {
//                printf("%d %d\n", gift[i], trap[i]);
//            }
            for (int i = 1; i < n; i++) {
                int x, y;
                scanf("%d%d", &x, &y);
                addedge(y, x);
                addedge(x, y);
                //printf("%d %d\n", x, y);
            }
            dfs(0, -1);
            printf("%d\n", ans);
        }
    }
    return 0;
}

 

posted @ 2017-08-18 15:06  Belleaholic  阅读(213)  评论(0)    收藏  举报