树形依赖背包的两种做法

  今天才发现自己根本不会树形背包,我太菜了。

  一般的树形背包是这样做的:

  看上去,它的复杂度是 $O(nk^2)$ 的。

 

第一种优化:

  这里,如果第二维的大小和子树大小有关,同时又不超过一个常数 $k$ 。例如:第二维表示子树内选了多少个点,那么通过一些精妙的分析和上界优化,复杂度就可以变成 $O(nk)$ 了。

  以下的 $siz_x$ 表示合并 $son$ 这个子树前 $x$ 子树的大小(注意:不是 $x$ 的真实子树大小,这里很重要)。

  这样分析出来的复杂度就是 $O(nk)$ .


  证明:摘自这里

  首先,定义 $T(n)$ 为处理 $n$ 这棵子树时所用的时间,$f(n)$ 为处理 $n$ 这个点时所用的时间。

  $T(x)=\left(\sum_{f_y=x} T_{y}\right)+f(x)\\f(x)=\min(m,siz(y_1))\times \min(m,siz(y_1))+\min(m,siz(y_1)+siz(y_2))\times \min(m,siz(y_1))\\ ~~~~~~~~~~~+\cdots+\min(m,siz(x))\times \min(m,siz(y_n))$

  现在进行一番放缩,把每个乘法的前一项统一变成 $\min(m,siz(x))$ ,这样显然只会使答案变大,所以分析出来的复杂度上界就应该是正确的。

  $f(x)=\min(m,siz(x))\times \left(\sum\limits_{f_y=x} \min(m,siz(y))\right)$

  再次放缩,把后面括号里的 $min$ 直接扔掉,得:

  $f(x)=\min(m,siz(x))\times \left(\sum\limits_{f_y=x} siz(y)\right)\\~~~~~~~~=\min(m,siz(x))\times siz(x)$

  对于 $siz(x)<m$ 的点,首先考虑他的子树都是叶子的情况:

  $T(x)=siz(x)^2+\sum 1$

  对于任意 $siz(x)<m$ 的点,递归证明,由于 “平方和小于和的平方” ,所以 $T(x)$ 与 $siz(x)^2$ 同阶;

  对于 $siz(x)>m$ 的点,首先考虑它的所有子树都小于 $m$ 的情况:

  $T(x)=m\times siz(x)+\sum siz(j)^2$

  接着放缩可得,$T(x)$ 与 $m\times siz(x)$ 同阶;

  继续使用递归证明的技巧,考虑某一层出现了子树大于 $m$ 的情况:

  $T(x)=m\times siz(x)+\sum siz(j)^2+\sum siz(j)\times m$

  所以,$T(x)$ 还是与 $m\times siz(x)$ 同阶;

  综上所述,这种做法的复杂度是 $n\times k$ 。


  选课加强版:https://www.luogu.org/problem/U53204

  
 1 # include <cstdio>
 2 # include <iostream>
 3 # include <cstring>
 4 # include <vector>
 5 # define R register int
 6 
 7 using namespace std;
 8 
 9 const int N=100005;
10 struct edge
11 {
12     int to,nex;
13 };
14 int si,h=0,n,m;
15 edge g[N<<1];
16 int firs[N],a[N],siz[N];
17 bool vis[N];
18 int dp[100000100];
19 
20 void add(int u,int v)
21 {
22     g[++h].to=v;
23     g[h].nex=firs[u];
24     firs[u]=h;
25 }
26 
27 void dfs(int x)
28 {
29     dp[x*(m+1)+1]=a[x];
30     siz[x]=1;
31     vis[x]=true;
32     int j;
33     for (R i=firs[x];i;i=g[i].nex)
34     {
35         j=g[i].to;
36         if(vis[j]) continue;
37         dfs(j);
38         for (int k=min(siz[x]+siz[j],m);k>=1;--k)
39             for (int z=max(1,k-siz[x]);z<=min(siz[j],k-1);++z)
40                 dp[x*(m+1)+k]=max(dp[x*(m+1)+k],dp[x*(m+1)+k-z]+dp[j*(m+1)+z]);
41         siz[x]+=siz[j];
42     }
43 }
44 
45 int read()
46 {
47     int x=0;
48     char c=getchar();
49     while (!isdigit(c)) c=getchar();
50     while (isdigit(c)) x=(x<<3)+(x<<1)+(c^48),c=getchar();
51     return x;
52 }
53 
54 int main()
55 {
56     scanf("%d%d",&n,&m); m++;
57     memset(g,0,sizeof(g));
58     for (R i=1;i<=n;i++)
59     {
60         si=read(),a[i]=read();
61         add(i,si);
62         add(si,i);
63     }
64     dfs(0);
65     printf("%d",dp[m]);
66     return 0;
67 }
选课

  这种做法比较好写,而且还有一个优点,就是它事实上求出了每棵子树的 $dp$ 值,换句话说,它可以统计到每个连通块的答案。当然,它也有一定的局限性,那就是第二维必须和子树的大小有关,否则复杂度就不对了。下面,再来介绍另一种不要求第二维大小的做法。

    

