谁都能懂的树链剖分

树链剖分

警钟敲烂

  • 赋值赋反
  • 线段树L,R和l,r分不清(会导致p疯狂扩张)
  • p<<1+1和p<<1|1是两个东西!!!!!

思想

把一颗树划分成若干条重链,利用重链性质在树上维护树上路径的信息。

概念

定义 重子节点 表示其子节点中子树最大的子结点。如果有多个子树最大的子结点,取其一。如果没有子节点,就无重子节点。

定义 轻子节点 表示剩余的所有子结点。

从这个结点到重子节点的边为 重边

到其他轻子节点的边为 轻边

若干条首尾衔接的重边构成 重链

性质

  • 重边优先遍历后重链上节点的dfs序是连续的;

  • 某个节点子树上节点的dfs序也是连续的;

  • 树上每个节点都属于且仅属于一条重链

  • 每条重链的开头都是轻儿子;

  • 叶子的个数和重链的条数相等;

应用

  • 求两点LCA(目前最快,码长略长...)
  • 在树上实现线段树的各种操作(把segment tree在序列上干的那些破事搬到树上路径和子树上)
  • 待续...

实现

树链剖分通过两次dfs完成对树的解剖(乙树解剖

第一遍dfs

求出\(fa(x),size(x),hson(x),dep(x)\).

void dfsb(int u,int deep,int lst){
	f[u]=u;
	dep[u]=deep;
	siz[u]=1;
	for(int i=head[u];i;i=nxt[i]){
		int v=to[i];
		if(v==lst)continue;
		dfsb(v,deep+1,u);
		siz[u]=siz[u]+siz[v];
		f[v]=u;
	
		if(siz[v]>siz[hson[u]]){
			hson[u]=v;
		} 
	}
}

第二遍dfs

求出 $$top(x),dfn(x),rnk(x)$$ 值得注意的是,这次dfs就是所谓的“重边优先搜索”。

void dfsp(int u,int topf){
	dfn[u]=++cnt;
	w[cnt]=a[u];
	top[u]=topf;
	if(!hson[u])return ;
	dfsp(hson[u],topf);
	for(int i=head[u];i;i=nxt[i]){
		int v=to[i];
		if(v==f[u]||v==hson[u])continue;
		dfsp(v,v);
	}
}

路径

利用类似树上倍增的写法,不断使所在链较短的节点向长链跳,直到两点在同一条链上,深度较低的点就是LCA。

跳的过程中把自己和链头之间的一段区间加,最后把两点之间的部分区间加即可。

void crange(int u,int v,int k){
	
	while(top[u]!=top[v]){
		
		if(dep[top[u]]<dep[top[v]])swap(u,v);
		
		update(1,dfn[top[u]],dfn[u],1,n,k);
		u=f[top[u]]; 
	}
	if(dep[u]>dep[v])swap(u,v);
	update(1,dfn[u],dfn[v],1,n,k);
} 

子树

更加简单,只需要把$$dfn(u)$$到$$dfn(u)+size(u)-1$$这一段处理就可以了www

void cson(int u,int k){
	update(1,dfn[u],dfn[u]+siz[u]-1,1,n,k);
}

把™合起来

从开始写到调好花了三天半...中间出了一堆bug。

#include<iostream>
using namespace std;
int n,m,r;
int head[200005],nxt[200005],to[200005],tot;
int a[100005],f[100005],dep[100005],siz[100005],hson[100005];
int top[100005],dfn[100005],cnt,w[100005];
int mod;
int d[400005],lazy[400005],rank[100005];
void add(int u,int v){
	nxt[++tot]=head[u],to[head[u]=tot]=v;
}

void build(int p,int l,int r){
	if(l==r){
		d[p]=w[l]%mod;
		return ;
	}
	int mid=(l+r)>>1;
	build(p<<1,l,mid);
	build(p<<1|1,mid+1,r);
	d[p]=(d[p<<1]+d[p<<1|1])%mod;
}

int getsum(int p,int L,int R,int l,int r){、
	if(l>=L&&r<=R){
		return d[p];
	}
	int mid=(l+r)>>1;
	if(lazy[p]&&l!=r){
		d[p<<1]+=lazy[p]*(mid-l+1);
		d[p<<1|1]+=lazy[p]*(r-mid);
		lazy[p<<1]+=lazy[p];
		lazy[p<<1|1]+=lazy[p];
		lazy[p]=0;
		d[p<<1]%=mod;
		d[p<<1|1]%=mod;
	} 
	int sum=0;
	if(mid>=L){
		sum=(sum+getsum(p<<1,L,R,l,mid))%mod;
	}
	if(mid<R){
		sum=(sum+getsum(p<<1|1,L,R,mid+1,r))%mod;
	}
	return sum;
}

void update(int p,int L,int R,int l,int r,long long w){
	if(l>=L&&r<=R){
		lazy[p]+=w;
		d[p]=(d[p]+w*(r-l+1))%mod;
		return ;
	}
	int mid=l+r>>1;
	if(lazy[p]&&l!=r){
		d[p<<1]=(d[p<<1]+lazy[p]*(mid-l+1))%mod;
		d[p<<1|1]+=lazy[p]*(r-mid);
		lazy[p<<1]+=lazy[p];
		lazy[p<<1|1]+=lazy[p];
		lazy[p]=0;
		d[p<<1]%=mod;
		d[p<<1|1]%=mod;
	} 
	if(mid>=L)update(p<<1,L,R,l,mid,w);
	if(mid<R)update(p<<1|1,L,R,mid+1,r,w);
	d[p]=(d[p<<1|1]+d[p<<1])%mod;
	
}
int t;
void dfsb(int u,int deep,int lst){
	f[u]=u;
	dep[u]=deep;
	siz[u]=1;
	for(int i=head[u];i;i=nxt[i]){
		int v=to[i];
		if(v==lst)continue;
		dfsb(v,deep+1,u);
		siz[u]=siz[u]+siz[v];
		f[v]=u;
	
		if(siz[v]>siz[hson[u]]){
			hson[u]=v;
		} 
	}
}

void dfsp(int u,int topf){
	dfn[u]=++cnt;
	w[cnt]=a[u];
	top[u]=topf;
	if(!hson[u])return ;
	dfsp(hson[u],topf);
	for(int i=head[u];i;i=nxt[i]){
		int v=to[i];
		if(v==f[u]||v==hson[u])continue;
		dfsp(v,v);
	}
}

void crange(int u,int v,int k){
	
	while(top[u]!=top[v]){
		
		if(dep[top[u]]<dep[top[v]])swap(u,v);
		
		update(1,dfn[top[u]],dfn[u],1,n,k);
		u=f[top[u]]; 
	}
	if(dep[u]>dep[v])swap(u,v);
	update(1,dfn[u],dfn[v],1,n,k);
} 

void cson(int u,int k){
	update(1,dfn[u],dfn[u]+siz[u]-1,1,n,k);
}

int grange(int u,int v){
	
	int ans=0;
	while(top[u]!=top[v]){
		if(dep[top[u]]<dep[top[v]])swap(u,v);
		ans=(ans+getsum(1,dfn[top[u]],dfn[u],1,n))%mod; 
		
		u=f[top[u]];
	}
	if(dep[u]>dep[v])swap(u,v);
	ans=(ans+getsum(1,dfn[u],dfn[v],1,n))%mod;
	return ans;
}

int gson(int u){
	return getsum(1,dfn[u],dfn[u]+siz[u]-1,1,n);
}

int main(){
	scanf("%d%d%d%d",&n,&m,&r,&mod);
	for(int i=1;i<=n;i++){
		scanf("%d",&a[i]);
	}
	for(int i=1;i<=n-1;i++){
		int u,v;
		cin>>u>>v;
		add(u,v);
		add(v,u);
	} 
	dfsb(r,1,-1);
	dfsp(r,r); 
	
	build(1,1,n);
	for(int i=1;i<=m;i++){
		int k;
		scanf("%d",&k);
		if(k==1){
			int x,y,z;
			scanf("%d%d%d",&x,&y,&z);
			crange(x,y,z);
		}
		else if(k==2){
			int x,y;
			scanf("%d%d",&x,&y);
			cout<<grange(x,y)<<endl;
		}
		else if(k==3){
			int x,z;
			scanf("%d%d",&x,&z);
			cson(x,z);
		}
		else if(k==4){
			int x;
			scanf("%d",&x);
			cout<<gson(x)<<endl;
		}
	}
}
posted @ 2022-10-25 23:14  狐适之  阅读(50)  评论(1)    收藏  举报