树上启发式合并

什么是启发式算法?

启发式算法就是根据人们的直觉对一些算法进行优化(这很启发式了),典型的就是并查集的路径压缩

fa[x] == x ? x : find(fa[x])

让高度小的树成为高度较大的树的子树,这个优化可以称为启发式合并算法。

image

AC自动机里面对fail指针绕圈的优化也是启发式算法。

树上启发式合并

树上启发式合并也叫 dsu on tree (树上并查集很形象吧),它是一种离线的算法。

特征为:

  • 没有修改操作。
  • 可以通过遍历子树,建立信息统计,然后统计查询的所有答案。

先引入一个概念,现在有若干单元素集合,每次都可以选择两个集合进行合并,合并的花费为元素少的集合的元素个数,问合并所有集合的最小花费是多少。

这个问题的思考角度是从每个元素最多被考虑多少次入手,x 元素被合并,它作为小的姿态消耗花费,合并完之后集合的大小最少增加两倍,如果下一次接着作为小的姿态做出贡献,集合的大小同样最少增加两倍,假设一开始的元素集合有 n 个,那么 x 被考虑的次数不超过 logn 次,因此最小的花费不超过 nlogn次。

这就是树上启发式合并的原理,需要用到树链剖分里面的重链剖分。

树上启发式合并的过程:

  • void dfs(u, keep) u表示当前节点,keep表示是否保留贡献。
  • 先遍历所有子树的轻儿子,遍历结束后清除对信息的贡献,dfs(轻儿子,0)。
  • 再遍历重儿子,遍历结束后保留对信息的贡献,dfs(重儿子,1)。
  • 考虑单个节点 u,统计其贡献
  • 此时 u 节点时间刻的信息表里面只有重儿子信息,遍历所有轻儿子的子树,把信息再收回来。
  • 做出 u 的答案。
  • 如果 keep == 0,清除贡献,否则保留贡献。

https://www.luogu.com.cn/problem/U41492

#include <iostream>
#include <vector>
#include <unordered_map>
#include <map>
#include <unordered_set>
#include <set>
#include <algorithm>
#include <cmath>
#include <string>
#include <cstring>
#include <queue>
#include <cstring>


using namespace std;
#define endl '\n'
typedef long long LL;
typedef pair<int, int> PII;
#define lc p << 1
#define rc p << 1 | 1
#define lowbit(x) (x & -x)
const int N = 1e5 + 10;
const LL MOD = 1e9 + 7;
const double ln2 = log(2);
const double rec_ln2 = 1.0 / ln2;
int fa[N], cnt[N], son[N];
vector<int> edges[N];
int n, a[N], m;
int colcnt[N], diff, col[N], ans[N];

void dfs1(int x, int f)
{
    fa[x] = f;
    cnt[x] = 1;
    for(auto& y : edges[x])
    {
        if(y == f) continue;
        dfs1(y, x);
        cnt[x] += cnt[y];
        if(!son[x] || cnt[son[x]] < cnt[y]) son[x] = y;
    }
}

void add(int x)
{
    if(colcnt[col[x]]++ == 0) diff++;
    for(auto& y : edges[x])
    {
        if(y == fa[x]) continue;
        add(y);
    }
}

void del(int x)
{
    if(colcnt[col[x]]-- == 1) diff--;
    for(auto& y : edges[x])
    {
        if(y == fa[x]) continue;
        del(y);
    }
}

void dfs2(int x, int keep)
{
    // 遍历轻儿子
    for(auto& y : edges[x])
    {
        if(y == fa[x] || y == son[x]) continue;
        dfs2(y, 0); // 轻儿子的贡献不保留
    }
    if(son[x]) dfs2(son[x], 1); // 重儿子

    // 结点自身的贡献
    if(colcnt[col[x]]++ == 0) diff++;

    // 收回轻儿子的信息
    for(auto& y : edges[x])
    {
        if(y == son[x] || y == fa[x]) continue;
        add(y);
    }

    // 统计结点 u 的答案
    ans[x] = diff;

    // 如果是轻儿子清空信息
    if(keep == 0) del(x);
}

