洛谷 P8294 - [省选联考 2022] 最大权独立集问题(DP)
巨大精神污染题。
首先考虑 \(n^3\) 做法:\(dp_{i,j,k}\) 表示目前已经断掉了 \(i\) 子树内所有边以及 \(i\) 与父亲的边,假设初始时 \(k\) 被换到了 \(i\) 的父亲的位置,要将 \(j\) 换出 \(i\) 的子树,\(k\) 换入 \(i\) 的子树所需的最小代价。列出状态后,转移应该是比较显然的:
其中 \(son_i\) 表示 \(i\) 点的儿子个数,上面转移方程中可能有一些对称的转移没有列出来,这是无关紧要的。
直接这样写可以获得 44 分,代码如下:
const int MAXN = 500;
int n, d[MAXN + 5], fa[MAXN + 5], sub[MAXN + 5][MAXN + 5];
int ls[MAXN + 5], rs[MAXN + 5];
ll dp[MAXN + 5][MAXN + 5][MAXN + 5];
int main() {
scanf("%d", &n);
for (int i = 1; i <= n; i++) scanf("%d", &d[i]);
for (int i = 2; i <= n; i++) scanf("%d", &fa[i]);
for (int i = 1; i <= n; i++) sub[i][i] = 1;
for (int i = 2; i <= n; i++) {
if (!ls[fa[i]]) ls[fa[i]] = i;
else rs[fa[i]] = i;
}
memset(dp, 63, sizeof(dp));
for (int i = n; i; i--) {
for (int j = 1; j <= n; j++) sub[fa[i]][j] |= sub[i][j];
for (int j = 1; j <= n; j++) if (sub[i][j])
for (int k = 1; k <= n; k++) if (!sub[i][k]) {
if (!ls[i]) dp[i][j][k] = d[j] + d[k];
else if (!rs[i]) {
int a = ls[i];
if (j == i) {
for (int l = 1; l <= n; l++) if (sub[a][l])
chkmin(dp[i][j][k], d[j] + d[k] + dp[a][l][k]);
} else {
dp[i][j][k] = dp[a][j][i] + d[j] + d[k];
}
} else {
int a = ls[i], b = rs[i];
if (j == i) {
for (int x = 1; x <= n; x++) if (sub[a][x])
for (int y = 0; y <= n; y++) if (sub[b][y]) {
chkmin(dp[i][j][k], d[j] + d[k] + dp[a][x][k] + dp[b][y][x]);
chkmin(dp[i][j][k], d[j] + d[k] + dp[b][y][k] + dp[a][x][y]);
}
} else if (sub[a][j]) {
for (int x = 1; x <= n; x++) if (sub[b][x]) {
chkmin(dp[i][j][k], dp[a][j][i] + d[j] + d[k] + dp[b][x][k]);
chkmin(dp[i][j][k], dp[b][x][i] + dp[a][j][x] + d[j] + d[k]);
}
} else {
for (int x = 1; x <= n; x++) if (sub[a][x]) {
chkmin(dp[i][j][k], dp[b][j][i] + d[j] + d[k] + dp[a][x][k]);
chkmin(dp[i][j][k], dp[a][x][i] + dp[b][j][x] + d[j] + d[k]);
}
}
}
// printf("dp[%d][%d][%d] = %lld\n", i, j, k, dp[i][j][k]);
}
}
ll mn = 1e18;
if (!rs[1]) {
int a = ls[1];
for (int i = 2; i <= n; i++) chkmin(mn, dp[a][i][1]);
} else {
int a = ls[1], b = rs[1];
for (int i = 2; i <= n; i++) if (sub[a][i])
for (int j = 2; j <= n; j++) if (sub[b][j]) {
chkmin(mn, dp[a][i][1] + dp[b][j][i]);
chkmin(mn, dp[b][j][1] + dp[a][i][j]);
}
}
printf("%lld\n", mn);
return 0;
}
考虑优化,我们考虑在每一轮 DP 的时候处理出以下数组:
- \(f0_{i,j}\) 表示根节点为 \(i\),把 \(i\) 换出去,\(j\) 换进来的最小代价,即 \(dp_{i,i,j}\)。
- \(f1_{i,j}\) 表示根节点为 \(i\),把 \(j\) 换出去,\(fa_i\) 换进来的最小代价,即 \(dp_{i,j,fa_i}\)。
- \(f2_{i,j}\) 表示把 \(i\) 子树内任何一个点换出去,\(j\) 换进来的最小代价,即 \(\min\limits_{x\in\text{subtree}(i)}dp_{i,x,j}\)。
- \(f3_{i,j}\) 表示先交换 \(j\) 相对于 \(i\) 的另一侧子树内某个节点 \(x\),并把 \(x\) 换出去、把 \(i\) 换进另一侧子树,再将 \(x\) 换进 \(j\) 这一侧子树,\(j\) 换出去的最小代价。这句话读起来稍微有点拗口,写成式子的形式就是如果 \(j\) 在 \(x\) 左儿子子树内,那么 \(f3_{i,j}=\min\limits_{x\in\text{subtree}(rs_i)}dp_{rs,x,i}+dp_{ls,j,x}\),对于右子树的情况也类似。
求出这四个数组有什么好处呢?不难发现有了这四个数组后我们就可以 \(\mathcal O(1)\) 地求出 \(dp_{i,j,k}\),即
这样空间复杂度就降到了 \(n^2\),但是时间复杂度还是 \(n^3\),大概取决于实现可以通过部分 \(n\le 1000\) 的测试点。
附:这部分的代码:
const int MAXN = 5000;
int n, d[MAXN + 5], fa[MAXN + 5], sub[MAXN + 5][MAXN + 5];
int ls[MAXN + 5], rs[MAXN + 5];
vector<int> vec[MAXN + 5];
ll f0[MAXN + 5][MAXN + 5], f1[MAXN + 5][MAXN + 5], f2[MAXN + 5][MAXN + 5], f3[MAXN + 5][MAXN + 5];
ll calc(int i, int j, int k) {
if (j == i) return f0[i][k];
else if (!rs[i]) return f1[ls[i]][j] + d[j] + d[k];
else {
int a = ls[i], b = rs[i];
if (sub[a][j]) return min(f1[a][j] + f2[b][k], f3[i][j]) + d[j] + d[k];
else return min(f1[b][j] + f2[a][k], f3[i][j]) + d[j] + d[k];
}
}
int main() {
scanf("%d", &n);
for (int i = 1; i <= n; i++) scanf("%d", &d[i]);
for (int i = 2; i <= n; i++) scanf("%d", &fa[i]);
for (int i = 1; i <= n; i++) sub[i][i] = 1, vec[i].pb(i);
for (int i = 2; i <= n; i++) {
if (!ls[fa[i]]) ls[fa[i]] = i;
else rs[fa[i]] = i;
}
for (int i = n; i > 1; i--) {
for (int j = 1; j <= n; j++) sub[fa[i]][j] |= sub[i][j];
for (int x : vec[i]) vec[fa[i]].pb(x);
for (int k = 1; k <= n; k++) if (!sub[i][k]) {
if (!ls[i]) f0[i][k] = d[i] + d[k];
else if (!rs[i]) f0[i][k] = d[i] + d[k] + f2[ls[i]][k];
else {
int a = ls[i], b = rs[i];
f0[i][k] = 1e18;
for (int x : vec[a]) for (int y : vec[b]) {
chkmin(f0[i][k], d[i] + d[k] + calc(a, x, k) + calc(b, y, x));
chkmin(f0[i][k], d[i] + d[k] + calc(b, y, k) + calc(a, x, y));
}
}
}
if (rs[i]) {
int a = ls[i], b = rs[i];
for (int x : vec[a]) f3[i][x] = 1e18;
for (int x : vec[b]) f3[i][x] = 1e18;
for (int x : vec[a]) for (int y : vec[b]) chkmin(f3[i][x], calc(b, y, i) + calc(a, x, y));
for (int x : vec[a]) for (int y : vec[b]) chkmin(f3[i][y], calc(a, x, i) + calc(b, y, x));
}
for (int j : vec[i]) f1[i][j] = calc(i, j, fa[i]);
for (int j = 1; j <= n; j++) if (!sub[i][j]) {
f2[i][j] = 1e18;
for (int x : vec[i]) chkmin(f2[i][j], calc(i, x, j));
}
}
ll mn = 1e18;
if (!rs[1]) {
int a = ls[1];
for (int i = 2; i <= n; i++) chkmin(mn, calc(a, i, 1));
} else {
int a = ls[1], b = rs[1];
for (int i : vec[a]) for (int j : vec[b]) {
chkmin(mn, calc(a, i, 1) + calc(b, j, i));
chkmin(mn, calc(b, j, 1) + calc(a, i, j));
}
}
printf("%lld\n", mn);
return 0;
}
/*
10
9 5 6 3 8 7 1 4 1 3
1 2 2 4 1 4 3 8 5
*/
接下来就是略微有点 dirty work 的地方了,注意到在上面的 DP 过程中,求解 \(f1\) 的部分复杂度肯定不会出错,\(f3\) 部分的复杂度分析类似于树形 DP,也是 \(\mathcal O(n^2)\) 的,复杂度瓶颈在于 \(f0\) 和 \(f2\),因此我们需要着手去优化它。以 \(f2\) 为例,我们固定了 \(i\),要对于每个 \(j\),求出 \(\min\limits_{x\in\text{subtree}(i)}dp_{i,x,j}\),不妨先特判 \(x=i\),对于 \(x\ne i\),按照上面的公式展开 \(dp_{i,x,j}\),分情况讨论:
- 如果 \(i\) 没有右儿子,那么 \(dp_{i,x,j}=d_x+d_j+f1_{ls,x}\),因此我们遍历所有 \(x\) 并维护 \(d_x+f1_{ls,x}\) 的最小值。
- 如果 \(i\) 有右儿子,那么这里不妨考虑 \(x\) 在 \(i\) 左子树内情况,另一部分镜像一下即可。则 \(dp_{i,x,j}=\min(f1_{ls,x}+f2_{rs,j},f3_{i,x})+d_x+d_j\),我们将两部分分别求 \(\min\),即求 \(f1_{ls,x}+f2_{rs,j}+d_x+d_j\) 和 \(f3_{i,x}+d_x+d_j\) 的最小值,这样我们只用遍历 \(x\) 的同时维护 \(f1_{ls,x}+d_x\) 和 \(f3_{i,x}+d_x\) 的最小值即可。
\(f0\) 的处理也是类似的,类似于一个参变分离的思想。这样时空复杂度均降到了 \(\mathcal O(n^2)\)。细节很多。
const int MAXN = 5000;
int n, d[MAXN + 5], fa[MAXN + 5], sub[MAXN + 5][MAXN + 5];
int ls[MAXN + 5], rs[MAXN + 5];
vector<int> vec[MAXN + 5];
ll f0[MAXN + 5][MAXN + 5]; // f0[i][j] 表示根节点为 i,把 i 换出去,j 换进来的最小代价。
ll f1[MAXN + 5][MAXN + 5]; // f1[i][j] 表示根节点为 i,把 j 换出去,fa[i] 换进来的最小代价
ll f2[MAXN + 5][MAXN + 5]; // f2[i][j] 表示可以把 i 子树内任何一个点换出去,j 换进来的最小代价
ll f3[MAXN + 5][MAXN + 5];
// f3[i][j] 表示先交换 j 相对于 i 的另一侧子树内某个节点 x,并把 x 换出去、把 i 换进另一侧子树,再将 x 换进 j 这一侧子树,j 换出去的最小代价。
ll calc(int i, int j, int k) {
if (j == i) return f0[i][k];
else if (!rs[i]) return f1[ls[i]][j] + d[j] + d[k];
else {
int a = ls[i], b = rs[i];
if (sub[a][j]) return min(f1[a][j] + f2[b][k], f3[i][j]) + d[j] + d[k];
else return min(f1[b][j] + f2[a][k], f3[i][j]) + d[j] + d[k];
}
}
int main() {
freopen("mis.in", "r", stdin);
freopen("mis.out", "w", stdout);
scanf("%d", &n);
for (int i = 1; i <= n; i++) scanf("%d", &d[i]);
for (int i = 2; i <= n; i++) scanf("%d", &fa[i]);
for (int i = 1; i <= n; i++) sub[i][i] = 1, vec[i].pb(i);
for (int i = 2; i <= n; i++) {
if (!ls[fa[i]]) ls[fa[i]] = i;
else rs[fa[i]] = i;
}
for (int i = n; i > 1; i--) {
for (int j = 1; j <= n; j++) sub[fa[i]][j] |= sub[i][j];
for (int x : vec[i]) vec[fa[i]].pb(x);
if (!ls[i]) {
for (int k = 1; k <= n; k++) if (!sub[i][k])
f0[i][k] = d[i] + d[k];
}
else if (!rs[i]) {
for (int k = 1; k <= n; k++) if (!sub[i][k])
f0[i][k] = d[i] + d[k] + f2[ls[i]][k];
}
else {
int a = ls[i], b = rs[i];
for (int k = 1; k <= n; k++) if (!sub[i][k]) f0[i][k] = 1e18;
ll mn1 = 1e18, mn2 = 1e18, mn3 = 1e18;
for (int x : vec[a]) {
ll mn = 1e18;
for (int y : vec[b]) chkmin(mn, calc(b, y, x));
if (x == a) {
for (int k = 1; k <= n; k++) if (!sub[i][k])
chkmin(f0[i][k], d[i] + d[k] + calc(a, x, k) + mn);
continue;
}
if (!rs[a]) chkmin(mn1, mn + d[i] + f1[ls[a]][x] + d[x]);
else {
int aa = ls[a], ab = rs[a];
if (sub[aa][x]) {
// calc(a, x, k) = min(f1[aa][x] + f2[ab][k], f3[a][x]) + d[x] + d[k]
chkmin(mn2, f1[aa][x] + d[x] + mn + d[i]);
chkmin(mn1, d[x] + mn + d[i] + f3[a][x]);
} else {
chkmin(mn3, f1[ab][x] + d[x] + mn + d[i]);
chkmin(mn1, d[x] + mn + d[i] + f3[a][x]);
}
}
}
if (!rs[a]) {
for (int k = 1; k <= n; k++) if (!sub[i][k])
chkmin(f0[i][k], mn1 + d[k] + d[k]);
} else {
int aa = ls[a], ab = rs[a];
for (int k = 1; k <= n; k++) if (!sub[i][k]) {
chkmin(f0[i][k], mn1 + d[k] + d[k]);
chkmin(f0[i][k], mn2 + f2[ab][k] + d[k] + d[k]);
chkmin(f0[i][k], mn3 + f2[aa][k] + d[k] + d[k]);
}
}
swap(a, b);
mn1 = 1e18, mn2 = 1e18, mn3 = 1e18;
for (int x : vec[a]) {
ll mn = 1e18;
for (int y : vec[b]) chkmin(mn, calc(b, y, x));
if (x == a) {
for (int k = 1; k <= n; k++) if (!sub[i][k])
chkmin(f0[i][k], d[i] + d[k] + calc(a, x, k) + mn);
continue;
}
if (!rs[a]) chkmin(mn1, mn + d[i] + f1[ls[a]][x] + d[x]);
else {
int aa = ls[a], ab = rs[a];
if (sub[aa][x]) {
// calc(a, x, k) = min(f1[aa][x] + f2[ab][k], f3[a][x]) + d[x] + d[k]
chkmin(mn2, f1[aa][x] + d[x] + mn + d[i]);
chkmin(mn1, d[x] + mn + d[i] + f3[a][x]);
} else {
chkmin(mn3, f1[ab][x] + d[x] + mn + d[i]);
chkmin(mn1, d[x] + mn + d[i] + f3[a][x]);
}
}
}
if (!rs[a]) {
for (int k = 1; k <= n; k++) if (!sub[i][k])
chkmin(f0[i][k], mn1 + d[k] + d[k]);
} else {
int aa = ls[a], ab = rs[a];
for (int k = 1; k <= n; k++) if (!sub[i][k]) {
chkmin(f0[i][k], mn1 + d[k] + d[k]);
chkmin(f0[i][k], mn2 + f2[ab][k] + d[k] + d[k]);
chkmin(f0[i][k], mn3 + f2[aa][k] + d[k] + d[k]);
}
}
}
if (rs[i]) {
int a = ls[i], b = rs[i];
for (int x : vec[a]) f3[i][x] = 1e18;
for (int x : vec[b]) f3[i][x] = 1e18;
for (int x : vec[a]) for (int y : vec[b]) chkmin(f3[i][x], calc(b, y, i) + calc(a, x, y));
for (int x : vec[a]) for (int y : vec[b]) chkmin(f3[i][y], calc(a, x, i) + calc(b, y, x));
}
for (int j : vec[i]) f1[i][j] = calc(i, j, fa[i]);
for (int j = 1; j <= n; j++) if (!sub[i][j]) f2[i][j] = calc(i, i, j);
if (!rs[i]) {
ll mn = 1e18;
for (int v : vec[ls[i]]) chkmin(mn, d[v] + f1[ls[i]][v]);
for (int k = 1; k <= n; k++) if (!sub[i][k]) chkmin(f2[i][k], mn + d[k]);
} else {
ll mn1 = 1e18, mn2 = 1e18, mn3 = 1e18;
int a = ls[i], b = rs[i];
for (int v : vec[a]) chkmin(mn1, f1[a][v] + d[v]), chkmin(mn3, f3[i][v] + d[v]);
for (int v : vec[b]) chkmin(mn2, f1[b][v] + d[v]), chkmin(mn3, f3[i][v] + d[v]);
for (int k = 1; k <= n; k++) if (!sub[i][k]) {
chkmin(f2[i][k], mn1 + f2[b][k] + d[k]);
chkmin(f2[i][k], mn2 + f2[a][k] + d[k]);
chkmin(f2[i][k], mn3 + d[k]);
}
}
}
ll mn = 1e18;
if (!rs[1]) {
int a = ls[1];
for (int i = 2; i <= n; i++) chkmin(mn, calc(a, i, 1));
} else {
int a = ls[1], b = rs[1];
for (int i : vec[a]) for (int j : vec[b]) {
chkmin(mn, calc(a, i, 1) + calc(b, j, i));
chkmin(mn, calc(b, j, 1) + calc(a, i, j));
}
}
printf("%lld\n", mn);
return 0;
}
/*
10
9 5 6 3 8 7 1 4 1 3
1 2 2 4 1 4 3 8 5
*/

浙公网安备 33010602011771号