虚树

虚树

虚树一般是用来优化树上一些算法的。

一般情况下,如果我们对一棵树的算法只和某些已知关键点有关,而其他点没有用,就可以建立一颗只含关键点及它们两两之间的 LCA 的树。
其实相当于在原树上抽出一棵含有关键点的树,因为 LCA 起到连接关键点的作用,所以把lca以外的点压缩掉,最后剩下的就是原树的虚数。
建出虚树之后就可以通过减少树的大小(最大到关键点数2倍)来优化时间复杂度。

这篇里学到了 \(\color{black}{\texttt p}\color{red}{\texttt {_b_p_b}}\) 建虚树的方法。

把关键点按 dfn 排序,相邻点的 LCA 丢进去,再排序去重,每个点在虚树上的父亲就是自己和前驱的 LCA。虚树树根是第一个点。

于是就可以做到 \(m(\log n+\log m)\) 建虚树(\(n,m\) 分别是原树点数和关键点数)。

简单的实现

点击查看代码
	dfs(1,0);
	for(int i=1;i<=m;i++)q[i]=in;
	sort(q+1,q+1+m,cmp);
	for(int i=m;i>=2;i--)
		q[++m]=lca(q[i],q[i-1]);			
	sort(q+1,q+1+m,cmp);
	int t=0;
	for(int i=1;i<=m;i++)
		if(q[i]!=q[i-1])q[++t]=q[i];
	m=t;
	for(int i=2;i<=m;i++)
		Xinsert(lca(q[i],q[i-1]),q[i]);
	dfs1(q[1],0);

记得删虚树的时候是把添加的边一个一个删,不能直接清空 head,因为复杂度没有保证。

然后你需要做的就是想出在树上的做法了,是不是简单了很多呢?

因为还是不简单,所以来看几道虚树的题吧。


CF613D Kingdom and its Cities

CF613D

虽然虚树一般用来优化点数,但有时把虚树抽出来之后可以把问题简化,从而更容易想到做法。

对于这道题可以容易地想到虚树,因为每次给出的就相当于关键点。

那么首先判断无解,如果存在两个关键点之间没有点就是无解,否则一定有解。

然后我们尝试把虚树建出来,发现建完之后有个比较明显的 dp,就是在虚树上 dfs 一遍。
\(g[u]\) 表示 \(u\) 子树中有没有与 \(u\) 联通的关键点,\(f[u]\) 表示到 \(u\) 时的需要断的点数。

到了一个点,首先 \(f[u]=\sum f[v]\)

如果它是关键点,那么就把所有没断的儿子断了 \(f[u]+=\sum g[v]\)\(g[u]=[1]\)

如果它不是关键点,那么看它如果连接了多于 \(1\) 个的没断的儿子,那就把自己断了,\(g[u]=0,f[u]++\);否则如果只连了一个儿子,那就留到祖先再断,说不定可以一次断更多的,\(g[u]=1\);如果一个关键点儿子都没有,就不用管了,\(g[u]=0\)

时间复杂度 \(\sum m(\log n+\log m)\)\(\sum m=n\),可以把 \(\log m\) 放大到 \(\log n\) 得到时间复杂度可视为 \(O(n\log n)\)

code

