洛谷 P8294 - [省选联考 2022] 最大权独立集问题(DP)

洛谷题面传送门

巨大精神污染题。

首先考虑 \(n^3\) 做法:\(dp_{i,j,k}\) 表示目前已经断掉了 \(i\) 子树内所有边以及 \(i\) 与父亲的边,假设初始时 \(k\) 被换到了 \(i\) 的父亲的位置,要将 \(j\) 换出 \(i\) 的子树,\(k\) 换入 \(i\) 的子树所需的最小代价。列出状态后,转移应该是比较显然的:

\[dp_{i,j,k}= \begin{cases} d_j+d_k&(\text{leftson}=\varnothing,\text{rightson}=\varnothing)\\ d_j+d_k+\min\limits_{l\in\text{subtree}(i),l\ne i}dp_{ls,l,k}&(\text{leftson}\ne\varnothing,\text{rightson}=\varnothing,j=i)\\ d_j+d_k+dp_{ls,j,i}&(\text{leftson}\ne\varnothing,\text{rightson}=\varnothing,j\ne i)\\ d_j+d_k+\min\limits_{x\in\text{subtree}(ls_i)}\min\limits_{y\in\text{subtree}(rs_i)}\min(dp_{ls,x,k}+dp_{rs,y,x},dp_{rs,y,k}+dp_{ls,x,y})&(\text{leftson}\ne\varnothing,\text{rightson}\ne\varnothing,j=i)\\ d_j+d_k+\min\limits_{x\in\text{subtree}(rs_i)}\min(dp_{ls,j,i}+dp_{rs,x,k},dp_{rs,x,i}+dp_{ls,j,x})&(\text{leftson}\ne\varnothing,\text{rightson}\ne\varnothing,j\in\text{subtree}(ls))\\ \end{cases} \]

其中 \(son_i\) 表示 \(i\) 点的儿子个数,上面转移方程中可能有一些对称的转移没有列出来,这是无关紧要的。

直接这样写可以获得 44 分,代码如下:

const int MAXN = 500;
int n, d[MAXN + 5], fa[MAXN + 5], sub[MAXN + 5][MAXN + 5];
int ls[MAXN + 5], rs[MAXN + 5];
ll dp[MAXN + 5][MAXN + 5][MAXN + 5];
int main() {
	scanf("%d", &n);
	for (int i = 1; i <= n; i++) scanf("%d", &d[i]);
	for (int i = 2; i <= n; i++) scanf("%d", &fa[i]);
	for (int i = 1; i <= n; i++) sub[i][i] = 1;
	for (int i = 2; i <= n; i++) {
		if (!ls[fa[i]]) ls[fa[i]] = i;
		else rs[fa[i]] = i;
	}
	memset(dp, 63, sizeof(dp));
	for (int i = n; i; i--) {
		for (int j = 1; j <= n; j++) sub[fa[i]][j] |= sub[i][j];
		for (int j = 1; j <= n; j++) if (sub[i][j])
			for (int k = 1; k <= n; k++) if (!sub[i][k]) {
				if (!ls[i]) dp[i][j][k] = d[j] + d[k];
				else if (!rs[i]) {
					int a = ls[i];
					if (j == i) {
						for (int l = 1; l <= n; l++) if (sub[a][l])
							chkmin(dp[i][j][k], d[j] + d[k] + dp[a][l][k]);
					} else {
						dp[i][j][k] = dp[a][j][i] + d[j] + d[k];
					}
				} else {
					int a = ls[i], b = rs[i];
					if (j == i) {
						for (int x = 1; x <= n; x++) if (sub[a][x])
							for (int y = 0; y <= n; y++) if (sub[b][y]) {
								chkmin(dp[i][j][k], d[j] + d[k] + dp[a][x][k] + dp[b][y][x]);
								chkmin(dp[i][j][k], d[j] + d[k] + dp[b][y][k] + dp[a][x][y]);
							}
					} else if (sub[a][j]) {
						for (int x = 1; x <= n; x++) if (sub[b][x]) {
							chkmin(dp[i][j][k], dp[a][j][i] + d[j] + d[k] + dp[b][x][k]);
							chkmin(dp[i][j][k], dp[b][x][i] + dp[a][j][x] + d[j] + d[k]);
						}
					} else {
						for (int x = 1; x <= n; x++) if (sub[a][x]) {
							chkmin(dp[i][j][k], dp[b][j][i] + d[j] + d[k] + dp[a][x][k]);
							chkmin(dp[i][j][k], dp[a][x][i] + dp[b][j][x] + d[j] + d[k]);
						}
					}
				}
//				printf("dp[%d][%d][%d] = %lld\n", i, j, k, dp[i][j][k]);
			}
	}
	ll mn = 1e18;
	if (!rs[1]) {
		int a = ls[1];
		for (int i = 2; i <= n; i++) chkmin(mn, dp[a][i][1]);
	} else {
		int a = ls[1], b = rs[1];
		for (int i = 2; i <= n; i++) if (sub[a][i])
			for (int j = 2; j <= n; j++) if (sub[b][j]) {
				chkmin(mn, dp[a][i][1] + dp[b][j][i]);
				chkmin(mn, dp[b][j][1] + dp[a][i][j]);
			}
	}
	printf("%lld\n", mn);
	return 0;
}

