【树形DP】codeforces K. Send the Fool Further! (medium)

http://codeforces.com/contest/802/problem/K

【题意】

给定一棵树,Heidi从根结点0出发沿着边走,每个结点最多经过k次,求这棵树的最大花费是多少(同一条边走n次花费只算一次)

【思路】

对于结点v:

  • 如果在v的某棵子树停下,那么可以“遍历”k棵子树(有的话)
  • 如果还要沿着v返回v的父节点p,那么只能“遍历”k-1棵子树(有的话)。

用dp[v][1]表示第一种情况,dp[v][0]表示第二种情况;最后要求的就是dp[0][0]。

1. 对于dp[v][1],把所有的子树从大到小排序

(t=k-1)

2. 对于dp[v][0],枚举子结点dp[u][0]中的u,剩下的k-1个dp[u][1]取最大的,所以我们可以这样预处理:

sum=

(t=k)

  • 如果u<k,则target=sum-dp[u][1]+dp[u][0]
  • 否则,        target=sum-dp[t][1]+dp[u][0](t是从大到小排序后的第k-1个)

这样,dp[0][0]就是所求结果(dp[0][0]一定大于dp[0][1]),时间复杂度是O(nlogn)

【官方题解】

【Accepted】

  1 #include<iostream>
  2 #include<cstdio>
  3 #include<cstring>
  4 #include<string>
  5 #include<cmath>
  6 #include<vector>
  7 #include<algorithm>
  8 
  9 using namespace std;
 10 int n,m;
 11 vector< vector< pair<int,int> > > g;
 12 const int maxn=1e5+5;
 13 int dp[maxn][2];
 14 void dfs(int v,int p,int edge)
 15 {
 16     //从p到v的花费要算在v里
 17     dp[v][0]+=edge;
 18     dp[v][1]+=edge;
 19     vector< pair<int,int> > s;
 20     //只有根结点没有父节点,非根结点有父节点,减去1
 21     if(v==0)
 22     {
 23         s.resize(g[v].size());
 24     }
 25     else
 26     {
 27         s.resize(g[v].size()-1);
 28     }
 29     //遍历
 30     int num=0;
 31     for(int i=0;i<g[v].size();i++)
 32     {
 33         int to=g[v][i].first;
 34         if(to==p)
 35         {
 36             continue;
 37          }
 38         dfs(to,v,g[v][i].second);
 39         s[num++]={dp[to][1],to};
 40     }
 41     //从大到小排序
 42     sort(s.begin(),s.end());
 43     reverse(s.begin(),s.end());
 44     //要记录各个子结点的rank,后面dp[v][0]枚举u是要分类
 45     int pos[maxn];
 46     for(int i=0;i<s.size();i++)
 47     {
 48         pos[s[i].second]=i;
 49     }
 50     //计算dp[v][1]
 51     for(int i=0;i<min(m-1,(int)s.size());i++)
 52     {
 53         dp[v][1]+=s[i].first;
 54     }
 55     //计算dp[v][0]
 56     int sum=0;
 57     for(int i=0;i<min(m,(int)s.size());i++)
 58     {
 59         sum+=s[i].first;
 60     }
 61     int maxu=-1;
 62     //枚举
 63     for(int i=0;i<g[v].size();i++)
 64     {
 65         int to=g[v][i].first;
 66         if(to==p)
 67         {
 68             continue;
 69         }
 70         if(pos[to]<m)
 71         {
 72             maxu=max(maxu,sum-dp[to][1]+dp[to][0]);
 73         }
 74         else
 75         {
 76             maxu=max(maxu,sum-s[m-1].first+dp[to][0]);
 77         }
 78     }
 79     if(maxu>-1)
 80     {
 81         dp[v][0]+=maxu;
 82     }
 83 }
 84 int main()
 85 {
 86     while(~scanf("%d%d",&n,&m))
 87     {
 88         memset(dp,0,sizeof(dp));
 89         g.resize(n);
 90         int u,v,c;
 91         for(int i=0;i<n-1;i++)
 92         {
 93             scanf("%d%d%d",&u,&v,&c);
 94             g[u].push_back({v,c});
 95             g[v].push_back({u,c});
 96         }
 97         //根结点为0,无父结点,根结点到父结点的花费也为0
 98         dfs(0,0,0);
 99         printf("%d\n",dp[0][0]);
100     }
101     return 0;
102  }
View Code

注意vector开始要resize.....orz

【WA】

 1 #include<iostream>
 2 #include<cstdio>
 3 #include<cstring>
 4 #include<string>
 5 #include<algorithm>
 6 #include<cmath>
 7 
 8 using namespace std;
 9 int n,k;
