【学习笔记】树链剖分

具体思想

树链剖分是将树形结构转化为线性结构,配合线段树进而将O(n)复杂度优化至O(log n)的数据结构
如何将树形结构转化为线性结构呢?显然的,将树上的节点打上编号,映射到数组上,这样可以转化为线性结构
但是如果随意打上编号,这样的话在P3384中难以操作,我们需要通过一种方式使得数组容易维护
一种方式是按照方式记录dfn序,再通过dfn序映射到数组中,这样的话,会形成若干个链条,同一个链条的dfn序是连续的
我们可以在链条之间跳跃,进而优化复杂度
但是如果这样操作的话仍然会被卡到O(n),可以通过优先遍历重儿子,这样复杂度为O(log n)

转换线性方式

我们采取dfs的方式记录dfn序
然后将信息按照dfn序映射到数组中
但是预处理重儿子也需要一次dfs
所以我们需要两次dfs转换线性
然后构建线段树

操作方式

同上,在链条之间跳跃

代码(P3384)

#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N=1e5+10;int n,m,root,mod;
int nw[N];
vector<int> mp[N];
int fa[N],dep[N],son[N],siz[N];
void dfs1(int u){
	dep[u]=dep[fa[u]]+1;
	siz[u]=1;
	for(auto v:mp[u]){
		if(fa[u]==v) continue;
		fa[v]=u;
		dfs1(v);
		siz[u]+=siz[v];
		if(son[u]==0||siz[son[u]]<siz[v]) son[u]=v;
	}
}
int top[N],a[N],dfn[N],num;
void dfs2(int u,int tp){
	dfn[u]=++num;
	a[num]=nw[u];
	top[u]=tp;
	if(!son[u]) return ;
	dfs2(son[u],tp);
	for(auto v:mp[u]){
		if(v==fa[u]||v==son[u]) continue;
		dfs2(v,v);
	}
}
struct seg_line{
	struct node{
		int l,r,sum,tag;
		#define l(q) tree[q].l
		#define r(q) tree[q].r
		#define sum(q) tree[q].sum
		#define tag(q) tree[q].tag
	}tree[N<<2];
	void push_up(int q){
		sum(q)=sum(q<<1)+sum(q<<1|1);   sum(q)%=mod;
	}
	void add(int q,int d){
		int l=l(q),r=r(q);
		sum(q)+=(r-l+1)*d;   sum(q)%=mod;
		tag(q)+=d;   tag(q)%=mod;
	}
	void push_down(int q){
		add(q<<1,tag(q));
		add(q<<1|1,tag(q));
		tag(q)=0;
	}
	void build(int q,int l,int r){
		l(q)=l;r(q)=r;
		tag(q)=0;
		if(l==r){
			sum(q)=a[l];
			return ;
		}
		int mid=l+r>>1;
		build(q<<1,l,mid);
		build(q<<1|1,mid+1,r);
		push_up(q);
	}
	void update(int q,int L,int R,int d){
		int l=l(q),r=r(q);
		if(L<=l&&r<=R){
			add(q,d);
			return ;	
		}
		push_down(q);
		int mid=l+r>>1;
		if(L<=mid) update(q<<1,L,R,d);
		if(mid<R) update(q<<1|1,L,R,d);
		push_up(q);
	}
	int query(int q,int L,int R){
		int l=l(q),r=r(q);
		if(L<=l&&r<=R){
			return sum(q);
		}
		push_down(q);
		int mid=l+r>>1;
		int res=0;
		if(L<=mid) res+=query(q<<1,L,R);   res%=mod;
		if(mid<R) res+=query(q<<1|1,L,R);   res%=mod;
		push_up(q);
		return res;
	}
}lt;
void path_update(int x,int y,int z){
	while(top[x]!=top[y]){
		if(dep[top[x]]<dep[top[y]]) swap(x,y);
		lt.update(1,dfn[top[x]],dfn[x],z);
		x=fa[top[x]];
	}
	if(dep[x]<dep[y]) swap(x,y);
	lt.update(1,dfn[y],dfn[x],z);	
}
int path_query(int x,int y){
	int res=0;
	while(top[x]!=top[y]){
		if(dep[top[x]]<dep[top[y]]) swap(x,y);
		res+=lt.query(1,dfn[top[x]],dfn[x]);   res%=mod;
		x=fa[top[x]];
	}
	if(dep[x]<dep[y]) swap(x,y);
	res+=lt.query(1,dfn[y],dfn[x]);	   res%=mod;
	return res%mod;
}
void tree_update(int x,int z){
	lt.update(1,dfn[x],dfn[x]+siz[x]-1,z);
}
int tree_query(int x){
	return lt.query(1,dfn[x],dfn[x]+siz[x]-1)%mod;
}
signed main(){
	cin>>n>>m>>root>>mod;
	for(int i=1;i<=n;i++) cin>>nw[i];
	for(int i=1;i<n;i++){
		int a,b;cin>>a>>b;
		mp[a].push_back(b);
		mp[b].push_back(a);
	}
	dfs1(root);
	dfs2(root,root);
	lt.build(1,1,n);
	while(m--){
		int op,x,y,z;cin>>op>>x;
		if(op==1||op==2) cin>>y;
		if(op==1||op==3) cin>>z;
		
		if(op==1) path_update(x,y,z);
		else if(op==2) cout<<path_query(x,y)<<"\n";
		else if(op==3) tree_update(x,z);
		else cout<<tree_query(x)<<"\n";
	}
	return 0;
}
posted @ 2025-12-02 18:40  Ming3398  阅读(2)  评论(0)    收藏  举报