点击查看代码
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define in read()
inline int read(){
	int p=0,f=1;
	char c=getchar();
	while(!isdigit(c)){if(c=='-')f=-1;c=getchar();}
	while(isdigit(c)){p=p*10+c-48;c=getchar();}
	return p*f;
}
const int N=1e5+5;
struct edge{
	int v,nxt;
}e[N<<1],X[N<<1];
int head[N],en;
int hd[N],xn;
inline void insert(int u,int v){
	e[++en].v=v;
	e[en].nxt=head[u];
	head[u]=en;
}
inline void Xinsert(int u,int v){
	X[++xn].v=v;
	X[xn].nxt=hd[u];
	hd[u]=xn;
}
inline void Xdel(int u){hd[u]=X[xn--].nxt;}
int fa[19][N],dep[N],dfn[N],sign;
inline void dfs(int u,int f){
	fa[0][u]=f,dep[u]=dep[f]+1,dfn[u]=++sign;
	for(int i=1;i<=18;i++)
		fa[i][u]=fa[i-1][fa[i-1][u]];
	for(int i=head[u],v=e[i].v;i;i=e[i].nxt,v=e[i].v)
		if(v^f)dfs(v,u);
}
inline int lca(int x,int y){
	if(dep[x]<dep[y])swap(x,y);
	for(int i=18;i>=0;i--)
		if(dep[fa[i][x]]>=dep[y])
			x=fa[i][x];
	if(x==y)return y;
	for(int i=18;i>=0;i--)
		if(fa[i][x]!=fa[i][y])
			x=fa[i][x],y=fa[i][y];
	return fa[0][x];
}
int ans;

int vis[N],flag=0,g[N],h[N],nump[N]; 
inline void dfs1(int u,int f){
	h[u]=nump[u]=0;
	for(int i=hd[u],v=X[i].v;i;i=X[i].nxt,v=X[i].v)
		if(v^f){
			if(dep[u]+1==dep[v]&&vis[u]&&vis[v]){flag=1;break;}
			dfs1(v,u),nump[u]+=g[v],h[u]+=h[v];
		}
	if(vis[u])h[u]+=nump[u],g[u]=1;
	else{
		if(nump[u]>1)h[u]++,g[u]=0;
		else if(nump[u]==1)g[u]=1;
		else g[u]=0;
	}
}
int n,m,T,q[N<<1],t;
int tbd[N];
bool cmp(int a,int b){return dfn[a]<dfn[b];}
signed main(){
	n=in;
	for(int i=1,u,v;i<n;i++)
		u=in,v=in,insert(u,v),insert(v,u);
	dfs(1,0);
	T=in;
	while(T--){
		m=in,ans=0,flag=0,t=0;
		for(int i=1;i<=m;i++)q[i]=in,vis[q[i]]=1;
		sort(q+1,q+1+m,cmp);
		for(int i=m;i>=2;i--)q[++m]=lca(q[i],q[i-1]);			
		sort(q+1,q+1+m,cmp);
		for(int i=1;i<=m;i++)if(q[i]!=q[i-1])q[++t]=q[i];m=t;
		for(int i=2;i<=m;i++)Xinsert(tbd[i]=lca(q[i],q[i-1]),q[i]);				
		dfs1(q[1],0);
		if(flag)cout<<-1<<'\n';
		else cout<<h[q[1]]<<'\n';
		for(int i=1;i<=m;i++)vis[q[i]]=0;
		for(int i=m;i>=2;i--)Xdel(tbd[i]);
	}
	return 0;
}

当然,你完全可以一眼秒出这个做法然后发现 \(f,g\) 的变化都在关键点及其LCA,然后再建虚树。


P2495 [SDOI2011]消耗战

链接

一样建出虚树,发现这个 dp 也比较简单。

\(dp[u]\) 表示切断 \(u\) 子树最小消耗,我们在第一次原树上 dfs 上维护一个从根到 \(u\) 的边权最小值 \(minn[u]\)

若当前点是关键点,则 \(dp[u]=minn[u]\)
否则,\(dp[u]=\min\{minn[u],\sum dp[v]\}\)

这里可能有一点疑问,就是可能会出现儿子节点的 \(dp[v]=minn[v]\) 比 LCA 节点父亲 \(u\) 更浅,但这样其实不会影响答案,因为这种情况下 \(\sum dp[v]\ge minn[u]\),dp不会从这种不合法情况转移。

code

