BZOJ 3257 树的难题

BZOJ 3257

题目描述

给定一棵树,树上每个点有黑,白,灰三种颜色。边有边权。要求切割某些边,使得每个连通块内没有黑点或者至多有1个白点。最小化切割代价。

解题思路

每个子树显然可以分开单独考虑。

\(f[u][i][j]\) 表示子树内有i个黑点,j个白点的最小代价。 \(i \leq 1, j\leq2\)

这样直接合并就好,但是要注意,白点合并的总个数大于2,也可以放到等于2的情况。

枚举子树内黑白点的状况,类似背包。

#include<bits/stdc++.h>
using namespace std;
#define LL long long
const int N = 3e5 + 11;
int n, col[N];
LL f[N][2][3], g[N], tmp[2][3], F[2][3], G[2][3];
int head[N], nex[N<<1], to[N<<1], wei[N<<1], size;
void add(int x, int y, int z){
    to[++size] = y;
    nex[size] = head[x];
    head[x] = size;
    wei[size] = z;
}
void dfs(int u, int fa){
    if(col[u] == 0) f[u][1][0] = 0;
    else if(col[u] == 1) f[u][0][1] = 0;
    else f[u][0][0] = 0;
    for(int i = head[u];i;i = nex[i]){
        int v = to[i];
        if(v == fa)continue;
        dfs(v, u);
        memset(tmp, 127 / 2, sizeof tmp);
        memcpy(F, f[u], sizeof F);
        memcpy(G, f[v], sizeof G);
        for(int a = 0;a <= 1; a++){
            for(int b = 0;b <= 1; b++){
                for(int c = 0;c <= 1; c++){
                    for(int d = 0;d <= 1; d++){
                        if(b + d != 2)tmp[a|c][b|d] = min(tmp[a|c][b|d], F[a][b] + G[c][d]);
                    }
                }
            }
        }
        for(int a = 0;a <= 1; a++){
            for(int b = 0;b <= 1; b++){
                tmp[a][b] = min(tmp[a][b], F[a][b] + g[v] + wei[i]);
            }
        }
        tmp[0][2] = min(tmp[0][2], F[0][2] + wei[i] + g[v]);
        for(int j = 0;j <= 2; j++){
            for(int k = 0;k <= 2; k++){
                if(j + k > 1)tmp[0][2] = min(tmp[0][2], F[0][j] + G[0][k]);
            }
        }
        memcpy(f[u], tmp, sizeof tmp);
    }
    for(int i = 0;i <= 1; i++){
        for(int j = 0;j <= 2; j++){
            g[u] = min(g[u], f[u][i][j]);
        }
    }
    /*printf("u=%d g=%lld\n", u, g[u]);
    for(int i = 0;i <= 1; i++){
        for(int j = 0;j <= 2; j++){
            printf("f[%d][%d]=%lld\n", i, j, f[u][i][j]);
        }
    }*/
}
void work(){
    memset(head, 0, sizeof head);
    size = 0;
    memset(f, 127 / 2, sizeof f);
    memset(g, 127 / 2, sizeof g);
    int u, v, w;
    cin>>n;
    for(int i = 1;i <= n; i++){
        scanf("%d", &col[i]);
    }
    for(int i = 1;i < n; i++){
        scanf("%d%d%d", &u, &v, &w);
        add(u, v, w); add(v, u, w);
    }
    dfs(1, 0);
    cout<<g[1]<<endl;
}
int main(){
    int T;
    cin>>T;
    while(T--){
        work();
    }
    return 0;
}
posted @ 2020-08-04 09:44  LawrenceD  阅读(101)  评论(0)    收藏  举报