点分治 学习笔记

点分治是一种解决树上问题的方式,比如处理带权树上两点的距离等于 \(k\) 的个数为多少。

流程

当我们想要计算一棵树 \(t\) 的时候,我们可以选定一个点 \(rt\) 作为这棵树临时的根,处理和 \(rt\) 有关的操作(如经过 \(rt\) 且距离为 \(k\) 的二元组个数),然后对于他的所有子树继续进行同样的操作,直到遍历完子树为止。

那么第一个问题:\(rt\) 应该选哪个?
很明显,我们应该选择的是这颗子树的重心。(重心的定义:如果在树中选择某个节点并删除,这棵树将分为若干棵子树,统计子树节点数并记录最大值。取遍树上所有节点,使此最大值取到最小的节点被称为整个树的重心。-- OIwiki)因为重心的所有子节点为根的子树大小都不会超过 \(\lfloor\frac{n}{2}\rfloor\),所以这样我们最多向下递归 \(O(\log n)\) 次,达到最优,所以我们知道点分治的时间复杂度:\(O(n\log n)\)

我们设 T_Div(u,siz) 表示当前子树的树根为 \(u\),大小为 \(siz\)。同时,我们用 \(sz_u\) 表示此时以 \(u\) 为根的子树的大小,\(mxs_u\) 表示以 \(u\) 的子节点为根的子树的大小的最大值,\(del_u\) 表示 \(u\) 点是否被删除。
首先我们先找到重心:

	int mxsiz=1e9,rt=-1;
	function<void(int,int)> dfs1=[&](int u,int fa) {
		sz[u]=1; mxs[u]=0;
		for(auto [v,d]:G[u]) if(v!=fa&&!del[v]) {
			dfs1(v,u);
			sz[u]+=sz[v];
			mxs[u]=max(mxs[u],sz[v]);
		}
		mxs[u]=max(mxs[u],siz-sz[u]);//因为 u 为根时,从当前的根走到 u 的路径也会被计算到子节点的子树的大小中
		if(mxs[u]<mxsiz) mxsiz=mxs[u],rt=u;
	};
	dfs1(u,0);

然后我们将重心设为树根,重新统计每个节点的子树的大小,同时统计此时的树根到节点的权值(如记录根节点到当前节点的距离)。

	for(auto [v,d]:G[rt]) if(!del[v]){
		function<void(int,int,int)> dfs2=[&](int u,int fa,int val) {
			sz[u]=1;
			v1.push_back(val),v2.push_back(val);
			for(auto [v,d]:G[u]) if(!del[v]&&v!=fa) {
				dfs2(v,u,val+d); 
				sz[u]+=sz[v];
			}
		};
		dfs2(v,rt,d);
	}

在做完这些事情之后,我们考虑统计答案。因为在子树中我们只需要计算通过该子树根节点的答案,经过简单的容斥我们可以知道相当于求经过所有点的答案减去只在子树中的答案。在 dfs2 时分别统计即可。

代码

综上,代码会长成以下这样(以洛谷 P3806 【模板】点分治 1为例):

#include<bits/stdc++.h>

#define int ll
#define ll long long
#define i128 __int128

#define mem(a,b) memset((a),(b),sizeof(a))
#define m0(a) memset((a),0,sizeof(a))
#define m1(a) memset(a,-1,sizeof(a))
#define lc(x) ((x)<<1)
#define rc(x) (((x)<<1)|1)
#define pb(G,x) (G).push_back((x))
#define For(a,b,c) for(int a=(b);a<=(c);a++)
#define Rep(a,b,c) for(int a=(b);a>=(c);a--)
#define in1(a) a=read()
#define in2(a,b) a=read(), b=read()
#define in3(a,b,c) a=read(), b=read(), c=read()
#define fst first 
#define scd second 
#define dbg puts("IAKIOI")

using namespace std;

int read() {
	int x=0,f=1; char c=getchar();
	for(;c<'0'||c>'9';c=getchar()) f=(c=='-'?-1:1); 
	for(;c<='9'&&c>='0';c=getchar()) x=(x<<1)+(x<<3)+(c^48);
	return x*f;
}
void write(int x) { if(x>=10) write(x/10); putchar('0'+x%10); }

