Loading

P3177-HAOI树上染色

3177-树上染色

dp状态方程比较难想。

不如先写三维的dp

dp[i][j][k]表示以j为根节点,已经遍历到第i个子节点,有k个节点染成黑色的时候该子树对结果作出的贡献。

\[dp[now][fath][k]=max(dp[last][fath][k-p]+dp[all][son][p]+val \]

\[val=w[fath][son]*((siz[son]-p)*(n-m-(siz[son]-p))+(p*(m-p)))) \]

由于树形dp考虑边界实在是太难写了,于是借鉴[JSOI潜入行动][https://www.luogu.com.cn/problem/P4516]的写法,可以用一个临时数组存遍历到上一个子节点的时候更新得到的dp值,这个数组比dp数组少一维。然后就可以滚动掉第一维并且不用考虑遍历顺序了...

\[dp[fath][j+k]=max(lin[j]+dp[son][k]+val) \]

#include <bits/stdc++.h>
using namespace std;
#define ios ios::sync_with_stdio(0);cin.tie(0);cout.tie(0);
#define endl '\n'
#define debugg(x) cout<<#x<<'='<<x<<endl;
#define debug1(x,y,z) cout<<#x<<' '<<x<<' '<<#y<<' '<<y<<' '<<#z<<' '<<z<<endl;
#define debug cout<<endl<<"********"<<endl;
#define ll long long
#define ull unisgned long long
#define ld long double
#define itn int
#define pii pair<int,int>
#define rep(I, A, B) for (int I = (A); I <= (B); ++I)
#define don(I, A, B) for (int I = (A); I >= (B); --I)
#define mod (ll)(1e9+7)
#define mid ((lo+ro)>>1)
void fre() {
    freopen("test.in", "r", stdin);
    freopen("test.out", "w", stdout);
}
void fc() {
    fclose(stdin);
    fclose(stdout);
}
const int maxn = 2e3 + 10;
const ll inf = 0x7fffffff;
struct node {
    int v;
    int w;
};
int n, m;
vector<node> all[maxn];
ll dp[maxn][maxn];
ll lin[maxn];
void add(int u, int v, int w) {
    all[u].push_back((node) {
        v, w
    });
}
int dfs(int now, int fat) {
    if (all[now].size() == 1 && all[now][0].v == fat)
        return 1;

    int sum = 1, t = 0;

    for (int i = 0; i < all[now].size(); i++) {
        int v = all[now][i].v;

        if (v == fat)
            continue;

        t = dfs(v, now);
        ll w = all[now][i].w;
        rep(j, 0, min(sum, m)) lin[j] = dp[now][j];
        rep(j, 0, min(sum, m)) {
            rep(k, 0, min(m - j, t)) {
                dp[now][j + k] = max(dp[now][j + k], lin[j] + dp[v][k] + w * (k * (m - k) + (t - k) * (n - m - t + k)));
            }
        }
        sum += t;
    }

    return sum;
}
int main() {
    ios;
    cin >> n >> m;

    for (int i = 1; i < n; i++) {
        int u, v, w;
        cin >> u >> v >> w;
        add(u, v, w);
        add(v, u, w);
    }

    dfs(1, 0);
    cout << dp[1][m] << endl;
    return 0;
}
posted @ 2021-04-01 16:32  14long  阅读(58)  评论(0)    收藏  举报