10 const int maxn=2e5+3; 
11 struct edge
12 {
13     int to;
14     int nxt;
15     int c;    
16 }e[maxn];
17 int head[maxn];
18 int tot;
19 struct node
20 {
21     int x;
22     int id;
23 }sz[maxn];
24 int rk[maxn];
25 bool cmp(node a,node b)
26 {
27     return a.x>b.x;
28 }
29 void init()
30 {
31     memset(head,-1,sizeof(head));
32     tot=0;
33 }
34 
35 void add(int u,int v,int c)
36 {
37     e[tot].to=v;
38     e[tot].c=c;
39     e[tot].nxt=head[u];
40     head[u]=tot++;
41 }
42 int dp[maxn][2];
43 
44 int dfs(int u,int pa,int c)
45 {
46     dp[u][1]=c;
47     dp[u][0]=c;
48     int cnt=0;
49     for(int i=head[u];i!=-1;i=e[i].nxt)
50     {
51         int v=e[i].to;
52         int c=e[i].c;
53         if(v==pa) continue;
54         dfs(v,u,c);
55         sz[cnt].x=dp[v][1];
56         sz[cnt++].id=v;
57      } 
58      sort(sz,sz+cnt,cmp);
59      for(int i=0;i<min(cnt,k-1);i++)
60      {
61          dp[u][1]+=sz[i].x;
62      }
63      int sum=0;
64      for(int i=0;i<min(cnt,k);i++)
65      {
66          sum+=sz[i].x;
67      }
68      int ans=0;
69     for(int i=0;i<cnt;i++)
70     {
71         if(i<k)
72         {
73             ans=max(ans,sum-sz[i].x+dp[sz[i].id][0]);
74         }
75         else
76         {
77             ans=max(ans,sum-sz[k-1].x+dp[sz[i].id][0]);
78         }
79     }
80      dp[u][0]+=ans;
81 }
82 int main()
83 {
84     while(~scanf("%d%d",&n,&k))
85     {
86         init();
87         memset(dp,0,sizeof(dp));
88         for(int i=0;i<n-1;i++)
89         {
90             int u,v,c;
91             scanf("%d%d%d",&u,&v,&c);
92             add(u,v,c);
93             add(v,u,c);
94         }
95         dfs(0,-1,0);
96         cout<<dp[0][0]<<endl; 
97     }
98     return 0;    
99 } 
Wrong Answer

终于弄清楚了这个为什么WA!因为我在dfs里用了一个全局变量sz来保存{dp[v][1],v}。然而这是一个全局变量,所以一层里的正确值会被另一层修改!比如当我递归到0时已经有了正确值sz[0].w=5,sz[0].v=2;然而再递归到0的另一分枝1的时候,会修改sz[0],最后再回溯到0时sz[0]已经不是当年的sz[0]了!

所以还是用vector临时申请吧!

【AC(一个更优美的代码)】

  1 #include<iostream>
  2 #include<cstdio>
  3 #include<cstring>
  4 #include<string>
  5 #include<algorithm>
  6 #include<cmath>
  7 
  8 using namespace std;
  9 int n,k;
 10 const int maxn=2e5+3; 
 11 struct edge
 12 {
 13     int to;
 14     int nxt;
 15     int c;    
 16 }e[maxn];
 17 int head[maxn];
 18 int tot;
 19 int dp[maxn][2];
 20 
 21 struct node
 22 {
 23     int x;
 24     int id;
 25     node(){}
 26     node(int _x,int _id):x(_x),id(_id){}
 27     bool operator<(const node & nd) const
 28     {
 29         return x>nd.x;
 30     }
 31 };
 32 
 33 void init()
 34 {
 35     memset(head,-1,sizeof(head));
 36     tot=0;
 37 }
 38 
 39 void add(int u,int v,int c)
 40 {
 41     e[tot].to=v;
 42     e[tot].c=c;
 43     e[tot].nxt=head[u];
 44     head[u]=tot++;
 45 }
 46 
 47 int dfs(int u,int pa,int c)
 48 {
 49     dp[u][1]=c;
 50     dp[u][0]=c;
 51     vector<node> s;
 52     for(int i=head[u];i!=-1;i=e[i].nxt)
 53     {
 54         int v=e[i].to;
 55         int c=e[i].c;
 56         if(v==pa) continue;
 57         dfs(v,u,c);
 58         s.push_back(node(dp[v][1],v));
 59      } 
 60      sort(s.begin(),s.end());
 61      int sz=s.size();
 62      for(int i=0;i<min(sz,k-1);i++)
 63      {
 64          dp[u][1]+=s[i].x;
 65      }
 66      int sum=0;
 67      for(int i=0;i<min(sz,k);i++)
 68      {
 69          sum+=s[i].x;
 70      }
 71      int ans=0;
 72     for(int i=0;i<sz;i++)
 73     {
 74         if(i<k)
 75         {
 76             ans=max(ans,sum-s[i].x+dp[s[i].id][0]);
 77         }
 78         else
 79         {
 80             ans=max(ans,sum-s[k-1].x+dp[s[i].id][0]);
 81         }
 82     }
 83      dp[u][0]+=ans;
 84 }
 85 int main()
 86 {
 87     while(~scanf("%d%d",&n,&k))
 88     {
 89         init();
 90         memset(dp,0,sizeof(dp));
 91         for(int i=0;i<n-1;i++)
 92         {
 93             int u,v,c;
 94             scanf("%d%d%d",&u,&v,&c);
 95             add(u,v,c);
 96             add(v,u,c);
 97         }
 98         dfs(0,-1,0);
 99         cout<<dp[0][0]<<endl; 
100     }
101     return 0;    
102 } 
View Code

如果是vector<pair<int,int>> 要从大到小排序,可以先sort(s.begin(),s.end()),再reverse(s.begin(),s.end())

 

posted @ 2017-06-01 14:00  shulin15  阅读(333)  评论(0编辑  收藏  举报