【学习笔记】虚树

虚树用来处理一些树上的多次询问,对于每个询问,只考虑那些和询问有关的点,将无关的点都缩成边或直接剪掉。对于 \(Q\) 个询问,如果每次询问涉及的点数为 \(k_i\),那么总的时间复杂度就是 \((O\sum k_i)\) 的。

一、建树

在建树过程中,我们维护一个栈 \(stk\),表示当前在树上走到的最右链。为了方便,一开始我们先把根节点加入栈中。然后按照 dfs 序遍历询问涉及到的点。遍历到 \(x\) 点时,找到 \(x\) 和栈顶的 lca,将栈中的点替换成根 \(\to\) lca \(\to x\),同时将弹栈的点加边建立父子关系。遍历完所有点后再将栈弹空并一一加边,我们就得到了一棵只含询问涉及的点和它们的 lca 的树,大小是 \(O(k)\) 的。

二、例题

1.CF613D Kingdom and its Cities

建出虚树后在树上进行贪心即可。对于一个被标记的点,如果它有儿子被标记那么一定得断一条边。对于没有被标记的点,如果只有一个儿子被标记那么可以不断,否则就要将自己断掉。

Code
#include<bits/stdc++.h>
#define ll long long
#define il inline
#define pb push_back
using namespace std;
namespace asbt{
namespace cplx{bool begin;}
const int maxn=1e5+5;
int n,m,fa[maxn],dep[maxn],dfn[maxn];
int cnt,idx[maxn<<1],oula[maxn<<1];
int enm,hd[maxn],a[maxn],sz[maxn];
int stk[maxn],top,ans;
vector<int> e[maxn];
struct{
	int v,nxt;
}E[maxn];
il void addedge(int u,int v){
	E[++enm].v=v;
	E[enm].nxt=hd[u];
	hd[u]=enm;
}
il void dfs1(int u){
	dfn[u]=++cnt;
	oula[cnt]=cnt;
	idx[cnt]=u;
	for(int v:e[u]){
		if(v!=fa[u]){
			fa[v]=u;
			dep[v]=dep[u]+1;
			dfs1(v);
			oula[++cnt]=dfn[u];
		}
	}
} 
struct{
	int Log[maxn<<1],st[maxn<<1][22];
	il void build(){
		for(int i=2;i<=cnt;i++){
			Log[i]=Log[i>>1]+1;
		}
		for(int i=1;i<=cnt;i++){
			st[i][0]=oula[i];
		}
		for(int j=1;j<=Log[cnt];j++){
			for(int i=1;i+(1<<j)-1<=cnt;i++){
				st[i][j]=min(st[i][j-1],st[i+(1<<(j-1))][j-1]);
			}
		}
	}
	il int query(int l,int r){
		int p=Log[r-l+1];
		return min(st[l][p],st[r-(1<<p)+1][p]);
	}
}ST;
il int lca(int u,int v){
	if(dfn[u]>dfn[v]){
		swap(u,v);
	}
	return idx[ST.query(dfn[u],dfn[v])];
}
il void dfs2(int u){
//	cout<<u<<"\n";
	if(sz[u]){
		for(int i=hd[u];i;i=E[i].nxt){
			int v=E[i].v;
			dfs2(v);
			if(sz[v]){
				sz[v]=0;
				ans++;
			}
		}
	}
	else{
		for(int i=hd[u];i;i=E[i].nxt){
			int v=E[i].v;
			dfs2(v);
			sz[u]+=sz[v];
			sz[v]=0;
		}
		if(sz[u]>1){
			ans++,sz[u]=0;
		}
	}
	hd[u]=0;
}
namespace cplx{
	bool end;
	il double usdmem(){return (&begin-&end)/1048576.0;}
}
int main(){
	ios::sync_with_stdio(0),cin.tie(0);
	cin>>n;
	for(int i=1,u,v;i<n;i++){
		cin>>u>>v;
		e[u].pb(v),e[v].pb(u);
	}
	dfs1(1),ST.build();
	cin>>m;
	while(m--){
		int k;
		cin>>k;
		for(int i=1;i<=k;i++){
			cin>>a[i];
			sz[a[i]]=1;
		}
		for(int i=1;i<=k;i++){
			if(sz[fa[a[i]]]){
				for(int j=1;j<=k;j++){
					sz[a[j]]=0;
				}
				cout<<"-1\n";
				goto togo;
			}
		}
		sort(a+1,a+k+1,[](const int &x,const int &y){return dfn[x]<dfn[y];});
		top=enm=0;
		if(a[1]!=1){
			stk[++top]=1;
		}
		for(int i=1;i<=k;i++){
			int x=a[i];
			if(!top){
				stk[++top]=x;
				continue;
			}
			int y=lca(x,stk[top]);
			while(top>1&&dep[y]<dep[stk[top-1]]){
				addedge(stk[top-1],stk[top]);
				top--;
			}
			if(dep[y]<dep[stk[top]]){
				addedge(y,stk[top--]);
			}
			if(!top||y!=stk[top]){
				stk[++top]=y;
			}
			stk[++top]=x;
		}
		while(top>1){
			addedge(stk[top-1],stk[top]);
			top--;
		}
		ans=0;
		dfs2(1);
		sz[1]=0;
		cout<<ans<<"\n";
		togo:;
	}
	return 0;
}
}
int main(){return asbt::main();}

