Luogu P7276 送给好友的礼物 题解 [ 蓝 ] [ 树上背包 DP ] [ 交换维度 ]
送给好友的礼物:众所周知,当两个 trick 被强行揉在了一起,它就成了一道新题。
观察路径,好像没有什么特别厉害的性质。因为数据范围较小,所以不难想到暴力 DP:定义 \(dp_{u, i, j}\) 表示 \(u\) 的子树内,小 M 走 \(i\) 步,小 B 走 \(j\) 步能否将整颗子树全部拿完。转移的时候需要做双层树上背包,复杂度直接爆炸,显然不可行。
考虑对 DP 进行优化,注意到 DP 内的值只能是 \(0/1\),值域极小,且当 \(i\) 确定、有合法方案时,\(j\) 更小一定更优,所以考虑“交换 DP 维度”的经典 trick。具体地,交换 \(j\) 与 \(dp\),新状态就是 \(dp_{u,i}\) 表示 \(u\) 子树内,小 M 走了 \(i\) 步,在将整颗子树拿完的前提下,小 B 最少要走的步数。
接下来就是个很板的树形背包了,运用上下界优化可以做到 \(O(n^2)\)。合并背包的转移方程如下:
\[dp_{u, i + j + \left[j > 0\right ] \times 2} \overset{\min}{\leftarrow} f_{u,i}+dp_{v, j} + \left [ dp_{v, j}>0 \right ]\times 2
\]
其中,\(f_{u,i}\) 表示 \(dp\) 数组的前一个版本,目的是防止用这一轮合并得到的 \(dp\) 更新这一轮的值。
注意要特判 \(u\) 为子树内唯一一个苹果的 corner,将他标记为叶子并对其特殊转移。
#include <bits/stdc++.h>
#define fi first
#define se second
#define eb(x) emplace_back(x)
#define pb(x) push_back(x)
#define lc(x) (tr[x].ls)
#define rc(x) (tr[x].rs)
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef long double ldb;
using pi = pair<int, int>;
const int N = 420, inf = 0x3f3f3f3f;
int n, m, a[N], dp[N][2 * N], sz[N], f[2 * N], ans = inf;
bool leaf[N];
vector<int> g[N];
void merge(int u, int v)
{
memcpy(f, dp[u], sizeof(f));
memset(dp[u], 0x3f, sizeof(dp[u]));
if(leaf[v])
{
for(int i = 0; i <= (sz[u] - 1) * 2; i += 2)
{
dp[u][i + 2] = min(dp[u][i + 2], f[i]);
dp[u][i] = min(dp[u][i], f[i] + 2);
}
}
else
{
for(int i = 0; i <= (sz[u] - 1) * 2; i += 2)
for(int j = 0; j <= (sz[v] - 1) * 2; j += 2)
dp[u][i + j + (j > 0) * 2] = min(dp[u][i + j + (j > 0) * 2],
f[i] + dp[v][j] + (dp[v][j] > 0) * 2);
}
sz[u] += sz[v];
a[u] += a[v];
}
void dfs(int u, int fa)
{
bool flag = a[u];
sz[u] = 1;
dp[u][0] = 0;
for(auto v : g[u])
{
if(v == fa) continue;
dfs(v, u);
merge(u, v);
}
if(a[u] == 0) dp[u][0] = 0;
else if(a[u] == 1 && flag) leaf[u] = 1;
}
int main()
{
//freopen("sample.in", "r", stdin);
//freopen("sample.out", "w", stdout);
ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
cin >> n >> m;
for(int i = 1; i < n; i++)
{
int u, v;
cin >> u >> v;
g[u].push_back(v);
g[v].push_back(u);
}
for(int i = 1; i <= m; i++)
{
int tmp;
cin >> tmp;
a[tmp]++;
}
memset(dp, 0x3f, sizeof(dp));
dfs(1, 0);
for(int i = 0; i <= 2 * (n - 1); i++)
ans = min(ans, max(dp[1][i], i));
cout << ans;
return 0;
}

浙公网安备 33010602011771号