[luoguP2495] [SDOI2011]消耗战(DP + 虚树)
明显虚树。
别的题解里都是这样说的。
先不考虑虚树,假设只有一组询问,该如何dp?
f[u]表示把子树u中所有的有资源的节点都切掉的最优解
如果节点u需要切掉了话,$f[u]=val[u]$
否则如果u的子树中有需要切除的点的话,$f[u] = min(val[u], \sum\limits_{v是u的儿子}f[v])$
val[u]表示是根到u的路径上最小的边的权值。
最后转移到虚树上即可。
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#define N 1000000
#define LL long long
using namespace std;
int n, m, cnt, rp, top, T;
int head[N], to[N], nex[N], dfn[N], f[N][21], q[N], deep[N], s[N];
LL ans[N], dp[N], val[N];
bool flag[N];
inline int read()
{
int x = 0, f = 1;
char ch = getchar();
for(; !isdigit(ch); ch = getchar()) if(ch == '-') f = -1;
for(; isdigit(ch); ch = getchar()) x = (x << 1) + (x << 3) + ch - '0';
return x * f;
}
inline void add(int x, int y, int z)
{
to[cnt] = y;
val[cnt] = z;
nex[cnt] = head[x];
head[x] = cnt++;
}
inline void dfs1(int u)
{
int i, v;
dfn[u] = ++rp;
deep[u] = deep[f[u][0]] + 1;
for(i = 0; f[u][i]; i++) f[u][i + 1] = f[f[u][i]][i];
for(i = head[u]; ~i; i = nex[i])
{
v = to[i];
if(!dfn[v])
{
f[v][0] = u;
dp[v] = min(dp[u], val[i]);
dfs1(v);
}
}
head[u] = -1;
}
inline int calc_lca(int x, int y)
{
int i, j;
if(deep[x] < deep[y]) swap(x, y);
for(i = 20; i >= 0; i--)
if(deep[f[x][i]] >= deep[y]) x = f[x][i];
if(x == y) return x;
for(i = 20; i >= 0; i--)
if(f[x][i] != f[y][i]) x = f[x][i], y = f[y][i];
return f[x][0];
}
inline bool cmp(int x, int y)
{
return dfn[x] < dfn[y];
}
inline void dfs2(int u)
{
LL sum = 0;
int i, v;
ans[u] = dp[u];
for(i = head[u]; ~i; i = nex[i])
{
v = to[i];
dfs2(v);
sum += ans[v];
}
if(sum && !flag[u]) ans[u] = min(ans[u], sum);
head[u] = -1;
}
inline void solve()
{
int i, lca;
m = read();
top = cnt = 0;
for(i = 1; i <= m; i++) q[i] = read(), flag[q[i]] = 1;
sort(q + 1, q + m + 1, cmp);
for(i = 1; i <= m; i++)
{
if(!top)
{
s[++top] = q[i];
continue;
}
lca = calc_lca(q[i], s[top]);
while(dfn[lca] < dfn[s[top]])
{
if(dfn[lca] >= dfn[s[top - 1]])
{
add(lca, s[top], 0);
if(s[--top] != lca) s[++top] = lca;
break;
}
add(s[top - 1], s[top], 0), top--;
}
s[++top] = q[i];
}
while(top > 1) add(s[top - 1], s[top], 0), top--;
dfs2(s[1]);
printf("%lld\n", ans[s[1]]);
for(i = 1; i <= m; i++) flag[q[i]] = 0;
}
int main()
{
int i, x, y, z;
n = read();
memset(head, -1, sizeof(head));
for(i = 1; i < n; i++)
{
x = read();
y = read();
z = read();
add(x, y, z);
add(y, x, z);
}
dp[1] = 1ll * 1e9 * 1e9;
dfs1(1);
T = read();
while(T--) solve();
return 0;
}

浙公网安备 33010602011771号