const int mod = 998244353;
int qpo(int a,int b) {int res=1; for(;b;b>>=1,a=(a*a)%mod) if(b&1) res=res*a%mod; return res; }
int inv(int a) {return qpo(a,mod-2); }

#define maxn 10050

int n,m;
vector<pair<int,int> > G[maxn];
int sz[maxn],mxs[maxn];
bool del[maxn];

int qu[maxn];
ll ans[maxn];
int vis[10000010];

void T_Div(int u,int siz){
	int mxsiz=1e9,rt=-1;
	function<void(int,int)> dfs1=[&](int u,int fa) {
		sz[u]=1; mxs[u]=0;
		for(auto [v,d]:G[u]) if(v!=fa&&!del[v]) {
			dfs1(v,u);
			sz[u]+=sz[v];
			mxs[u]=max(mxs[u],sz[v]);
		}
		mxs[u]=max(mxs[u],siz-sz[u]);
		if(mxs[u]<mxsiz) mxsiz=mxs[u],rt=u;
	};
	dfs1(u,0);
	auto calc =[&](vector<int> v,int flg) {
		For(i,1,m) {
			for(auto x:v) {
				if(qu[i]-x>=0) ans[i]+=flg*vis[qu[i]-x];
				if(x<=1e7) vis[x]++;
			}
			for(auto x:v) if(x<=1e7) vis[x]=0;
		}
	};
	vector<int> v1,v2;
	v1.clear(); v2.clear();
	v1.push_back(0);
	for(auto [v,d]:G[rt]) if(!del[v]){
		v2.clear();
		function<void(int,int,int)> dfs2=[&](int u,int fa,int val) {
			sz[u]=1;
			v1.push_back(val),v2.push_back(val);
			for(auto [v,d]:G[u]) if(!del[v]&&v!=fa) {
				dfs2(v,u,val+d); 
				sz[u]+=sz[v];
			}
		};
		dfs2(v,rt,d);
		calc(v2,-1);
	}
	calc(v1,1);
	del[rt]=1;
	for(auto [v,d]:G[u]) if(!del[v]) T_Div(v,sz[v]);
}

void work() {
	in2(n,m);
	For(i,1,n-1) {
		int u,v,d;
		in3(u,v,d);
		G[u].push_back({v,d});
		G[v].push_back({u,d});
	}For(i,1,m) in1(qu[i]);
	T_Div(1,n);
	For(i,1,m){ 
//		cout<<ans[i]<<'\n';
		if(ans[i]>=1) cout<<"AYE"<<'\n';
		else cout<<"NAY"<<'\n';
	}
}

signed main() {
//	freopen("data.in","r",stdin);
//	freopen("myans.out","w",stdout);
//	ios::sync_with_stdio(false); 
//	cin.tie(0); cout.tie(0);
	double stt=clock();
	int _=1;
//	_=read();
//	cin>>_;
	For(i,1,_) {
		work();
	}
	cerr<<"\nTotal Time is:"<<(clock()-stt)*1.0/1000<<" second(s)."<<'\n';
	return 0;
}

练习

那么做完了上面这些,再拿两道题练练手吧!

P4178 Tree

给定一棵 \(n\) 个节点的树,每条边有边权,求出树上两点距离小于等于 \(k\) 的点对数量。

Code
#include<bits/stdc++.h>

#define int ll
#define ll long long
#define i128 __int128

#define mem(a,b) memset((a),(b),sizeof(a))
#define m0(a) memset((a),0,sizeof(a))
#define m1(a) memset(a,-1,sizeof(a))
#define lc(x) ((x)<<1)
#define rc(x) (((x)<<1)|1)
#define pb(G,x) (G).push_back((x))
#define For(a,b,c) for(int a=(b);a<=(c);a++)
#define Rep(a,b,c) for(int a=(b);a>=(c);a--)
#define in1(a) a=read()
#define in2(a,b) a=read(), b=read()
#define in3(a,b,c) a=read(), b=read(), c=read()
#define fst first 
#define scd second 
#define dbg puts("IAKIOI")

using namespace std;

