20251208 - 树上启发式合并
树上启发式合并总结
概念
树上启发式合并,俗称 DSU on tree,是一种离线处理子树内的一种方法,时间复杂度大约为 \(O(n\log_2n)\)。
一般用于树上数颜色问题。
DSU 其实是并查集的英文缩写,那么这个是树上并查集吗?不对,这一个算法是基于并查集的启发式合并,也就是把小的集合并到大的集合里面,所以就有了树上启发式合并。
可以发现,最终答案有一部分是要保留的,那么保留谁最好呢?答案是子树最大的子节点。
通俗一点来说,就是重子树和轻子树分开算。
和分块一样,是一种优雅的暴力
重子树:子树最大的子节点
轻子树:除了重子树以外的树
算法步骤
首先先预处理出一些东西(是否为重子树,子树和等),然后再暴力 dfs,如果遇到重子树,就保留结果,否则清空并重开。
代码(写的时候啥都忘了,所以写了亿些注释,不知道是不是一个不好的习惯):
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define ull unsigned long long
#define db double
#define sz(x) ((int)x.size())
#define inf (1 << 30)
#define pb push_back
typedef pair<int, int> PII;
const int N = 1e5 + 7;
const int P = 998244353;
int read() {
int x = 0, f = 1;
char ch = getchar();
while (!(ch >= '0' && ch <= '9')) {if (ch == '-') f = -f;ch = getchar();}
while (ch >= '0' && ch <= '9') {x = x * 10 + ch - '0';ch = getchar();}
return x * f;
}
int n, c[N], siz[N], a[N], son[N];
ll sum, ans[N], mx = 0;
// a[c[i]] : 颜色 c[i] 出现的次数
// son[i] : 是否为重节点
vector<int> edges[N];
inline void solve(int u, int fa) {
// 求出子树大小 & 是否为重节点
int mx = 0, root = 0;
siz[u] = 1;
for (auto v : edges[u]) {
if (v == fa) continue;
solve(v, u);
siz[u] += siz[v];
if (siz[v] > mx) {
mx = siz[v];
root = v;
}
}
if (root)
son[root] = 1;
}
inline void clear(int u, int fa) {
// 清空
--a[c[u]];
for (auto v : edges[u]) {
if (v == fa) continue;
clear(v, u);
}
}
inline void DFS(int u, int fa, int root) {
++a[c[u]]; // 增加个数
if (a[c[u]] > mx) {
sum = c[u];
mx = a[c[u]];
}else if (a[c[u]] == mx) {
sum += c[u]; // 编号和
}
for (auto v : edges[u]) {
if (v == fa || v == root) continue;
DFS(v, u, root);
}
}
inline void dfs(int u, int fa) {
// 轻节点
int root = 0;
for (auto v : edges[u]) {
if (v == fa) continue;
if (!son[v]) {
dfs(v, u);
clear(v, u);
sum = 0, mx = 0;
}else {
root = v;
}
}
if (root) dfs(root, u);
DFS(u, fa, root); // 不能遍历重儿子
ans[u] = sum;
}
int main() {
n = read();
for (int i = 1; i <= n; i++)
c[i] = read();
for (int i = 1; i < n; i++) {
int x = read(), y = read();
edges[x].pb(y);
edges[y].pb(x);
}
solve(1, -1);
dfs(1, -1);
for (int i = 1; i <= n; i++)
printf("%lld ", ans[i]);
putchar('\n');
return 0;
}
例题:C - 颜色平衡树
思路一:
可以暴力枚举每一个点。
思路二
看到树上数颜色,想到树上启发式合并。。。
其他的就是板子了,问题就在如何判断。。。
看题解都说开两个桶,我怎么和他们的做法不一样呢?
可以考虑暴力,时间复杂度:\(O(n^2)\)。
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define ull unsigned long long
#define db double
#define sz(x) ((int)x.size())
#define inf (1 << 30)
#define pb push_back
typedef pair<int, int> PII;
const int N = 1e5 + 7;
const int P = 998244353;
int read() {
int x = 0, f = 1;
char ch = getchar();
while (!(ch >= '0' && ch <= '9')) {if (ch == '-') f = -f;ch = getchar();}
while (ch >= '0' && ch <= '9') {x = x * 10 + ch - '0';ch = getchar();}
return x * f;
}
int n, c[N], siz[N], a[N], son[N];
ll sum, ans[N], mx = 0;
// a[c[i]] : 颜色 c[i] 出现的次数
// son[i] : 是否为重节点
vector<int> edges[N];
inline void solve(int u, int fa) {
// 求出子树大小 & 是否为重节点
int mx = 0, root = 0;
siz[u] = 1;
for (auto v : edges[u]) {
if (v == fa) continue;
solve(v, u);
siz[u] += siz[v];
if (siz[v] > mx) {
mx = siz[v];
root = v;
}
}
if (root)
son[root] = 1;
}
inline void clear(int u, int fa) {
// 清空
--a[c[u]];
for (auto v : edges[u]) {
if (v == fa) continue;
clear(v, u);
}
}
inline void DFS(int u, int fa, int root) {
++a[c[u]]; // 增加个数
int lst = 0;
bool ok = false;
for (int i = 1; i <= n && !ok; i++) {
if (a[i] != 0 && lst != 0 && lst != a[i])
ok = true;
else if (a[i] != 0)
lst = a[i];
}
sum = !ok;
for (auto v : edges[u]) {
if (v == fa || v == root) continue;
DFS(v, u, root);
}
}
inline void dfs(int u, int fa) {
// 轻节点
int root = 0;
for (auto v : edges[u]) {
if (v == fa) continue;
if (!son[v]) {
dfs(v, u);
clear(v, u);
sum = 0, mx = 0;
}else {
root = v;
}
}
if (root) dfs(root, u);
DFS(u, fa, root); // 不能遍历重儿子
ans[u] = sum;
}
int main() {
n = read();
for (int i = 1; i <= n; i++) {
c[i] = read();
int x = read();
if (i == 1) continue;
edges[i].pb(x);
edges[x].pb(i);
}
solve(1, -1);
dfs(1, -1);
// for (int i = 1; i <= n; i++)
// printf("%lld ", ans[i]);
// putchar('\n');
ll end_ans = 0;
for (int i = 1; i <= n; i++)
end_ans += ans[i];
printf("%lld\n", end_ans);
return 0;
}
这份代码慢在判断上,有不有在 \(\log_2 n\) 的时间复杂度内统计有多少个不同的数?
map / unordered_map & set/ unordered_set,但是我们要统计个数,所以只能用map / unordered_map。
接下来就好做了!
维护一个 map / unordered_map,添加点的时候就删除原来的,加上当前的就好了,删除同理。
警示后人:如果最终的个数为 0 时,请不要删除,不然会挂掉。
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define ull unsigned long long
#define db double
#define sz(x) ((int)x.size())
#define inf (1 << 30)
#define pb push_back
typedef pair<int, int> PII;
const int N = 2e5 + 7;
const int P = 998244353;
int read() {
int x = 0, f = 1;
char ch = getchar();
while (!(ch >= '0' && ch <= '9')) {if (ch == '-') f = -f;ch = getchar();}
while (ch >= '0' && ch <= '9') {x = x * 10 + ch - '0';ch = getchar();}
return x * f;
}
int n, c[N], siz[N], a[N], son[N];
ll sum, ans[N], mx = 0;
unordered_map<int, ll> mp;
// a[c[i]] : 颜色 c[i] 出现的次数
// son[i] : 是否为重节点
vector<int> edges[N];
inline void solve(int u, int fa) {
// 求出子树大小 & 是否为重节点
int mx = 0, root = 0;
siz[u] = 1;
for (auto v : edges[u]) {
if (v == fa) continue;
solve(v, u);
siz[u] += siz[v];
if (siz[v] > mx) {
mx = siz[v];
root = v;
}
}
if (root)
son[root] = 1;
}
void del(int u){
if (mp[u] != 1)
mp[u]--;
else
mp.erase(u);
}
inline void clear(int u, int fa) {
// 清空
if (a[c[u]] > 0)
del(a[c[u]]); // 删除
--a[c[u]];
if (a[c[u]] > 0)
mp[a[c[u]]]++; // 添加
for (auto v : edges[u]) {
if (v == fa) continue;
clear(v, u);
}
}
inline void DFS(int u, int fa, int root) {
if (a[c[u]] > 0)
del(a[c[u]]);
++a[c[u]]; // 增加个数
if (a[c[u]] > 0)
mp[a[c[u]]]++;
for (auto v : edges[u]) {
if (v == fa || v == root) continue;
DFS(v, u, root);
}
}
inline void dfs(int u, int fa) {
// 轻节点
int root = 0;
for (auto v : edges[u]) {
if (v == fa) continue;
if (!son[v]) {
dfs(v, u);
clear(v, u);
sum = 0, mx = 0;
}else {
root = v;
}
}
if (root) dfs(root, u);
DFS(u, fa, root); // 不能遍历重儿子
ans[u] = mp.size() == 1;
}
int main() {
n = read();
for (int i = 1; i <= n; i++) {
c[i] = read();
int x = read();
if (i == 1) continue;
edges[i].pb(x);
edges[x].pb(i);
}
solve(1, -1);
dfs(1, -1);
// for (int i = 1; i <= n; i++)
// printf("%lld ", ans[i]);
// putchar('\n');
ll end_ans = 0;
for (int i = 1; i <= n; i++)
end_ans += ans[i];
printf("%lld\n", end_ans);
return 0;
}
后记
如果遇到树上数颜色问题,可以考虑树上启发式合并。

浙公网安备 33010602011771号