【学习笔记】重链剖分

先直接看模板题:【模板】重链剖分/树链剖分

这道题的操作还是修改和查询和,看起来是线段树可以维护的操作,区别在于这是树上的操作。

那怎么把树上的和转化为区间和呢?这里就要用到DFS序(图和树的是一样的)。这样每个节点就有了一个在序列中的下标。加的时候只要对相应的节点所组成的若干区间加即可。

但是这带来了一个问题:有的时候,一条路径需要修改的区间特别多,那么如何减少这种情况呢?这就要用到重链剖分的策略了。

重链剖分往往可以把树上问题转化为区间问题,通常与线段树和最近公共祖先紧密结合。

【1】重孩子、重链、轻边:

重孩子的定义:

对于任意节点 \(u\),其中子树大小最大的一个孩子被称为重孩子,又称重节点

例如这么一棵树: (假设根是 \(1\) 号节点)

1 -> 2
1 -> 3
1 -> 4
2 -> 5
2 -> 6
4 -> 7
5 -> 8
5 -> 9
6- > 10

其中,\(1\) 号节点的重孩子是 \(2\) 号节点,\(2\) 号节点的重孩子是 \(5\) 号节点。

重链的定义:

对于任意一条链,除了链头,其他节点都是重节点的链被称为重链

还是以上面的树举例,其中 \(1 \to 2 \to 5\) 这条链就是一条重链,\(4 \to 7\) 也是一条重链。

轻边的定义:

不属于任何一条重链的边被称为轻边。

不难发现,任何两条重链之间都至少由 \(1\) 条轻边连接,所以树上任何一个节点的路径上经过的重链条数不会超过路径上的轻边条数减 \(1\)

【2】重链剖分的优势:

接下来证明一个对复杂度分析至关重要的结论:

定理1: 对树上任何一个节点,它到根的路径上的重链条数和轻边条数不会超过 \(\log n\)

证明:

根据上文“重链条数不会超过的轻边条数减 \(1\)”的结论,只要证明轻边条数不超过 \(\log n\)

\(u \to v\) 是路径上的一条轻边,其中 \(u\) 是父节点,则 \(u\) 存在一个重孩子 \(w\),使得 \(siz_w \geq siz_v\)。于是 \(siz_u \geq siz_v + siz_w \geq 2siz_v\)。这就是说,从任意节点开始,向上走到根,每经过一条轻边,当前节点所在的子树大小至少翻倍。经过至多 \(\log n\) 次翻倍后,树的大小就会超过 \(n\),也就是走到了根。

因此路径上至多有 \(\log n\) 条轻边,从而推出至多 \(\log n\) 条重链。

这有什么用呢?对于路径 \(x,y\) 的修改,设 \(\operatorname{LCA}(x,y) = u\),把修改的路径拆分成 \(x \sim u\)\(y \sim u\) 两条,这两条路径分别是 \(x\) 到根,\(y\) 到根的一部分,因此也至多经过\(\log n\) 条轻边和至多 \(\log n\) 条重链。而重链上节点的 DFS 序是连续的,所以只需做 \(O(\log n)\) 次修改即可对路径进行修改。

【3】如何用重链剖分:

对于后两个问题,相对简单。因为一棵子树上所有节点的 DFS 序肯定是连续的,所以我们以根节点 \(x\) 的 DFS 序为左区间端点,\(x\) 的子树大小作为区间长度,对区间 \([dfn_x,dfn_x + siz_x - 1]\) 进行区间修改和查询,方法和线段树一样。

对于前两个问题,稍微麻烦。如果我们要更新 \(x,y\) 路径上的所有节点,设 \(\operatorname{LCA}(x,y) = u\),那么应该先修改 \(x\)\(u\) 之间的路径,再修改 \(y\)\(u\) 上的路径。

考虑 \(x\)\(y\) 的链头 \(top_x,top_y\)

  1. \(top_x = top_y\) 时,说明两节点在同一链上。不妨设 \(dfn_x < dfn_y\),则直接修改 \([dfn_x,dfn_y]\) 即可(反之就把 \(dfn_x,dfn_y\) 反过来)。

  2. 否则,比较 \(top_x,top_y\) 的深度。不妨设 \(top_x\) 深度较小(否则交换 \(x,y\)),则 \(x \sim top_x\) 的路径上不可能有 \(x,y\)\(\operatorname{LCA}\),直接修改 \([dfn_{{top}_x},dfn_x]\),然后将 \(x\) 跳到 \(x\) 所在重链链头的父节点,然后继续修改,直至满足条件 1,退出修改。

已证明,修改的复杂度为 \(O(\log n)\)

事实上,这也引出了用重链剖分在 \(O(\log n)\) 的时间复杂度下求 \(\operatorname{LCA}\) 的方法,具体可以看