考虑优化,我们考虑在每一轮 DP 的时候处理出以下数组:

  • \(f0_{i,j}\) 表示根节点为 \(i\),把 \(i\) 换出去,\(j\) 换进来的最小代价,即 \(dp_{i,i,j}\)
  • \(f1_{i,j}\) 表示根节点为 \(i\),把 \(j\) 换出去,\(fa_i\) 换进来的最小代价,即 \(dp_{i,j,fa_i}\)
  • \(f2_{i,j}\) 表示把 \(i\) 子树内任何一个点换出去,\(j\) 换进来的最小代价,即 \(\min\limits_{x\in\text{subtree}(i)}dp_{i,x,j}\)
  • \(f3_{i,j}\) 表示先交换 \(j\) 相对于 \(i\) 的另一侧子树内某个节点 \(x\),并把 \(x\) 换出去、把 \(i\) 换进另一侧子树,再将 \(x\) 换进 \(j\) 这一侧子树,\(j\) 换出去的最小代价。这句话读起来稍微有点拗口,写成式子的形式就是如果 \(j\)\(x\) 左儿子子树内,那么 \(f3_{i,j}=\min\limits_{x\in\text{subtree}(rs_i)}dp_{rs,x,i}+dp_{ls,j,x}\),对于右子树的情况也类似。

求出这四个数组有什么好处呢?不难发现有了这四个数组后我们就可以 \(\mathcal O(1)\) 地求出 \(dp_{i,j,k}\),即

\[dp_{i,j,k}= \begin{cases} f0_{i,k}&(j=i)\\ f1_{ls,j}+d_j+d_k&(\text{leftson}\ne\varnothing,\text{rightson}=\varnothing)\\ \min(f1_{ls,j}+f2_{rs,k},f3_{ls,j})+d_j+d_k&(\text{leftson}\ne\varnothing,\text{rightson}=\varnothing,j\in\text{subtree(ls)})\\ \end{cases} \]

这样空间复杂度就降到了 \(n^2\),但是时间复杂度还是 \(n^3\),大概取决于实现可以通过部分 \(n\le 1000\) 的测试点。

附:这部分的代码:

