树链剖分

树链剖分是基于线段树的一种算法。这种算法主要用于对树的维护,可以在 \(O(logn^2)\) 的时间内实现对于树上的一条简单路径的修改与查询。
树链剖分主要分为重链剖分、长链剖分等几类。这里讲的是重链剖分。
洛谷P3384为例,我们可以把一棵树分为几个重链,再给子节点编号,最后按照线段树的做法解决。
在任意一棵树中,我们可以记录每一个节点所在的子树的节点数(包括本身),最终,一个节点的几个孩子中,把节点数最多的成为重孩子。如果有很多个重孩子,随便选择其中一个,我们把一个节点到它的重孩子的边叫做重边。而重链,就是把多条重边连起来的链。可以证明,一个子树可以倍分为很多条重链,且无重复、无遗漏。
然后,我们给节点编号时,就可以先给重孩子的子树编号,再给轻孩子(不是重孩子的子节点)编号。这样的编号带来了两个好处,分别是:
1.同一条重链上的编号是连续的。
2.同一个子树内的编号也是连续的。
这样,我们再按照这个编号去建立线段树。
很明显,找到重孩子和编号,分别需要两次Dfs来实现。两次Dfs的思路如下:
第一个Dfs:记录每个节点的重孩子、节点的子树大小、节点的父亲和节点的深度;
第二个Dfs:记录每个节点的编号、每个编号对应的节点和链头。
代码如下:

void dfs1(int u)
{
	wc[u] = 0;
	siz[u] = 1; //节点本身也属于这个子树
	for (int i = 0; i < v[u].size(); ++ i)
	{
		if (v[u][i] == fa[u]) continue; //确保不是父亲
		dep[v[u][i]] = dep[u] + 1; //儿子的深度比父亲多一。
		fa[v[u][i]] = u; //记录父亲
		dfs1(v[u][i]); //递归处理
		siz[u] += siz[v[u][i]]; //累加节点数
		if (siz[wc[u]] < siz[v[u][i]]) wc[u] = v[u][i]; //记录重孩子
	}
}
void dfs2(int u, int tp)
{
	bh[u] = ++ cnt; //记录编号
	dy[cnt] = u; //记录编号对应的节点
	top[u] = tp; //记录链头
	if (wc[u] == 0) return; //没有重孩子
	dfs2(wc[u], tp); //先遍历重孩子
	for (int i = 0; i < v[u].size(); ++ i)
	{
		if (v[u][i] == fa[u] || v[u][i] == wc[u]) continue; //是轻孩子才便利
		dfs2(v[u][i], v[u][i]); //递归处理轻孩子
	}
}

线段树的部分就不用详细讲了吧,大家都会(注意要取模)。

void pushup(int u)
{
	w[u] = (w[u * 2] + w[u * 2 + 1]) % p;
}
void build(int u, int l, int r)
{
	if (l == r)
	{
		w[u] = a[dy[l]] % p;
		return;
	}
	int mid = (l + r) >> 1;
	build(u * 2, l, mid);
	build(u * 2 + 1, mid + 1, r);
	pushup(u);
}
void maketag(int u, int len, int x)
{
	w[u] += len * x; w[u] %= p;
	lzy[u] += x; lzy[u] %= p;
}
void pushdown(int u, int l, int r)
{
	int mid = (l + r) >> 1;
	maketag(u * 2, mid - l + 1, lzy[u]);
	maketag(u * 2 + 1, r - mid, lzy[u]);
	lzy[u] = 0;
}
bool inrange(int l, int r, int L, int R)
{
	return (l <= L) && (R <= r);
}
bool outofrange(int l, int r, int L, int R)
{
	return (R < l) || (r < L);
}
void update(int u, int l, int r, int L, int R, int x)
{
	if (inrange(L, R, l, r))
	{
		maketag(u, r - l + 1, x);
	}
	else if (!outofrange(l, r, L, R))
	{
		pushdown(u, l, r);
		int mid = (l + r) >> 1;
		update(u * 2, l, mid, L, R, x);
		update(u * 2 + 1, mid + 1, r, L, R, x);
		pushup(u);
	}
}
int query(int u, int l, int r, int L, int R)
{
	if (inrange(L, R, l, r))
	{
		return w[u];
	}
	else if (!outofrange(l, r, L, R))
	{
		pushdown(u, l, r);
		int mid = (l + r) >> 1;
		return (query(u * 2, l, mid, L, R) + query(u * 2 + 1, mid + 1, r, L, R)) % p;
	}
	return 0;
}

写完线段树的模板,现在,我们需要思考路径上的查询了。我们可以一直把两个节点往上提,每次提链头较低的节点,一直提到在同一条链上,每提一次,计算次节点到链头的距离(节点到链头一定是连续的),统计答案,提到一条链上之后,计算两个节点之间的距离,累加到答案里,最后返回答案即可。查询同理,最后可以得到以下代码:

void upd(int x, int y, int z)
{
	while (top[x] != top[y])
	{
		if (dep[top[x]] < dep[top[y]]) swap(x, y);
		update(1, 1, n, bh[top[x]], bh[x], z);
		x = fa[top[x]];
	}
	update(1, 1, n, min(bh[x], bh[y]), max(bh[x], bh[y]), z);
}
int qry(int x, int y)
{
	int ans = 0;
	while (top[x] != top[y])
	{
		if (dep[top[x]] < dep[top[y]]) swap(x, y);
		ans += query(1, 1, n, bh[top[x]], bh[x]);
		ans %= p;
		x = fa[top[x]];
	}
	return (ans + query(1, 1, n, min(bh[x], bh[y]), max(bh[x], bh[y]))) % p;
}

