树上启发式合并

\(Write-by-尹以豪\)

What is “启发式” ?

启发式指的是在树上处理问题的过程中,利用问题已经拥有的信息来引导解决问题,达到减少计算范围,降低问题复杂度的目的。

用人话来说就是基于人类的经验和直观感觉,对一些算法的优化。

举个例子,你要合并并查集,你肯定将小的数组追加的大数组的后面,代码如下。

void merge(int x, int y)
{
  int xx = find(x), yy = find(y);
  if(size[xx] < size[yy]) swap(xx, yy);
  fa[yy] = xx;
  size[xx] += size[yy];
}

这是最常见的启发式优化,这个优化可以提升我们寻找父亲节点的速度,和减少更新迭代的次数,这就是启发式优化。

dsu on tree

先引入一道例题。

\[\boxed{给出一棵 n 个节点以 1 为根的树,节点 u 的颜色为 c_u,\\ 于每个结点 u 询问以 u 为根的子树里一共出现了多少种不同的颜色。} \\ \boxed {n\le 2\times 10^5。} \]

考虑离线算法,\(O(1)\),直接朴素暴力的解法:

枚举一棵子树的根节点\(O(n)\),然后枚举每一个根节点的所有子节点\(O(n)\),时间复杂度\(O(n^2)\)

显然是很没救的。

不难想到优化,我们可以将每一棵子树(从叶子到根),的答案保留下来,这样似乎可以做到 \(O(nm)\)\(m\) 是有多少种颜色。

但是同样会爆炸,带着着一种思想,我们可不可以结合我们之前的重链剖分,先说做法,待会再讲解。

分析这个问题的性质,每一个子树的答案都是根据它的子树和它本身得到。

做法

以下用 \(cnt_i\) 表示 \(i\) 的颜色出现次数,\(ans_u\) 表示 \(u\) 的答案。

  1. 预处理出所有的重链剖分数据。
  2. 遍历它的重儿子保留它对 \(cnt\) 数组的影响
  3. 便利它的轻儿子不保留它对 \(cnt\) 数组的影响
  4. 再此遍历他的轻儿子,统计它对 \(ans_u\) 的贡献。

这个不难理解,不就是省去了几个重儿子的计算吗?能优化这么多?我一开始也费解,接下来,将证明这个算法的时间复杂度。

证明

我们像树链剖分一样定义重边和轻边(连向重儿子的为重边,其余为轻边)。关于重儿子和重边的定义,可以见下图,对于一棵有 \(n\) 个节点的树:

根节点到树上任意节点的轻边数不超过 \(\log n\) 条。我们设根到该节点有 \(x\) 条轻边该节点的子树大小为 \(y\),显然轻边连接的子节点的子树大小小于父亲的一半(若大于一半就不是轻边了),则 \(y<n/2^x\) ,显然 \(n>2^x\) ,所以 \(x<\log n\)

又因为如果一个节点是其父亲的重儿子,则它的子树必定在它的兄弟之中最多,所以任意节点到根的路径上所有重边连接的父节点在计算答案时必定不会遍历到这个节点,所以一个节点的被遍历的次数等于它到根节点路径上的轻边数 \(+1\)(之所以要 \(+1\) 是因为它本身要被遍历到),所以一个节点的被遍历次数 \(=\log n+1\), 总时间复杂度则为 \(O(n(\log n+1))=O(n\log n)\) ,输出答案花费 \(O(m)\)

练习题

CF741D Arpa’s letter-marked tree and Mehrdad’s Dokhtar-kosh paths

这题还是说一下吧,这一道题是找一条路径!!!,那怎么办,还是先分析吧,题目要求找出排序后的回文串最长,那也就是说一个串内字母出现次数为奇数的个数最多为 \(1\),那就满足题目要求。因为我们只需要统计奇偶性(也就是\(0/1\)),\(0/1 ?\) ,我们可以用一个二进制数表示一个穿内字母出现次数的奇偶性,设这一个数为 \(dis_i\)。回归这个问题的难点,要找的是 一条路径 ,一条路径每个字母出现次数的奇偶性我们可以这样表示:\(dis_x \oplus dis_{lca(x, y)} \oplus dis_y \oplus dis_{lca(x, y)} = dis_x \oplus dis_y\) ,OK,解释清楚了,\(dis_x \oplus dis_y\) 一共有 \(23\) 种方案,分别是全是 \(0\) ,和一个位置是 \(1\)

