树链剖分
树链剖分是基于线段树的一种算法。这种算法主要用于对树的维护,可以在 \(O(logn^2)\) 的时间内实现对于树上的一条简单路径的修改与查询。
树链剖分主要分为重链剖分、长链剖分等几类。这里讲的是重链剖分。
以洛谷P3384为例,我们可以把一棵树分为几个重链,再给子节点编号,最后按照线段树的做法解决。
在任意一棵树中,我们可以记录每一个节点所在的子树的节点数(包括本身),最终,一个节点的几个孩子中,把节点数最多的成为重孩子。如果有很多个重孩子,随便选择其中一个,我们把一个节点到它的重孩子的边叫做重边。而重链,就是把多条重边连起来的链。可以证明,一个子树可以倍分为很多条重链,且无重复、无遗漏。
然后,我们给节点编号时,就可以先给重孩子的子树编号,再给轻孩子(不是重孩子的子节点)编号。这样的编号带来了两个好处,分别是:
1.同一条重链上的编号是连续的。
2.同一个子树内的编号也是连续的。
这样,我们再按照这个编号去建立线段树。
很明显,找到重孩子和编号,分别需要两次Dfs来实现。两次Dfs的思路如下:
第一个Dfs:记录每个节点的重孩子、节点的子树大小、节点的父亲和节点的深度;
第二个Dfs:记录每个节点的编号、每个编号对应的节点和链头。
代码如下:
void dfs1(int u)
{
wc[u] = 0;
siz[u] = 1; //节点本身也属于这个子树
for (int i = 0; i < v[u].size(); ++ i)
{
if (v[u][i] == fa[u]) continue; //确保不是父亲
dep[v[u][i]] = dep[u] + 1; //儿子的深度比父亲多一。
fa[v[u][i]] = u; //记录父亲
dfs1(v[u][i]); //递归处理
siz[u] += siz[v[u][i]]; //累加节点数
if (siz[wc[u]] < siz[v[u][i]]) wc[u] = v[u][i]; //记录重孩子
}
}
void dfs2(int u, int tp)
{
bh[u] = ++ cnt; //记录编号
dy[cnt] = u; //记录编号对应的节点
top[u] = tp; //记录链头
if (wc[u] == 0) return; //没有重孩子
dfs2(wc[u], tp); //先遍历重孩子
for (int i = 0; i < v[u].size(); ++ i)
{
if (v[u][i] == fa[u] || v[u][i] == wc[u]) continue; //是轻孩子才便利
dfs2(v[u][i], v[u][i]); //递归处理轻孩子
}
}
线段树的部分就不用详细讲了吧,大家都会(注意要取模)。
void pushup(int u)
{
w[u] = (w[u * 2] + w[u * 2 + 1]) % p;
}
void build(int u, int l, int r)
{
if (l == r)
{
w[u] = a[dy[l]] % p;
return;
}
int mid = (l + r) >> 1;
build(u * 2, l, mid);
build(u * 2 + 1, mid + 1, r);
pushup(u);
}
void maketag(int u, int len, int x)
{
w[u] += len * x; w[u] %= p;
lzy[u] += x; lzy[u] %= p;
}
void pushdown(int u, int l, int r)
{
int mid = (l + r) >> 1;
maketag(u * 2, mid - l + 1, lzy[u]);
maketag(u * 2 + 1, r - mid, lzy[u]);
lzy[u] = 0;
}
bool inrange(int l, int r, int L, int R)
{
return (l <= L) && (R <= r);
}
bool outofrange(int l, int r, int L, int R)
{
return (R < l) || (r < L);
}
void update(int u, int l, int r, int L, int R, int x)
{
if (inrange(L, R, l, r))
{
maketag(u, r - l + 1, x);
}
else if (!outofrange(l, r, L, R))
{
pushdown(u, l, r);
int mid = (l + r) >> 1;
update(u * 2, l, mid, L, R, x);
update(u * 2 + 1, mid + 1, r, L, R, x);
pushup(u);
}
}
int query(int u, int l, int r, int L, int R)
{
if (inrange(L, R, l, r))
{
return w[u];
}
else if (!outofrange(l, r, L, R))
{
pushdown(u, l, r);
int mid = (l + r) >> 1;
return (query(u * 2, l, mid, L, R) + query(u * 2 + 1, mid + 1, r, L, R)) % p;
}
return 0;
}
写完线段树的模板,现在,我们需要思考路径上的查询了。我们可以一直把两个节点往上提,每次提链头较低的节点,一直提到在同一条链上,每提一次,计算次节点到链头的距离(节点到链头一定是连续的),统计答案,提到一条链上之后,计算两个节点之间的距离,累加到答案里,最后返回答案即可。查询同理,最后可以得到以下代码:
void upd(int x, int y, int z)
{
while (top[x] != top[y])
{
if (dep[top[x]] < dep[top[y]]) swap(x, y);
update(1, 1, n, bh[top[x]], bh[x], z);
x = fa[top[x]];
}
update(1, 1, n, min(bh[x], bh[y]), max(bh[x], bh[y]), z);
}
int qry(int x, int y)
{
int ans = 0;
while (top[x] != top[y])
{
if (dep[top[x]] < dep[top[y]]) swap(x, y);
ans += query(1, 1, n, bh[top[x]], bh[x]);
ans %= p;
x = fa[top[x]];
}
return (ans + query(1, 1, n, min(bh[x], bh[y]), max(bh[x], bh[y]))) % p;
}
最后考虑如何修改或查询一个节点的子树中的信息。因为每一个子树中的节点编号都是连续的,而且根节点是编号最小的,因此我们可以直接修改或查询区间 \([bh_u, bh_u + siz_u - 1]\)(\(u\) 是节点编号)。代码直接在全部代码中找吧。
因此,全部代码如下:
#include <bits/stdc++.h>
#define int long long
using namespace std;
int w[505000], dep[505000], a[505000], wc[505000], fa[505000];
int lzy[505000], bh[505000], n, m, r, p, siz[505000], cnt = 0;
int dy[505000], top[505000];
vector <int> v[505000];
void dfs1(int u)
{
wc[u] = 0;
siz[u] = 1;
for (int i = 0; i < v[u].size(); ++ i)
{
if (v[u][i] == fa[u]) continue;
dep[v[u][i]] = dep[u] + 1;
fa[v[u][i]] = u;
dfs1(v[u][i]);
siz[u] += siz[v[u][i]];
if (siz[wc[u]] < siz[v[u][i]]) wc[u] = v[u][i];
}
}
void dfs2(int u, int tp)
{
bh[u] = ++ cnt;
dy[cnt] = u;
top[u] = tp;
if (wc[u] == 0) return;
dfs2(wc[u], tp);
for (int i = 0; i < v[u].size(); ++ i)
{
if (v[u][i] == fa[u] || v[u][i] == wc[u]) continue;
dfs2(v[u][i], v[u][i]);
}
}
void pushup(int u)
{
w[u] = (w[u * 2] + w[u * 2 + 1]) % p;
}
void build(int u, int l, int r)
{
if (l == r)
{
w[u] = a[dy[l]] % p;
return;
}
int mid = (l + r) >> 1;
build(u * 2, l, mid);
build(u * 2 + 1, mid + 1, r);
pushup(u);
}
void maketag(int u, int len, int x)
{
w[u] += len * x; w[u] %= p;
lzy[u] += x; lzy[u] %= p;
}
void pushdown(int u, int l, int r)
{
int mid = (l + r) >> 1;
maketag(u * 2, mid - l + 1, lzy[u]);
maketag(u * 2 + 1, r - mid, lzy[u]);
lzy[u] = 0;
}
bool inrange(int l, int r, int L, int R)
{
return (l <= L) && (R <= r);
}
bool outofrange(int l, int r, int L, int R)
{
return (R < l) || (r < L);
}
void update(int u, int l, int r, int L, int R, int x)
{
if (inrange(L, R, l, r))
{
maketag(u, r - l + 1, x);
}
else if (!outofrange(l, r, L, R))
{
pushdown(u, l, r);
int mid = (l + r) >> 1;
update(u * 2, l, mid, L, R, x);
update(u * 2 + 1, mid + 1, r, L, R, x);
pushup(u);
}
}
int query(int u, int l, int r, int L, int R)
{
if (inrange(L, R, l, r))
{
return w[u];
}
else if (!outofrange(l, r, L, R))
{
pushdown(u, l, r);
int mid = (l + r) >> 1;
return (query(u * 2, l, mid, L, R) + query(u * 2 + 1, mid + 1, r, L, R)) % p;
}
return 0;
}
void upd(int x, int y, int z)
{
while (top[x] != top[y])
{
if (dep[top[x]] < dep[top[y]]) swap(x, y);
update(1, 1, n, bh[top[x]], bh[x], z);
x = fa[top[x]];
}
update(1, 1, n, min(bh[x], bh[y]), max(bh[x], bh[y]), z);
}
int qry(int x, int y)
{
int ans = 0;
while (top[x] != top[y])
{
if (dep[top[x]] < dep[top[y]]) swap(x, y);
ans += query(1, 1, n, bh[top[x]], bh[x]);
ans %= p;
x = fa[top[x]];
}
return (ans + query(1, 1, n, min(bh[x], bh[y]), max(bh[x], bh[y]))) % p;
}
signed main()
{
cin >> n >> m >> r >> p;
for (int i = 1; i <= n; ++ i) cin >> a[i];
for (int i = 1; i < n; ++ i)
{
int x, y;
cin >> x >> y;
v[x].push_back(y);
v[y].push_back(x);
}
dep[r] = 1, fa[r] = r;
dfs1(r); dfs2(r, r);
build(1, 1, n);
while (m --)
{
int op, x, y, z;
cin >> op;
if (op == 1)
{
cin >> x >> y >> z;
upd(x, y, z);
}
else if (op == 2)
{
cin >> x >> y;
cout << qry(x, y) << endl;
}
else if (op == 3)
{
cin >> x >> z;
update(1, 1, n, bh[x], bh[x] + siz[x] - 1, z);
}
else
{
cin >> x;
cout << query(1, 1, n, bh[x], bh[x] + siz[x] - 1) << endl;
}
}
}

浙公网安备 33010602011771号