CF1263F Economic Difficulties
分析
考虑 dp。设 \(w[0/1][l][r]\) 为区间 \([l, r]\) 内的设备不与上 / 下树连通且不影响其他设备与上 / 下树连通性时可以删除的最多边数。可以枚举 \(l\),递推 \(r\)。假设我们现在要从区间 \([l, r)\) 得到区间 \([l, r]\) 的答案,考虑可以多删哪些边。发现多删去的边一定是 \(x[r]\) 及其若干祖先组成的一条链。于是可以观察这条链最多可以延伸到哪个点。由于我们要保证 \(x[r + 1]\) 与 \(x[l - 1]\) 与根的连通性,于是这两个点的祖先都不能删。于是这条链最长就是到 \(LCA(x[r], x[r + 1]\) 与 \(LCA(x[r], x[l - 1])\) 中较深的那个。由于 \([l, r)\) 这个状态已经保证了 \(x[r]\) 与根的连通性,所以这条链里的边都可以随便删,不用考虑删重复的问题。
接下来定义 \(f[i]\) 为只考虑 \([1, i]\) 这个区间里的设备时,使其中每台设备都与任意一根连通时所能删去的最多边数。于是有转移 \(f[i] = \max_{0 \leq j < i} \{ f[j] + \max \{ w[0][j + 1][i], w[1][j + 1][i] \} \}\)。这个式子应该比较好理解。
代码
#include <iostream>
using namespace std;
int n, a, b;
int x[1005][2];
class Tree {
public:
int head[20005], nxt[20005], to[20005], ecnt;
void add(int u, int v) { to[++ecnt] = v, nxt[ecnt] = head[u], head[u] = ecnt; }
int n, lf;
int top[10005], son[10005], dep[10005], sz[10005], f[10005];
int dp[1005][1005];
void dfs1(int x, int fa, int d) {
dep[x] = d;
sz[x] = 1;
f[x] = fa;
for (int i = head[x]; i; i = nxt[i]) {
int v = to[i];
if (v != fa) {
dfs1(v, x, d + 1);
sz[x] += sz[v];
if (sz[son[x]] < sz[v])
son[x] = v;
}
}
}
void dfs2(int x, int t) {
top[x] = t;
if (!son[x])
return;
dfs2(son[x], t);
for (int i = head[x]; i; i = nxt[i]) {
int v = to[i];
if (v != f[x] && v != son[x])
dfs2(v, v);
}
}
void ini() { dfs1(1, 0, 1), dfs2(1, 1); }
int LCA(int x, int y) {
if (!x || !y)
return 1;
while (top[x] ^ top[y]) (dep[top[x]] < dep[top[y]]) ? (y = f[top[y]]) : (x = f[top[x]]);
return (dep[x] < dep[y] ? x : y);
}
void main() {
cin >> n;
for (int i = 2, x; i <= n; i++) {
cin >> x;
add(x, i);
add(i, x);
}
ini();
}
void Solve(int a) {
for (int i = 1; i <= lf; i++) {
for (int j = i; j <= lf; j++) {
int l1 = LCA(x[i - 1][a], x[j][a]), l2 = LCA(x[j][a], x[j + 1][a]);
if (dep[l1] < dep[l2])
swap(l1, l2);
dp[i][j] = dp[i][j - 1] + dep[x[j][a]] - dep[l1];
}
}
}
} T[2];
int dp[1005];
int main() {
cin >> n;
T[0].main();
for (int i = 1; i <= n; i++) cin >> x[i][0];
T[1].main();
for (int i = 1; i <= n; i++) cin >> x[i][1];
T[0].lf = T[1].lf = n;
T[0].Solve(0);
T[1].Solve(1);
for (int i = 1; i <= n; i++) {
for (int j = 0; j < i; j++)
dp[i] = max(dp[i], dp[j] + max(T[0].dp[j + 1][i], T[1].dp[j + 1][i]));
}
cout << dp[n];
return 0;
}

浙公网安备 33010602011771号