树上启发式合并
什么是启发式算法?
启发式算法就是根据人们的直觉对一些算法进行优化(这很启发式了),典型的就是并查集的路径压缩
fa[x] == x ? x : find(fa[x])
让高度小的树成为高度较大的树的子树,这个优化可以称为启发式合并算法。

AC自动机里面对fail指针绕圈的优化也是启发式算法。
树上启发式合并
树上启发式合并也叫 dsu on tree (树上并查集很形象吧),它是一种离线的算法。
特征为:
- 没有修改操作。
- 可以通过遍历子树,建立信息统计,然后统计查询的所有答案。
先引入一个概念,现在有若干单元素集合,每次都可以选择两个集合进行合并,合并的花费为元素少的集合的元素个数,问合并所有集合的最小花费是多少。
这个问题的思考角度是从每个元素最多被考虑多少次入手,x 元素被合并,它作为小的姿态消耗花费,合并完之后集合的大小最少增加两倍,如果下一次接着作为小的姿态做出贡献,集合的大小同样最少增加两倍,假设一开始的元素集合有 n 个,那么 x 被考虑的次数不超过 logn 次,因此最小的花费不超过 nlogn次。
这就是树上启发式合并的原理,需要用到树链剖分里面的重链剖分。
树上启发式合并的过程:
- void dfs(u, keep) u表示当前节点,keep表示是否保留贡献。
- 先遍历所有子树的轻儿子,遍历结束后清除对信息的贡献,dfs(轻儿子,0)。
- 再遍历重儿子,遍历结束后保留对信息的贡献,dfs(重儿子,1)。
- 考虑单个节点 u,统计其贡献
- 此时 u 节点时间刻的信息表里面只有重儿子信息,遍历所有轻儿子的子树,把信息再收回来。
- 做出 u 的答案。
- 如果 keep == 0,清除贡献,否则保留贡献。
https://www.luogu.com.cn/problem/U41492
#include <iostream>
#include <vector>
#include <unordered_map>
#include <map>
#include <unordered_set>
#include <set>
#include <algorithm>
#include <cmath>
#include <string>
#include <cstring>
#include <queue>
#include <cstring>
using namespace std;
#define endl '\n'
typedef long long LL;
typedef pair<int, int> PII;
#define lc p << 1
#define rc p << 1 | 1
#define lowbit(x) (x & -x)
const int N = 1e5 + 10;
const LL MOD = 1e9 + 7;
const double ln2 = log(2);
const double rec_ln2 = 1.0 / ln2;
int fa[N], cnt[N], son[N];
vector<int> edges[N];
int n, a[N], m;
int colcnt[N], diff, col[N], ans[N];
void dfs1(int x, int f)
{
fa[x] = f;
cnt[x] = 1;
for(auto& y : edges[x])
{
if(y == f) continue;
dfs1(y, x);
cnt[x] += cnt[y];
if(!son[x] || cnt[son[x]] < cnt[y]) son[x] = y;
}
}
void add(int x)
{
if(colcnt[col[x]]++ == 0) diff++;
for(auto& y : edges[x])
{
if(y == fa[x]) continue;
add(y);
}
}
void del(int x)
{
if(colcnt[col[x]]-- == 1) diff--;
for(auto& y : edges[x])
{
if(y == fa[x]) continue;
del(y);
}
}
void dfs2(int x, int keep)
{
// 遍历轻儿子
for(auto& y : edges[x])
{
if(y == fa[x] || y == son[x]) continue;
dfs2(y, 0); // 轻儿子的贡献不保留
}
if(son[x]) dfs2(son[x], 1); // 重儿子
// 结点自身的贡献
if(colcnt[col[x]]++ == 0) diff++;
// 收回轻儿子的信息
for(auto& y : edges[x])
{
if(y == son[x] || y == fa[x]) continue;
add(y);
}
// 统计结点 u 的答案
ans[x] = diff;
// 如果是轻儿子清空信息
if(keep == 0) del(x);
}
void solve()
{
cin >> n;
for(int i = 1; i <= n - 1; i++)
{
int x, y; cin >> x >> y;
edges[x].push_back(y);
edges[y].push_back(x);
}
for(int i = 1; i <= n; i++) cin >> col[i];
// cerr << 1 << endl;
dfs1(1, 0);
dfs2(1, 0);
cin >> m;
while(m--)
{
int x; cin >> x;
cout << ans[x] << endl;
}
}
int main()
{
cin.tie(0), cout.tie(0), ios::sync_with_stdio(false);
int T = 1;
// cin >> T;
while(T--)
{
solve();
}
return 0;
}
时间复杂度为:O(nlogn)
证明:
- 如果不考虑什么清空轻儿子,收回轻儿子的贡献,那么时间复杂度显然为O(n)。
- 这两个操作看似暴力实则不然。
- 还记得上面说到的集合合并操作吗?
- 对于清空轻儿子,考虑每个节点最多被清空的次数,如果该节点被清空一次,那么他一定是作为某个子树根节点的轻儿子所在的子树,由于重儿子的大小一定大于等于该轻儿子子树的大小,该节点被清空一次意味着子树的大小一定倍增,因此每个节点最多被清空 logn次。
- 对于重新收回轻儿子的贡献操作,同样考虑每个节点被收回的次数,由于重儿子的节点信息保留不会被重新收回,上述节点被收回一次后子树还是倍增,因此收回次数同样不超过logn
- 总结到,时间复杂度为O(nlogn)。
树上启发式合并可以离线处理有根树的k级孩子相关的问题,比如下面两道题都,一道求 k 级孩子中不同的个数,一道可以转化成 k 级孩子的个数,没有强制在线要求的树形dp,只要涉及到合并,最后用树上启发式合并,时间要优于树上莫队。
https://www.luogu.com.cn/problem/CF208E
#include <iostream>
#include <vector>
#include <unordered_map>
#include <map>
#include <unordered_set>
#include <set>
#include <algorithm>
#include <cmath>
#include <string>
#include <cstring>
#include <queue>
#include <cstring>
using namespace std;
#define endl '\n'
typedef long long LL;
typedef pair<int, int> PII;
#define lc p << 1
#define rc p << 1 | 1
#define lowbit(x) (x & -x)
const int N = 1e5 + 10, M = 25;
const LL MOD = 1e9 + 7;
const double ln2 = log(2);
const double rec_ln2 = 1.0 / ln2;
int n;
int dep[N], cnt[N], fa[N], son[N];
vector<int> edges[N];
vector<PII> op[N];
int ans[N];
int st[N][M];
// 层次为 i 的孩子数量
unordered_map<int, int> mp;
void dfs1(int x, int f)
{
cnt[x] = 1;
st[x][0] = f;
for(int i = 1; i <= 20; i++)
st[x][i] = st[st[x][i - 1]][i - 1];
dep[x] = dep[f] + 1;
for(auto& y : edges[x])
{
if(y == fa[x]) continue;
dfs1(y, x);
cnt[x] += cnt[y];
if(!son[x] || cnt[son[x]] < cnt[y]) son[x] = y;
}
}
void add(int x)
{
mp[dep[x]]++;
for(auto& y : edges[x])
{
if(y != fa[x])
add(y);
}
}
void del(int x)
{
mp[dep[x]] = 0;
for(auto& y : edges[x])
{
if(y != fa[x]) del(y);
}
}
void dfs2(int x, int keep)
{
for(auto& y : edges[x])
{
if(y == son[x] || y == fa[x]) continue;
dfs2(y, 0);
}
if(son[x]) dfs2(son[x], 1);
mp[dep[x]]++;
for(auto& y : edges[x])
{
if(y == son[x] || y == fa[x]) continue;
add(y);
}
for(auto& [id, k] : op[x])
{
ans[id] = mp[dep[x] + k] - 1;
}
if(keep == 0) del(x);
}
void solve()
{
cin >> n;
for(int i = 1; i <= n; i++)
{
int x; cin >> x;
fa[i] = x;
edges[x].push_back(i);
}
int m; cin >> m;
for(int x = 1; x <= n; x++)
{
if(fa[x] == 0) dfs1(x, 0);
}
for(int i = 1; i <= m; i++)
{
int v, k; cin >> v >> k;
for(int j = 20; j >= 0; j--)
{
if((k >> j) & 1) v = st[v][j];
}
// cout << v << ' ' << k << endl;
op[v].push_back({i, k});
}
for(int x = 1; x <= n; x++)
{
if(fa[x] == 0) dfs2(x, 0);
}
for(int i = 1; i <= m; i++) cout << ans[i] << " ";
}
int main()
{
cin.tie(0), cout.tie(0), ios::sync_with_stdio(false);
int T = 1;
// cin >> T;
while(T--)
{
solve();
}
return 0;
}
https://www.luogu.com.cn/problem/CF246E
#include <iostream>
#include <vector>
#include <unordered_map>
#include <map>
#include <unordered_set>
#include <set>
#include <algorithm>
#include <cmath>
#include <string>
#include <cstring>
#include <queue>
#include <cstring>
using namespace std;
#define endl '\n'
typedef long long LL;
typedef pair<int, int> PII;
#define lc p << 1
#define rc p << 1 | 1
#define lowbit(x) (x & -x)
const int N = 2e5 + 10;
const LL MOD = 1e9 + 7;
const double ln2 = log(2);
const double rec_ln2 = 1.0 / ln2;
unordered_map<string, int> mp;
int id, ans[N];
vector<int> edges[N];
unordered_map<int, set<int>> depcnt; // 层数为 i 的字符串列表
int dep[N], fa[N], son[N], cnt[N];
int n;
string a[N];
vector<PII> op[N];
void dfs1(int x, int f)
{
cnt[x] = 1;
dep[x] = dep[f] + 1;
for(auto& y : edges[x])
{
dfs1(y, x);
cnt[x] += cnt[y];
if(!son[x] || cnt[son[x]] < cnt[y]) son[x] = y;
}
}
void add(int x)
{
depcnt[dep[x]].insert(mp[a[x]]);
for(auto& y : edges[x])
{
add(y);
}
}
void del(int x)
{
depcnt[dep[x]].clear();
for(auto& y : edges[x])
{
del(y);
}
}
void dfs2(int x, int keep)
{
for(auto& y : edges[x])
{
if(y == son[x]) continue;
dfs2(y, 0);
}
if(son[x]) dfs2(son[x], 1);
depcnt[dep[x]].insert(mp[a[x]]);
for(auto& y : edges[x])
{
if(y == son[x]) continue;
add(y);
}
for(auto& [id, k] : op[x])
{
ans[id] = depcnt[dep[x] + k].size();
}
if(keep == 0) del(x);
}
void solve()
{
cin >> n;
for(int i = 1; i <= n; i++)
{
string s; cin >> s;
int r; cin >> r;
if(!mp.count(s)) mp[s] = ++id;
fa[i] = r;
a[i] = s;
edges[r].push_back(i);
}
int m; cin >> m;
for(int i = 1; i <= m; i++)
{
int u, k; cin >> u >> k;
op[u].push_back({i, k});
}
for(int x = 1; x <= n; x++)
{
if(fa[x] == 0) dfs1(x, 0);
}
for(int x = 1; x <= n; x++)
{
if(fa[x] == 0) dfs2(x, 0);
}
for(int i = 1; i <= m; i++) cout << ans[i] << endl;
}
int main()
{
cin.tie(0), cout.tie(0), ios::sync_with_stdio(false);
int T = 1;
// cin >> T;
while(T--)
{
solve();
}
return 0;
}
由于树上启发式合并的算法思想很简单,所以篮球杯可能也会考一点这东西,没活硬整。
#include <iostream>
#include <vector>
#include <unordered_map>
#include <map>
#include <unordered_set>
#include <set>
#include <algorithm>
#include <cmath>
#include <string>
#include <cstring>
#include <queue>
#include <cstring>
using namespace std;
#define endl '\n'
typedef long long LL;
typedef pair<int, int> PII;
#define lc p << 1
#define rc p << 1 | 1
#define lowbit(x) (x & -x)
const int N = 2e5 + 10;
const LL MOD = 1e9 + 7;
const double ln2 = log(2);
const double rec_ln2 = 1.0 / ln2;
int fa[N], cnt[N], son[N];
// 每种颜色出现的次数,出现次数为 i 的颜色种类数
int colcnt[N], cntnum[N];
vector<int> edges[N];
int n, a[N], ans;
void dfs1(int x, int f)
{
fa[x] = f;
cnt[x] = 1;
for(auto& y : edges[x])
{
dfs1(y, x);
cnt[x] += cnt[y];
if(!son[x] || cnt[son[x]] < cnt[y]) son[x] = y;
}
}
void add(int x)
{
cntnum[colcnt[a[x]]]--;
cntnum[++colcnt[a[x]]]++;
for(auto& y : edges[x])
{
add(y);
}
}
void del(int x)
{
cntnum[colcnt[a[x]]]--;
cntnum[--colcnt[a[x]]]++;
for(auto& y : edges[x])
{
del(y);
}
}
void dfs2(int x, int keep)
{
// 遍历轻儿子
for(auto& y : edges[x])
{
if(y == son[x]) continue;
dfs2(y, 0); // 贡献清空
}
// 重儿子
if(son[x]) dfs2(son[x], 1); // 贡献保留
cntnum[colcnt[a[x]]]--;
cntnum[++colcnt[a[x]]]++;
// 收回轻儿子的贡献
for(auto& y : edges[x])
{
if(y == son[x]) continue;
add(y);
}
// 统计 x 结点的答案
if(colcnt[a[x]] * cntnum[colcnt[a[x]]] == cnt[x]) ans++;
// 如果是轻儿子清空贡献
if(keep == 0) del(x);
}
void solve()
{
cin >> n;
for(int i = 1; i <= n; i++)
{
cin >> a[i];
int x; cin >> x;
edges[x].push_back(i);
}
dfs1(1, 0);
dfs2(1, 0);
cout << ans << endl;
}
int main()
{
cin.tie(0), cout.tie(0), ios::sync_with_stdio(false);
int T = 1;
// cin >> T;
while(T--)
{
solve();
}
return 0;
}
另外树上启发式合并也一些挺难的题,可能会求一些子树的最长回文路径的长度什么。
https://codeforces.com/problemset/status?my=on
- 首先把边权下方到点权,点权记录的是从 a 开始第 i 个字符的奇偶性,因此路径上字符组成的路径的奇偶性信息就清楚了,异或出来的答案中,如果为 0 或者只有一位为 1 就说明可以组成回文字符串,记录答案。
- 树上启发式合并的过程中,根节点的答案会收上所有子树的答案,它的maxdep中保留了重儿子的信息,去收轻儿子点权的时候枚举所有的可能,然后再把点权收上来。
时间复杂度的瓶颈在于树上启发式合并,为nlogn但是有一个22的常数。
浙公网安备 33010602011771号