HDU 5293 Tree chain problem 树形DP

题意:
给出一棵\(n\)个节点的树和\(m\)条链,每条链有一个权值。
从中选出若干条链,两两不相交,并且使得权值之和最大。

分析:
题解

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <map>
#include <set>
#include <vector>
#include <iostream>
#include <string>
using namespace std;
#define REP(i, a, b) for(int i = a; i < b; i++)
#define PER(i, a, b) for(int i = b - 1; i >= a; i--)
#define SZ(a) ((int)a.size())
#define MP make_pair
#define PB push_back
#define EB emplace_back
#define ALL(a) a.begin(), a.end()
typedef long long LL;
typedef pair<int, int> PII;

const int maxn = 100000 + 10;

int n, m;
vector<int> G[maxn], Q[maxn];
int u[maxn], v[maxn], w[maxn];
int l[maxn], r[maxn], dfs_clock;
int dp[maxn], sum[maxn];

int dep[maxn];
int anc[maxn][20];

void dfs(int u, int fa) {
    l[u] = ++dfs_clock;
    anc[u][0] = fa;
    for(int i = 0; anc[u][i]; i++)
        anc[u][i+1] = anc[anc[u][i]][i];
    dep[u] = dep[fa] + 1;
    for(int v : G[u]) if(v != fa) {
        dfs(v, u);
    }
    r[u] = ++dfs_clock;
}

int LCA(int u, int v) {
    if(dep[u] < dep[v]) swap(u, v);
    PER(i, 0, 20)
        if(dep[anc[u][i]] >= dep[v]) u = anc[u][i];
    if(u == v) return u;
    PER(i, 0, 20) if(anc[u][i] != anc[v][i])
        u = anc[u][i], v = anc[v][i];
    return anc[u][0];
}

int C[maxn << 1];
#define lowbit(x) (x&(-x))
void add(int x, int v) {
    while(x <= n * 2) {
        C[x] += v;
        x += lowbit(x);
    }
}
int query(int x) {
    int ans = 0;
    while(x) {
        ans += C[x];
        x -= lowbit(x);
    }
    return ans;
}

void init() {
    REP(i, 1, n + 1) G[i].clear(), Q[i].clear();
    memset(anc, 0, sizeof(anc));
    dfs_clock = 0;
    memset(C, 0, sizeof(C));
    memset(dp, 0, sizeof(dp));
    memset(sum, 0, sizeof(sum));
}

void upd(int& a, int b) { if(a < b) a = b; }

void solve(int x) {
    for(int y : G[x]) if(y != anc[x][0]) {
        solve(y);
        sum[x] += dp[y];
    }
    dp[x] = sum[x];
    for(int q : Q[x]) {
        upd(dp[x], sum[x] + query(l[u[q]]) + query(l[v[q]]) + w[q]);
    }
    add(l[x], sum[x] - dp[x]);
    add(r[x], dp[x] - sum[x]);
}

int main() {
    int T; scanf("%d", &T);
    while(T--) {
        scanf("%d%d", &n, &m);
        init();
        REP(i, 1, n) {
            int u, v; scanf("%d%d", &u, &v);
            G[u].PB(v);
            G[v].PB(u);
        }
        dfs(1, 0);
        REP(i, 0, m) {
            scanf("%d%d%d", u + i, v + i, w + i);
            int lca = LCA(u[i], v[i]);
            Q[lca].PB(i);
        }
        solve(1);
        printf("%d\n", dp[1]);
    }

    return 0;
}
posted @ 2017-06-06 23:38 AOQNRMGYXLMV 阅读(...) 评论(...) 编辑 收藏