第二种优化:

  首先对树求出后序遍历序,设 $f[i][j]$ 表示:dfs序编号在i之前的点当前都满足依赖条件时的背包;$j$ 表示什么因题目而异;看上去有点难以理解?解释一下,“当前满足依赖条件”是指,在仅考虑前 $i$ 个点构成的森林的情况下,每个点都满足依赖关系(当前已经出现的祖先都被选了,还没出现的祖先不用考虑)。转移方程十分简单,在往森林里一个点时,如果不选它,那它的子树就都不能选,因为它的子树的dfs序是一段连续的区间,我们直接跳回到还没有考虑过这棵子树时的状态;如果选它,那就从上一个点进行转移即可。复杂度显然为 $n\times m$ 。  

   这种方法比上一种还好写,但是它也有一个问题,那就是只能算出以指定点为根时的答案,而不能做任意联通块。

  
 1 # include <cstdio>
 2 # include <iostream>
 3 # include <cstring>
 4 # include <vector>
 5 # define R register int
 6 
 7 using namespace std;
 8 
 9 const int N=100005;
10 struct edge
11 {
12     int to,nex;
13 };
14 int si,h=0,n,m;
15 edge g[N<<1];
16 int firs[N],a[N],siz[N];
17 bool vis[N];
18 int dp[100000100];
19 int no[N],cnt;
20 
21 void add(int u,int v)
22 {
23     g[++h].to=v;
24     g[h].nex=firs[u];
25     firs[u]=h;
26 }
27 
28 int read()
29 {
30     int x=0;
31     char c=getchar();
32     while (!isdigit(c)) c=getchar();
33     while (isdigit(c)) x=(x<<3)+(x<<1)+(c^48),c=getchar();
34     return x;
35 }
36 
37 void dfs (int x)
38 {
39     int j;
40     siz[x]=1;
41     for (R i=firs[x];i;i=g[i].nex)
42     {
43         j=g[i].to;
44         if(vis[j]) continue;
45         vis[j]=1;
46         dfs(j);
47         siz[x]+=siz[j];
48     }
49     no[++cnt]=x;
50 }
51 
52 int main()
53 {
54     scanf("%d%d",&n,&m); m++;
55     memset(g,0,sizeof(g));
56     for (R i=1;i<=n;i++)
57     {
58         si=read(),a[i]=read();
59         add(i,si);
60         add(si,i);
61     }
62     vis[0]=1;
63     dfs(0);
64     int x;
65     for (R i=1;i<=cnt;++i)
66     {
67         x=no[i];
68         for (R j=1;j<=m;++j)
69             dp[i*(m+1)+j]=max(dp[(i-1)*(m+1)+j-1]+a[x],dp[(i-siz[x])*(m+1)+j]);
70     }
71     printf("%d",dp[cnt*(m+1)+m]);
72     return 0;
73 }
选课

  

  学习了以上知识后,我们来做一道题?

  Shopping:https://www.lydsy.com/JudgeOnline/problem.php?id=4182

  这好像是个权限题?那我来概述一下题意:

  给定一棵 $n$ 个点的树,每个点上有一种物品 $(w,c,d)$ 表示它的价值是 $w$ ,价格是 $c$ ,有 $d$ 个。你有 $m$ 元钱,并希望它们能买到价值和最大的物品,还有一个限制是买了物品的点必须是树上的一个连通块,求最大价值。$n<=500,m<=4000,d<=100$

  一个显然的思路是直接上树形背包的第一种做法,因为它事实上是在每个连通块最高的点处对这个连通块进行了处理,可以直接求出这道题的答案。

  不过,别忘了第一种优化的前提,如果你以为它任何条件下都适用,那就会 $TLE$ 得很惨。在这道题中,即使是很小的子树也可以有满的 $dp$ 数组,所以复杂度就是 $O(nm^2)$

  看起来第一种做法已经走进死路,让我们来考虑一下第二种做法吧。

  第二种做法可以枚举根,复杂度 $n^2m$ ,感觉已经有了不少改进呢!可以发现,枚举根是一个比较愚蠢的方法,因为在做第一次的时候,就已经把所有与这个根有交的连通块都算过了,接下来只需要对每个子树再做就好了。子树大小有可能不平均?点分治!

  
  1 # include <cstdio>
  2 # include <iostream>
  3 # include <cstring>
  4 # include <vector>
  5 # define R register int
  6 
  7 using namespace std;
  8 
  9 const int N=502;
 10 int T,n,m,h,x,y,cnt,rt,ans,S,d;
 11 int firs[N],siz[N],no[N],vis[N],w[N],c[N],maxs[N];
 12 int dp[N][4005];
 13 struct edge
 14 {
 15     int too,nex;
 16 }g[N<<1];
 17 struct thi
 18 {
 19     int c,w;
 20     thi (int a=0,int b=0) { c=a; w=b; }
 21 };
 22 vector <thi> v[N];
 23 
 24 void add (int x,int y)
 25 {
 26     g[++h].nex=firs[x];
 27     firs[x]=h;
 28     g[h].too=y;
 29 }
 30 
 31 int read()
 32 {
 33     int x=0;
 34     char c=getchar();
 35     while (!isdigit(c)) c=getchar();
 36     while (isdigit(c)) x=(x<<3)+(x<<1)+(c^48),c=getchar();
 37     return x;
 38 }
 39 
 40 void get_root (int x,int f)
 41 {
 42     siz[x]=1,maxs[x]=0;
 43     int j;
 44     for (R i=firs[x];i;i=g[i].nex)
 45     {
 46         j=g[i].too;
 47         if(vis[j]||f==j) continue;
 48         get_root(j,x);
 49         siz[x]+=siz[j];
 50         maxs[x]=max(maxs[x],siz[j]);
 51     }
 52     maxs[x]=max(maxs[x],S-siz[x]);
 53     if(maxs[x]<maxs[rt]) rt=x;
 54 }
 55 
 56 void dfs (int x,int f)
 57 {
 58     int j; siz[x]=1;
 59     for (R i=firs[x];i;i=g[i].nex)
 60     {
 61         j=g[i].too;
 62         if(vis[j]||j==f) continue;
 63         dfs(j,x);
 64         siz[x]+=siz[j];
 65     }
 66     no[++cnt]=x;
 67 }
 68 
 69 void pdc (int x)
 70 {
 71     cnt=0;
 72     dfs(x,0);
 73     for (R i=0;i<=cnt;++i)
 74         for (R j=0;j<=m;++j)
 75             dp[i][j]=0;
 76     for (R i=1;i<=cnt;++i)
 77     {
 78         int a=no[i],vs=v[a].size();
 79         for (R j=0;j<=m;++j)
 80             dp[i][j]=max(dp[i][j],dp[ i-siz[a] ][j]);
 81         for (R k=0;k<vs;++k)
 82             for (R j=m;j>=v[a][k].c;--j)
 83                 dp[i][j]=max(dp[i][j],max(dp[i-1][ j-v[a][k].c ]+v[a][k].w,dp[i][ j-v[a][k].c ]+v[a][k].w));
 84     }
 85     for (R i=1;i<=m;++i)
 86         ans=max(ans,dp[cnt][i]);
 87 }
 88 
 89 void solve (int x)
 90 {
 91     vis[x]=1;
 92     pdc(x);
 93     int j;
 94     for (R i=firs[x];i;i=g[i].nex)
 95     {
 96         j=g[i].too;
 97         if(vis[j]) continue;
 98         rt=0; maxs[rt]=n; S=siz[j];
 99         get_root(j,0);
100         solve(rt);
101     }
102 }
103 
104 void t4182()
105 {
106     n=read(),m=read();
107     memset(firs,0,sizeof(firs));
108     memset(g,0,sizeof(g));
109     memset(vis,0,sizeof(vis));
110     h=0;
111     for (R i=1;i<=n;++i)
112         v[i].clear();
113     for (R i=1;i<=n;++i)
114         w[i]=read();
115     for (R i=1;i<=n;++i)
116         c[i]=read();
117     for (R i=1;i<=n;++i)
118     {
119         d=read();
120         int x=1;
121         while(x<=d)
122         {
123             d-=x;
124             v[i].push_back(thi(c[i]*x,w[i]*x));
125             x<<=1;
126         }
127         if(d>0) v[i].push_back(thi(c[i]*d,w[i]*d));
128     }
129     for (R i=1;i<n;++i)
130     {
131         x=read(),y=read();
132         add(x,y); add(y,x);
133     }
134     rt=0;
135     S=maxs[rt]=n;
136     get_root(1,0);
137     ans=0;
138     solve(rt);
139     printf("%d\n",ans);
140 }
141 
142 int main()
143 {
144     scanf("%d",&T);
145     while(T--)
146         t4182();
147     return 0;
148 }
shopping

---shzr

posted @ 2019-09-06 17:49  shzr  阅读(912)  评论(0编辑  收藏  举报