树上问题

树链剖分

多指重链剖分

可以解决的问题

\(1\). 修改 树上两点之间的路径上 所有点的值。

\(2\). 查询 树上两点之间的路径上 节点权值的 和/极值/其它在序列上可以用数据结构维护,便于合并的信息

概念

  • 重儿子:对于每一个非叶子节点,它的儿子中 儿子数量最多 的那一个儿子为该节点的重儿子
  • 轻儿子:对于每一个非叶子节点,它的儿子中非 重儿子 的剩下所有儿子即为轻儿子
  • 叶子节点没有重儿子也没有轻儿子(因为它没有儿子)
  • 重边:连接任意两个重儿子的边叫做重边
  • 轻边:剩下的即为轻边
  • 重链:相邻重边连起来的 连接一条重儿子 的链叫重链
  • 对于叶子节点,若其为轻儿子,则有一条以自己为起点的长度为 \(1\) 的链
  • 每一条重链以轻儿子为起点

实现

第一个 \(dfs\)

  • 标记每个点的深度 \(dep\)
  • 标记每个点的父亲 \(fa\)
  • 标记每个非叶子节点的子树大小(含它自己)
  • 标记每个非叶子节点的重儿子编号 \(hes\)
void dfs(int x){
	siz[x]=1;
	dep[x]=dep[fa[x]]+1;
	for(int i=0;i<q[x].size();i++){
		int y=q[x][i];
		if(y==fa[x]) continue;
		fa[y]=x;
		dfs(y);
		siz[x]+=siz[y];
		if(!hes[x]||siz[hes[x]]<siz[y]) hes[x]=y;
	}
}

第二个 \(dfs\)

  • 标记每个点的新编号
  • 赋值每个点的初始值到新编号上
  • 处理每个点所在链的顶端
  • 处理每条链
void dfs1(int x,int v){
	top[x]=v;
	if(hes[x]) dfs1(hes[x],v);
	for(int i=0;i<q[x].size();i++){
		int y=q[x][i];
		if(y==fa[x]||y==hes[x]) continue;
		dfs1(y,y);
	}
}

因为用的是 \(dfs\) 序遍历,所以每条链的编号都是连续的。

接下来可以在此基础上做一些操作来解决其他问题。

比如说下面两道题:

P3379 【模板】最近公共祖先(LCA)P3384 【模板】重链剖分/树链剖分

一个就是纯树链剖分,另一个还要用线段树进行区间查询。

代码如下:

#include<bits/stdc++.h>
using namespace std;
const int N=2e6+10;
vector<int> q[N];
struct Node{
	int l,r,val,siz,f;
}tr[N];
int p,cnt,a[N],b[N];
int dep[N],fa[N],hes[N],siz[N],top[N],idx[N];
void dfs(int x){
	dep[x]=dep[fa[x]]+1;
	siz[x]=1;
	for(int i=0;i<q[x].size();i++){
		int y=q[x][i];
		if(y==fa[x]) continue;
		fa[y]=x;
		dfs(y);
		siz[x]+=siz[y];
		if(!hes[x]||siz[hes[x]]<siz[y]) hes[x]=y;
	}
}
void update(int x){
	tr[x].val=(tr[x*2].val+tr[x*2+1].val+p)%p;
}
void build(int x,int l,int r){
	tr[x].l=l,tr[x].r=r;
	tr[x].siz=r-l+1;
	if(l==r){
		tr[x].val=a[l]%p;
		return;
	}
	int mid=(l+r)>>1;
	build(x*2,l,mid);
	build(x*2+1,mid+1,r);
	update(x);
}
void dfs1(int x,int v){
	idx[x]=++cnt;
	a[cnt]=b[x];
	top[x]=v;
	if(!hes[x]) return;
	dfs1(hes[x],v);
	for(int i=0;i<q[x].size();i++){
		int y=q[x][i];
		if(!idx[y]) dfs1(y,y);
	}
}
void pushdown(int x){
	if(!tr[x].f) return;
	tr[x*2].val=(tr[x*2].val+tr[x*2].siz*tr[x].f)%p;
	tr[x*2+1].val=(tr[x*2+1].val+tr[x*2+1].siz*tr[x].f)%p;
	tr[x*2].f=(tr[x*2].f+tr[x].f)%p;
	tr[x*2+1].f=(tr[x*2+1].f+tr[x].f)%p;
	tr[x].f=0;
}
void intervaladd(int x,int l,int r,int val){
	if(l<=tr[x].l&&tr[x].r<=r){
		tr[x].val+=tr[x].siz*val;
		tr[x].f+=val;
		return;
	}
	pushdown(x);
	int mid=(tr[x].l+tr[x].r)>>1;
	if(l<=mid) intervaladd(x*2,l,r,val);
	if(r>mid) intervaladd(x*2+1,l,r,val);
	update(x);
}
void treeadd(int x,int y,int val){
	while(top[x]!=top[y]){
		if(dep[top[x]]<dep[top[y]]) swap(x,y);
		intervaladd(1,idx[top[x]],idx[x],val);
		x=fa[top[x]];
	}
	if(dep[x]>dep[y]) swap(x,y);
	intervaladd(1,idx[x],idx[y],val);
}
int intervalsum(int x,int l,int r){
	int ans=0;
	if(l<=tr[x].l&&tr[x].r<=r) return tr[x].val;
	pushdown(x);
	int mid=(tr[x].l+tr[x].r)>>1;
	if(l<=mid) ans=(ans+intervalsum(x*2,l,r))%p;
	if(r>mid) ans=(ans+intervalsum(x*2+1,l,r))%p;
	return ans;
}
void treesum(int x,int y){
	int ans=0;
	while(top[x]!=top[y]){
		if(dep[top[x]]<dep[top[y]]) swap(x,y);
		ans=(ans+intervalsum(1,idx[top[x]],idx[x]))%p;
		x=fa[top[x]];
	}
	if(dep[x]>dep[y]) swap(x,y);
	ans=(ans+intervalsum(1,idx[x],idx[y]))%p;
	cout<<ans<<endl;
}
int main(){
	int n,m,s,x,y,z,opt;
	cin>>n>>m>>s>>p;
	for(int i=1;i<=n;i++) cin>>b[i];
	for(int i=1;i<n;i++){
		cin>>x>>y;
		q[x].push_back(y);
		q[y].push_back(x);
	}
	dfs(s);
	dfs1(s,s);
	build(1,1,n);
	while(m--){
		cin>>opt;
		if(opt==1){
			cin>>x>>y>>z,z=z%p;
			treeadd(x,y,z);
		}
		else if(opt==2){
			cin>>x>>y;
			treesum(x,y);
		}
		else if(opt==3){
			cin>>x>>z;
			intervaladd(1,idx[x],idx[x]+siz[x]-1,z%p);
		}
		else if(opt==4){
			cin>>x;
			cout<<intervalsum(1,idx[x],idx[x]+siz[x]-1)<<endl;
		}
	}
	return 0;
}
posted @ 2025-10-02 21:31  筝小鱼  阅读(7)  评论(0)    收藏  举报