虚树

引入

P2495 [SDOI2011] 消耗战
我们容易发现我们可以有一个 \(O(nq)\) 的树形 Dp.
我们称有资源的点为关键点。有 \(k\) 个。

\(dp_u\) 表示 \(u\) 子树内不连接任何关键点的最小费用。
\(dp_u=dp_u+\min(dp_v,w)\),当 \(v\) 非关键点。
\(dp_u=dp_u+w\),当 \(v\) 为关键点。

那我们发现每次询问只用到一部分点。
那么我们把树浓缩。

发现对 dp 有用的只有关键点和关键点的两两 \(LCA\).
我们发现这样最多只有 \(2k-1\) 个点。

如何构造呢?
将关键点按 DFS 序排序;
遍历一遍,任意两个相邻的关键点求一下 \(LCA\),并且判重;
然后根据原树中的祖先后代关系建树。

具体实现上,在关键点序列上,枚举相邻的两个数,两两求得 \(lca\) 并且加入序列 \(A\) 中。
序列 \(A\) 按 dfn 排序后去重。
然后,在序列 \(A\) 上,枚举相邻两个数 \(x,y\),求出 \(lca\),然后连接 \(lca,y\) 就完成了。

为了方便,我们将 \(1\) 也加入 \(A\) 中。

code
#include<bits/stdc++.h>
#define st first
#define nd second
#define pi pair<int,int>
#define mp make_pair
using namespace std;
const int N=500050,logn=18;
vector<pi> e[N],vt[N];
int n,q,f[N][logn],g[N][logn],depth[N],dfn[N],num,val[N];
long long dp[N];
void dfs(int u,int fa) {
	dfn[u]=++num;
	f[u][0]=fa; depth[u]=depth[fa]+1;
	for(int i=1; i<logn; i++) f[u][i]=f[f[u][i-1]][i-1];
	for(int i=1; i<logn; i++) g[u][i]=min(g[f[u][i-1]][i-1],g[u][i-1]);
	for(auto it:e[u]) {
		int v=it.st,w=it.nd;
		if(v==fa) continue;
		g[v][0]=w;
		dfs(v,u);
	}
}
pi Lca(int u,int v) {
	int res=1e9;
	if(depth[u]>depth[v]) swap(u,v);
	for(int i=logn-1; i>=0; i--) {
		if(depth[f[v][i]]>=depth[u]) 
			res=min(res,g[v][i]),v=f[v][i];
	}
	if(u==v) return mp(u,res);
	for(int i=logn-1; i>=0; i--) {
		if(f[u][i]!=f[v][i]) 
			res=min(res,min(g[u][i],g[v][i])),u=f[u][i],v=f[v][i];
	}
	res=min(res,min(g[u][0],g[v][0]));
	return mp(f[u][0],res);
}
int p[N],A[N],tot,tc;
bool cmp(int i,int j) {
	return dfn[i]<dfn[j];
}
void buildvt() {
	tc=0;
	p[++tot]=1;
	sort(p+1,p+1+tot,cmp);
	for(int i=1; i<tot; i++) {
		auto lc=Lca(p[i],p[i+1]);
		A[++tc]=p[i]; A[++tc]=lc.st;
	}
	A[++tc]=p[tot];
	sort(A+1,A+1+tc,cmp);
	tc=unique(A+1,A+1+tc)-A-1;
	for(int i=1; i<tc; i++) {
		auto lc=Lca(A[i],A[i+1]);
		auto cc=Lca(lc.st,A[i+1]);
		vt[lc.st].push_back(mp(A[i+1],cc.nd));
	}
}
void solve(int u) {
	dp[u]=0;
	for(auto it:vt[u]) {
		int v=it.st,w=it.nd;
		solve(v);
		if(!val[v]) dp[u]=dp[u]+min(dp[v],1ll*w);
		else dp[u]=dp[u]+w;
	}
} 
int main() {
	scanf("%d",&n);
	for(int i=1,u,v,w; i<n; i++) {
		scanf("%d%d%d",&u,&v,&w);
		e[u].push_back(mp(v,w));
		e[v].push_back(mp(u,w));
	}
	g[1][0]=2e9,dfs(1,1);
	scanf("%d",&q);
	for(int qq=1,m,u; qq<=q; qq++) {
		scanf("%d",&m);
		tot=0;
		for(int i=1; i<=m; i++)
			scanf("%d",&u),p[++tot]=u,val[p[tot]]=1;
		buildvt();
		solve(1);
		printf("%lld\n",dp[1]);
		for(int i=1; i<=tc; i++) val[A[i]]=0;
		for(int i=1; i<=tc; i++) dp[A[i]]=0,vt[A[i]].clear();
	}
	return 0;
}
posted @ 2023-05-11 21:33  s1monG  阅读(16)  评论(0)    收藏  举报