2.[bzoj2286][Sdoi2011]消耗战

显然又要建虚树。考虑如果要在 \(u\)\(v\) 的路径上删掉一条边,那么一定是删掉最小的那条最划算。因此虚树上的边权就是路径上的最小边权。可以用倍增求出。然后考虑对于虚树上的一个点 \(u\) 求出使它的子树(不含自身)中的关键点都与根断开联系的最小花费 \(f_u\),考虑它的一个儿子 \(v\),如果 \(v\) 是关键点那么一定要断开 \((u,v)\) 这条边,否则可以在 \(f_v\)\((u,v)\) 的边权中去取 \(\min\)。答案即为 \(f_1\)。记得清空。

Code
#include<bits/stdc++.h>
#define int long long
#define il inline
#define pb push_back
#define mp make_pair
#define pii pair<int,int>
#define fir first
#define sec second
using namespace std;
namespace asbt{
namespace cplx{bool begin;}
namespace IO{
	const int bufsz=1<<20;
	char buf[bufsz],*p1=buf,*p2=buf;
	#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,bufsz,stdin),p1==p2)?EOF:*p1++)
	il int read(){
		int ch=getchar();
		while(ch<'0'||ch>'9'){
			ch=getchar();
		}
		int x=0;
		while(ch>='0'&&ch<='9'){
			x=(x<<1)+(x<<3)+(ch^48);
			ch=getchar();
		}
		return x;
	}
}
using IO::read;
const int maxn=2.5e5+5,inf=1e18;
int n,m,cnt,dfn[maxn],dep[maxn];
int idx[maxn<<1],oula[maxn<<1];
int anc[maxn][22],mnw[maxn][22];
int a[maxn],stk[maxn],top,sz[maxn];
int hd[maxn],enm,f[maxn];
vector<pii> e[maxn];
struct{
	int v,w,nxt;
}E[maxn];
il void addedge(int u,int v,int w){
	E[++enm].v=v;
	E[enm].w=w;
	E[enm].nxt=hd[u];
	hd[u]=enm;
}
il void dfs1(int u){
	for(int i=1;i<=20;i++){
		anc[u][i]=anc[anc[u][i-1]][i-1];
		mnw[u][i]=min(mnw[u][i-1],mnw[anc[u][i-1]][i-1]);
	}
	dfn[u]=++cnt,dep[u]=dep[anc[u][0]]+1;
	idx[cnt]=u,oula[cnt]=cnt;
	for(pii i:e[u]){
		int v=i.fir,w=i.sec;
		if(v==anc[u][0]){
			continue;
		}
		anc[v][0]=u,mnw[v][0]=w;
		dfs1(v);
		oula[++cnt]=dfn[u];
	}
}
struct{
	int st[maxn<<1][22],Log[maxn<<1];
	il void build(){
		for(int i=2;i<=cnt;i++){
			Log[i]=Log[i>>1]+1;
		}
		for(int i=1;i<=cnt;i++){
			st[i][0]=oula[i];
		}
		for(int j=1;j<=Log[cnt];j++){
			for(int i=1;i+(1<<j)-1<=cnt;i++){
				st[i][j]=min(st[i][j-1],st[i+(1<<(j-1))][j-1]);
			}
		}
	}
	il int query(int l,int r){
		int p=Log[r-l+1];
		return min(st[l][p],st[r-(1<<p)+1][p]);
	}
}ST;
il int lca(int u,int v){
	if(dfn[u]>dfn[v]){
		swap(u,v);
	}
	return idx[ST.query(dfn[u],dfn[v])];
}
il int Mnv(int u,int p){
	int res=inf,tmp=0;
	while(p){
		if(p&1){
			res=min(res,mnw[u][tmp]);
			u=anc[u][tmp];
		}
		p>>=1,tmp++;
	}
	return res;
}
il void dfs2(int u){
	if(!hd[u]){
		f[u]=inf;
		return ;
	}
	for(int i=hd[u];i;i=E[i].nxt){
		int v=E[i].v,w=E[i].w;
		dfs2(v);
		if(sz[v]){
			f[u]+=w;
		}
		else{
			f[u]+=min(w,f[v]);
		}
		f[v]=sz[v]=0;
	}
	hd[u]=0;
}
namespace cplx{
	bool end;
	il double usdmem(){return (&begin-&end)/1048576.0;}
}
int main(){
	n=read();
	for(int i=1,u,v,w;i<n;i++){
		u=read(),v=read(),w=read();
		e[u].pb(mp(v,w));
		e[v].pb(mp(u,w));
	}
	dfs1(1),ST.build();
	m=read();
	while(m--){
		int k;
		k=read();
		for(int i=1;i<=k;i++){
			a[i]=read();
			sz[a[i]]=1;
		}
		sort(a+1,a+k+1,[](const int &x,const int &y){return dfn[x]<dfn[y];});
		enm=top=0;
		if(a[1]!=1){
			stk[++top]=1;
		}
		for(int i=1;i<=k;i++){
			int x=a[i];
			if(!top){
				stk[++top]=x;
				continue;
			}
			int y=lca(x,stk[top]);
			while(top>1&&dep[stk[top-1]]>dep[y]){
				addedge(stk[top-1],stk[top],Mnv(stk[top],dep[stk[top]]-dep[stk[top-1]]));
				top--;
			}
			if(dep[y]<dep[stk[top]]){
				addedge(y,stk[top],Mnv(stk[top],dep[stk[top]]-dep[y]));
				top--;
			}
			if(!top||y!=stk[top]){
				stk[++top]=y;
			}
			stk[++top]=x;
		}
		while(top>1){
			addedge(stk[top-1],stk[top],Mnv(stk[top],dep[stk[top]]-dep[stk[top-1]]));
			top--;
		}
		dfs2(1);
		cout<<f[1]<<"\n";
		f[1]=sz[1]=0;
	}
	return 0;
}
}
signed main(){return asbt::main();}

