树链剖分

前言

树剖代码往往很长,但是逻辑并不复杂,理清每一步在干什么就可以。

基础概念

重儿子与轻儿子

重儿子就是树上一个节点的儿子中,子树节点数最多的那一个儿子(多个任选一个),其余的都是轻儿子。

重边与轻边

由一个节点连向它重儿子的边就是重边,另外的边就是轻边。

重链

重链是首尾相接的重边组成的链(也可以是不是任何一条重边端点的一个单独的点)。

重链的数量等于叶子节点的数量,感性理解一下,从上往下出现的重链都会在叶子节点才会结束,并且不会共用也不会遗漏。

重链剖分(树链剖分)

由于每个点都会在一条重链上,我们就可以把树看成由很多条重链和一些轻边组成,而我们可以把在树上路径上点的问题变成一段一段的重链上的问题。

初始化

为了方便对重链上的一段进行处理,我们可以把一条重链放到线段树上。

一种方法是对每一条重链都建一颗线段树,但是有些麻烦。

所以我们可以把所有节点放到同一颗线段树上,只要保证同一条重链上的节点是一个连续区间就可以了。

我们通过 dfs 序建树,如果我们每次先遍历重儿子就可以保证一条重链是连续的一段。

所以我们通过两个 dfs,一个用于求重儿子,一个求 dfs 序。

求好后我们就可以愉快的建树了。

路径修改/查询

对于从 \(x\)\(y\) 的路径上的修改,我们可以一边往上跳,找他们的 LCA,一边对经过的重链修改。

我们可以记录下每一个点 \(i\) 所在重链上深度最小的节点 \(anc_i\),往上跳时从当前节点到此节点做一次区间修改。

\(x\)\(y\) 一起往上跳,每次跳深度 \(d\) 大的,保证跳的过程中两者都还未到 LCA,如果 \(anc_x=anc_y\),那么最后跳一次,使 \(x=y\),结束。

子树修改/查询

因为在 dfs 的过程中,只有遍历完一棵子树才会离开这颗子树,所以一棵子树内的所有节点的 dfs 序也是连续的。

记录下每颗子树内 dfs 序最大是多少,最小肯定是子树根的 dfs 序。

总结

树剖的每一步都很好理解也很好实现,但步骤繁多,对于刚开始写长代码的同学不太友好,需要练习。

以下是 P3384 【模板】重链剖分/树链剖分 的代码。

#include <bits/stdc++.h>
#define int long long
#define debug() cout << "--,--" << endl
using namespace std;

const int N = 100010;
int n, m, root, mod, timestamp;
int a[N];
vector <int> v[N];
int son[N], siz[N], dfn[N], maxn[N], anc[N], anti[N], d[N], father[N];

void dfs1(int u, int fa)
{
	father[u] = fa;
	siz[u] = 1;
	int maxn = 0;
	for (int j : v[u])
	{
		if (j == fa) continue;
		d[j] = d[u] + 1;
		dfs1 (j, u);
		siz[u] += siz[j];
		if (siz[j] > siz[son[u]]) son[u] = j;
	}
}

void dfs2(int u, int fa)
{
	if (u != son[fa]) anc[u] = u;
	else anc[u] = anc[fa];
	maxn[u] = dfn[u] = ++timestamp;
	anti[timestamp] = u;
	
	if (son[u])
		dfs2 (son[u], u), maxn[u] = max (maxn[u], maxn[son[u]]);
	for (int j : v[u])
	{
		if (j == fa || j == son[u]) continue;
		dfs2 (j, u);
		maxn[u] = max (maxn[u], maxn[j]);
	}
}

struct node
{
	int l, r, val, tag;
} tr[N << 2];

int len(int u)
{
	return tr[u].r - tr[u].l + 1;
}
void pushup(int u)
{
	tr[u].val = (tr[u << 1].val + tr[u << 1 | 1].val) % mod;
}

void build(int u, int l, int r)
{
	tr[u].l = l, tr[u].r = r;
	if (l == r)
	{
		tr[u].val = a[anti[l]];
		return;
	}
	int mid = l + r >> 1;
	build (u << 1, l, mid);
	build (u << 1 | 1, mid + 1, r);
	pushup (u);
}

void pushdown(int u)
{
	int& tag = tr[u].tag;
	if (!tag) return;
	tr[u << 1].tag = (tr[u << 1].tag + tag) % mod;
	tr[u << 1 | 1].tag = (tr[u << 1 | 1].tag + tag) % mod;
	tr[u << 1].val = (tr[u << 1].val + len (u << 1) * tag % mod) % mod;
	tr[u << 1 | 1].val = (tr[u << 1 | 1].val + len (u << 1 | 1) * tag % mod) % mod;
	tag = 0;
}

void add(int u, int l, int r, int x)
{
	if (l <= tr[u].l && tr[u].r <= r)
	{
		tr[u].tag = (tr[u].tag + x + mod) % mod;
		tr[u].val = (tr[u].val + len (u) * x % mod + mod) % mod;
		return;
	}
	pushdown (u);
	int mid = tr[u].l + tr[u].r >> 1;
	if (l <= mid) add (u << 1, l, r, x);
	if (r > mid) add (u << 1 | 1, l, r, x);
	pushup (u);
}

int find(int u, int l, int r)
{
	if (l <= tr[u].l && tr[u].r <= r)
		return tr[u].val;
	pushdown (u);
	int mid = tr[u].l + tr[u].r >> 1;
	int ans = 0;
	if (l <= mid) ans += find (u << 1, l, r);
	if (r > mid) ans += find (u << 1 | 1, l, r);
	return ans % mod;
}

signed main()
{
	cin >> n >> m >> root >> mod;
	for (int i = 1; i <= n; i++)
		cin >> a[i], a[i] %= mod;
	for (int i = 1; i < n; i++)
	{
		int x, y;
		cin >> x >> y;
		v[x].push_back (y);
		v[y].push_back (x);
	}
	dfs1 (root, 0);
	dfs2 (root, 0);
	
	build (1, 1, n);
	for (int i = 1; i <= m; i++)
	{
		int op, x, y, z;
		cin >> op;
		if (op == 1)
		{
			cin >> x >> y >> z;
			z %= mod;
			while (anc[x] != anc[y])
			{
				if (d[anc[x]] > d[anc[y]])
					add (1, dfn[anc[x]], dfn[x], z), x = father[anc[x]];
				else add (1, dfn[anc[y]], dfn[y], z), y = father[anc[y]];
			}
			add (1, min (dfn[x], dfn[y]), max (dfn[x], dfn[y]), z);
		}
		else if (op == 2)
		{
			cin >> x >> y;
			int ans = 0;
			while (anc[x] != anc[y])
			{
				if (d[anc[x]] > d[anc[y]])
					ans = (ans + find (1, dfn[anc[x]], dfn[x])) % mod, x = father[anc[x]];
				else ans = (ans + find (1, dfn[anc[y]], dfn[y])) % mod, y = father[anc[y]];
			}
			ans = (ans + find (1, min (dfn[x], dfn[y]), max (dfn[x], dfn[y]))) % mod;
			cout << ans << endl;
		}
		else if (op == 3)
		{
			cin >> x >> z;
			add (1, dfn[x], maxn[x], z);
		}
		else
		{
			cin >> x;
			cout << find (1, dfn[x], maxn[x]) << endl;
		}
	}
	return 0;
}

练习

posted @ 2025-02-09 18:14  wo2011  阅读(47)  评论(0)    收藏  举报
//雪花飘落效果