void solve()
{
    cin >> n;
    for(int i = 1; i <= n - 1; i++)
    {
        int x, y; cin >> x >> y;
        edges[x].push_back(y);
        edges[y].push_back(x);
    }    
    for(int i = 1; i <= n; i++) cin >> col[i];
    // cerr << 1 << endl;
    dfs1(1, 0);
    dfs2(1, 0);
    cin >> m;
    while(m--)
    {
        int x; cin >> x;
        cout << ans[x] << endl;
    }
}

int main()
{
    cin.tie(0), cout.tie(0), ios::sync_with_stdio(false);
    int T = 1;
    // cin >> T;
    while(T--)
    {
        solve();
    }
    return 0;
}

时间复杂度为:O(nlogn)

证明:

  • 如果不考虑什么清空轻儿子,收回轻儿子的贡献,那么时间复杂度显然为O(n)。
  • 这两个操作看似暴力实则不然。
  • 还记得上面说到的集合合并操作吗?
  • 对于清空轻儿子,考虑每个节点最多被清空的次数,如果该节点被清空一次,那么他一定是作为某个子树根节点的轻儿子所在的子树,由于重儿子的大小一定大于等于该轻儿子子树的大小,该节点被清空一次意味着子树的大小一定倍增,因此每个节点最多被清空 logn次。
  • 对于重新收回轻儿子的贡献操作,同样考虑每个节点被收回的次数,由于重儿子的节点信息保留不会被重新收回,上述节点被收回一次后子树还是倍增,因此收回次数同样不超过logn
  • 总结到,时间复杂度为O(nlogn)。

树上启发式合并可以离线处理有根树的k级孩子相关的问题,比如下面两道题都,一道求 k 级孩子中不同的个数,一道可以转化成 k 级孩子的个数,没有强制在线要求的树形dp,只要涉及到合并,最后用树上启发式合并,时间要优于树上莫队。

https://www.luogu.com.cn/problem/CF208E

#include <iostream>
#include <vector>
#include <unordered_map>
#include <map>
#include <unordered_set>
#include <set>
#include <algorithm>
#include <cmath>
#include <string>
#include <cstring>
#include <queue>
#include <cstring>


using namespace std;
#define endl '\n'
typedef long long LL;
typedef pair<int, int> PII;
#define lc p << 1
#define rc p << 1 | 1
#define lowbit(x) (x & -x)
const int N = 1e5 + 10, M = 25;
const LL MOD = 1e9 + 7;
const double ln2 = log(2);
const double rec_ln2 = 1.0 / ln2;
int n;
int dep[N], cnt[N], fa[N], son[N];
vector<int> edges[N];
vector<PII> op[N];
int ans[N];
int st[N][M];
// 层次为 i 的孩子数量
unordered_map<int, int> mp; 
void dfs1(int x, int f)
{
    cnt[x] = 1;
    st[x][0] = f;
    for(int i = 1; i <= 20; i++) 
        st[x][i] = st[st[x][i - 1]][i - 1];
    dep[x] = dep[f] + 1;
    for(auto& y : edges[x])
    {
        if(y == fa[x]) continue;
        dfs1(y, x);
        cnt[x] += cnt[y];
        if(!son[x] || cnt[son[x]] < cnt[y]) son[x] = y;
    }
}

void add(int x)
{
    mp[dep[x]]++;
    for(auto& y : edges[x])
    {
        if(y != fa[x])
        add(y);
    }
}

void del(int x)
{
    mp[dep[x]] = 0;
    for(auto& y : edges[x])
    {
        if(y != fa[x]) del(y);
    }
}

void dfs2(int x, int keep)
{
    for(auto& y : edges[x])
    {
        if(y == son[x] || y == fa[x]) continue;
        dfs2(y, 0);
    }
    if(son[x]) dfs2(son[x], 1);
    mp[dep[x]]++;
    for(auto& y : edges[x])
    {
        if(y == son[x] || y == fa[x]) continue;
        add(y);
    }
    for(auto& [id, k] : op[x])
    {
        ans[id] = mp[dep[x] + k] - 1;
    }
    if(keep == 0) del(x);
}