点击查看代码
#include<bits/stdc++.h>
#define LL long long
using namespace std;

const LL maxn = 5e5 + 5;
LL n, cnt, idx, head[maxn], dep[maxn], sz[maxn], son[maxn];
LL book[maxn * 10], vis[maxn], ans[maxn], dis[maxn], Id[maxn];
LL L[maxn], R[maxn];

struct Edge
{
    LL to, next, w;
} e[maxn];

void addEdge(LL x, LL y, LL w)
{
    e[++cnt].to = y;
    e[cnt].w = w;
    e[cnt].next = head[x];
    head[x] = cnt;
}

void dfs1(LL u, LL father)
{
    sz[u] = 1;
    dep[u] = dep[father] + 1;
    L[u] = ++idx;
    Id[idx] = u;
    LL maxson = -1;
    
    for(LL i = head[u]; i; i = e[i].next)
	{
        LL v = e[i].to;
        if(v == father) continue;
        dis[v] = dis[u] ^ e[i].w;
        dfs1(v, u);
        sz[u] += sz[v];
        if(sz[v] > maxson)
		{
            son[u] = v;
            maxson = sz[v];
        }
    }
    R[u] = idx;
}

void dfs2(LL u, LL keep) {
    for(LL i = head[u]; i; i = e[i].next)
	{
        LL v = e[i].to;
        if(v == son[u]) continue;
        dfs2(v, 0);
        ans[u] = max(ans[u], ans[v]);
    }
    
    if(son[u])
	{
        dfs2(son[u], 1);
        ans[u] = max(ans[u], ans[son[u]]);
    }
    
    if(book[dis[u]]) 
        ans[u] = max(ans[u], book[dis[u]] - dep[u]);
    
    for(LL i = 0; i <= 21; i++) 
        if(book[dis[u] ^(1LL << i)]) 
            ans[u] = max(ans[u], book[dis[u] ^(1LL << i)] - dep[u]);
    
    book[dis[u]] = max(dep[u], book[dis[u]]);
    
    for(LL i = head[u]; i; i = e[i].next)
	{
        LL v = e[i].to;
        if(v == son[u]) continue;
        
        for(LL j = L[v]; j <= R[v]; j++)
		{
            LL x = Id[j];
            if(book[dis[x]]) 
                ans[u] = max(ans[u], book[dis[x]] + dep[x] - 2 * dep[u]);
            
            for(LL k = 0; k <= 21; k++) 
                if(book[dis[x] ^(1LL << k)]) 
                    ans[u] = max(ans[u], book[dis[x] ^(1LL << k)] + dep[x] - 2 * dep[u]);
        }
        
        for(LL j = L[v]; j <= R[v]; j++) 
            book[dis[Id[j]]] = max(book[dis[Id[j]]], dep[Id[j]]);
    }
    
    if(!keep) 
        for(LL i = L[u]; i <= R[u]; i++) 
            book[dis[Id[i]]] = 0;
}

int main()
{
    ios::sync_with_stdio(0);
    cin.tie(0), cout.tie(0);
    
    cin >> n;
    for(LL i = 2; i <= n; i++)
	{
        LL x;
        cin >> x;
        char ch;
		cin >> ch;
        addEdge(x, i, 1LL << (ch - 'a'));
    }
    
    dfs1(1, 0);
    dfs2(1, 1);
    
    for(LL i = 1; i <= n; i++) 
        cout << ans[i] << " ";
    
    return 0;
}

A - Lomsat gelral CodeForces - 600E

点击查看代码
#include<bits/stdc++.h>
#define LL long long
using namespace std;

const LL maxn = 1e5 + 5;
LL n, c[maxn], x, y, newpos, sz[maxn], son[maxn], cnt[maxn], maxx, flag, ans[maxn], sum;
vector <int> e[maxn];
void dfs1(int u, int father)
{
	sz[u] = 1;
	for(int i : e[u])
	{
		int v = i;
		if(v == father) continue;
		dfs1(v, u);
		sz[u] += sz[v];
		if(sz[v] > sz[son[u]]) son[u] = v;
	}
}
void del(int u, int father)
{
	cnt[c[u]]--;
	for(int i : e[u])
	{
		int v = i;
		if(v == father) continue;
		del(v, u);
	}
}
void calc(int u, int father)
{
	cnt[c[u]]++;
	if(cnt[c[u]] > maxx)
	{
		maxx = cnt[c[u]];
		sum = c[u];
	}
	else if(cnt[c[u]] == maxx)
		sum += c[u];
	for(int i : e[u])
	{
		int v = i;
		if(v == father or v == flag) continue;
		calc(v, u);
	}
}
void solve(int u, int father, bool keep)
{
	for(int i : e[u])
	{
		int v = i;
		if(v == father or v == son[u]) continue;
		solve(v, u, 0);
	}
	if(son[u]) solve(son[u], u, 1), flag = son[u];
	calc(u, father);
	flag = 0;
	ans[u] = sum;
	if(!keep) del(u, father), sum = maxx = 0;
}

