树链剖分(重链剖分)结合线段树,维护子树、路径等信息

https://www.luogu.com.cn/problem/P3384

//author:kzssCCC

#include <bits/stdc++.h>
using namespace std;
using ll = long long;

class HLD{
public:
	int n,rt;
	vector<vector<int>> adj;

	//son表示重儿子
	//top表示该点对应重链的头
	//ref是dfn的逆映射

	vector<int> par,depth,sz,son,dfn,top,ref;
	int timer = 1;

	//第一遍dfs,统计par,depth,sz,son
	void dfs1(int u){
		sz[u] = 1;
		int mx = 0;
		int pos = 0;

		for (auto& v:adj[u]){
			if (v==par[u]) continue;

			par[v] = u;
			depth[v] = depth[u]+1;
			dfs1(v);
			sz[u] += sz[v];

			//子树最大的儿子为重儿子
			if (sz[v]>mx){
				mx = sz[v];
				pos = v;
			}
		}

		son[u] = pos;
	}

	//第二遍dfs,统计top,dfn,ref
	//head为重链的头
	void dfs2(int u,int head){
		dfn[u] = timer++;
		ref[dfn[u]] = u;
		top[u] = head;

		if (son[u]==0) return;

		//先访问重儿子
		dfs2(son[u],head);

		//访问其他子节点,此时为新的重链,注意链头更新
		for (auto& v:adj[u]){
			if (v==par[u] || v==son[u]) continue;

			dfs2(v,v);
		}
	}

	HLD(int _n,int _rt,vector<vector<int>>& _adj){
		n = _n;
		rt = _rt;
		adj = _adj;
		par = sz = depth = dfn = top = ref = vector<int>(n+1);
		son = vector<int>(n+1,0);

		par[rt] = -1;
		depth[rt] = 0;
		dfs1(rt);
		dfs2(rt,rt);
	}

	//计算LCA,不断向上跳跃即可,最后u、v属于同一重链时,dfn小的(或深度小的)为LCA
	int getlca(int u,int v){
		while (top[u]!=top[v]){
			if (depth[top[u]]<depth[top[v]]){
				swap(u,v);
			}

			u = par[top[u]];
		}

		if (dfn[u]<dfn[v]){
			return u;
		}
		else{
			return v;
		}
	}
};

int MOD;

class node{
public:
	ll sum,lazy;
};

class segmentTree{
public:
	int n;
	vector<node> seg;

	segmentTree(int _n){
		n = _n;
		seg = vector<node>(4*n+1);
	}

	node merge(node p1,node p2){
		node temp;
		temp.sum = (p1.sum+p2.sum)%MOD;
		return temp;
	}

	void build(vector<ll>& a){
		build(1,1,n,a);
	}	

	void build(int rt,int l,int r,vector<ll>& a){
		if (l==r){
			seg[rt].sum = a[l];

			return;
		}	

		int mid = l+r >> 1;
		build(rt<<1,l,mid,a);
		build(rt<<1|1,mid+1,r,a);

		seg[rt] = merge(seg[rt<<1],seg[rt<<1|1]);			
	}

	void push_down(int rt,int l,int r){
		if (seg[rt].lazy==0) return;
		int mid = l+r >> 1;

		seg[rt<<1].sum = (seg[rt<<1].sum+seg[rt].lazy*(mid-l+1)%MOD)%MOD;
		seg[rt<<1].lazy = (seg[rt<<1].lazy+seg[rt].lazy)%MOD;

		seg[rt<<1|1].sum = (seg[rt<<1|1].sum+seg[rt].lazy*(r-mid)%MOD)%MOD;
		seg[rt<<1|1].lazy = (seg[rt<<1|1].lazy+seg[rt].lazy)%MOD;

		seg[rt].lazy = 0;
	}

	void update(int pos,ll val){
		update(1,1,n,pos,val);
	}

	void update(int rt,int l,int r,int pos,ll val){
		if (l==r){
			//

			return;
		}		

		int mid = l+r >> 1;
		push_down(rt,l,r);

		if (pos<=mid){
			update(rt<<1,l,mid,pos,val);
		}
		else{
			update(rt<<1|1,mid+1,r,pos,val);
		}

		seg[rt] = merge(seg[rt<<1],seg[rt<<1|1]);
	}

	void update_range(int x,int y,ll val){
		update_range(1,1,n,x,y,val);
	}