void solve()
{
    cin >> n;
    for(int i = 1; i <= n; i++)
    {
        int x; cin >> x;
        fa[i] = x;
        edges[x].push_back(i);
    }
    int m; cin >> m;
    for(int x = 1; x <= n; x++)
    {
        if(fa[x] == 0) dfs1(x, 0);
    }
    for(int i = 1; i <= m; i++) 
    {
        int v, k; cin >> v >> k;
        for(int j = 20; j >= 0; j--)
        {
            if((k >> j) & 1) v = st[v][j];
        }
        // cout << v << ' ' << k << endl;
        op[v].push_back({i, k});
    }
    for(int x = 1; x <= n; x++)
    {
        if(fa[x] == 0) dfs2(x, 0);
    }
    for(int i = 1; i <= m; i++) cout << ans[i] << " ";
} 

int main()
{
    cin.tie(0), cout.tie(0), ios::sync_with_stdio(false);
    int T = 1;
    // cin >> T;
    while(T--)
    {
        solve();
    }
    return 0;
}

https://www.luogu.com.cn/problem/CF246E

#include <iostream>
#include <vector>
#include <unordered_map>
#include <map>
#include <unordered_set>
#include <set>
#include <algorithm>
#include <cmath>
#include <string>
#include <cstring>
#include <queue>
#include <cstring>


using namespace std;
#define endl '\n'
typedef long long LL;
typedef pair<int, int> PII;
#define lc p << 1
#define rc p << 1 | 1
#define lowbit(x) (x & -x)
const int N = 2e5 + 10;
const LL MOD = 1e9 + 7;
const double ln2 = log(2);
const double rec_ln2 = 1.0 / ln2;
unordered_map<string, int> mp;
int id, ans[N];
vector<int> edges[N];
unordered_map<int, set<int>> depcnt; // 层数为 i 的字符串列表
int dep[N], fa[N], son[N], cnt[N];
int n;
string a[N];
vector<PII> op[N];

void dfs1(int x, int f)
{
    cnt[x] = 1;
    dep[x] = dep[f] + 1;
    for(auto& y : edges[x])
    {
        dfs1(y, x);
        cnt[x] += cnt[y];
        if(!son[x] || cnt[son[x]] < cnt[y]) son[x] = y;
    }
}

void add(int x)
{
    depcnt[dep[x]].insert(mp[a[x]]);
    for(auto& y : edges[x])
    {
        add(y);
    }
}

void del(int x)
{
    depcnt[dep[x]].clear();
    for(auto& y : edges[x])
    {
        del(y);
    }
}

void dfs2(int x, int keep)
{
    for(auto& y : edges[x])
    {
        if(y == son[x]) continue;
        dfs2(y, 0);
    }
    if(son[x]) dfs2(son[x], 1);

    depcnt[dep[x]].insert(mp[a[x]]);
    for(auto& y : edges[x])
    {
        if(y == son[x]) continue;
        add(y);
    }
    for(auto& [id, k] : op[x])
    {
        ans[id] = depcnt[dep[x] + k].size();
    }
    if(keep == 0) del(x);
}

void solve()
{
    cin >> n;
    for(int i = 1; i <= n; i++) 
    {
        string s; cin >> s;
        int r; cin >> r;
        if(!mp.count(s)) mp[s] = ++id;
        fa[i] = r;
        a[i] = s;
        edges[r].push_back(i);
    }    
    int m; cin >> m;
    for(int i = 1; i <= m; i++)
    {
        int u, k; cin >> u >> k;
        op[u].push_back({i, k});
    }
    for(int x = 1; x <= n; x++)
    {
        if(fa[x] == 0) dfs1(x, 0);
    }
    for(int x = 1; x <= n; x++)
    {
        if(fa[x] == 0) dfs2(x, 0);
    }
    for(int i = 1; i <= m; i++) cout << ans[i] << endl;
}