点击查看代码
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define in read()
inline int read(){
	int p=0,f=1;
	char c=getchar();
	while(!isdigit(c)){if(c=='-')f=-1;c=getchar();}
	while(isdigit(c)){p=p*10+c-48;c=getchar();}
	return p*f;
}
const int N=25e4+5;
struct edge{
	int v,w,nxt;
}e[N<<1],X[N<<1];
int head[N],en;
int hd[N],xn;
inline void insert(int u,int v,int w){
	e[++en].v=v;
	e[en].w=w;
	e[en].nxt=head[u];
	head[u]=en;
}
inline void Xinsert(int u,int v){
	X[++xn].v=v;
	X[xn].nxt=hd[u];
	hd[u]=xn;
}
inline void Xdel(int u){hd[u]=X[xn--].nxt;}
int fa[19][N],dep[N],dfn[N],sign,minn[N];
inline void dfs(int u,int f){
	fa[0][u]=f,dep[u]=dep[f]+1,dfn[u]=++sign;
	for(int i=1;i<=18;i++)
		fa[i][u]=fa[i-1][fa[i-1][u]];
	for(int i=head[u],v=e[i].v,w=e[i].w;i;i=e[i].nxt,v=e[i].v,w=e[i].w)
		if(v^f)minn[v]=min(minn[u],w),dfs(v,u);
}
inline int lca(int x,int y){
	if(dep[x]<dep[y])swap(x,y);
	for(int i=18;i>=0;i--)
		if(dep[fa[i][x]]>=dep[y])
			x=fa[i][x];
	if(x==y)return y;
	for(int i=18;i>=0;i--)
		if(fa[i][x]!=fa[i][y])
			x=fa[i][x],y=fa[i][y];
	return fa[0][x];
}
int vis[N],ans[N]; 
inline int dfs1(int u,int f){
	if(vis[u])return ans[u]=minn[u];
	else{
		int res=0;
		for(int i=hd[u],v=X[i].v;i;i=X[i].nxt,v=X[i].v)
			if(v^f)res+=dfs1(v,u);
		return min(res,minn[u]);		
	}
}
int n,m,T,q[N<<1],t;
int tbd[N];
bool cmp(int a,int b){return dfn[a]<dfn[b];}
signed main(){
	n=in;
	for(int i=1,u,v,w;i<n;i++)
		u=in,v=in,w=in,insert(u,v,w),insert(v,u,w);
	minn[1]=0x7fffffffffffffff,dfs(1,0);
	T=in;
	while(T--){
		m=in,t=0;
		for(int i=1;i<=m;i++)q[i]=in,vis[q[i]]=1;
		sort(q+1,q+1+m,cmp);
		for(int i=m;i>=2;i--)q[++m]=lca(q[i],q[i-1]);			
		sort(q+1,q+1+m,cmp);
		for(int i=1;i<=m;i++)if(q[i]!=q[i-1])q[++t]=q[i];m=t;
		for(int i=2;i<=m;i++)Xinsert(tbd[i]=lca(q[i],q[i-1]),q[i]);		
		cout<<dfs1(q[1],0)<<'\n';		
		for(int i=1;i<=m;i++)vis[q[i]]=0;
		for(int i=m;i>=2;i--)Xdel(tbd[i]);
	}
	return 0;
}

P4103 [HEOI2014]大工程

也比较简单,对于最大路径和最小路径只需要在每个子树维护一个 \(u\) 子树中的关键点到 \(u\) 的最大距离和最小距离,然后每次用儿子中的最大和次大,最小和次小来更新答案。
对于路径距离和,可以把答案拆开,单个查询是 \(dep[u]+dep[v]-2*dep[lca]\),所以就对关键点单点加 \(dep[u]*(k-1)\),然后减自己作为 LCA 时的 \(2*dep[lca]\) 只需要统计下儿子子树的关键点数即可。

时间复杂度仍然是 \(O(n\log n)\)

code

