点分治学习笔记

点分治学习笔记

基本思想

现在假设我们有一颗以\(root\)为根的树

考虑所有树中的路径。

我们发现:要么该路径经过\(root\)(包括端点在\(root\)上的路径),要么该路径在\(root\)的子树

那么,对于这些路径,我们只要先处理经过\(root\)的路径,对于剩下的路径,我们分别以它的孩子为根进行同样的操作即可(因为剩下的路径一定在以\(root\)为根的子树里)。这便是点分治

![img](file:///C:\Users\ybwowen\Documents\Tencent Files\1804820207\Image\Group$TDA\(38[Q4O25VR[%MPO[\)U.png)

放个图充实版面

例题:POJ 1741 Tree

题目大意:

给你一棵树,求点对\((u,v)\)的数量,使得\(u\)\(v\)之间的距离不超过\(k\)

Solution:

考虑点分治。

对于当前树的每个节点,我们记录深度\(dep_i\)(根节点深度为\(0\)),以及除根节点以外最远的祖先\(fa_i\)

我们现在只考虑经过根节点的路径数。

于是我们得到:

\[\sum[fa_i\ne fa_j \land dep_i+dep_j\le k] \]

其中:

\([]\)表示当表达式为真是取值为\(1\),否则为\(0\)

\(\land\) 表示逻辑与

该式等于:

\[\sum [dep_i+dep_j \le k]-\sum[fa_i=fa_j \land dep_i+dep_j \le k] \]

显然,式子的后边可以递归分治地解决(相当于以\(fa_i\)为根)

我们只要考虑前边怎么算了。

而前面只需要一个\(two\) \(pointers\) 就解决了:

我们把所有的\(dep\)排个序,让\(l\)指向最小的,让\(r\)指向最大的那个

从左到右遍历\(l\),如果\(dep_l+dep_r\le k\),则说明对于任意满足\(l < t <= r\)\(t\),都有\(dep_l+dep_t\le k\)

即此时所有点对\((l,t)\)都符合调价。于是我们只需将答案加上\(r-l\)即可

注意:为防止题目出一条链来卡点分治,使复杂度退化为\(O(n^2)\) ,我们需要将树的重心作为根节点,以保证复杂度为\(O(n log n)\)

Q:这做法复杂度看上去像\(O(n^2)\)的啊

A:不是的。如果我们以树的重心为根节点,我们至多下降\(logn\)层,而每一层的操作是\(O(n)\)的,所以总体的时间复杂度是\(O(n log n)\)

Code:

#include<bits/stdc++.h>
using namespace std;
const int maxn=4e4+5;
struct Node{
	int to;
	int next;
	int w;
}edge[maxn<<1];
int head[maxn],cnt;
inline void add(int x,int y,int w){
	edge[++cnt].next=head[x];
	edge[cnt].to=y;
	edge[cnt].w=w;
	head[x]=cnt;
}
int n;
int K;
int size[maxn];
int f[maxn];
int ans;
int dis[maxn];
vector<int>dep;
int v[maxn];
int tot;
int root;
inline void dfs1(int x,int fa){
	size[x]=1; f[x]=0;
	for(int i=head[x];i!=0;i=edge[i].next){
		int k=edge[i].to;
		if(k==fa) continue;
		if(v[k]) continue;
		dfs1(k,x);
		size[x]+=size[k];
		f[x]=max(f[x],size[k]); 
	}
	f[x]=max(f[x],tot-f[x]);
	if(f[x]<f[root]) root=x;
}
inline void dfs2(int x,int fa){
	dep.push_back(dis[x]);
	size[x]=1;
	for(int i=head[x];i!=0;i=edge[i].next){
		int k=edge[i].to;
		int w=edge[i].w;
		if(v[k]) continue;
		if(k==fa) continue;
		dis[k]=dis[x]+w;
		dfs2(k,x);
		size[x]+=size[k];
	}
}
inline int calc(int x,int tmp){
	int res=0;
	dep.clear();
	dis[x]=tmp;
	dfs2(x,0);
	sort(dep.begin(),dep.end());
	int l=0; int r=dep.size()-1;
	while(l<r){
		if(dep[l]+dep[r]<=K) res+=r-l,l++;
		else r--;
	}
	return res;
}
inline void dfs3(int x){
	ans+=calc(x,0); 
	v[x]=1;
	for(int i=head[x];i!=0;i=edge[i].next){
		int k=edge[i].to;
		int w=edge[i].w;
		if(v[k]) continue;
		ans-=calc(k,w);
		f[0]=tot=size[k];
		dfs1(k,0);
		dfs3(k);	
	}
}
int main(){
	scanf("%d",&n);
	for(int i=1;i<=n-1;i++){
		int x,y,w;
		scanf("%d%d%d",&x,&y,&w);
		add(x,y,w);
		add(y,x,w);
	}
	scanf("%d",&K);
	f[0]=n; tot=n;
	dfs1(1,0);
	dfs3(root);
	printf("%d\n",ans);
	return 0;
}
posted @ 2019-06-10 21:53  ybwowen  阅读(192)  评论(2编辑  收藏  举报