虚树
引入
P2495 [SDOI2011] 消耗战
我们容易发现我们可以有一个 \(O(nq)\) 的树形 Dp.
我们称有资源的点为关键点。有 \(k\) 个。
设 \(dp_u\) 表示 \(u\) 子树内不连接任何关键点的最小费用。
\(dp_u=dp_u+\min(dp_v,w)\),当 \(v\) 非关键点。
\(dp_u=dp_u+w\),当 \(v\) 为关键点。
那我们发现每次询问只用到一部分点。
那么我们把树浓缩。
发现对 dp 有用的只有关键点和关键点的两两 \(LCA\).
我们发现这样最多只有 \(2k-1\) 个点。
如何构造呢?
将关键点按 DFS 序排序;
遍历一遍,任意两个相邻的关键点求一下 \(LCA\),并且判重;
然后根据原树中的祖先后代关系建树。
具体实现上,在关键点序列上,枚举相邻的两个数,两两求得 \(lca\) 并且加入序列 \(A\) 中。
序列 \(A\) 按 dfn 排序后去重。
然后,在序列 \(A\) 上,枚举相邻两个数 \(x,y\),求出 \(lca\),然后连接 \(lca,y\) 就完成了。
为了方便,我们将 \(1\) 也加入 \(A\) 中。
code
#include<bits/stdc++.h>
#define st first
#define nd second
#define pi pair<int,int>
#define mp make_pair
using namespace std;
const int N=500050,logn=18;
vector<pi> e[N],vt[N];
int n,q,f[N][logn],g[N][logn],depth[N],dfn[N],num,val[N];
long long dp[N];
void dfs(int u,int fa) {
dfn[u]=++num;
f[u][0]=fa; depth[u]=depth[fa]+1;
for(int i=1; i<logn; i++) f[u][i]=f[f[u][i-1]][i-1];
for(int i=1; i<logn; i++) g[u][i]=min(g[f[u][i-1]][i-1],g[u][i-1]);
for(auto it:e[u]) {
int v=it.st,w=it.nd;
if(v==fa) continue;
g[v][0]=w;
dfs(v,u);
}
}
pi Lca(int u,int v) {
int res=1e9;
if(depth[u]>depth[v]) swap(u,v);
for(int i=logn-1; i>=0; i--) {
if(depth[f[v][i]]>=depth[u])
res=min(res,g[v][i]),v=f[v][i];
}
if(u==v) return mp(u,res);
for(int i=logn-1; i>=0; i--) {
if(f[u][i]!=f[v][i])
res=min(res,min(g[u][i],g[v][i])),u=f[u][i],v=f[v][i];
}
res=min(res,min(g[u][0],g[v][0]));
return mp(f[u][0],res);
}
int p[N],A[N],tot,tc;
bool cmp(int i,int j) {
return dfn[i]<dfn[j];
}
void buildvt() {
tc=0;
p[++tot]=1;
sort(p+1,p+1+tot,cmp);
for(int i=1; i<tot; i++) {
auto lc=Lca(p[i],p[i+1]);
A[++tc]=p[i]; A[++tc]=lc.st;
}
A[++tc]=p[tot];
sort(A+1,A+1+tc,cmp);
tc=unique(A+1,A+1+tc)-A-1;
for(int i=1; i<tc; i++) {
auto lc=Lca(A[i],A[i+1]);
auto cc=Lca(lc.st,A[i+1]);
vt[lc.st].push_back(mp(A[i+1],cc.nd));
}
}
void solve(int u) {
dp[u]=0;
for(auto it:vt[u]) {
int v=it.st,w=it.nd;
solve(v);
if(!val[v]) dp[u]=dp[u]+min(dp[v],1ll*w);
else dp[u]=dp[u]+w;
}
}
int main() {
scanf("%d",&n);
for(int i=1,u,v,w; i<n; i++) {
scanf("%d%d%d",&u,&v,&w);
e[u].push_back(mp(v,w));
e[v].push_back(mp(u,w));
}
g[1][0]=2e9,dfs(1,1);
scanf("%d",&q);
for(int qq=1,m,u; qq<=q; qq++) {
scanf("%d",&m);
tot=0;
for(int i=1; i<=m; i++)
scanf("%d",&u),p[++tot]=u,val[p[tot]]=1;
buildvt();
solve(1);
printf("%lld\n",dp[1]);
for(int i=1; i<=tc; i++) val[A[i]]=0;
for(int i=1; i<=tc; i++) dp[A[i]]=0,vt[A[i]].clear();
}
return 0;
}