3.「HEOI2014」大工程

建立虚树,然后在树上 dp 就行了。记得清空。

学到了一个非常简单的建立虚树的方法,就是先按 dfn 排序,再求相邻 lca,放到一起排序去重再相邻连边。

然后树剖 lca 好写好调省时省空间还是太棒了。

Code
#include<bits/stdc++.h>
#define ll long long
#define il inline
#define pb push_back
using namespace std;
namespace asbt{
namespace cplx{bool begin;}
const int maxn=1e6+5;
const ll inf=1e18;
int n,m,cnt,fa[maxn],dep[maxn];
int top[maxn],sz[maxn],hes[maxn];
int a[maxn<<1],dfn[maxn];
int enm,hd[maxn];
ll ans1,ans2,ans3;
ll f1[maxn],f2[maxn],f3[maxn];
bool fch[maxn];
vector<int> e[maxn];
struct{
	int v,w,nxt;
}E[maxn];
il void addedge(int u,int v,int w){
	E[++enm].v=v;
	E[enm].w=w;
	E[enm].nxt=hd[u];
	hd[u]=enm;
}
il void dfs1(int u){
	sz[u]=1;
	int mxs=0;
	for(int v:e[u]){
		if(v==fa[u]){
			continue;
		}
		fa[v]=u;
		dep[v]=dep[u]+1;
		dfs1(v);
		sz[u]+=sz[v];
		if(mxs<sz[v]){
			mxs=sz[v];
			hes[u]=v;
		}
	}
}
il void dfs2(int u){
	dfn[u]=++cnt;
	if(!top[u]){
		top[u]=u;
	}
	if(hes[u]){
		top[hes[u]]=top[u];
		dfs2(hes[u]);
	}
	for(int v:e[u]){
		if(v!=fa[u]&&v!=hes[u]){
			dfs2(v);
		}
	}
}
il int lca(int u,int v){
	while(top[u]!=top[v]){
		if(dep[top[u]]>dep[top[v]]){
			u=fa[top[u]];
		}
		else{
			v=fa[top[v]];
		}
	}
	return dep[u]<dep[v]?u:v;
}
il bool cmp(const int &x,const int &y){
	return dfn[x]<dfn[y];
}
il void dfs3(int u){
//	cout<<u<<"\n";
	sz[u]=fch[u],f1[u]=0;
	if(fch[u]){
		f2[u]=f3[u]=0;
	}
	else{
		f2[u]=inf,f3[u]=-inf;
	}
	for(int i=hd[u];i;i=E[i].nxt){
		int v=E[i].v,w=E[i].w;
		dfs3(v);
		ans1+=f1[u]*sz[v]+(f1[v]+w*sz[v])*sz[u];
		ans2=min(ans2,f2[u]+w+f2[v]);
		ans3=max(ans3,f3[u]+w+f3[v]);
		f1[u]+=f1[v]+w*sz[v];
		f2[u]=min(f2[u],f2[v]+w);
		f3[u]=max(f3[u],f3[v]+w);
		sz[u]+=sz[v];
	}
}
namespace cplx{
	bool end;
	il double usdmem(){return (&begin-&end)/1048576.0;}
}
int main(){
	ios::sync_with_stdio(0),cin.tie(0);
	cin>>n;
	for(int i=1,u,v;i<n;i++){
		cin>>u>>v;
		e[u].pb(v),e[v].pb(u);
	}
	dfs1(1),dfs2(1);
	cin>>m;
	while(m--){
		int k;
		cin>>k;
		for(int i=1;i<=k;i++){
			cin>>a[i];
			fch[a[i]]=1;
		}
		sort(a+1,a+k+1,cmp);
		int tot=k;
		for(int i=1;i<k;i++){
			a[++tot]=lca(a[i],a[i+1]);
		}
		a[++tot]=1;
		sort(a+1,a+tot+1,cmp);
		tot=unique(a+1,a+tot+1)-a-1;
		enm=0;
		for(int i=1;i<tot;i++){
			int x=lca(a[i],a[i+1]),y=a[i+1];
			addedge(x,y,dep[y]-dep[x]);
		}
		ans1=0,ans2=inf,ans3=-inf;
		dfs3(1);
		cout<<ans1<<" "<<ans2<<" "<<ans3<<"\n";
		for(int i=1;i<=tot;i++){
//			cout<<a[i]<<" ";
			hd[a[i]]=fch[a[i]]=0;
		}
//		puts("");
	}
	return 0;
}
}
int main(){return asbt::main();}