#include<bits/stdc++.h>
using namespace std;
const int N = 1e5 + 9;
int n,m,root,p;
int op,x,y,k;
struct Tree{
	struct egde{
		int to,nex;
	} e[N << 1];
	int ecnt,head[N];
	int fa[N],weight_child[N],siz[N];
	int weight_link_top[N],dep[N];
	int dfn[N],rdfn[N],id;
	Tree(){}
	void addegde(int u,int v){
		ecnt++;
		e[ecnt] = (egde){v,head[u]};
		head[u] = ecnt;
	}
	void dfs1(int cur,int father){
		fa[cur] = father;
		siz[cur] = 1;
		dep[cur] = dep[father] + 1;
		for(int i = head[cur];i;i = e[i].nex){
			int v = e[i].to;
			if(v != father){
				dfs1(v,cur);
				siz[cur] += siz[v];
				if(siz[v] > siz[weight_child[cur]])
					weight_child[cur] = v;
			}
		}
	}
	void dfs2(int cur,int link_top){
		id++;
		dfn[cur] = id;
		rdfn[id] = cur;
		weight_link_top[cur] = link_top;
		if(weight_child[cur]){
			dfs2(weight_child[cur],link_top);
			for(int i = head[cur];i;i = e[i].nex){
				int v = e[i].to;
				if(v != fa[cur] && v != weight_child[cur])
					dfs2(v,v);
			}
		}
	}
} t;
int a[N];
struct seg_tree{
	struct node {
		int val,add;
	} tree[N << 2];
	bool in_range(int l,int r,int now_l,int now_r){
		return l <= now_l && now_r <= r;
	}
	bool out_range(int l,int r,int now_l,int now_r){
		return now_r < l || now_l > r;
	}
	int len(int l,int r){
	    return r - l + 1;
	}
	void push_up(int root){
		int lchild = root * 2,rchild = root * 2 + 1;
		tree[root].val = (tree[lchild].val + tree[rchild].val) % p;
	}
	void make_tag(int Len,int root,int add){
		tree[root].add = (tree[root].add + add) % p;
		tree[root].val = (tree[root].val + Len * add % p) % p;
	}
	void push_down(int l,int r,int root){
	    int mid = (l + r) / 2,lchild = root * 2,rchild = root * 2 + 1;
		make_tag(len(l,mid),lchild,tree[root].add % p);
		make_tag(len(mid + 1,r),rchild,tree[root].add % p);
		tree[root].add = 0;
	}
	void build(int l,int r,int root) {
		if(l == r) {
			tree[root].val = a[t.rdfn[l]] % p;
			return;
		}
		int mid = (l + r) / 2,lchild = root * 2,rchild = root * 2 + 1;
		build(l,mid,lchild);
		build(mid + 1,r,rchild);
		push_up(root);
	}
	void update(int l,int r,int now_l,int now_r,int root,int add) {
		if(in_range(l,r,now_l,now_r)) {
			tree[root].val = (tree[root].val + len(now_l,now_r) * add) % p;
			tree[root].add = (tree[root].add + add) % p;
		}
		else if(!out_range(l,r,now_l,now_r)){
			int mid = (now_l + now_r) / 2,lchild = root * 2,rchild = root * 2 + 1;
			push_down(now_l,now_r,root);
			update(l,r,mid + 1,now_r,rchild,add);
			update(l,r,now_l,mid,lchild,add);
			push_up(root);
		}
		return;
	}
	void path_update(int x,int y,int k){
		while(t.weight_link_top[x] != t.weight_link_top[y]){
			if(t.dep[t.weight_link_top[x]] < t.dep[t.weight_link_top[y]])
				swap(x,y);
			update(t.dfn[t.weight_link_top[x]],t.dfn[x],1,n,1,k);
			x = t.fa[t.weight_link_top[x]];
		}
		update(min(t.dfn[x],t.dfn[y]),max(t.dfn[x],t.dfn[y]),1,n,1,k);
	}
	int getsum(int l, int r, int now_l, int now_r, int root) {
		int mid = (now_l + now_r) / 2,lchild = root * 2,rchild = root * 2 + 1;
		if(in_range(l,r,now_l,now_r))
			return tree[root].val % p;
		else if(!out_range(l,r,now_l,now_r)){
			push_down(now_l,now_r,root);
			return (getsum(l,r,now_l,mid,lchild) + getsum(l,r,mid + 1,now_r,rchild)) % p;
		}
		else
			return 0;
	}
	int path_query(int x,int y){
		int ans = 0;
		while(t.weight_link_top[x] != t.weight_link_top[y]){
			if(t.dep[t.weight_link_top[x]] < t.dep[t.weight_link_top[y]])
				swap(x,y);
			ans = (ans + getsum(t.dfn[t.weight_link_top[x]],t.dfn[x],1,n,1)) % p;
			x = t.fa[t.weight_link_top[x]];
		}
		return (ans + getsum(min(t.dfn[x],t.dfn[y]),max(t.dfn[x],t.dfn[y]),1,n,1));
	}
} seg;
signed main() {
	scanf("%d%d%d%d", &n, &m ,&root, &p);
	for(int i = 1; i <= n; i++)
		scanf("%d", &a[i]);
	for(int i = 1;i < n;i++){
		scanf("%d%d", &x, &y);
		t.addegde(x,y);
		t.addegde(y,x);
	}
	t.dfs1(root,0);
	t.dfs2(root,0);
	seg.build(1,n,1);
	for(int i = 1;i <= m;i++){
		scanf("%d" ,&op);
		if(op == 1){
			scanf("%d%d%d", &x ,&y, &k);
			k %= p;
			seg.path_update(x,y,k);
		}
		if(op == 2){
			scanf("%d%d", &x ,&y);
			printf("%d\n",seg.path_query(x,y) % p);
		}
		if(op == 3){
			scanf("%d%d", &x ,&k);
			k %= p;
			seg.update(t.dfn[x],t.dfn[x] + t.siz[x] - 1,1,n,1,k);
		}
		if(op == 4){
			scanf("%d", &x);
			printf("%d\n",seg.getsum(t.dfn[x],t.dfn[x] + t.siz[x] - 1,1,n,1) % p);
		}
	}
	return 0;
}
posted @ 2024-01-31 16:10  5t0_0r2  阅读(83)  评论(0)    收藏  举报