点击查看代码
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define in read()
inline int read(){
	int p=0,f=1;
	char c=getchar();
	while(!isdigit(c)){if(c=='-')f=-1;c=getchar();}
	while(isdigit(c)){p=p*10+c-48;c=getchar();}
	return p*f;
}
const int N=1e6+5;
struct edge{
	int v,nxt;
}e[N<<1],X[N<<1];
int head[N],en;
int hd[N],xn;
inline void insert(int u,int v){
	e[++en].v=v;
	e[en].nxt=head[u];
	head[u]=en;
}
inline void Xinsert(int u,int v){
	X[++xn].v=v;
	X[xn].nxt=hd[u];
	hd[u]=xn;
}
inline void Xdel(int u){hd[u]=X[xn--].nxt;}
int fa[19][N],dep[N],dfn[N],sign;
inline void dfs(int u,int f){
	fa[0][u]=f,dep[u]=dep[f]+1,dfn[u]=++sign;
	for(int i=1;i<=18;i++)
		fa[i][u]=fa[i-1][fa[i-1][u]];
	for(int i=head[u],v=e[i].v;i;i=e[i].nxt,v=e[i].v)
		if(v^f)dfs(v,u);
}
inline int lca(int x,int y){
	if(dep[x]<dep[y])swap(x,y);
	for(int i=18;i>=0;i--)
		if(dep[fa[i][x]]>=dep[y])
			x=fa[i][x];
	if(x==y)return y;
	for(int i=18;i>=0;i--)
		if(fa[i][x]!=fa[i][y])
			x=fa[i][x],y=fa[i][y];
	return fa[0][x];
}
int vis[N],ans[N];
int ans1,ans2,ans3,minn[N],maxn[N],siz[N];
int n,m,T,q[N<<1],t1,t2;
inline void dfs1(int u,int f){
	int res1=0,res2=0,mx=0,cmx=0,mn=0x7fffffff,cmn=0x7fffffff;
	siz[u]=vis[u];
	for(int i=hd[u],v=X[i].v;i;i=X[i].nxt,v=X[i].v)
		if(v^f){
			dfs1(v,u),res1+=siz[v],res2+=siz[v]*siz[v],siz[u]+=siz[v];
			if(maxn[v]+dep[v]-dep[u]>=mx)cmx=mx,mx=maxn[v]+dep[v]-dep[u];
			else if(maxn[v]+dep[v]-dep[u]>cmx)cmx=maxn[v]+dep[v]-dep[u];
			if(minn[v]+dep[v]-dep[u]<=mn)cmn=mn,mn=minn[v]+dep[v]-dep[u];
			else if(minn[v]+dep[v]-dep[u]<cmn)cmn=minn[v]+dep[v]-dep[u];
		}
	maxn[u]=mx,minn[u]=mn;
	ans1+=dep[u]*vis[u]*(m-1)-2*dep[u]*((res1*res1-res2)/2+vis[u]*res1);
	ans2=min(ans2,mn+cmn);
	if(vis[u])ans2=min(ans2,mn),minn[u]=0;
	ans3=max(ans3,mx+cmx);
}
int tbd[N];
bool cmp(int a,int b){return dfn[a]<dfn[b];}
signed main(){
	n=in;
	for(int i=1,u,v;i<n;i++)
		u=in,v=in,insert(u,v),insert(v,u);
	minn[1]=0x7fffffff,dfs(1,0);
	T=in;
	while(T--){
		m=in;
		for(int i=1;i<=m;i++)q[i]=in,vis[q[i]]=1;
		sort(q+1,q+1+m,cmp);t1=m,t2=0;
		for(int i=m;i>=2;i--)q[++t1]=lca(q[i],q[i-1]);			
		sort(q+1,q+1+t1,cmp);
		for(int i=1;i<=t1;i++)if(q[i]!=q[i-1])q[++t2]=q[i];
		for(int i=2;i<=t2;i++)Xinsert(tbd[i]=lca(q[i],q[i-1]),q[i]);
		ans1=0,ans2=0x7fffffff,ans3=-0x7fffffff,dfs1(q[1],0);
		cout<<ans1<<' '<<ans2<<' '<<ans3<<'\n';
		for(int i=1;i<=t2;i++)vis[q[i]]=0;
		for(int i=t2;i>=2;i--)Xdel(tbd[i]);
	}
	return 0;
}