4.「SDOI2015」寻宝游戏

不难发现,将当前的关键点按 \(dfn\) 排序后,答案即为 \(dis(a_1,a_2)+dis(a_2,a_3)+\dots+dis(a_{k-1},a_k)+dis(a_k,a_1)\),出发点就是 \(a_1\)。那么拿个 set 维护一下即可。

Code
#include<bits/stdc++.h>
#define ll long long
#define il inline
#define pii pair<int,int>
#define fir first
#define sec second
#define pb push_back
#define mp make_pair
#define it set<int>::iterator
using namespace std;
namespace asbt{
namespace cplx{bool begin;}
const int maxn=1e5+5;
int n,m,cnt,dfn[maxn],oula[maxn<<1],idx[maxn<<1];
ll dep[maxn];
vector<pii> e[maxn];
il void dfs(int u,int fa){
	dfn[u]=++cnt;
	oula[cnt]=cnt;
	idx[cnt]=u;
	for(pii i:e[u]){
		int v=i.fir,w=i.sec;
		if(v==fa){
			continue;
		}
		dep[v]=dep[u]+w;
		dfs(v,u);
		oula[++cnt]=dfn[u];
	}
}
struct{
	int st[maxn<<1][22],Log[maxn<<1];
	il void build(){
		for(int i=2;i<=cnt;i++){
			Log[i]=Log[i>>1]+1;
		}
		for(int i=1;i<=cnt;i++){
			st[i][0]=oula[i];
		}
		for(int j=1;j<=Log[cnt];j++){
			for(int i=1;i+(1<<j)-1<=cnt;i++){
				st[i][j]=min(st[i][j-1],st[i+(1<<(j-1))][j-1]);
			}
		}
	}
	il int query(int l,int r){
		int p=Log[r-l+1];
		return min(st[l][p],st[r-(1<<p)+1][p]);
	}
}ST;
il int lca(int u,int v){
	if(dfn[u]>dfn[v]){
		swap(u,v);
	}
	return idx[ST.query(dfn[u],dfn[v])];
}
il ll dis(int u,int v){
	return dep[u]+dep[v]-dep[lca(u,v)]*2;
}
struct cmp{
	il bool operator()(const int x,const int y)const{
		return dfn[x]<dfn[y];
	}
};
set<int,cmp> S;
namespace cplx{
	bool end;
	il double usdmem(){return (&begin-&end)/1048576.0;}
}
int main(){
	ios::sync_with_stdio(0),cin.tie(0);
	cin>>n>>m;
	for(int i=1,u,v,w;i<n;i++){
		cin>>u>>v>>w;
		e[u].pb(mp(v,w));
		e[v].pb(mp(u,w));
	}
	dfs(1,0);
	ST.build();
	ll ans=0;
	while(m--){
		int u;
		cin>>u;
		if(S.count(u)){
			if(S.size()<=2){
				ans=0,S.erase(u);
			}
			else{
				it cur=S.find(u);
				int pre=cur==S.begin()?*S.rbegin():*prev(cur);
				int nxt=cur==prev(S.end())?*S.begin():*next(cur);
				ans-=dis(pre,u)+dis(u,nxt);
				ans+=dis(pre,nxt);
				S.erase(cur);
			}
		}
		else{
			it cur=S.insert(u).fir;
			if(S.size()==1){
				ans=0;
			}
			else{
				int pre=cur==S.begin()?*S.rbegin():*prev(cur);
				int nxt=cur==prev(S.end())?*S.begin():*next(cur);
				ans-=dis(pre,nxt);
				ans+=dis(pre,u)+dis(u,nxt);
			}
		}
		cout<<ans<<"\n";
	}
	return 0;
}
}
int main(){return asbt::main();}

