两棵树问题的一种点分治做法
简述题面:
你有两棵树,\(T_1\) ,\(T_2\) ,然后你需要对于每个点求出 \(\min_{j\not=i}(dist(T_1,i,j)+dist(T_2,i,j))\)
要求时间复杂度 \(O(n\log^2 n)\) 或更优
做法:
考虑点分治,假如在 \(T_1\) 固定 \(i,j\) 一定要经过某个 \(x\) ,然后把 \(x\) 作为分治点,那么实际上 \(val[i,j]=dist(T_2,i,j)+dep(T_1,i)+dep(T_1,j)\)
其中 \(dep\) 表示距离分治点的距离,然后发现这个就是一个只关于 \(T_2\) 上距离的加和的东西,类似于换根dp的问题,具体的就是把 \(T_1\) 的分治点接管的所有 \(T_1\) 上的节点在 \(T_2\) 上的虚树建出来,然后再虚树上跑 换根dp,然后每层分治点的并是 \(O(n)\) 的,所以时间复杂度是 \(O(n\log n) \cdot O(\log n)\) (建虚树)的
赛时代码
#pragma GCC optimize("Ofast")
#pragma GCC optimize(2)
#pragma GCC optimize(3)
#include <bits/stdc++.h>
#define int long long
#define fi first
#define se second
#define pb push_back
using namespace std;
inline int read(){
	int x=0,f=1;char ch=getchar();
	while(!isdigit(ch))f^=ch=='-',ch=getchar();
	while(isdigit(ch))x=x*10+(ch^48),ch=getchar();
	return f?x:-x;
}
const int N=1e5+5,inf=1e18;
int ans[N],n;
vector<pair<int,int>> T[N];
int top[N],siz[N],son[N],up[N],dfn[N],dep[N],dfntot;
void dfs1(int u,int fa){
	dfn[u]=++dfntot;
	up[u]=fa;
	siz[u]=1;
	for(pair<int,int> e:T[u]){
		int v=e.fi,w=e.se;
		if(v==fa)continue;
		dep[v]=dep[u]+w;
		dfs1(v,u);
		siz[u]+=siz[v];
		if(siz[son[u]]<siz[v])
			son[u]=v;
	}
}
void dfs2(int u,int tp){
	top[u]=tp;
	if(son[u])dfs2(son[u],tp);
	for(pair<int,int> e:T[u]){
		int v=e.fi,w=e.se;
		if(v==up[u]||v==son[u])continue;
		dfs2(v,v);
	}
}
int lca(int x,int y){
	while(top[x]!=top[y]){
		if(dep[top[x]]<dep[top[y]])swap(x,y);
		x=up[top[x]];
	}
	if(dep[x]>dep[y])swap(x,y);
	return x;
}
/*虚树part*/
vector<pair<int,int>> vt[N];
int sta[N],stop,vis[N];
void insert(int x){
	if(stop<=1){
		sta[++stop]=x;
		return;
	}
	int t=lca(x,sta[stop]);
	if(t==sta[stop]){
		sta[++stop]=x;
		return;
	}
	while(stop>1&&dfn[sta[stop-1]]>=dfn[t]){
		int u=sta[stop-1];
		int v=sta[stop--];
		vt[u].pb({v,dep[v]-dep[u]});
	}
	if(sta[stop]!=t){
		vt[t].pb({sta[stop],dep[sta[stop]]-dep[t]});
		sta[stop]=t;
	}
	sta[++stop]=x;
}
vector<int> sss;
void build(vector<int> st){
	sss=st;
	int len=sss.size();
	sort(sss.begin(),sss.end(),[](int x,int y){
		return dfn[x]<dfn[y];
	});
	stop=0;
	for(int v:sss)vis[v]=1;
	if(!vis[1])insert(1);
	for(int v:sss)insert(v);
	while(stop>1){
		int u=sta[stop-1];
		int v=sta[stop--];
		vt[u].pb({v,dep[v]-dep[u]});
	}
}
int f[N],g[N],d[N];
void pmin(int &x,int &y,int z){
	if(z<=x){
		y=x;
		x=z;
	}
	else y=min(y,z);
}
void dp1(int u,int fa){
	f[u]=g[u]=inf;
	for(pair<int,int> e:vt[u]){
		int v=e.fi,w=e.se;
		if(v==fa)continue;
		dp1(v,u);
		pmin(f[u],g[u],f[v]+w);
	}
	if(vis[u])pmin(f[u],g[u],d[u]);
}
void dp2(int u,int fa){
	for(pair<int,int> e:vt[u]){
		int v=e.fi,w=e.se;
		if(v==fa)continue;
		if(f[v]+w==f[u]){
			pmin(f[v],g[v],g[u]+w);
		}
		else{
			pmin(f[v],g[v],f[u]+w);
		}
		dp2(v,u);
	}
}
void dp3(int u,int fa){
	if(vis[u]){
		int tmpu=0;
		if(d[u]==f[u])tmpu=g[u];
		else tmpu=f[u];
		ans[u]=min(ans[u],d[u]+tmpu);
	}
	for(pair<int,int> e:vt[u]){
		int v=e.fi,w=e.se;
		if(v==fa)continue;
		dp3(v,u);
	}
	vis[u]=0;
	vt[u].clear();
}
/*虚树part-end*/
struct tree2{
	vector<pair<int,int>> T[N];
	int vis[N],siz[N],f[N],root,sum,dep[N];
	vector<int> nodes;
	void getsize(int u,int fa){
		siz[u]=1;
		for(pair<int,int> e:T[u]){
			int v=e.fi;
			if(v==fa||vis[v])continue;
			getsize(v,u);
			siz[u]+=siz[v];
		}
	}
	void getroot(int u,int fa){
		f[u]=0;
		for(pair<int,int> e:T[u]){
			int v=e.fi;
			if(v==fa||vis[v])continue;
			getroot(v,u);
			f[u]=max(f[u],siz[v]);
		}
		f[u]=max(f[u],sum-siz[u]);
		if(f[u]<f[root])root=u;
	}
	void dfs1(int u,int fa){
		for(pair<int,int> e:T[u]){
			int v=e.fi,w=e.se;
			if(vis[v]||v==fa)continue;
			dep[v]=dep[u]+w;
			dfs1(v,u);
		}
		nodes.pb(u);
	}
	void solve(int u){
		nodes.clear();
		dep[u]=0;
		dfs1(u,0);
		for(int x:nodes){
			d[x]=dep[x];
		}
		build(nodes);
		dp1(1,0);
		dp2(1,0);
		dp3(1,0);
		vis[u]=1;
		for(pair<int,int> e:T[u]){
			int v=e.fi,w=e.se;
			if(vis[v])continue;
			getsize(v,u);
			sum=siz[v];
			root=0;
			f[root]=inf;
			getroot(v,u);
			solve(root);
		}
	}
	void work(){
		sum=n;
		root=0;
		f[0]=inf;
		getsize(1,0);
		getroot(1,0);
		solve(root);
	}
}T2;
signed main(){
//	freopen("sample2.in","r",stdin);
//	freopen("tester.out","w",stdout);
	n=read();
	for(int i=1;i<=n;++i)ans[i]=inf;
	for(int i=1;i<n;++i){
		int x=read(),y=read(),z=read();
		T[x].pb({y,z});
		T[y].pb({x,z});
	}
	dfs1(1,0);
	dfs2(1,0);
	for(int i=1;i<n;++i){
		int x=read(),y=read(),z=read();
		T2.T[x].pb({y,z});
		T2.T[y].pb({x,z});
	}
	T2.work();
	for(int i=1;i<=n;++i){
		printf("%lld\n",ans[i]);
	}
	return 0;
}
一些补充
实际上我虽然固定了 \((i,j)\) 经过 \(x\) 但是建出虚树后不一定能够保证 \((i,j)\) 一定在 \(x\) 的不同子树中,但是没关系,因为如果 \((i,j)\) 的 \(lca\) 是 \(x\) 的后代,那么我会在 \(lca\) 处再一次统计 (i,j) ,而我在 \(x\) 处统计出的 \((i,j)\) 的答案一定不优于在 \(lca\) 处统计的答案
                    
                
                
            
        
浙公网安备 33010602011771号