点分治学习笔记

点分治

采用分治思想。对树上路径问题进行查询时,把路径分成两部分,一部分是经过根节点的路径,一部分是不经过根节点的路径。

而在处理不经过根节点的路径时,可以才有分治思想,递归到左右子树进行求解。

这样复杂度是 \(O(n^2)\) 的,但是若我们每次选取的根节点都是要求解的子树的重心,则复杂度可以优化到 \(O(nlogn)\)

实现思路

需要实现以下函数:

  • solve:分治过程,不断取重心分治。
  • getzx:求重心。
  • calc:计算经过当前点的路径对答案的贡献。

例题

洛谷 P3806 【模板】点分治1

传送门
开一个数组记录当前的路径长度有哪些。
先读入所有询问,然后到达一个点就更新一下答案。
总复杂度 \(O(mnlogn)\)

AC代码

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<queue>
#include<stack>
#include<map>
#include<vector>
using namespace std;
const int maxn=10010,bign=10001000;

int n,m,tmp[bign],judge[bign];
int sz[maxn],vis[maxn];
int head[maxn],q[maxn];
int size,maxp[maxn];
int tot,rt,dis[maxn];
int qqq[maxn],ynn[maxn],cnt,p[maxn];

struct node{
	int v,next,w;
}e[maxn*2];
void insert(int u,int v,int w){
	cnt++;
	e[cnt].v=v;
	e[cnt].w=w;
	e[cnt].next=p[u];
	p[u]=cnt;
}

void getzx(int u,int fa){
	maxp[u]=0;
	sz[u]=1;
	for(int i=p[u];i!=-1;i=e[i].next){
		int v=e[i].v;
		if(vis[v]||v==fa) continue;
		getzx(v,u);
		sz[u]+=sz[v];
		maxp[u]=max(maxp[u],sz[v]); 
	}
	maxp[u]=max(maxp[u],tot-sz[u]);
	if(maxp[u]<maxp[rt]) rt=u;
}
inline void getdis(int u,int fa){
	tmp[++tmp[0]]=dis[u];
	for(int i=p[u];i!=-1;i=e[i].next){
		int v=e[i].v;
		if(vis[v]||v==fa) continue;
		dis[v]=dis[u]+e[i].w;
		getdis(v,u);
	}
}
inline void calc(int u){
	int ppp=0;
	for(int i=p[u];i!=-1;i=e[i].next){
		int v=e[i].v;
		if(vis[v]) continue;
		tmp[0]=0;
		dis[v]=e[i].w;
		getdis(v,u);
		for(int j=1;j<=tmp[0];j++){
			for(int k=1;k<=m;k++){
				if(q[k]>=tmp[j]) ynn[k]|=judge[q[k]-tmp[j]];
			}
		}
		for(int j=1;j<=tmp[0];j++){
			if(tmp[j]>=bign) continue;
			qqq[++ppp]=tmp[j];
			judge[tmp[j]]=1;
		}
	}
	for(int i=1;i<=ppp;i++) judge[qqq[i]]=0;
}
inline void solve(int u){
	vis[u]=judge[0]=1; calc(u);
	for(int i=p[u];i!=-1;i=e[i].next){
		int v=e[i].v;
		if(vis[v]) continue;
		tot=sz[v];
		maxp[rt=0]=sz[v];
		getzx(v,0);
		solve(rt);
	}
}
int main(){
	ios::sync_with_stdio(false);
	memset(p,-1,sizeof(p));
	cin>>n>>m;
	for(int i=1;i<n;i++){
		int u,v,w;
		cin>>u>>v>>w;
		insert(u,v,w);
		insert(v,u,w); 
	}
	for(int i=1;i<=m;i++) cin>>q[i];
	maxp[rt=0]=n;
	tot=n;
	getzx(1,0);
	solve(rt);
	for(int i=1;i<=m;i++){
		if(ynn[i]) cout<<"AYE"<<endl;
		else cout<<"NAY"<<endl;
	}
	return 0;
}

CF161D Distance in Tree

传送门
数组记录的内容变成当前长度为x的路径的数量。
其他和板子基本相同。

AC代码

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<queue>
#include<stack>
#include<map>
#include<vector>
using namespace std;
template<class T>inline void read(T &x)
{
    x=0;register char c=getchar();register bool f=0;
    while(!isdigit(c))f^=c=='-',c=getchar();
    while(isdigit(c))x=(x<<3)+(x<<1)+(c^48),c=getchar();
    if(f)x=-x;
}
template<class T>inline void print(T x)
{
    if(x<0)putchar('-'),x=-x;
    if(x>9)print(x/10);
    putchar('0'+x%10);
}