5.「ZJOI2019」语言

将经过某个点的所有路径的并求出来,然后就很简单了。

有这样一个结论:包含 \(a_1,a_2,\dots,a_k\) 的最小的子图的边数即为 \(\sum_{i=1}^{n}dep_{a_i}-\sum_{i=2}^{n}dep_{\operatorname{lca}(a_{i-1},a_i)}-dep_{\operatorname{lca}_{i=1}^{n}a_i}\),其中 \(a\) 按照 \(dfn\) 排序。那么可以用线段树去维护这个东西。将每个路径进行树上差分,然后线段树合并即可。

Code
#include<bits/stdc++.h>
#define ll long long
#define il inline
#define pb push_back
using namespace std;
namespace asbt{
namespace cplx{bool begin;}
const int maxn=1e5+5;
int n,m,fa[maxn],dfn[maxn],idx[maxn<<1],dep[maxn],oula[maxn<<1],cnt;
vector<int> e[maxn],g[maxn];
il void dfs1(int u,int faz){
	dfn[u]=++cnt;
	idx[cnt]=u;
	oula[cnt]=cnt;
	fa[u]=faz;
	for(int v:e[u]){
		if(v==faz){
			continue;
		}
		dep[v]=dep[u]+1;
		dfs1(v,u);
		oula[++cnt]=dfn[u];
	}
}
struct{
	int st[maxn<<1][22],Log[maxn<<1];
	il void build(){
		for(int i=2;i<=cnt;i++){
			Log[i]=Log[i>>1]+1;
		}
		for(int i=1;i<=cnt;i++){
			st[i][0]=oula[i];
		}
		for(int j=1;j<=Log[cnt];j++){
			for(int i=1;i+(1<<j)-1<=cnt;i++){
				st[i][j]=min(st[i][j-1],st[i+(1<<(j-1))][j-1]);
			}
		}
	}
	il int query(int l,int r){
		int p=Log[r-l+1];
		return min(st[l][p],st[r-(1<<p)+1][p]);
	}
}ST;
il int lca(int u,int v){
	if(dfn[u]>dfn[v]){
		swap(u,v);
	}
	return idx[ST.query(dfn[u],dfn[v])];
}
int ls[maxn*150],rs[maxn*150],tot,rt[maxn];
ll ans;
struct node{
	int l,r,p,sum,cnt;
	node(int l=0,int r=0,int p=0,int sum=0,int cnt=0):l(l),r(r),p(p),sum(sum),cnt(cnt){}
	il node operator+(const node &x)const{
		if(l==-1){
			return x;
		}
		if(x.l==-1){
			return *this;
		}
		node res;
		res.l=l,res.r=x.r;
		res.p=lca(p,x.p);
		res.sum=sum+x.sum+dep[p]+dep[x.p]-dep[lca(r,x.l)]-dep[res.p];
		return res;
	}
}tr[maxn*150];
il void pushup(int id){
	tr[id]=tr[ls[id]]+tr[rs[id]];
}
il void add(int &id,int l,int r,int p,int v){
	if(!id){
		id=++tot;
	}
	if(l==r){
		int tmp=tr[id].cnt+v;
		if(tmp>0){
			tr[id]=node(idx[p],idx[p],idx[p],0,tmp);
		}
		else{
			tr[id]=node(-1,-1,-1,0,tmp);
		}
		return ;
	}
	int mid=(l+r)>>1;
	if(p<=mid){
		add(ls[id],l,mid,p,v);
	}
	else{
		add(rs[id],mid+1,r,p,v);
	}
	pushup(id);
}
il int merge(int p,int q,int l,int r){
	if(!p||!q){
		return p+q;
	}
	if(l==r){
		int tmp=tr[p].cnt+tr[q].cnt;
		if(tmp>0){
			tr[p]=node(idx[l],idx[l],idx[l],0,tmp);
		}
		else{
			tr[p]=node(-1,-1,-1,0,tmp);
		}
		return p;
	}
	int mid=(l+r)>>1;
	ls[p]=merge(ls[p],ls[q],l,mid);
	rs[p]=merge(rs[p],rs[q],mid+1,r);
	pushup(p);
	return p;
}
il void dfs2(int u,int fa){
	for(int v:e[u]){
		if(v==fa){
			continue;
		}
		dfs2(v,u);
		rt[u]=merge(rt[u],rt[v],1,cnt);
	}
	for(int v:g[u]){
		int x=abs(v),y=v>0?1:-1;
		add(rt[u],1,cnt,x,y);
	}
//	cout<<u<<" "<<tr[rt[u]].sum<<"\n";
	ans+=tr[rt[u]].sum;
}
namespace cplx{
	bool end;
	il double usdmem(){return (&begin-&end)/1048576.0;}
}
int main(){
	ios::sync_with_stdio(0),cin.tie(0);
	cin>>n>>m;
	for(int i=1,u,v;i<n;i++){
		cin>>u>>v;
		e[u].pb(v),e[v].pb(u);
	}
	dfs1(1,0);
	ST.build();
//	for(int i=1;i<=n;i++){
//		cout<<dep[i]<<" ";
//	}
//	puts("");
//	for(int i=1;i<=n;i++){
//		for(int j=1;j<=n;j++){
//			cout<<lca(i,j)<<" ";
//		}
//		puts("");
//	}
	while(m--){
		int u,v;
		cin>>u>>v;
		int x=lca(u,v),y=fa[x];
		g[u].pb(dfn[u]),g[u].pb(dfn[v]);
		g[v].pb(dfn[u]),g[v].pb(dfn[v]);
		g[x].pb(-dfn[u]),g[x].pb(-dfn[v]);
		g[y].pb(-dfn[u]),g[y].pb(-dfn[v]);
	}
	tr[0]=node(-1);
	dfs2(1,0);
	cout<<ans/2;
	return 0;
}
}
int main(){return asbt::main();}