const int MAXN = 5000;
int n, d[MAXN + 5], fa[MAXN + 5], sub[MAXN + 5][MAXN + 5];
int ls[MAXN + 5], rs[MAXN + 5];
vector<int> vec[MAXN + 5];
ll f0[MAXN + 5][MAXN + 5], f1[MAXN + 5][MAXN + 5], f2[MAXN + 5][MAXN + 5], f3[MAXN + 5][MAXN + 5];
ll calc(int i, int j, int k) {
	if (j == i) return f0[i][k];
	else if (!rs[i]) return f1[ls[i]][j] + d[j] + d[k];
	else {
		int a = ls[i], b = rs[i];
		if (sub[a][j]) return min(f1[a][j] + f2[b][k], f3[i][j]) + d[j] + d[k];
		else return min(f1[b][j] + f2[a][k], f3[i][j]) + d[j] + d[k];
	}
}
int main() {
	scanf("%d", &n);
	for (int i = 1; i <= n; i++) scanf("%d", &d[i]);
	for (int i = 2; i <= n; i++) scanf("%d", &fa[i]);
	for (int i = 1; i <= n; i++) sub[i][i] = 1, vec[i].pb(i);
	for (int i = 2; i <= n; i++) {
		if (!ls[fa[i]]) ls[fa[i]] = i;
		else rs[fa[i]] = i;
	}
	for (int i = n; i > 1; i--) {
		for (int j = 1; j <= n; j++) sub[fa[i]][j] |= sub[i][j];
		for (int x : vec[i]) vec[fa[i]].pb(x);
		for (int k = 1; k <= n; k++) if (!sub[i][k]) {
			if (!ls[i]) f0[i][k] = d[i] + d[k];
			else if (!rs[i]) f0[i][k] = d[i] + d[k] + f2[ls[i]][k];
			else {
				int a = ls[i], b = rs[i];
				f0[i][k] = 1e18;
				for (int x : vec[a]) for (int y : vec[b]) {
					chkmin(f0[i][k], d[i] + d[k] + calc(a, x, k) + calc(b, y, x));
					chkmin(f0[i][k], d[i] + d[k] + calc(b, y, k) + calc(a, x, y));
				}
			}
		}
		if (rs[i]) {
			int a = ls[i], b = rs[i];
			for (int x : vec[a]) f3[i][x] = 1e18;
			for (int x : vec[b]) f3[i][x] = 1e18;
			for (int x : vec[a]) for (int y : vec[b]) chkmin(f3[i][x], calc(b, y, i) + calc(a, x, y));
			for (int x : vec[a]) for (int y : vec[b]) chkmin(f3[i][y], calc(a, x, i) + calc(b, y, x));
		}
		for (int j : vec[i]) f1[i][j] = calc(i, j, fa[i]);
		for (int j = 1; j <= n; j++) if (!sub[i][j]) {
			f2[i][j] = 1e18;
			for (int x : vec[i]) chkmin(f2[i][j], calc(i, x, j));
		}
	}
	ll mn = 1e18;
	if (!rs[1]) {
		int a = ls[1];
		for (int i = 2; i <= n; i++) chkmin(mn, calc(a, i, 1));
	} else {
		int a = ls[1], b = rs[1];
		for (int i : vec[a])  for (int j : vec[b]) {
			chkmin(mn, calc(a, i, 1) + calc(b, j, i));
			chkmin(mn, calc(b, j, 1) + calc(a, i, j));
		}
	}
	printf("%lld\n", mn);
	return 0;
}
/*
10
9 5 6 3 8 7 1 4 1 3
1 2 2 4 1 4 3 8 5
*/

接下来就是略微有点 dirty work 的地方了,注意到在上面的 DP 过程中,求解 \(f1\) 的部分复杂度肯定不会出错,\(f3\) 部分的复杂度分析类似于树形 DP,也是 \(\mathcal O(n^2)\) 的,复杂度瓶颈在于 \(f0\)\(f2\),因此我们需要着手去优化它。以 \(f2\) 为例,我们固定了 \(i\),要对于每个 \(j\),求出 \(\min\limits_{x\in\text{subtree}(i)}dp_{i,x,j}\),不妨先特判 \(x=i\),对于 \(x\ne i\),按照上面的公式展开 \(dp_{i,x,j}\),分情况讨论:

  • 如果 \(i\) 没有右儿子,那么 \(dp_{i,x,j}=d_x+d_j+f1_{ls,x}\),因此我们遍历所有 \(x\) 并维护 \(d_x+f1_{ls,x}\) 的最小值。
  • 如果 \(i\) 有右儿子,那么这里不妨考虑 \(x\)\(i\) 左子树内情况,另一部分镜像一下即可。则 \(dp_{i,x,j}=\min(f1_{ls,x}+f2_{rs,j},f3_{i,x})+d_x+d_j\),我们将两部分分别求 \(\min\),即求 \(f1_{ls,x}+f2_{rs,j}+d_x+d_j\)\(f3_{i,x}+d_x+d_j\) 的最小值,这样我们只用遍历 \(x\) 的同时维护 \(f1_{ls,x}+d_x\)\(f3_{i,x}+d_x\) 的最小值即可。

\(f0\) 的处理也是类似的,类似于一个参变分离的思想。这样时空复杂度均降到了 \(\mathcal O(n^2)\)。细节很多。