int read() {
	int x=0,f=1; char c=getchar();
	for(;c<'0'||c>'9';c=getchar()) f=(c=='-'?-1:1); 
	for(;c<='9'&&c>='0';c=getchar()) x=(x<<1)+(x<<3)+(c^48);
	return x*f;
}
void write(int x) { if(x>=10) write(x/10); putchar('0'+x%10); }

const int mod = 998244353;
int qpo(int a,int b) {int res=1; for(;b;b>>=1,a=(a*a)%mod) if(b&1) res=res*a%mod; return res; }
int inv(int a) {return qpo(a,mod-2); }

#define maxn 40040

int n,k;
int sz[maxn];
int mxs[maxn];
bool del[maxn];
vector<pair<int,int> > G[maxn];
int ans=0;

struct BIT {
	int tr[maxn];
	int lowb(int x) { return x&-x; }
	void add(int x,int val) { for(;x<maxn;x+=lowb(x)) tr[x]+=val; }
	int qup(int x) { int res=0; for(;x;x-=lowb(x)) res+=tr[x]; return res;}
	int qulr(int l,int r) {	return qup(r+1)-qup(l); } 
};

void T_Div(int u,int siz) {
	int mnsiz=1e9,rt=-1;
	function<void(int,int)> dfs1=[&](int u,int fa) {
		mxs[u]=0; sz[u]=1;
		for(auto [v,w]:G[u]) if(v!=fa&&!del[v]) {
			dfs1(v,u);
			sz[u]+=sz[v];
			mxs[u]=max(mxs[u],sz[v]);
		}
		mxs[u]=max(mxs[u],siz-sz[u]);
		if(mxs[u]<mnsiz) mnsiz=mxs[u],rt=u;
	};
	dfs1(u,-1);
//	cout<<u<<' '<<siz<<' '<<rt<<'\n';
	vector<int> v1,v2;
	v1.clear(); v2.clear();
	v1.push_back(0);
	auto calc = [&](vector<int> v,int flg) {
		int res=0;
		int n=v.size();
		sort(v.begin(),v.end());
//		cout<<"    "<<flg<<' ' ;for(auto x:v) cout<<x<<' '; puts("");
		For(i,0,n-1) {
			int l=i,r=n-1;
			while(l<r) {
				int mid=l+r+1>>1;
				if(v[mid]+v[i]>k) r=mid-1;
				else l=mid;
			}
			res+=l-i;
		}
//		cout<<"    calc ans:"<<res<<'\n';
		return res*flg;
	};
	for(auto [v,w]:G[rt]) if(!del[v]) {
		v2.clear();
		function<void(int,int,int)> dfs2=[&](int u,int fa,int val) {
			sz[u]=1;
			v1.push_back(val); v2.push_back(val);
			for(auto [v,w]:G[u]) if(v!=fa&&!del[v]) {
				dfs2(v,u,val+w);
				sz[u]+=sz[v];
			}
		};
		dfs2(v,rt,w);
		ans+=calc(v2,-1);
	}
	ans+=calc(v1,1);
	del[rt]=1;
	for(auto [v,w]:G[rt]) if(!del[v]) T_Div(v,sz[v]);
}

void work() {
	in1(n);
	For(i,2,n) {
		int u,v,w;
		in3(u,v,w);
		G[u].push_back({v,w});
		G[v].push_back({u,w});
	}
	in1(k);
	T_Div(1,n);
	cout<<ans<<'\n';
}

signed main() {
//	freopen("data.in","r",stdin);
//	freopen("myans.out","w",stdout);
//	ios::sync_with_stdio(false); 
//	cin.tie(0); cout.tie(0);
	double stt=clock();
	int _=1;
//	_=read();
//	cin>>_;
	For(i,1,_) {
		work();
	}
	cerr<<"\nTotal Time is:"<<(clock()-stt)*1.0/1000<<" second(s)."<<'\n';
	return 0;
}

P4149 [IOI 2011] Race

给一棵树,每条边有权。求一条简单路径,权值和等于 \(k\),且边的数量最小。

Hint

不会计算答案?试着在统计路径长度的同时记录该路径长度的最少边数。统计答案是更新边数就好了!

Code
#include<bits/stdc++.h>

#define int ll
#define ll long long
#define i128 __int128
#define PII pair<int,int>