int main()
{
    cin.tie(0), cout.tie(0), ios::sync_with_stdio(false);
    int T = 1;
    // cin >> T;
    while(T--)
    {
        solve();
    }
    return 0;
}

由于树上启发式合并的算法思想很简单,所以篮球杯可能也会考一点这东西,没活硬整。

#include <iostream>
#include <vector>
#include <unordered_map>
#include <map>
#include <unordered_set>
#include <set>
#include <algorithm>
#include <cmath>
#include <string>
#include <cstring>
#include <queue>
#include <cstring>


using namespace std;
#define endl '\n'
typedef long long LL;
typedef pair<int, int> PII;
#define lc p << 1
#define rc p << 1 | 1
#define lowbit(x) (x & -x)
const int N = 2e5 + 10;
const LL MOD = 1e9 + 7;
const double ln2 = log(2);
const double rec_ln2 = 1.0 / ln2;
int fa[N], cnt[N], son[N];
// 每种颜色出现的次数,出现次数为 i 的颜色种类数
int colcnt[N], cntnum[N]; 
vector<int> edges[N];
int n, a[N], ans;

void dfs1(int x, int f)
{
    fa[x] = f;
    cnt[x] = 1;
    for(auto& y : edges[x])
    {
        dfs1(y, x);
        cnt[x] += cnt[y];
        if(!son[x] || cnt[son[x]] < cnt[y]) son[x] = y;
    }
}


void add(int x)
{
    cntnum[colcnt[a[x]]]--;
    cntnum[++colcnt[a[x]]]++;
    for(auto& y : edges[x])
    {
        add(y);
    }
}

void del(int x)
{
    cntnum[colcnt[a[x]]]--;
    cntnum[--colcnt[a[x]]]++;
    for(auto& y : edges[x])
    {
        del(y);
    }
}

void dfs2(int x, int keep)
{
    // 遍历轻儿子
    for(auto& y : edges[x])
    {
        if(y == son[x]) continue;
        dfs2(y, 0); // 贡献清空
    }

    // 重儿子
    if(son[x]) dfs2(son[x], 1); // 贡献保留

    cntnum[colcnt[a[x]]]--;
    cntnum[++colcnt[a[x]]]++;

    // 收回轻儿子的贡献
    for(auto& y : edges[x])
    {
        if(y == son[x]) continue;
        add(y);
    }

    // 统计 x 结点的答案
    if(colcnt[a[x]] * cntnum[colcnt[a[x]]] == cnt[x]) ans++;

    // 如果是轻儿子清空贡献
    if(keep == 0) del(x);
}

void solve()
{
    cin >> n;
    for(int i = 1; i <= n; i++)
    {
        cin >> a[i];
        int x; cin >> x;
        edges[x].push_back(i);
    }    
    dfs1(1, 0);
    dfs2(1, 0);
    cout << ans << endl;
}

int main()
{
    cin.tie(0), cout.tie(0), ios::sync_with_stdio(false);
    int T = 1;
    // cin >> T;
    while(T--)
    {
        solve();
    }
    return 0;
}

另外树上启发式合并也一些挺难的题,可能会求一些子树的最长回文路径的长度什么。

https://codeforces.com/problemset/status?my=on

  • 首先把边权下方到点权,点权记录的是从 a 开始第 i 个字符的奇偶性,因此路径上字符组成的路径的奇偶性信息就清楚了,异或出来的答案中,如果为 0 或者只有一位为 1 就说明可以组成回文字符串,记录答案。
  • 树上启发式合并的过程中,根节点的答案会收上所有子树的答案,它的maxdep中保留了重儿子的信息,去收轻儿子点权的时候枚举所有的可能,然后再把点权收上来。

时间复杂度的瓶颈在于树上启发式合并,为nlogn但是有一个22的常数。


posted on 2026-06-29 10:47  我不爱吃汉堡  阅读(3)  评论(0)    收藏  举报

导航