const int MAXN = 5000;
int n, d[MAXN + 5], fa[MAXN + 5], sub[MAXN + 5][MAXN + 5];
int ls[MAXN + 5], rs[MAXN + 5];
vector<int> vec[MAXN + 5];
ll f0[MAXN + 5][MAXN + 5]; // f0[i][j] 表示根节点为 i,把 i 换出去,j 换进来的最小代价。 
ll f1[MAXN + 5][MAXN + 5]; // f1[i][j] 表示根节点为 i,把 j 换出去,fa[i] 换进来的最小代价 
ll f2[MAXN + 5][MAXN + 5]; // f2[i][j] 表示可以把 i 子树内任何一个点换出去,j 换进来的最小代价 
ll f3[MAXN + 5][MAXN + 5];
// f3[i][j] 表示先交换 j 相对于 i 的另一侧子树内某个节点 x,并把 x 换出去、把 i 换进另一侧子树,再将 x 换进 j 这一侧子树,j 换出去的最小代价。
ll calc(int i, int j, int k) {
	if (j == i) return f0[i][k];
	else if (!rs[i]) return f1[ls[i]][j] + d[j] + d[k];
	else {
		int a = ls[i], b = rs[i];
		if (sub[a][j]) return min(f1[a][j] + f2[b][k], f3[i][j]) + d[j] + d[k];
		else return min(f1[b][j] + f2[a][k], f3[i][j]) + d[j] + d[k];
	}
}
int main() {
	freopen("mis.in", "r", stdin);
	freopen("mis.out", "w", stdout);
	scanf("%d", &n);
	for (int i = 1; i <= n; i++) scanf("%d", &d[i]);
	for (int i = 2; i <= n; i++) scanf("%d", &fa[i]);
	for (int i = 1; i <= n; i++) sub[i][i] = 1, vec[i].pb(i);
	for (int i = 2; i <= n; i++) {
		if (!ls[fa[i]]) ls[fa[i]] = i;
		else rs[fa[i]] = i;
	}
	for (int i = n; i > 1; i--) {
		for (int j = 1; j <= n; j++) sub[fa[i]][j] |= sub[i][j];
		for (int x : vec[i]) vec[fa[i]].pb(x);
		if (!ls[i]) {
			for (int k = 1; k <= n; k++) if (!sub[i][k])
				f0[i][k] = d[i] + d[k];
		}
		else if (!rs[i]) {
			for (int k = 1; k <= n; k++) if (!sub[i][k])
				f0[i][k] = d[i] + d[k] + f2[ls[i]][k];
		}
		else {
			int a = ls[i], b = rs[i];
			for (int k = 1; k <= n; k++) if (!sub[i][k]) f0[i][k] = 1e18;
			ll mn1 = 1e18, mn2 = 1e18, mn3 = 1e18;
			for (int x : vec[a]) {
				ll mn = 1e18;
				for (int y : vec[b]) chkmin(mn, calc(b, y, x));
				if (x == a) {
					for (int k = 1; k <= n; k++) if (!sub[i][k])
						chkmin(f0[i][k], d[i] + d[k] + calc(a, x, k) + mn);
					continue;
				}
				if (!rs[a]) chkmin(mn1, mn + d[i] + f1[ls[a]][x] + d[x]);
				else {
					int aa = ls[a], ab = rs[a];
					if (sub[aa][x]) {
						// calc(a, x, k) = min(f1[aa][x] + f2[ab][k], f3[a][x]) + d[x] + d[k]
						chkmin(mn2, f1[aa][x] + d[x] + mn + d[i]);
						chkmin(mn1, d[x] + mn + d[i] + f3[a][x]);
					} else {
						chkmin(mn3, f1[ab][x] + d[x] + mn + d[i]);
						chkmin(mn1, d[x] + mn + d[i] + f3[a][x]);
					}
				}
			}
			if (!rs[a]) {
				for (int k = 1; k <= n; k++) if (!sub[i][k])
					chkmin(f0[i][k], mn1 + d[k] + d[k]);
			} else {
				int aa = ls[a], ab = rs[a];
				for (int k = 1; k <= n; k++) if (!sub[i][k]) {
					chkmin(f0[i][k], mn1 + d[k] + d[k]);
					chkmin(f0[i][k], mn2 + f2[ab][k] + d[k] + d[k]);
					chkmin(f0[i][k], mn3 + f2[aa][k] + d[k] + d[k]);
				}
			}
			swap(a, b);
			mn1 = 1e18, mn2 = 1e18, mn3 = 1e18;
			for (int x : vec[a]) {
				ll mn = 1e18;
				for (int y : vec[b]) chkmin(mn, calc(b, y, x));
				if (x == a) {
					for (int k = 1; k <= n; k++) if (!sub[i][k])
						chkmin(f0[i][k], d[i] + d[k] + calc(a, x, k) + mn);
					continue;
				}
				if (!rs[a]) chkmin(mn1, mn + d[i] + f1[ls[a]][x] + d[x]);
				else {
					int aa = ls[a], ab = rs[a];
					if (sub[aa][x]) {
						// calc(a, x, k) = min(f1[aa][x] + f2[ab][k], f3[a][x]) + d[x] + d[k]
						chkmin(mn2, f1[aa][x] + d[x] + mn + d[i]);
						chkmin(mn1, d[x] + mn + d[i] + f3[a][x]);
					} else {
						chkmin(mn3, f1[ab][x] + d[x] + mn + d[i]);
						chkmin(mn1, d[x] + mn + d[i] + f3[a][x]);
					}
				}
			}
			if (!rs[a]) {
				for (int k = 1; k <= n; k++) if (!sub[i][k])
					chkmin(f0[i][k], mn1 + d[k] + d[k]);
			} else {
				int aa = ls[a], ab = rs[a];
				for (int k = 1; k <= n; k++) if (!sub[i][k]) {
					chkmin(f0[i][k], mn1 + d[k] + d[k]);
					chkmin(f0[i][k], mn2 + f2[ab][k] + d[k] + d[k]);
					chkmin(f0[i][k], mn3 + f2[aa][k] + d[k] + d[k]);
				}
			}
		}
		if (rs[i]) {
			int a = ls[i], b = rs[i];
			for (int x : vec[a]) f3[i][x] = 1e18;
			for (int x : vec[b]) f3[i][x] = 1e18;
			for (int x : vec[a]) for (int y : vec[b]) chkmin(f3[i][x], calc(b, y, i) + calc(a, x, y));
			for (int x : vec[a]) for (int y : vec[b]) chkmin(f3[i][y], calc(a, x, i) + calc(b, y, x));
		}
		for (int j : vec[i]) f1[i][j] = calc(i, j, fa[i]);
		for (int j = 1; j <= n; j++) if (!sub[i][j]) f2[i][j] = calc(i, i, j);
		if (!rs[i]) {
			ll mn = 1e18;
			for (int v : vec[ls[i]]) chkmin(mn, d[v] + f1[ls[i]][v]);
			for (int k = 1; k <= n; k++) if (!sub[i][k]) chkmin(f2[i][k], mn + d[k]);
		} else {
			ll mn1 = 1e18, mn2 = 1e18, mn3 = 1e18;
			int a = ls[i], b = rs[i];
			for (int v : vec[a]) chkmin(mn1, f1[a][v] + d[v]), chkmin(mn3, f3[i][v] + d[v]);
			for (int v : vec[b]) chkmin(mn2, f1[b][v] + d[v]), chkmin(mn3, f3[i][v] + d[v]);
			for (int k = 1; k <= n; k++) if (!sub[i][k]) {
				chkmin(f2[i][k], mn1 + f2[b][k] + d[k]);
				chkmin(f2[i][k], mn2 + f2[a][k] + d[k]);
				chkmin(f2[i][k], mn3 + d[k]);
			}
		}
	}
	ll mn = 1e18;
	if (!rs[1]) {
		int a = ls[1];
		for (int i = 2; i <= n; i++) chkmin(mn, calc(a, i, 1));
	} else {
		int a = ls[1], b = rs[1];
		for (int i : vec[a])  for (int j : vec[b]) {
			chkmin(mn, calc(a, i, 1) + calc(b, j, i));
			chkmin(mn, calc(b, j, 1) + calc(a, i, j));
		}
	}
	printf("%lld\n", mn);
	return 0;
}
/*
10
9 5 6 3 8 7 1 4 1 3
1 2 2 4 1 4 3 8 5
*/
posted @ 2022-06-06 14:41  tzc_wk  阅读(84)  评论(0)    收藏  举报