#define mem(a,b) memset((a),(b),sizeof(a))
#define m0(a) memset((a),0,sizeof(a))
#define m1(a) memset(a,-1,sizeof(a))
#define lc(x) ((x)<<1)
#define rc(x) (((x)<<1)|1)
#define pb(G,x) (G).push_back((x))
#define For(a,b,c) for(int a=(b);a<=(c);a++)
#define Rep(a,b,c) for(int a=(b);a>=(c);a--)
#define in1(a) a=read()
#define in2(a,b) a=read(), b=read()
#define in3(a,b,c) a=read(), b=read(), c=read()
#define fst first 
#define scd second 
#define dbg puts("IAKIOI")

using namespace std;

int read() {
	int x=0,f=1; char c=getchar();
	for(;c<'0'||c>'9';c=getchar()) f=(c=='-'?-1:1); 
	for(;c<='9'&&c>='0';c=getchar()) x=(x<<1)+(x<<3)+(c^48);
	return x*f;
}
void write(int x) { if(x>=10) write(x/10); putchar('0'+x%10); }

const int mod = 998244353;
int qpo(int a,int b) {int res=1; for(;b;b>>=1,a=(a*a)%mod) if(b&1) res=res*a%mod; return res; }
int inv(int a) {return qpo(a,mod-2); }

#define maxn 200050

int n,k;
vector<pair<int,int> > G[maxn];
int siz[maxn],mxs[maxn];
bool del[maxn];
int res[1000100],ans=1e18;
void T_Div(int u,int sz) {
	int mnsiz=1e18,rt=-1;
	function<void(int,int)> dfs1=[&](int u,int fa) {
		siz[u]=1,mxs[u]=0;
		for(auto [v,w]:G[u]) if(v!=fa&&!del[v]) {
			dfs1(v,u);
			siz[u]+=siz[v];
			mxs[u]=max(mxs[u],siz[v]);
		}
		mxs[u]=max(mxs[u],sz-siz[u]);
		if(mxs[u]<mnsiz) mnsiz=mxs[u],rt=u;
	};
	dfs1(u,-1);
	res[0]=0;
//	cout<<u<<' '<<sz<<' '<<rt<<'\n';
	auto calc =[&](vector<PII > v) {
//		for(auto [x,qwq]:v) cout<<"{"<<x<<' '<<qwq<<"} "; puts("");
		for(auto [x,qwq]:v) if(k>=x) ans=min(ans,qwq+res[k-x]);
		for(auto [x,qwq]:v) if(x<=k) res[x]=min(res[x],qwq);
	};
	vector<PII > v1,v2;
	v2.clear(); v1.clear();
	for(auto [v,w]:G[rt]) if(!del[v]) {
		v2.clear();
		function<void(int,int,int,int)> dfs2= [&](int u,int fa,int val,int dep) {
			v1.push_back({val,dep});
			v2.push_back({val,dep});
			siz[u]=1;
			for(auto [v,w]:G[u]) if(!del[v]&&v!=fa) {
				dfs2(v,u,val+w,dep+1);
				siz[u]+=siz[v];
			}
		};
		dfs2(v,rt,w,1);
		calc(v2);
	}
	for(auto [x,qwq]:v1) if(x<=k) res[x]=1e18;
	del[rt]=1;
	for(auto [v,w]:G[rt]) if(!del[v]) T_Div(v,siz[v]);
}

void work() {
	For(i,1,1000000) res[i]=1e18;
	in2(n,k);
	For(i,2,n) {
		int u,v,w;
		in3(u,v,w);
		u++,v++;
		G[u].push_back({v,w});
		G[v].push_back({u,w});
	}
	T_Div(1,n);
	if(ans==1e18) cout<<-1<<'\n';
	else cout<<ans<<'\n';
}

signed main() {
//	freopen("data.in","r",stdin);
//	freopen("myans.out","w",stdout);
//	ios::sync_with_stdio(false); 
//	cin.tie(0); cout.tie(0);
	double stt=clock();
	int _=1;
//	_=read();
//	cin>>_;
	For(i,1,_) {
		work();
	}
	cerr<<"\nTotal Time is:"<<(clock()-stt)*1.0/1000<<" second(s)."<<'\n';
	return 0;
}
posted @ 2025-03-21 16:51  coding_goat_qwq  阅读(28)  评论(0)    收藏  举报