关于树形背包的上下界枚举优化
关于树形背包的上下界枚举优化
树形背包的 dp 做法大致说来就是这么简单:
- 枚举 \(u\) 的每个儿子 \(v\)。
- 枚举 \(u\) 使用容量 \(i\) 的最优值。
- 枚举 \(v\) 使用容量 \(j\) 的最优值,状态转移方程一般类似于 \(f_{u,i}\gets\min(f_{u,i},f_{u,i-j}+f_{v,j})\)。
我们拿经典题目举例子:
[CTSC1997] 选课
朴素来说,这题的 DFS 就是这么写的:
void dfs(int u) {
f[u][0] = 0;
f[u][1] = c[u];
for (int i = head[p];i;i = e[i].nxt) {
int v = e[i].to;
dfs(v);
for (int j = m + 1; j >= 1; j--) {//m+1是因为我们把无前置课的课程把0当作前置课来把它变成一棵树,这样要给课程0一份时间。
for (int k = 1; k < j; k++) {
f[u][j] = max(f[u][j],f[u][j-k]+f[v][k]);
}
}
}
}
但是这么做很显然,复杂度是 \(O(nm^2)\) 的。如果 \(n,m\) 都是 \(10^3\) 量级,那么就一定会 TLE。
我们发现上面这个枚举过程有很多被枚举到的状态都是没用的。
首先,外层循环枚举 \(j\) 时,我们根本没必要从 \(m+1\) 枚举,因为此时我们考虑过的节点可能根本没有 \(m+1\) 个,我们可以动态更新以 \(u\) 为根的子树大小 \(siz_u\),只从 \(siz_u+siz_v\) 枚举。
其次,\(k\) 从 \(0\) 枚举也没必要,因为此时 \(siz_u\) 存储着 \(u\) 已经考虑过的节点数,至少也会留下 \(j-siz_u\) 的时间给 \(v\) 这棵子树,所以 \(k\) 直接从 \(\max(1,j-siz_u)\) 枚举即可。
最后,\(k\) 枚举到 \(j\) 也是没有意义的,因为子树 \(v\) 的大小也可能没有 \(m\) 个,我们只用枚举到 \(\min(siz_v,j-1)\) 即可。
这样,我们的代码变为:
void dfs(int u) {
f[u][0] = 0;
f[u][1] = c[u];
siz[u] += 1;
for (int i = head[p];i;i = e[i].nxt) {
int v = e[i].to;
dfs(v);
for (int j = min(siz[u]+siz[v],m+1); j >= 1; j--) {
for (int k = max(1,j-siz[u]); k <= min(siz[v],j-1); k++) {
f[u][j] = max(f[u][j],f[u][j-k]+f[v][k]);
}
}
siz[u] += siz[v];//因为枚举k时用到的siz[u]是考虑子树v之前的,所以在这里才能更新siz[u]。
}
}
看似我们这只是加了一些无关紧要的剪枝优化,但是复杂度就降到了 \(O(nm)\)!
为什么呢?我不会证但确实就是。
PS:树上背包的上下界优化 - ouuan - 博客园 的证明是错误的,经过我的胡乱理解后发现问题,和 Kaito 讨论一下午后无果,最终被 Kenma 宣判死刑。
但是在 Kenma 指导下,我又会了。

我们考虑一棵树有 \(n\) 个结点,那么对于根节点 \(u\) 来说,我们设统计单点所需的时间是 \(t_u\),计算一下 \(t_u\):
我们看看这个东西有没有什么特殊意义。它的意思是,当我遍历到 \(v_i\) 这棵子树时,我用它的大小乘上之前统计过的子树的大小。这其实相当于,我统计了树上以 \(u\) 为 lca 的点对个数。
而程序运行的时间就是 \(t_1+t_2+\cdots+t_n\),即树上以每个点为 lca 的点对个数之和,这不就相当于树上任取两点,求点对个数么?显然,这是 \(O(n^2)\) 的。
然后注意到因为有了 \(m\) 的存在,程序的运行时间会变为 \(O(nm)\),因为点对个数有了限制,不能全部统计完毕。
为什么上文博客的证明是错的?
在他的博客中,\(t_u\) 是这么计算的:
多出了 \(siz_{v_1}^2+siz{v_2}^2+\cdots+siz_{v_k}^2\) 这一项,导致 \(t_u\) 为 \(O(n^2)\) 量级。
但是看代码,我们发现:
for (int j = min(siz[u]+siz[v],m+1); j >= 1; j--) {
for (int k = max(1,j-siz[u]); k <= min(siz[v],j-1); k++) {
f[u][j] = max(f[u][j],f[u][j-k]+f[v][k]);
}
}
如果内层想要跑满 \(siz_v\) 层,外层的 \(j\le siz_u\)。当 \(j>siz_u\) 时,内层跑不满 \(siz_v\)。所以没有了 \(siz_v^2\) 的存在。
并且,这么直接统计一个节点的时间复杂度加起来统计是错误的,这个 \(n^2\) 应当从总体计算得到,而不是计算个体然后再加起来。
如上的证明也可能存在错误,如果你发现了请指出,我会尽力修改。

浙公网安备 33010602011771号