int main()
{
	ios::sync_with_stdio(0);
	cin.tie(0), cout.tie(0);
	
	cin >> n;
	for(int i = 1; i <= n; i++)
		cin >> c[i];
	for(int i = 1; i < n; i++)
	{
		cin >> x >> y;
		e[x].push_back(y);
		e[y].push_back(x);
	}
	dfs1(1, 1);
	solve(1, 1, 1);
	for(int i = 1; i <= n; i++)
		cout << ans[i] << ' ';
	return 0;
}

Dominant Indices CodeForces - 1009F

点击查看代码
#include<bits/stdc++.h>
#define LL long long
using namespace std;

const LL maxn = 1e6 + 5;
LL n, x, y;
vector<int> e[maxn], f[maxn];
LL fa[maxn], len[maxn], son[maxn], ans[maxn];

void dfs1(int u, int father)
{
    fa[u] = father;
    for(int i : e[u])
    {
        int v = i;
        if(v == father) continue;
        dfs1(v, u);
        if(len[v] >= len[son[u]])
            son[u] = v, len[u] = len[v] + 1;
    }
}

void dfs2(int u)
{
    if(son[u])
    {
        dfs2(son[u]);
        swap(f[u], f[son[u]]);
        f[u].push_back(1);
        ans[u] = ans[son[u]];
        if(f[u][ans[u]] == 1) ans[u] = len[u];
        
        for(int i : e[u])
        {
            int v = i;
            if(v == fa[u] or v == son[u]) continue;
            dfs2(v);
            for(int i = len[v]; i >= 0; i--)
            {
                LL tmp = i + len[u] - len[v] - 1;
                f[u][tmp] += f[v][i];
                if(f[u][tmp] > f[u][ans[u]] or (f[u][tmp] == f[u][ans[u]] and tmp > ans[u]))
                    ans[u] = tmp;
            }
        }
    }
    else
    {
        f[u].push_back(1);
        ans[u] = 0;
    }
}

int main()
{
    ios::sync_with_stdio(0);
    cin.tie(0), cout.tie(0);
    
    cin >> n;
    for(int i = 1; i < n; i++)
    {
        cin >> x >> y;
        e[x].push_back(y);
        e[y].push_back(x);
    }
    dfs1(1, 0);
    dfs2(1);
    for(int i = 1; i <= n; i++)
        cout << len[i] - ans[i] << '\n';
    return 0;
}

Blood Cousins CodeForces - 208E

点击查看代码
#include <cstdio>
#include <vector>
#include <iostream>
using namespace std;

const int maxn = 1e5 + 5, MAXK = 20;
int n, m, col[maxn];
struct Query
{
    int id, dep;
};
vector<Query> Qry[maxn];
vector<int> Graph[maxn];

int dep[maxn], Size[maxn], Son[maxn], Anc[maxn][MAXK + 5];
void dfs1(int u, int fa)
{
    dep[u] = dep[fa] + 1, Size[u] = 1;
    Anc[u][0] = fa;
    for(int i = 1; i <= MAXK; i++)
		Anc[u][i] = Anc[Anc[u][i - 1]][i - 1];
    for(int i = 0; i < Graph[u].size(); i++)
	{
        int v = Graph[u][i];
        if (v == fa) continue;
        dfs1(v, u);
        Size[u] += Size[v];
        if (Size[Son[u]] < Size[v])
			Son[u] = v;
    }
}