最后考虑如何修改或查询一个节点的子树中的信息。因为每一个子树中的节点编号都是连续的,而且根节点是编号最小的,因此我们可以直接修改或查询区间 \([bh_u, bh_u + siz_u - 1]\)\(u\) 是节点编号)。代码直接在全部代码中找吧。
因此,全部代码如下:

#include <bits/stdc++.h>
#define int long long
using namespace std;
int w[505000], dep[505000], a[505000], wc[505000], fa[505000];
int lzy[505000], bh[505000], n, m, r, p, siz[505000], cnt = 0;
int dy[505000], top[505000];
vector <int> v[505000];
void dfs1(int u)
{
	wc[u] = 0;
	siz[u] = 1;
	for (int i = 0; i < v[u].size(); ++ i)
	{
		if (v[u][i] == fa[u]) continue;
		dep[v[u][i]] = dep[u] + 1;
		fa[v[u][i]] = u;
		dfs1(v[u][i]);
		siz[u] += siz[v[u][i]];
		if (siz[wc[u]] < siz[v[u][i]]) wc[u] = v[u][i];
	}
}
void dfs2(int u, int tp)
{
	bh[u] = ++ cnt;
	dy[cnt] = u;
	top[u] = tp;
	if (wc[u] == 0) return;
	dfs2(wc[u], tp);
	for (int i = 0; i < v[u].size(); ++ i)
	{
		if (v[u][i] == fa[u] || v[u][i] == wc[u]) continue;
		dfs2(v[u][i], v[u][i]);
	}
}
void pushup(int u)
{
	w[u] = (w[u * 2] + w[u * 2 + 1]) % p;
}
void build(int u, int l, int r)
{
	if (l == r)
	{
		w[u] = a[dy[l]] % p;
		return;
	}
	int mid = (l + r) >> 1;
	build(u * 2, l, mid);
	build(u * 2 + 1, mid + 1, r);
	pushup(u);
}
void maketag(int u, int len, int x)
{
	w[u] += len * x; w[u] %= p;
	lzy[u] += x; lzy[u] %= p;
}
void pushdown(int u, int l, int r)
{
	int mid = (l + r) >> 1;
	maketag(u * 2, mid - l + 1, lzy[u]);
	maketag(u * 2 + 1, r - mid, lzy[u]);
	lzy[u] = 0;
}
bool inrange(int l, int r, int L, int R)
{
	return (l <= L) && (R <= r);
}
bool outofrange(int l, int r, int L, int R)
{
	return (R < l) || (r < L);
}
void update(int u, int l, int r, int L, int R, int x)
{
	if (inrange(L, R, l, r))
	{
		maketag(u, r - l + 1, x);
	}
	else if (!outofrange(l, r, L, R))
	{
		pushdown(u, l, r);
		int mid = (l + r) >> 1;
		update(u * 2, l, mid, L, R, x);
		update(u * 2 + 1, mid + 1, r, L, R, x);
		pushup(u);
	}
}
int query(int u, int l, int r, int L, int R)
{
	if (inrange(L, R, l, r))
	{
		return w[u];
	}
	else if (!outofrange(l, r, L, R))
	{
		pushdown(u, l, r);
		int mid = (l + r) >> 1;
		return (query(u * 2, l, mid, L, R) + query(u * 2 + 1, mid + 1, r, L, R)) % p;
	}
	return 0;
}
void upd(int x, int y, int z)
{
	while (top[x] != top[y])
	{
		if (dep[top[x]] < dep[top[y]]) swap(x, y);
		update(1, 1, n, bh[top[x]], bh[x], z);
		x = fa[top[x]];
	}
	update(1, 1, n, min(bh[x], bh[y]), max(bh[x], bh[y]), z);
}
int qry(int x, int y)
{
	int ans = 0;
	while (top[x] != top[y])
	{
		if (dep[top[x]] < dep[top[y]]) swap(x, y);
		ans += query(1, 1, n, bh[top[x]], bh[x]);
		ans %= p;
		x = fa[top[x]];
	}
	return (ans + query(1, 1, n, min(bh[x], bh[y]), max(bh[x], bh[y]))) % p;
}
signed main()
{
	cin >> n >> m >> r >> p;
	for (int i = 1; i <= n; ++ i) cin >> a[i];
	for (int i = 1; i < n; ++ i)
	{
		int x, y;
		cin >> x >> y;
		v[x].push_back(y);
		v[y].push_back(x);
	}
	dep[r] = 1, fa[r] = r;
	dfs1(r); dfs2(r, r);
	build(1, 1, n);
	while (m --)
	{
		int op, x, y, z;
		cin >> op;
		if (op == 1)
		{
			cin >> x >> y >> z;
			upd(x, y, z);
		}
		else if (op == 2)
		{
			cin >> x >> y;
			cout << qry(x, y) << endl;
		}
		else if (op == 3)
		{
			cin >> x >> z;
			update(1, 1, n, bh[x], bh[x] + siz[x] - 1, z);
		}
		else
		{
			cin >> x;
			cout << query(1, 1, n, bh[x], bh[x] + siz[x] - 1) << endl;
		}
	}
}
posted @ 2025-05-02 10:28  langni2013  阅读(19)  评论(0)    收藏  举报