const int maxn=50005;
long long ans;
int n,k,judge[maxn],tmp[maxn],siz[maxn],vis[maxn],p[maxn],cnt,tot,maxp[maxn],rt,dis[maxn],q[maxn];
struct node{
	int v,next;
}e[maxn*2];
void insert(int u,int v){
	cnt++;
	e[cnt].v=v;
	e[cnt].next=p[u];
	p[u]=cnt;
}
void getzx(int u,int fa){
	maxp[u]=0;
	siz[u]=1;
	for(int i=p[u];i!=-1;i=e[i].next){
		int v=e[i].v;
		if(vis[v]||v==fa) continue;
		getzx(v,u);
		siz[u]+=siz[v];
		maxp[u]=max(maxp[u],siz[v]);
	}
	maxp[u]=max(maxp[u],tot-siz[u]);
	if(maxp[u]<=maxp[rt]) rt=u;
}
void getdis(int u,int fa){
	tmp[++tmp[0]]=dis[u];
	for(int i=p[u];i!=-1;i=e[i].next){
		int v=e[i].v;
		if(v==fa||vis[v]) continue;
		dis[v]=dis[u]+1;
		getdis(v,u); 
	}
}
void calc(int u){
	int cntq=0;
	dis[u]=0;
	judge[0]=1;
	for(int i=p[u];i!=-1;i=e[i].next){
		int v=e[i].v;
		if(vis[v]) continue;
		tmp[0]=0;
		dis[v]=1;
		getdis(v,u);
		for(int j=1;j<=tmp[0];j++) if(k>=tmp[j]) ans+=judge[k-tmp[j]];
		for(int j=1;j<=tmp[0];j++) judge[tmp[j]]++,q[++cntq]=tmp[j];
	}
	for(int i=1;i<=cntq;i++) judge[q[i]]--;
}
void solve(int u){
	vis[u]=1;calc(u);
	for(int i=p[u];i!=-1;i=e[i].next){
		int v=e[i].v;
		if(vis[v]) continue;
		maxp[rt=0]=tot=siz[v];
		getzx(v,-1);
		solve(rt);
	}
}
int main(){
	memset(p,-1,sizeof(p));
	read(n);read(k);
	for(int i=1;i<n;i++){
		int u,v;
		read(u);read(v);
		insert(u,v);
		insert(v,u);
	}
	maxp[rt=0]=tot=n;
	getzx(1,-1);
	solve(rt);
	print(ans);
	return 0;
}

洛谷 P4149 [IOI2011]Race

开一个数组记录当前路径权值;
开一个数组记录当前权值和为x的路径的最少的边数。
两个数组同时求、清空。

AC代码

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<queue>
#include<stack>
#include<map>
#include<vector>
using namespace std;
template<class T>inline void read(T &x)
{
    x=0;register char c=getchar();register bool f=0;
    while(!isdigit(c))f^=c=='-',c=getchar();
    while(isdigit(c))x=(x<<3)+(x<<1)+(c^48),c=getchar();
    if(f)x=-x;
}
template<class T>inline void print(T x)
{
    if(x<0)putchar('-'),x=-x;
    if(x>9)print(x/10);
    putchar('0'+x%10);
}

const int maxn=2e5+5;
const int maxm=1e6+5;
int n,k,judge[maxm],tmp[maxn],siz[maxn],vis[maxn],p[maxn],cnt,tot,maxp[maxn],rt,dis[maxn],q[maxn],dep[maxn],tmp2[maxn],anss[maxm],ans=0x3f3f3f3f;
struct node{
	int v,next,w;
}e[maxn*2];
void insert(int u,int v,int w){
	cnt++;
	e[cnt].v=v;
	e[cnt].w=w;
	e[cnt].next=p[u];
	p[u]=cnt;
}
void getzx(int u,int fa){
	maxp[u]=0;
	siz[u]=1;
	for(int i=p[u];i!=-1;i=e[i].next){
		int v=e[i].v;
		if(vis[v]||v==fa) continue;
		getzx(v,u);
		siz[u]+=siz[v];
		maxp[u]=max(maxp[u],siz[v]);
	}
	maxp[u]=max(maxp[u],tot-siz[u]);
	if(maxp[u]<=maxp[rt]) rt=u;
}
void getdis(int u,int fa){
	tmp[++tmp[0]]=dis[u];tmp2[tmp[0]]=dep[u];
	for(int i=p[u];i!=-1;i=e[i].next){
		int v=e[i].v;
		if(v==fa||vis[v]) continue;
		dis[v]=dis[u]+e[i].w;
		dep[v]=dep[u]+1;
		getdis(v,u); 
	}
}
void calc(int u){
	int cntq=0;
	dis[u]=0;
	anss[0]=0;
	judge[0]=1;
	for(int i=p[u];i!=-1;i=e[i].next){
		int v=e[i].v;
		if(vis[v]) continue;
		tmp[0]=0;
		dep[v]=1;
		dis[v]=e[i].w;
		getdis(v,u);
		for(int j=1;j<=tmp[0];j++) if(k>=tmp[j]&&judge[k-tmp[j]]) ans=min(ans,anss[k-tmp[j]]+tmp2[j]);
		for(int j=1;j<=tmp[0];j++){
			if(tmp[j]>k) continue;
			judge[tmp[j]]=1;
			anss[tmp[j]]=min(anss[tmp[j]],tmp2[j]);
			q[++cntq]=tmp[j];
		}
	}
	for(int i=1;i<=cntq;i++) judge[q[i]]=0,anss[q[i]]=0x3f3f3f3f;
}
void solve(int u){
	vis[u]=1;calc(u);
	for(int i=p[u];i!=-1;i=e[i].next){
		int v=e[i].v;
		if(vis[v]) continue;
		maxp[rt=0]=tot=siz[v];
		getzx(v,-1);
		solve(rt);
	}
}
int main(){
	memset(p,-1,sizeof(p));
	memset(anss,0x3f,sizeof(anss));
	read(n);read(k);
	for(int i=1;i<n;i++){
		int u,v,w;
		read(u);read(v);read(w);
		u++;v++;
		insert(u,v,w);
		insert(v,u,w);
	}
	maxp[rt=0]=tot=n;
	getzx(1,-1);
	solve(rt);
	print((ans==0x3f3f3f3f?-1:ans));
	return 0;
}

