关于树形背包的上下界枚举优化

关于树形背包的上下界枚举优化

树形背包的 dp 做法大致说来就是这么简单:

  • 枚举 \(u\) 的每个儿子 \(v\)
  • 枚举 \(u\) 使用容量 \(i\) 的最优值。
  • 枚举 \(v\) 使用容量 \(j\) 的最优值,状态转移方程一般类似于 \(f_{u,i}\gets\min(f_{u,i},f_{u,i-j}+f_{v,j})\)

我们拿经典题目举例子:

[CTSC1997] 选课

朴素来说,这题的 DFS 就是这么写的:

void dfs(int u) {
	f[u][0] = 0;
	f[u][1] = c[u];
	for (int i = head[p];i;i = e[i].nxt) {
		int v = e[i].to;
		dfs(v);
		for (int j = m + 1; j >= 1; j--) {//m+1是因为我们把无前置课的课程把0当作前置课来把它变成一棵树,这样要给课程0一份时间。
			for (int k = 1; k < j; k++) {
				f[u][j] = max(f[u][j],f[u][j-k]+f[v][k]);
			}
		}
    }
}

但是这么做很显然,复杂度是 \(O(nm^2)\) 的。如果 \(n,m\) 都是 \(10^3\) 量级,那么就一定会 TLE。

我们发现上面这个枚举过程有很多被枚举到的状态都是没用的。

首先,外层循环枚举 \(j\) 时,我们根本没必要从 \(m+1\) 枚举,因为此时我们考虑过的节点可能根本没有 \(m+1\) 个,我们可以动态更新以 \(u\) 为根的子树大小 \(siz_u\),只从 \(siz_u+siz_v\) 枚举。

其次,\(k\)\(0\) 枚举也没必要,因为此时 \(siz_u\) 存储着 \(u\) 已经考虑过的节点数,至少也会留下 \(j-siz_u\) 的时间给 \(v\) 这棵子树,所以 \(k\) 直接从 \(\max(1,j-siz_u)\) 枚举即可。

最后,\(k\) 枚举到 \(j\) 也是没有意义的,因为子树 \(v\) 的大小也可能没有 \(m\) 个,我们只用枚举到 \(\min(siz_v,j-1)\) 即可。

这样,我们的代码变为:

void dfs(int u) {
	f[u][0] = 0;
	f[u][1] = c[u];
	siz[u] += 1;
	for (int i = head[p];i;i = e[i].nxt) {
		int v = e[i].to;
		dfs(v);
		for (int j = min(siz[u]+siz[v],m+1); j >= 1; j--) {
			for (int k = max(1,j-siz[u]); k <= min(siz[v],j-1); k++) {
				f[u][j] = max(f[u][j],f[u][j-k]+f[v][k]);
			}
		}
        siz[u] += siz[v];//因为枚举k时用到的siz[u]是考虑子树v之前的,所以在这里才能更新siz[u]。
    }
}

看似我们这只是加了一些无关紧要的剪枝优化,但是复杂度就降到了 \(O(nm)\)

为什么呢?我不会证但确实就是。

PS:树上背包的上下界优化 - ouuan - 博客园 的证明是错误的,经过我的胡乱理解后发现问题,和 Kaito 讨论一下午后无果,最终被 Kenma 宣判死刑。

但是在 Kenma 指导下,我又会了。

我们考虑一棵树有 \(n\) 个结点,那么对于根节点 \(u\) 来说,我们设统计单点所需的时间是 \(t_u\),计算一下 \(t_u\)

\[\begin{aligned}t_u&=1+siz_{v_1}+(1+siz_{v_1})siz_{v_2}+(1+siz_{v_1}+siz_{v_2})siz_{v_3}+\cdots+(1+siz_{v_1}+siz_{v_2}+\cdots+siz_{v_{k-1}})siz_{v_k}\\\end{aligned} \]

我们看看这个东西有没有什么特殊意义。它的意思是,当我遍历到 \(v_i\) 这棵子树时,我用它的大小乘上之前统计过的子树的大小。这其实相当于,我统计了树上以 \(u\) 为 lca 的点对个数。

而程序运行的时间就是 \(t_1+t_2+\cdots+t_n\),即树上以每个点为 lca 的点对个数之和,这不就相当于树上任取两点,求点对个数么?显然,这是 \(O(n^2)\) 的。

然后注意到因为有了 \(m\) 的存在,程序的运行时间会变为 \(O(nm)\),因为点对个数有了限制,不能全部统计完毕。

为什么上文博客的证明是错的?

在他的博客中,\(t_u\) 是这么计算的:

\[\begin{aligned}t_u&=1+(1+siz_{v_1})siz_{v_1}+(1+siz_{v_1}+siz_{v_2})siz_{v_2}+(1+siz_{v_1}+siz_{v_2}+siz_{v_3})siz_{v_3}+\cdots+(1+siz_{v_1}+siz_{v_2}+\cdots+siz_{v_{k-1}}+siz_{v_k})siz_{v_k}\\\end{aligned} \]

多出了 \(siz_{v_1}^2+siz{v_2}^2+\cdots+siz_{v_k}^2\) 这一项,导致 \(t_u\)\(O(n^2)\) 量级。

但是看代码,我们发现:

for (int j = min(siz[u]+siz[v],m+1); j >= 1; j--) {
	for (int k = max(1,j-siz[u]); k <= min(siz[v],j-1); k++) {
		f[u][j] = max(f[u][j],f[u][j-k]+f[v][k]);
	}
}

如果内层想要跑满 \(siz_v\) 层,外层的 \(j\le siz_u\)。当 \(j>siz_u\) 时,内层跑不满 \(siz_v\)。所以没有了 \(siz_v^2\) 的存在。

并且,这么直接统计一个节点的时间复杂度加起来统计是错误的,这个 \(n^2\) 应当从总体计算得到,而不是计算个体然后再加起来。

如上的证明也可能存在错误,如果你发现了请指出,我会尽力修改。

posted @ 2025-05-13 19:38  Ascnbeta  阅读(30)  评论(2)    收藏  举报