int Cnt[maxn], Ans[maxn], Sonu;
void Count(int u, int fa, int val)
{
    Cnt[dep[u]] += val;
    for(int i = 0; i < Graph[u].size(); i++)
	{
        int v = Graph[u][i];
        if (v == fa || v == Sonu) continue;
        Count(v, u, val);
    }
}
bool vis[maxn];
void dfs2(int u, int fa, bool is_hs)
{
    vis[u] = 1;
    for(int i = 0; i < Graph[u].size(); i++)
	{
        int v = Graph[u][i];
        if(v == fa || v == Son[u]) continue;
        dfs2(v, u, 0);
    }
    if(Son[u])
		dfs2(Son[u], u, 1);
    Sonu = Son[u];
    Count(u, fa, 1);
    Sonu = 0;
    for(int i = 0; i < Qry[u].size(); i++)
		Ans[Qry[u][i].id] = Cnt[dep[u] + Qry[u][i].dep] - 1;
    if(!is_hs) Count(u, fa, -1);
}

int main() {
    scanf("%d", &n);
    for(int i = 1, fa; i <= n; i++)
	{
        scanf("%d", &fa);
        Graph[fa].push_back(i);
    }
    for(int i = 1; i <= n; i++)
        if(!dep[i]) dfs1(i, 0);
    scanf("%d", &m);
    for(int i = 1, rt, k; i <= m; i++)
	{
        scanf("%d %d", &rt, &k);
        for(int i = 0; i <= MAXK; i++)
            if((k >> i) & 1) rt = Anc[rt][i];
        Qry[rt].push_back({i, k});
    }
    for(int i = 1; i <= n; i++)
        if(!vis[i]) dfs2(i, 0, 0);
    for(int i = 1; i <= m; i++)
        printf("%d ", Ans[i]);
    return 0;
}

Promotion Counting P 洛谷 - P3605

这道题用树状数组优化一下。

点击查看代码
#include<bits/stdc++.h>
#define LL long long
using namespace std;

const LL maxn = 1e5 + 5;
LL n, c[maxn], x, sz[maxn], son[maxn], cnt[maxn], ans[maxn];
vector<int> e[maxn];
struct whole
{
    LL pos, num;
} a[maxn];

bool cmp(whole a, whole b)
{
    return a.num < b.num;
}

struct FenwickTree {
    vector<LL> bit;
    int n;
    
    FenwickTree(int size)
	{
        this->n = size;
        bit.assign(n + 2, 0);
    }
    
    void add(int idx, LL delta)
	{
        for(; idx <= n; idx += idx & -idx)
            bit[idx] += delta;
    }
    
    LL query(int idx)
	{
        LL res = 0;
        for(; idx > 0; idx -= idx & -idx)
            res += bit[idx];
        return res;
    }
    
    LL queryGreater(int x)
	{
        return query(n) - query(x);
    }
};

void dfs1(int u, int father) {
    sz[u] = 1;
    for(int i : e[u]) {
        int v = i;
        if(v == father) continue;
        dfs1(v, u);
        sz[u] += sz[v];
        if(sz[v] > sz[son[u]]) son[u] = v;
    }
}

void calc(int u, int father, int k, FenwickTree &fenwick)
{
    fenwick.add(c[u], k);
    for(int i : e[u]) {
        int v = i;
        if(v == father) continue;
        calc(v, u, k, fenwick);
    }
}

void solve(int u, int father, bool keep, FenwickTree &fenwick)
{
    for(int i : e[u])
	{
        int v = i;
        if(v == father || v == son[u]) continue;
        solve(v, u, false, fenwick);
    }
    
    if(son[u]) solve(son[u], u, true, fenwick);
    
    for(int i : e[u])
	{
        int v = i;
        if(v == father || v == son[u]) continue;
        calc(v, u, 1, fenwick);
    }
    
    fenwick.add(c[u], 1);
    ans[u] = fenwick.queryGreater(c[u]);
    
    if(!keep)
	{
        calc(u, father, -1, fenwick);
    }
}

int main() {
    ios::sync_with_stdio(0);
    cin.tie(0), cout.tie(0);
    
    cin >> n;
    for(int i = 1; i <= n; i++)
	{
        cin >> a[i].num;
        a[i].pos = i;
    }
    sort(a + 1, a + 1 + n, cmp);
    for(int i = 1; i <= n; i++)
        c[a[i].pos] = i;
    
    for(int i = 2; i <= n; i++)
	{
        cin >> x;
        e[x].push_back(i);
        e[i].push_back(x);
    }
    
    dfs1(1, 0);
    FenwickTree fenwick(n);
    solve(1, 0, true, fenwick);
    
    for(int i = 1; i <= n; i++)
        cout << ans[i] << '\n';
    
    return 0;
}
posted @ 2025-07-18 15:07  StudentE  阅读(20)  评论(0)    收藏  举报