6.「HNOI2014」世界树

首先考虑暴力,换根 DP 即可。

于是建虚树。对于虚树上的点,依然是换根 DP。考虑处理不在虚树上的点。

不难发现,对于虚树上一个点 \(u\),如果它在原树上的一个儿子 \(v\) 的子树中没有虚树节点,那么 \(v\) 的子树就一定都归管 \(u\) 的点管。于是只剩下虚树上的边(在原树中是一条链)上的点了。

而这也是好做的,这条链显然一部分归爸爸另一部分归儿子。计算有哪些归爸爸哪些归儿子即可。需要倍增。

时间复杂度 \(O(\sum m\log n)\)

Code
#include<bits/stdc++.h>
#define ll long long
#define il inline
#define pb push_back
#define pii pair<int,int>
#define fir first
#define sec second
#define mp make_pair
using namespace std;
namespace asbt{
const int maxn=3e5+5,inf=1e9;
int n,m,cnt,dep[maxn],sz[maxn];
int hes[maxn],top[maxn],dfn[maxn];
int anc[maxn][22],sum[maxn][22],b[maxn];
int a[maxn<<1],g[maxn],dp[maxn],ans[maxn];
bool f[maxn];
vector<int> e[maxn];
vector<pii> E[maxn];
il void dfs1(int u){
	sz[u]=1;
	for(int i=1;i<=20;i++){
		anc[u][i]=anc[anc[u][i-1]][i-1];
	}
	int mxs=0;
	for(int v:e[u]){
		if(v==anc[u][0]){
			continue;
		}
		anc[v][0]=u;
		dep[v]=dep[u]+1;
		dfs1(v);
		sz[u]+=sz[v];
		if(mxs<sz[v]){
			mxs=sz[v];
			hes[u]=v;
		}
	}
}
il void dfs2(int u){
	if(!top[u]){
		top[u]=u;
	}
	dfn[u]=++cnt;
	for(int i=1;i<=20;i++){
		sum[u][i]=sum[u][i-1]+sum[anc[u][i-1]][i-1];
	}
	if(hes[u]){
		int v=hes[u];
		sum[v][0]=sz[u]-sz[v];
		top[v]=top[u];
		dfs2(v);
	}
	for(int v:e[u]){
		if(v==anc[u][0]||v==hes[u]){
			continue;
		}
		sum[v][0]=sz[u]-sz[v];
		dfs2(v);
	}
}
il int lca(int u,int v){
	while(top[u]!=top[v]){
		if(dep[top[u]]<dep[top[v]]){
			swap(u,v);
		}
		u=anc[top[u]][0];
	}
	return dep[u]<dep[v]?u:v;
}
il int kth(int u,int k){
	int x=0;
	while(k){
		if(k&1){
			u=anc[u][x];
		}
		k>>=1,x++;
	}
	return u;
}
il int ktot(int u,int k){
	int x=0,res=0;
	while(k){
		if(k&1){
			res+=sum[u][x];
			u=anc[u][x];
		}
		k>>=1,x++;
	}
	return res;
}
bool cmp(const int &x,const int &y){
	return dfn[x]<dfn[y];
}
il void dfs3(int u){
	if(f[u]){
		dp[u]=0,g[u]=u;
	}
	else{
		dp[u]=inf;
	}
	for(pii i:E[u]){
		int v=i.fir,w=i.sec;
		dfs3(v);
		if(dp[v]+w<dp[u]||dp[v]+w==dp[u]&&g[v]<g[u]){
			dp[u]=dp[v]+w;
			g[u]=g[v];
		}
	}
}
il void dfs4(int u){
//	cout<<u<<"\n";
	ans[g[u]]+=sz[u];
	for(pii i:E[u]){
		int v=i.fir,w=i.sec;
//		cout<<v<<" "<<w<<"\n";
		if(dp[u]+w<dp[v]||dp[u]+w==dp[v]&&g[u]<g[v]){
			dp[v]=dp[u]+w;
			g[v]=g[u];
		}
		int x=kth(v,dep[v]-dep[u]-1);
		ans[g[u]]-=sz[x];
		if(dep[v]==dep[u]+1){
			goto togo;
		}
		if(g[u]==g[v]){
			int zong=ktot(v,dep[v]-dep[u]-1);
			ans[g[u]]+=zong;
		}
		else{
			int zong=ktot(v,dep[v]-dep[u]-1);
			int len=dp[u]-dep[u]+dep[v]+dp[v];
			if(len&1){
				len>>=1;
			}
			else{
				len>>=1;
				if(g[v]>g[u]){
					len--;
				}
			}
			len-=dp[v];
			int tmp=ktot(v,len);
			ans[g[v]]+=tmp;
			ans[g[u]]+=zong-tmp;
//			cout<<tmp<<" "<<zong<<"\n";
		}
		togo:;
		dfs4(v);
	}
}
int main(){
//	freopen("P3233_1.in","r",stdin);
	ios::sync_with_stdio(0),cin.tie(0);
	cin>>n;
	for(int i=1,u,v;i<n;i++){
		cin>>u>>v;
		e[u].pb(v),e[v].pb(u);
	}
	dfs1(1),dfs2(1);
	cin>>m;
	while(m--){
		int k;
		cin>>k;
		for(int i=1;i<=k;i++){
			cin>>a[i];
			f[a[i]]=1;
			b[i]=a[i];
		}
		int tot=k;
		a[++tot]=1;
		sort(a+1,a+tot+1,cmp);
		for(int i=1;i<=k;i++){
			a[++tot]=lca(a[i],a[i+1]);
		}
		sort(a+1,a+tot+1,cmp);
		tot=unique(a+1,a+tot+1)-a-1;
		for(int i=1;i<tot;i++){
			int x=lca(a[i],a[i+1]),y=a[i+1];
			E[x].pb(mp(y,dep[y]-dep[x]));
		}
//		for(int i=1;i<=tot;i++){
//			cout<<a[i]<<" ";
//		}
//		puts("");
		dfs3(1);
		dfs4(1);
		for(int i=1;i<=k;i++){
			cout<<ans[b[i]]<<" ";
		}
		cout<<"\n";
		for(int i=1;i<=tot;i++){
			g[a[i]]=dp[a[i]]=ans[a[i]]=f[a[i]]=0;
			E[a[i]].clear();
		}
	}
	return 0;
}
}
int main(){return asbt::main();}
posted @ 2025-05-25 17:26  zhangxy__hp  阅读(48)  评论(0)    收藏  举报