点分治

先写静态点分治,带修改的还没学,咕咕咕


点分治是用于处理树上简单路径统计的一种算法,利用分治的思想,对每一课子树统计答案,最后累加(看起来就很暴力

所以我们要对其进行优化,将每一棵树按重心进行分割,再逐个处理子树,整体复杂度在 \(O(nlog_n)\) 左右

求重心

需要 \(dfs\) 一遍,对每一个节点开一个变量记录子树中最大的子树的 \(size\) ,让最大的 \(size\) 最小即可

点击查看代码
void get(int x,int fa)
{
	size[x]=1,wt[x]=0;
	for(int i=head[x];i;i=nxt[i])
	{
		int y=to[i];
		if(y==fa||vis[y])continue;
		get(y,x);
		size[x]+=size[y],wt[x]=max(wt[x],size[y]);
	}
	wt[x]=max(wt[x],siz-size[x]);
	if(wt[root]>wt[x])root=x;
}

分治过程

分治的方法大概有两种,一是求整棵树对答案的贡献,再把子树中不合法的去了,类似容斥,二是一个一个子树合并来统计答案

相比而言代码量差不多,但第二个更泛用一些。

板子什么的我就随便一放,毕竟题和题的代码不是完全一样的。。。

方法一,摘自《聪聪可可》
void lsx(int x,int d,int fa)
{
	arr[++cnt]=d;
	for(int i=head[x];i;i=nxt[i])
	{
		int y=to[i];
		if(y==fa||vis[y])continue;
		lsx(y,d+val[i],x);
	}
}

int calc(int x,int d){
    cnt=0; lsx(x,d,0); int l=1,r=cnt,sum=0;
    sort(arr+1,arr+cnt+1);
    for(;;++l){
        while(r&&arr[l]+arr[r]>k) --r;
        if(r<l) break;
        sum+=r-l+1;
    }
    return sum;
}

void solve(int x){
    ans+=calc(x,0); vis[x]=1;
    for(int i=head[x];i;i=nxt[i])
    {
    	int y=to[i];
    	if(vis[y])continue;
    	ans-=calc(y,val[i]);
        root=0, siz=size[y],get(y,0);
        solve(root);
	}
}
方法二,摘自《Race》
void lsx(int x,int d,int fa,int deep)
{
	if(d>k)return ;
	arr[++cnt]=d;
	c[d]=min(c[d],deep);
	for(int i=head[x];i;i=nxt[i])
	{
		int y=to[i];
		if(y==fa||vis[y])continue;
		lsx(y,d+val[i],x,deep+1);
	}
}

void solve(int x,int fa)
{
	vis[x]=1;
	b[0]=0,q[0]++,a[++sum]=0;
	for(int i=head[x];i;i=nxt[i])
	{
		int y=to[i];
		if(vis[y]||y==fa)continue;
		lsx(y,val[i],x,1);
		for(int j=1;j<=cnt;j++)
		{
			if(q[k-arr[j]]) ans=min(ans,(long long)c[arr[j]]+b[k-arr[j]]);
		}
		for(int j=1;j<=cnt;j++)
		{
			b[arr[j]]=min(b[arr[j]],c[arr[j]]);
			c[arr[j]]=0x7f7f7f;
		}
		for(int j=1;j<=cnt;j++) q[arr[j]]++,a[++sum]=arr[j];
		cnt=0;
	}
	for(int i=1;i<=sum;i++) b[a[i]]=0x7f7f7f,q[a[i]]--;
	sum=0;
	for(int i=head[x];i;i=nxt[i])
	{
		int y=to[i];
		if(vis[y]||y==fa)continue;
		root=0,siz=size[y],get(y,0);
		solve(root,0);
	}	
}

题目

鉴于洛谷被封ip了,就不放链接了。。。

  • 1:《聪聪可可》

开桶记录一下模3后为0,1,2,的边的个数,直接算即可

点击查看代码
#include<bits/stdc++.h>
const int maxn=1e5+10; 
using namespace std;
int n,k,ans,root,size[maxn],siz,wt[maxn],arr[maxn],cnt;
int head[maxn],nxt[maxn<<1],to[maxn<<1],val[maxn<<1],tot;
int f[3];
bool vis[maxn];

void add(int x,int y,int z)
{
	to[++tot]=y;
	val[tot]=z;
	nxt[tot]=head[x];
	head[x]=tot;
}
void addm(int x,int y,int z)
{
	add(x,y,z);add(y,x,z);
}

void get(int x,int fa)
{
	size[x]=1,wt[x]=0;
	for(int i=head[x];i;i=nxt[i])
	{
		int y=to[i];
		if(y==fa||vis[y])continue;
		get(y,x);
		size[x]+=size[y],wt[x]=max(wt[x],size[y]);
	}
	wt[x]=max(wt[x],siz-size[x]);
	if(wt[root]>wt[x])root=x;
}

void lsx(int x,int d,int fa)
{
	f[d%3]++;
	for(int i=head[x];i;i=nxt[i])
	{
		int y=to[i];
		if(y==fa||vis[y])continue;
		lsx(y,d+val[i],x);
	}
}

int calc(int x,int d)
{
	memset(f,0,sizeof f);
	lsx(x,d,0);
	return f[0]*(f[0]-1)/2+f[1]*f[2]; 
}

void solve(int x)
{
	ans+=calc(x,0);vis[x]=1;
	for(int i=head[x];i;i=nxt[i])
	{
		int y=to[i];
		if(vis[y])continue;
		ans-=calc(y,val[i]);
		root=0,siz=size[y],get(y,0);
		solve(root);
	}
}

int main()
{
	ios::sync_with_stdio(0);
	cin.tie(0),cout.tie(0);
	cin>>n;
	for(int i=1;i<n;i++)
	{
		int x,y,z;
		cin>>x>>y>>z;
		addm(x,y,z);
	}
	wt[root=0]=0x7f7f7f;
	siz=n;
	get(1,0);
	solve(root);
	int a=ans*2+n,b=n*n,p=__gcd(a,b);
	cout<<a/p<<"/"<<b/p<<'\n';

	return 0;
}
  • 2: 《Race》

记录一个每个边权是否出现,所用的最小边数,这里方法一不太适用,所用只能按子树合并,直接把已合并的子树和要合并

的子树的贡献统计即可,记得清空

点击查看代码
#include<bits/stdc++.h>
const int maxn=2e5+10; 
using namespace std;
int n,k,root,size[maxn],siz,wt[maxn],arr[maxn],cnt,b[1000005],c[1000005];
int head[maxn],nxt[maxn<<1],to[maxn<<1],val[maxn<<1],tot,sum,a[1000005],q[1000005];
long long ans;
bool vis[maxn];

void add(int x,int y,int z)
{
	to[++tot]=y;
	val[tot]=z;
	nxt[tot]=head[x];
	head[x]=tot;
}
void addm(int x,int y,int z)
{
	add(x,y,z);add(y,x,z);
}

void get(int x,int fa)
{
	size[x]=1,wt[x]=0;
	for(int i=head[x];i;i=nxt[i])
	{
		int y=to[i];
		if(y==fa||vis[y])continue;
		get(y,x);
		size[x]+=size[y],wt[x]=max(wt[x],size[y]);
	}
	wt[x]=max(wt[x],siz-size[x]);
	if(wt[root]>wt[x])root=x;
}

void lsx(int x,int d,int fa,int deep)
{
	if(d>k)return ;
	arr[++cnt]=d;
	c[d]=min(c[d],deep);
	for(int i=head[x];i;i=nxt[i])
	{
		int y=to[i];
		if(y==fa||vis[y])continue;
		lsx(y,d+val[i],x,deep+1);
	}
}

void solve(int x,int fa)
{
	vis[x]=1;
	b[0]=0,q[0]++,a[++sum]=0;
	for(int i=head[x];i;i=nxt[i])
	{
		int y=to[i];
		if(vis[y]||y==fa)continue;
		lsx(y,val[i],x,1);
		for(int j=1;j<=cnt;j++)
		{
//			cout<<arr[j]<<"! "<<q[k-arr[j]]<<endl;
			if(q[k-arr[j]])
			{
//				cout<<ans<<"!";
				ans=min(ans,(long long)c[arr[j]]+b[k-arr[j]]);
			}
		}
		for(int j=1;j<=cnt;j++)
		{
//			cout<<arr[j]<<"!"<<endl;
			b[arr[j]]=min(b[arr[j]],c[arr[j]]);
			c[arr[j]]=0x7f7f7f;
		}
		for(int j=1;j<=cnt;j++)q[arr[j]]++,a[++sum]=arr[j];
		cnt=0;
	}
	for(int i=1;i<=sum;i++) b[a[i]]=0x7f7f7f,q[a[i]]--;
	sum=0;
	for(int i=head[x];i;i=nxt[i])
	{
		int y=to[i];
		if(vis[y]||y==fa)continue;
		root=0,siz=size[y],get(y,0);
//		cout<<root<<"!"<<'\n';
		solve(root,0);
	}	
}

int main()
{
	ios::sync_with_stdio(0);
	cin.tie(0),cout.tie(0);
	cin>>n>>k;
	for(int i=1;i<n;i++)
	{
		int x,y,z;
		cin>>x>>y>>z;
		x++,y++;
		addm(x,y,z);
		if(z==k)
		{
			cout<<1<<'\n';
			return 0;
		}
	}
	wt[root=0]=0x7f7f7f;
	memset(b,0x7f,sizeof b);
	memset(c,0x7f,sizeof c);
	siz=n;
	ans=1e17;
	get(1,0);
//	cout<<root<<"!"<<'\n'; 
	solve(root,0);
	cout<<(ans>=n?-1:ans)<<'\n';

	return 0;
}
/*
4 3
0 1 1
1 2 2
2 3 4
*/
  • 3:《tree》

对答案贡献的只有过根的路径,把到子树根的距离都统计,双指针统计即可

点击查看代码
#include<bits/stdc++.h>
const int maxn=4e4+10;
using namespace std;
int n,k,ans,root,size[maxn],siz,wt[maxn],arr[10001],cnt;
int head[maxn],nxt[maxn<<1],to[maxn<<1],val[maxn<<1],tot;
bool vis[maxn];

void add(int x,int y,int z)
{
	to[++tot]=y;
	val[tot]=z;
	nxt[tot]=head[x];
	head[x]=tot;
}

void get(int x,int fa)
{
	size[x]=1;wt[x]=0;
	for(int i=head[x];i;i=nxt[i])
	{
		int y=to[i];
		if(y==fa||vis[y])continue;
		get(y,x);
		size[x]+=size[y],wt[x]=max(wt[x],size[y]);
	}
	wt[x]=max(wt[x],siz-size[x]);
	if(wt[root>wt[x]])root=x;
}

void dfs(int x,int d,int fa)
{
	arr[++cnt]=d;
	for(int i=head[x];i;i=nxt[i])
	{
		int y=to[i];
		if(y==fa||vis[y])continue;
		dfs(y,d+val[i],x);
	}
}

int calc(int x,int d){
    cnt=0; dfs(x,d,0); int l=1,r=cnt,sum=0;
    sort(arr+1,arr+cnt+1);
    for(;;++l){
        while(r&&arr[l]+arr[r]>k) --r;
        if(r<l) break;
        sum+=r-l+1;
    }
    return sum;
}

void solve(int x){
    ans+=calc(x,0); vis[x]=1;
    for(int i=head[x];i;i=nxt[i])
    {
    	int y=to[i];
    	if(vis[y])continue;
    	ans-=calc(y,val[i]);
        root=0, siz=size[y],get(y,0);
        solve(root);
	}
}

int main()
{
	ios::sync_with_stdio(0);
	cin.tie(0),cout.tie(0);
	cin>>n;
	for(int i=1;i<n;i++)
	{
		int x,y,z;
		cin>>x>>y>>z;
		add(x,y,z);add(y,x,z);
	}
	cin>>k;
	wt[root=0]=0x7f7f7f;
	siz=n;
	get(1,0);
	solve(root);
	cout<<ans-n;
	return 0;
}

/*
7
1 6 13 
6 3 9 
3 5 7 
4 1 3 
2 4 20 
4 7 2 
10
*/


posted @ 2024-09-04 21:31  _君の名は  阅读(26)  评论(0)    收藏  举报