P3233 [HNOI2014]世界树

好不容易感觉自己独立做出来一次,你却让我WA的这么彻底,焯!

其实是我傻了。

首先把虚树建出来。

把一个点接受管辖看成被染色。

然后你发现如果父子都被染了色,那么中间路径上的东西就一定可以二分出一个断点来统计。所以考虑先把虚树上没有染色的点也就是 LCA 点染色。

那么有一个比较明显的做法,就是 dfs 第一遍用儿子染自己,dfs 第二遍用父亲染自己。
然后因为你染的时候取的是最近的,所以一样会排掉一些不合法情况,也就不需要维护什么次小值了(
实现就很简单

点击查看代码
inline void dfs1(int u,int f){
	siz[u]=vis[u];
	if(vis[u])minn[u]=0,belo[u]=u;
	else minn[u]=0x7fffffffffffffff;
	for(int i=hd[u],v=X[i].v;i;i=X[i].nxt,v=X[i].v)
		if(v^f){
			dfs1(v,u),siz[u]+=siz[v];
			if(minn[v]+dep[v]-dep[u]<minn[u]||(minn[v]+dep[v]-dep[u]==minn[u]&&belo[v]<belo[u]))	
				minn[u]=minn[v]+dep[v]-dep[u],belo[u]=belo[v];
		}
}
inline void dfs2(int u,int f){
	for(int i=hd[u],v=X[i].v;i;i=X[i].nxt,v=X[i].v)
		if(v^f){
			if(minn[u]+dep[v]-dep[u]<minn[v]||(minn[u]+dep[v]-dep[u]==minn[v]&&belo[u]<belo[v]))
				minn[v]=minn[u]+dep[v]-dep[u],belo[v]=belo[u];
			dfs2(v,u);
		}
}

belo 是染的颜色。

然后来求答案,我们发现除了虚树上压缩的路径要用二分统计,还有一些没有加入虚树的子树,所以我们考虑用原树的 \(siz\) 来容斥得答案。

一次 dfs,每个点对答案的贡献刚开始是 \(siz[u]\) 然后如果儿子是同色就减掉 \(siz[v]\),否则二分出断点 \(t\),断点及其下面贡献给儿子的颜色,断点上面贡献给自己,所以自己的贡献减去 \(siz[t]\),儿子的贡献加上 \(siz[t]-siz[v]\)

然后这里注意断点可以根据深度倍增跳,注意是根据 \(u,v\) 的被染上色的那个颜色点来算断点的深度,而不是 \(u,v\) 自己。我因为这个【】错误调了一天。

然后可能你虚树的根的上面还有一些点,所以记得给虚树根答案加上 \(siz[1]-siz[rt]\)

于是这道题就做完了,但细节很多,建议写个对拍找些小样例来做。

code

点击查看代码
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define in read()
inline int read(){
	int p=0,f=1;
	char c=getchar();
	while(!isdigit(c)){if(c=='-')f=-1;c=getchar();}
	while(isdigit(c)){p=p*10+c-48;c=getchar();}
	return p*f;
}
const int N=3e5+5;
struct edge{int v,nxt;}e[N<<1],X[N<<1];
int head[N],en,hd[N],xn;
inline void insert(int u,int v){e[++en].v=v;e[en].nxt=head[u];head[u]=en;}
inline void Xinsert(int u,int v){X[++xn].v=v,X[xn].nxt=hd[u],hd[u]=xn;}
inline void Xdel(int u){hd[u]=0,xn--;}
int fa[21][N],dep[N],dfn[N],sign,tsiz[N];
inline void dfs(int u,int f){
	fa[0][u]=f,dep[u]=dep[f]+1,dfn[u]=++sign,tsiz[u]=1;
	for(int i=1;i<=20;i++)
		fa[i][u]=fa[i-1][fa[i-1][u]];
	for(int i=head[u],v=e[i].v;i;i=e[i].nxt,v=e[i].v)
		if(v^f)dfs(v,u),tsiz[u]+=tsiz[v];
}
inline int lca(int x,int y){
	if(dep[x]<dep[y])swap(x,y);
	for(int i=20;i>=0;i--)
		if(dep[fa[i][x]]>=dep[y])
			x=fa[i][x];
	if(x==y)return y;
	for(int i=20;i>=0;i--)
		if(fa[i][x]!=fa[i][y])
			x=fa[i][x],y=fa[i][y];
	return fa[0][x];
}
bool cmp(int a,int b){return dfn[a]<dfn[b];}
int vis[N],n,m,T,q[N<<1],tq[N<<1],t1,t2,tbd[N];
int minn[N],belo[N],siz[N],ans[N];
inline void dfs1(int u,int f){
	siz[u]=vis[u];
	if(vis[u])minn[u]=0,belo[u]=u;
	else minn[u]=0x7fffffffffffffff;
	for(int i=hd[u],v=X[i].v;i;i=X[i].nxt,v=X[i].v)
		if(v^f){
			dfs1(v,u),siz[u]+=siz[v];
			if(minn[v]+dep[v]-dep[u]<minn[u]||(minn[v]+dep[v]-dep[u]==minn[u]&&belo[v]<belo[u]))	
				minn[u]=minn[v]+dep[v]-dep[u],belo[u]=belo[v];
		}
}
inline void dfs2(int u,int f){
	for(int i=hd[u],v=X[i].v;i;i=X[i].nxt,v=X[i].v)
		if(v^f){
			if(minn[u]+dep[v]-dep[u]<minn[v]||(minn[u]+dep[v]-dep[u]==minn[v]&&belo[u]<belo[v]))
				minn[v]=minn[u]+dep[v]-dep[u],belo[v]=belo[u];
			dfs2(v,u);
		}
}
inline void getans(int u,int f){
	int res1=tsiz[u];
	for(int i=hd[u],v=X[i].v,t,h;i;i=X[i].nxt,v=X[i].v){
		if(v==f)continue;
		if(belo[v]^belo[u]){
			h=dep[u]-minn[u]+dep[belo[v]],t=v;
			h=h&1?(h+1)>>1:belo[u]<belo[v]?(h>>1)+1:h>>1;
			for(int i=20;i>=0;i--)if(dep[fa[i][t]]>=h)t=fa[i][t];
			res1-=tsiz[t];
			ans[belo[v]]+=tsiz[t]-tsiz[v];
		}
		else res1-=tsiz[v];
		getans(v,u);
	}
	ans[belo[u]]+=res1;
}
int tn[N],up[N];
signed main(){
	n=in;
	for(int i=1,u,v;i<n;i++)
		u=in,v=in,insert(u,v),insert(v,u);
	dfs(1,0),T=in;
	while(T--){
		m=in;
		for(int i=1;i<=m;i++)tq[i]=q[i]=in,vis[q[i]]=1;
		sort(q+1,q+1+m,cmp);t1=m,t2=0;
		for(int i=t1;i>=2;i--)q[++m]=lca(q[i],q[i-1]);			
		sort(q+1,q+1+m,cmp);
		for(int i=1;i<=m;i++)if(q[i]!=q[i-1])q[++t2]=q[i];
		for(int i=2;i<=t2;i++)Xinsert(tbd[i]=lca(q[i],q[i-1]),q[i]);	
		dfs1(q[1],0),dfs2(q[1],0),getans(q[1],0);
		ans[belo[q[1]]]+=tsiz[1]-tsiz[q[1]];
		for(int i=1;i<=t1;i++)cout<<ans[tq[i]]<<' ',ans[tq[i]]=0;cout<<'\n';	
		for(int i=1;i<=t1;i++)vis[tq[i]]=0;
		for(int i=t2;i>=2;i--)Xdel(tbd[i]);
	}
	return 0;
} 

posted @ 2022-02-23 17:12  llmmkk  阅读(79)  评论(0)    收藏  举报