[树形dp] Jzoj P3347 树的难题

Description

 

Input

输入文件 为split.in 。
第一行 包含 一个正整数 T,表示有T组测试数据 。接下来 依次是 T组测试数 据。
每组测试数 据的第一行包含个正整数N。
第二行包含 N个 0、1、2之一 的整数,依次 表示点 1到点 N的颜色。其中0表示黑色, 1表示白色, 2表示灰色。
接下来 N-1行 ,每行为三个整数 ui、vi、ci,表示 一条权值等于 ci的边 (ui, vi)。
 

Output

输出文件为 split.out 。
输出 T行 ,每一个整数, 依次 表示 每组测试数据 的答案。
 

Sample Input

1
5
0 1 1 1 0
1 2 5
1 3 3
5 2 5
2 4 16

Sample Output

10
【样例解释】
花费 10 的代价删去 边(1, 2)和边(2, 5)。
 
 

Data Constraint

对于 10% 的数据: 1 ≤ N ≤ 10。
对于 30% 的数据: 1 ≤ N ≤ 50 0。
对于 60% 的数据: 1 ≤ N ≤ 50 000 。
对于 100% 的数据: 1 ≤ N ≤ 300 000 ,1 ≤ T ≤ 5,0 ≤ ci ≤ 10^9。

 

题解

 

  • 显然就是个树形dp
  • 设置状态f[i][0],f[i][1],f[i][2]分别表示使以i为根的原树的子树中删去一些边后合法时,以i为根的当前子树中
  • ①不含黑色节点
  • ②含有1或0个白色节点
  • ③不含有白色节点
  • 再进行状态转移即可
  • bzoj一交就过,纪中STM爆栈搞得我又要打bfs

代码

 1 #include <cstdio>
 2 #include <cstring>
 3 #include <iostream>
 4 #define min(a,b) (a<b?a:b)
 5 #define mem(a,b) memset(a,b,sizeof(a))
 6 #define ll long long
 7 #define fb(i,x) for(i=head[x];i;i=nex[i])
 8 #define fo(i,a,b) for(i=a;i<=b;i++)
 9 using namespace std;
10 const ll N=3e5+10,inf=1e9*N;
11 ll n,x,y,tail,cnt,ans,t,head[N],col[N],f[N],g[N],h[N],Q[N],cur[N],sum[N],fa[N];
12 struct edge { int to,from,v; }e[N*2];
13 void insert(ll x,ll y,ll z)
14 {
15     e[++cnt].to=y,e[cnt].v=z,e[cnt].from=head[x],head[x]=cnt;
16     e[++cnt].to=x,e[cnt].v=z,e[cnt].from=head[y],head[y]=cnt;
17 }
18 void dfs()
19 {
20     Q[tail=1]=1,cur[1]=head[1];
21     if (!col[1]) f[1]=inf; 
22     if (col[1]==1) g[1]=inf;
23     while (tail)
24     {
25         x=Q[tail];
26         if (cur[x])
27         {
28             y=e[cur[x]].to;
29             if (y==Q[tail-1]) { cur[x]=e[cur[x]].from; continue; }
30             fa[y]=e[cur[x]].v,Q[++tail]=y,cur[y]=head[y],cur[x]=e[cur[x]].from;
31             if (!col[y]) f[y]=inf;
32             if (col[y]==1) g[y]=inf;
33         }
34         else
35         {
36             y=x,x=Q[--tail];
37             if (col[y]!=1) h[y]=g[y]-sum[y];
38             if (!col[x]) g[x]+=min(min(f[y],h[y])+fa[y],g[y]),sum[x]=max(min(min(f[y],h[y])+fa[y],g[y])-h[y],sum[x]);
39             if (col[x]==1) f[x]+=min(min(g[y],h[y])+fa[y],f[y]),h[x]+=min(min(f[y],h[y])+fa[y],g[y]);
40             if (col[x]==2) f[x]+=min(min(g[y],h[y])+fa[y],f[y]),g[x]+=min(min(f[y],h[y])+fa[y],g[y]),sum[x]=max(min(min(f[y],h[y])+fa[y],g[y])-h[y],sum[x]);
41         }
42     }
43 }
44 int main()
45 {
46     for(scanf("%lld",&t);t;t--)
47     {
48         mem(head,0),cnt=0,mem(f,0),mem(g,0),mem(h,0),mem(cur,0),mem(sum,0);
49         scanf("%lld",&n);
50         for (ll i=1;i<=n;i++) scanf("%lld",&col[i]);
51         for (ll i=1,x,y,z;i<n;i++) scanf("%lld%lld%lld",&x,&y,&z),insert(x,y,z);
52         dfs(),printf("%lld\n",min(min(f[1],g[1]),h[1]));
53     }
54 }

 

posted @ 2019-07-10 16:12 BEYang_Z 阅读(...) 评论(...) 编辑 收藏