树上差分

树上差分与线性差分差不多,只不过是在树上进行差分,每次将两个点x和y的标志加1,将lca(x,y)和fa(lca(x,y))的标志减1,最后来一次深搜求和,就可以得到值了

下面给出几道例题

1.P3128 [USACO15DEC] Max Flow P

解析:

树上差分板子题,直接套班子,求完值后,求最大值即可

代码:

#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N = 2e6+39+7;
int depth[N],f[N][21],n,head[N],tot,k,cnt[N],ans;
struct node{
	int u,v;
}edge[N<<1];
void add(int x,int y){
	edge[++tot].u=head[x];
	edge[tot].v=y;
	head[x]=tot;
}
void dfs(int u,int fa){
	depth[u]=depth[fa]+1;
	f[u][0]=fa;
	for(int i=1;(1<<i)<=depth[u];i++)f[u][i]=f[f[u][i-1]][i-1];
	for(int i=head[u];i;i=edge[i].u){
		if(edge[i].v==fa)continue;
		dfs(edge[i].v,u);
	}
}
int lca(int x,int y){
	if(depth[x]>depth[y])swap(x,y);
	for(int i=20;i>=0;i--)if(depth[y]-(1<<i)>=depth[x])y=f[y][i];
	if(x==y)return x;
	for(int i=20;i>=0;i--){
		if(f[x][i]==f[y][i])continue;
		x=f[x][i],y=f[y][i];
	}
	return f[x][0];
}
void dfss(int u,int fa){
	for(int i=head[u];i;i=edge[i].u){
		if(edge[i].v==fa)continue;
		dfss(edge[i].v,u);
		cnt[u]+=cnt[edge[i].v];
	}
}
int main(){
	cin>>n>>k;
	for(int i=1,a,b;i<n;i++)cin>>a>>b,add(a,b),add(b,a);
	dfs(1,0);
	for(int i=1,x,y,la;i<=k;i++){
		cin>>x>>y;
		la=lca(x,y);
		cnt[x]++;cnt[y]++;
		cnt[la]--;cnt[f[la][0]]--;
	}
	dfss(1,0);
	for(int i=1;i<=n;i++)ans=max(ans,cnt[i]);
	cout<<ans;
	return 0;
}

  

2.P3258 [JLOI2014] 松鼠的新家

解析:

树上差分板子题,每个点都会多算1次,所以,在深搜求完值之后,需要把每个数减1,依次输出即可

代码:

#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N = 2e6+39+7;
int a[N],depth[N],f[N][21],n,head[N],tot,k,cnt[N],ans;
struct node{
	int u,v;
}edge[N<<1];
void add(int x,int y){
	edge[++tot].u=head[x];
	edge[tot].v=y;
	head[x]=tot;
}
void dfs(int u,int fa){
	depth[u]=depth[fa]+1;
	f[u][0]=fa;
	for(int i=1;(1<<i)<=depth[u];i++)f[u][i]=f[f[u][i-1]][i-1];
	for(int i=head[u];i;i=edge[i].u){
		if(edge[i].v==fa)continue;
		dfs(edge[i].v,u);
	}
}
int lca(int x,int y){
	if(depth[x]>depth[y])swap(x,y);
	for(int i=20;i>=0;i--)if(depth[y]-(1<<i)>=depth[x])y=f[y][i];
	if(x==y)return x;
	for(int i=20;i>=0;i--){
		if(f[x][i]==f[y][i])continue;
		x=f[x][i],y=f[y][i];
	}
	return f[x][0];
}
void dfss(int u,int fa){
	for(int i=head[u];i;i=edge[i].u){
		if(edge[i].v==fa)continue;
		dfss(edge[i].v,u);
		cnt[u]+=cnt[edge[i].v];
	}
}
int main(){
	cin>>n;
	for(int i=1;i<=n;i++)cin>>a[i];
	for(int i=1,a,b;i<n;i++){
		cin>>a>>b;
		add(a,b);
		add(b,a);
	}
	dfs(1,0);
	for(int i=1,LCA;i<n;i++){
		LCA=lca(a[i],a[i+1]);
		cnt[a[i]]++;cnt[a[i+1]]++;
		cnt[LCA]--;cnt[f[LCA][0]]--;
	}
	dfss(1,0);
	for(int i=2;i<=n;i++)cnt[a[i]]--;
	for(int i=1;i<=n;i++)cout<<cnt[i]<<'\n';
	return 0;
}

  

3.P2680 [NOIP2015 提高组] 运输计划

解析:

这道题使用了树上差分和记录路径的方法,预处理init数组,fa数组,dep数组等,进行求解,使用静态算法,存储每一次的问题,和两点之间的距离和lca,使用二分枚举时间,即可得到答案

代码:

#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N = 1e6+39+7;
struct node{
	int x,y,lca,dis;
	bool operator <(const node &a)const{
		return dis<a.dis;
	}
}query[N];
struct edg{
	int to,next,w;
}e[N<<1];
int l,r,m,dep[N],fa[N][21],d[N],n,head[N],tot=-1,k,ans,init[N],cnt[N];
void add(int x,int y,int z){
	e[++tot]=(edg){y,head[x],z};
	head[x]=tot;
}
void dfs(int x,int father,int dis){
	dep[x]=dep[father]+1;
	fa[x][0]=father;init[x]=dis;
	for(int i=1;(1<<i)<=dep[x];i++)fa[x][i]=fa[fa[x][i-1]][i-1];
	for(int i=head[x];~i;i=e[i].next){
		int y=e[i].to;
		if(y==father)continue;
		d[y]=d[x]+e[i].w;
		dfs(y,x,e[i].w);
	}
}
int lca(int x,int y){
	if(dep[x]>dep[y])swap(x,y);
	for(int i=20;i>=0;i--)if(dep[y]-(1<<i)>=dep[x])y=fa[y][i];
	if(x==y)return x;
	for(int i=20;i>=0;i--){
		if(fa[x][i]==fa[y][i])continue;
		x=fa[x][i];y=fa[y][i];
	}
	return fa[x][0];
}
void dfss(int u,int father){
	for(int i=head[u];~i;i=e[i].next){
		int y=e[i].to;
		if(y==father)continue;
		dfss(y,u);
		cnt[u]+=cnt[y];
	}
}
bool ok(int x){
	int num=0,now=0;
	for(int i=1;i<=n;i++)cnt[i]=0;
	for(int i=1;i<=m;i++){
		if(query[i].dis<=x)continue;
		cnt[query[i].x]++;cnt[query[i].y]++;
		cnt[query[i].lca]-=2;
		num++;
	}
	dfss(1,0);
	for(int i=1;i<=n;i++)if(cnt[i]==num)now=max(now,init[i]);
	return query[m].dis-now<=x;
}
int main(){
	memset(head,-1,sizeof(head));
	cin>>n>>m;
	for(int i=1,x,y,z;i<n;i++){
		cin>>x>>y>>z;
		add(x,y,z);add(y,x,z);
	}
	dfs(1,0,0);
	for(int i=1,x,y;i<=m;i++){
		cin>>x>>y;
		query[i].lca=lca(x,y);
		query[i].dis=d[x]+d[y]-2*d[query[i].lca];
		r=max(r,query[i].dis);
		query[i].x=x;query[i].y=y;
	}
	sort(query+1,query+m+1);
	while(l<=r){
		int mid=(l+r)/2;
		if(ok(mid))r=mid-1;
		else l=mid+1;
	}
	cout<<l;
	return 0;
}

  

posted @ 2023-09-08 22:45  天雷小兔  阅读(24)  评论(0)    收藏  举报