洛谷 P4178 Tree

传送门
一种做法是充斥一下,但是感觉好麻烦而且常数很大,所以我采用树状数组。
加路径相当于单点修改,更新答案时查询前缀和。
注意先更新答案,后更新存路径数量的桶。

AC代码

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<queue>
#include<stack>
#include<map>
#include<vector>
using namespace std;
template<class T>inline void read(T &x)
{
    x=0;register char c=getchar();register bool f=0;
    while(!isdigit(c))f^=c=='-',c=getchar();
    while(isdigit(c))x=(x<<3)+(x<<1)+(c^48),c=getchar();
    if(f)x=-x;
}
template<class T>inline void print(T x)
{
    if(x<0)putchar('-'),x=-x;
    if(x>9)print(x/10);
    putchar('0'+x%10);
}
const int maxn=4e4+5;
const int maxx=2e4+5;
int n,cnt,num,p[maxn],d[maxn],siz[maxn],maxp[maxn],tot,k,dis[maxn],tmp[maxn],ans,vis[maxn],rt;
struct node{
	int v,next,w;
}e[maxn*2];
void insert(int u,int v,int w){
	cnt++;
	e[cnt].v=v;
	e[cnt].w=w;
	e[cnt].next=p[u];
	p[u]=cnt;
}
inline int lowbit(int x){
	return x&-x;
}
void update(int x,int v){
	for(int i=x;i<maxx;i+=lowbit(i)) d[i]+=v; 
}
int query(int x){
	int res=0;
	for(int i=x;i>=1;i-=lowbit(i)) res+=d[i];
	return res;
}
void getzx(int u,int fa){
	siz[u]=1;
	maxp[u]=0;
	for(int i=p[u];i!=-1;i=e[i].next){
		int v=e[i].v;
		if(v==fa||vis[v]) continue;
		getzx(v,u);
		siz[u]+=siz[v];
		maxp[u]=max(maxp[u],siz[v]);
	}
	maxp[u]=max(maxp[u],tot-siz[u]);
	if(maxp[u]<maxp[rt]) rt=u;
}
void getdis(int u,int fa){
	tmp[++tmp[0]]=dis[u];
	for(int i=p[u];i!=-1;i=e[i].next){
		int v=e[i].v;
		if(v==fa||vis[v]) continue;
		dis[v]=dis[u]+e[i].w;
		getdis(v,u);
	}
}
void cal(int u){
	int num=0,q[maxn];
	dis[u]=0;
	for(int i=p[u];i!=-1;i=e[i].next){
		int v=e[i].v;
		if(vis[v]) continue;
		dis[v]=e[i].w;
		tmp[0]=0;
		getdis(v,u);
		for(int j=1;j<=tmp[0];j++){
			if(k>tmp[j]) ans+=query(k-tmp[j]);
			if(k>=tmp[j]) ans++;
		}
		for(int j=1;j<=tmp[0];j++){
			if(k>=tmp[j]){
				q[++num]=tmp[j];
				update(tmp[j],1);
			}
		}
	}
	for(int i=1;i<=num;i++) update(q[i],-1);
}
void solve(int u){
	vis[u]=1;
	cal(u);
	for(int i=p[u];i!=-1;i=e[i].next){
		int v=e[i].v;
		if(vis[v]) continue;
		tot=maxp[rt=0]=siz[v];
		getzx(v,u);
		solve(v);
	}
}
int main(){
	memset(p,-1,sizeof(p));
	read(n);
	for(int i=1;i<n;i++){
		int u,v,w;
		read(u);read(v);read(w);
		insert(u,v,w);
		insert(v,u,w);
	}
	read(k);
	maxp[rt=0]=tot=n;
	getzx(1,-1);
	solve(1);
	print(ans);
	return 0;
}
posted @ 2021-11-01 17:02  尹昱钦  阅读(40)  评论(0编辑  收藏  举报