	void update_range(int rt,int l,int r,int x,int y,ll val){
		if (r<x || l>y){
			return;
		}

		if (x<=l && y>=r){
			seg[rt].sum = (seg[rt].sum+val*(r-l+1)%MOD)%MOD;
			seg[rt].lazy = (seg[rt].lazy+val)%MOD;

			return;
		}

		int mid = l+r >> 1;
		push_down(rt,l,r);

		update_range(rt<<1,l,mid,x,y,val);
		update_range(rt<<1|1,mid+1,r,x,y,val);

		seg[rt] = merge(seg[rt<<1],seg[rt<<1|1]);
	}


	node query(int pos){
		return query(1,1,n,pos);
	}

	node query(int rt,int l,int r,int pos){
		if (l==r){
			return seg[rt];
		}		

		int mid = l+r >> 1;
		push_down(rt,l,r);

		if (pos<=mid){
			return query(rt<<1,l,mid,pos);
		}
		else{
			return query(rt<<1|1,mid+1,r,pos);
		}
	}


	node query_range(int l,int r){
		return query_range(1,1,n,l,r);
	}

	node query_range(int rt,int l,int r,int x,int y){
		if (r<x || l>y){
			return node();
		}

		if (x<=l && y>=r){
			return seg[rt];	
		}

		int mid = l+r >> 1;
		push_down(rt,l,r);

		return merge(query_range(rt<<1,l,mid,x,y),query_range(rt<<1|1,mid+1,r,x,y));
	}
};

void solve(){
	int n,q,rt;
	cin >> n >> q >> rt >> MOD;

	vector<ll> a(n+1);
	for (int i=1;i<=n;i++){
		cin >> a[i];
		a[i]%=MOD;
	}	

	vector<vector<int>> adj(n+1);
	for (int i=0;i<n-1;i++){
		int u,v;
		cin >> u >> v;

		adj[u].push_back(v);
		adj[v].push_back(u);
	}

	HLD hld(n,rt,adj);
	segmentTree sg(n);

	vector<ll> val(n+1);
	for (int i=1;i<=n;i++){
		val[i] = a[hld.ref[i]];
	}
	sg.build(val);

	while (q--){
		int op;
		cin >> op;

		//操作1:将树从 x 到 y 结点最短路径上所有节点的值都加上 z
		if (op==1){
			int x,y,z;
			cin >> x >> y >> z;

			while (hld.top[x]!=hld.top[y]){
				//链头深度更大的节点向上跳跃
				if (hld.depth[hld.top[x]]<hld.depth[hld.top[y]]){
					swap(x,y);
				}

				//更新这条重链上的信息
				int tx = hld.top[x];
				sg.update_range(hld.dfn[tx],hld.dfn[x],z);

				x = hld.par[tx];
			}		

			//属于同一重链,更新路径上的信息
			sg.update_range(min(hld.dfn[x],hld.dfn[y]),max(hld.dfn[x],hld.dfn[y]),z);
		}

		//操作2:求树从 x 到 y 结点最短路径上所有节点的值之和
		//与操作1基本一样,查询的过程
		else if (op==2){
			int x,y;
			cin >> x >> y;

			ll res = 0;

			while (hld.top[x]!=hld.top[y]){
				if (hld.depth[hld.top[x]]<hld.depth[hld.top[y]]){
					swap(x,y);
				}

				int tx = hld.top[x];
				res = (res+sg.query_range(hld.dfn[tx],hld.dfn[x]).sum)%MOD; 

				x = hld.par[tx];
			}

			res = (res+sg.query_range(min(hld.dfn[x],hld.dfn[y]),max(hld.dfn[x],hld.dfn[y])).sum)%MOD;
			cout << res << '\n';
		}

		//操作3:将以 x 为根节点的子树内所有节点值都加上 z
		//维护子树信息不需要跳跃,dfn连续,直接更新即可
		else if (op==3){
			int x,z;
			cin >> x >> z;

			sg.update_range(hld.dfn[x],hld.dfn[x]+hld.sz[x]-1,z);
		}

		//操作4:求以 x 为根节点的子树内所有节点值之和
		//同3
		else{
			int x;
			cin >> x;

			cout << sg.query_range(hld.dfn[x],hld.dfn[x]+hld.sz[x]-1).sum << '\n';
		}
	}
}

int main(){
	ios::sync_with_stdio(false);
	cin.tie(0);
	
	int t = 1;
	// cin >> t;
	while (t--) solve();

	return 0;
}
posted @ 2026-04-13 11:34  kzssCCC  阅读(5)  评论(0)    收藏  举报