[学习笔记] 树上启发式合并
知识讲解
1.定义:
英文名\(\text{dsu on tree}\),是用来处理一类离线的树上询问问题的方法。一般时间复杂度\(O(n \log n)\)。
2.操作步骤:
1.先遍历其非重儿子,获取它的ans,但不保留遍历后它的信息
2.遍历它的重儿子,保留它的信息
3.再次遍历其非重儿子及其父亲,用重儿子的信息对遍历到的节点进行计算,获取整棵子树的ans
3.代码
//重链剖分
inline void dfs (int u, int fa) {
size[u] = 1;
for (int i = head[u]; i; i = edge[i].next) {
int v = edge[i].to;
if (v == fa) continue;
dfs (v, u);
size[u] += size[v];
if (size[v] > size[son[u]]) son[u] = v;
}
}
inline void add (int u, int fa, int val) {
// 将轻儿子暴力合并到点上
for (int i = head[u]; i; i = edge[i].next) {
int v = edge[i].to;
if (v == fa || v == Son) continue;
add (v, u, val);
}
}
inline void dfs (int u, int fa, bool keep) {
for (int i = head[u]; i; i = edge[i].next) {
int v = edge[i].to;
if (v == fa || v == son[u]) continue;
dfs (v, u, 0);//遍历轻儿子 不保存轻儿子信息
}
if (son[u]) dfs (son[u], u, 1);//走重儿子 保存重儿子信息
Son = son[u];
add (u, fa, 1);//再次遍历轻儿子 暴力合并
Son = 0;
ans[u] = sum;
if (!keep) {
add (u, fa, -1);//清空不需要保存的情况
sum = 0;
}
}
4.时间复杂度证明
首先我们知道一个点到根节点上有\(\log n\)条重链和\(\log n\)条轻边。
注意到每条轻边会导致下面的子树被多遍历一次,因此每个点最多被遍历\(\log n\)次。
因此总时间复杂度为\(O(n \log n)\)。
例题
U41492 树上数颜色
题面描述
给定一棵树,每次询问子树颜色种类数。
题解
先考虑暴力\(O(n^2)\)算法
inline void dfs (int u, int fa) {
for (int i = head[u]; i; i = edge[i].next) {
int v = edge[i].to;
if (v == fa) continue;
dfs (v, u);
for (int j = 0; j <= 10000; j ++ ) mp[u][j] += mp[v][j];
}
int cnt = 0;
for (int i = 0; i <= 10000; i ++ ) if (mp[u][i]) cnt ++ ;
ans[u] = cnt;
}
然后考虑如何优化 因为没有修改操作,而且是子树查询,所以用\(\text{dsu on tree}\)
我们可以先遍历轻儿子,然后遍历重儿子保存重儿子的信息,然后再加上轻儿子的信息就行,只要实现一下\(\text{add}\)操作即可
inline void add (int u, int fa, int opt) {
if (opt > 0 && !cnt[col[u]]) sum ++ ;
cnt[col[u]] += opt;
if (opt < 0 && !cnt[col[u]) sum -- ;
for (int i = head[u]; i; i = edge[i].next) {
int v = edge[i].to;
if (v == fa || v == Son) continue;
add (v, u, opt);
}
}
CF600E
题面
一棵树有\(n\)个结点,每个结点都是一种颜色,每个颜色有一个编号,求树中每个子树的最多的颜色编号的和。
题解
与上题类似,同样是先一下\(\text{add}\)函数即可
inline void add (int u, int fa, int val) {
cnt[col[u]] += val;
if (cnt[col[u]] > mx) {
mx = cnt[col[u]];
sum = col[u];
} else if (cnt[col[u]] == mx)
sum += col[u];
for (int i = head[u]; i; i = edge[i].next) {
int v = edge[i].to;
if (v == fa || v == Son) continue;
add (v, u, val);
}
}