2568. 树链剖分
2568. 树链剖分
原题链接
思路:
- 先将一颗树进行树链剖分,再将它的dfs序求出来
- 利用dfs序将线段树模拟出来(build出来)
- 再将输入的路径进行切割导入线段树处理,直到两个元素进入同一条重链
注意点
1. top数组的元素记录的是自己所在的重链的顶点,但是如果一个元素是轻儿子,那么就自己做自己的顶点!!一定要记录自己为顶点!!
2. dfs序应当从它的重链开始编号,重链中包含的元素最多使得此条路径在转换为线段树后依旧得到连续的元素最多,使得update_path与query_path中的while循环次数减少,一定程度上降低了程序的时间复杂度。
代码:
#include<bits/stdc++.h>
using namespace std;
#define lc(i) i << 1
#define rc(i) i << 1 | 1
const int N = 1e5 + 7;
typedef long long ll;
int son[N], sz[N], dep[N], f[N], top[N];
int idx[N], nw[N], cnt;
struct tree{
	int l, r;
	ll add, sum;
}tr[N * 4];
int tot, w[N], h[N];
struct edge{
	int ne, to;
}e[N * 2];
void add(int u, int v){
	e[++ tot].to = v;
	e[tot].ne = h[u];
	h[u] = tot; 
}
void dfs1(int now, int fa){
	dep[now] = dep[fa] + 1;
	sz[now] = 1, f[now] = fa;
	for(int i = h[now]; i ; i = e[i].ne){
		int to = e[i].to;
		if(to == fa) continue;
		dfs1(to, now); 
		sz[now] += sz[to];
		if(sz[son[now]] < sz[to])  son[now] = to;
	}
}
void dfs2(int u, int t){
	idx[u] = ++ cnt, nw[cnt] = w[u], top[u] = t;
	if(!son[u]) return;
	dfs2(son[u], t);
	for(int i = h[u]; i ; i = e[i].ne){
		int to = e[i].to;
		if(to == f[u] || to == son[u]) continue;
		dfs2(to, to);
	}
}
void pushup(int i){
	tr[i].sum = tr[lc(i)].sum + tr[rc(i)].sum;
}
void pushdown(int i){
	int k = tr[i].add;
	if(k == 0) return;
	tr[lc(i)].add += k, tr[rc(i)].add += k;
	tr[lc(i)].sum += (tr[lc(i)].r - tr[lc(i)].l + 1) * k;
	tr[rc(i)].sum += (tr[rc(i)].r - tr[rc(i)].l + 1) * k;
	tr[i].add = 0;
}
void build(int l, int r, int i){
	tr[i] = {l, r, 0, 0};
	if(l == r){
		tr[i].sum = nw[l];
		return;
	}
	int mid = (l + r) >> 1;
	build(l, mid, lc(i));
	build(mid + 1, r, rc(i));
	pushup(i);
}
void update(int x, int y, int i, int k){
	if(x <= tr[i].l && tr[i].r <= y){
		tr[i].sum += (tr[i].r - tr[i].l + 1) * k;
		tr[i].add += k;
		return;
	}
	pushdown(i);
	int mid = (tr[i].l + tr[i].r) >> 1;
	if(mid >= x) update(x, y, lc(i), k);
	if(mid < y) update(x, y, rc(i), k);
	pushup(i);
}
ll query(int x, int y, int i){
	if(x <= tr[i].l && tr[i].r <= y) return tr[i].sum;
	pushdown(i);
	ll res = 0;
	int mid = (tr[i].r + tr[i].l) >> 1;
	if(mid >= x) res += query(x, y, i << 1);
	if(mid < y) res += query(x, y, i << 1 | 1);
	return res;
}
void update_path(int x, int y, int k){
	while(top[x] != top[y]){
		if(dep[top[x]] < dep[top[y]]) swap(x, y);
		update(idx[top[x]], idx[x], 1, k);
		x = f[top[x]];
	}
	if(dep[x] < dep[y]) swap(x, y);
	update(idx[y], idx[x], 1, k);
}
ll query_path(int x, int y){
	ll res = 0;
	while(top[x] != top[y]){
		if(dep[top[x]] < dep[top[y]]) swap(x, y);
		res += query(idx[top[x]], idx[x], 1);
		x = f[top[x]];
	}
	if(dep[x] < dep[y]) swap(x, y);
	return res + query(idx[y], idx[x], 1);
}
int main(){
	int n, m;
	scanf("%d", &n);
	for(int i = 1;i <= n;i ++) scanf("%d", &w[i]);
	for(int i = 1;i < n;i ++){
		int x, y;
		scanf("%d%d", &x, &y);
		add(x, y), add(y, x);
	}
	dfs1(1, 0);
	dfs2(1, 1);
	scanf("%d", &m);
	build(1, n, 1);   
	int op, u;
	while(m --){
		scanf("%d%d", &op, &u);
		if(op == 1){
			int v, k;
			scanf("%d%d", &v, &k);
			update_path(u, v, k);
		}
		else if(op == 2){
			int k;
			scanf("%d", &k);
			update(idx[u], idx[u] + sz[u] - 1, 1, k);
		}
		else if(op == 3){
			int v;
			scanf("%d", &v);
			printf("%lld\n", query_path(u, v));
		}
		else printf("%lld\n", query(idx[u], idx[u] + sz[u] - 1, 1));
	}
	return 0;
}
 
                    
                     
                    
                 
                    
                
 
                
            
         
         浙公网安备 33010602011771号
浙公网安备 33010602011771号