dsu on Tree

这是利用了并查集的启发式合并设计的一种算法。
常见模板如下(后面会介绍对其dfs优化)

void solve(int x) {
	找到重儿子 big
	for (y is son of x)
		if (y != big) solve(y), 删除y这棵子树
	solve(big);// 不删除big对应的子树
	for (y is son of x)
		if (y != big) 加入y这棵子树
	加入x节点
	回答在x节点上的询问
}

以下翻译自CF大佬的博客

什么是 dsu on tree?

使用 dsu 可以帮助我们回答以下类型的查询:

如何在 \(O(nlog\;n)\) 时间内,查询以每一个节点为根的子树内,所有满足某一性质的点的个数?

例如:

给定一棵树,每一个节点都有颜色。查询节点 v 的子树中有多少个点的颜色是 c

我们看看如何解决这个问题。

首先,我们需要树上每个节点的子树大小,可以用 dfs 轻松搞定:

int sz[N];
void getsz(int v, int p){
    sz[v] = 1;  // v 也属于以自己为根的子树
    for (auto u : g[v])
        if (u != p){
            getsz(u, v);
            sz[v] += sz[u]; // 将子树 u 的大小加入 v 中
        }
}

现在 sz[v] 即为以 v 为根的子树大小。

先来看一下暴力的写法(\(O(n^2)\)):

int cnt[N];
void add(int v, int p, int x){
    cnt[ col[v] ] += x;
    for (auto u: g[v])
        if (u != p)
            add(u, v, x)
}
void dfs(int v, int p){
    add(v, p, 1); //求出以 v 为根的子树中包含的颜色为 c 的节点个数,存在 cnt[c] 中
    add(v, p, -1);//回溯,清掉 cnt 的值,准备计算别的子树
    for (auto u : g[v])
        if (u != p)
            dfs(u, v);
}

现在想想,如何改进?这里介绍如下几种代码技巧:

  1. 编写简单,但需要 \(O(nlog^2\;n)\)
map<int, int> *cnt[N];
void dfs(int v, int p){
    int mx = -1, bigChild = -1;
    for (auto u : g[v])
       if (u != p){
           dfs(u, v);
           if (sz[u] > mx)
               mx = sz[u], bigChild = u;
       }
    if (bigChild != -1)
        cnt[v] = cnt[bigChild];
    else
        cnt[v] = new map<int, int> ();
    (*cnt[v])[ col[v] ] ++;
    for (auto u : g[v])
       if (u != p && u != bigChild){
           for (auto x : *cnt[u])
               (*cnt[v])[x.first] += x.second;
       }
    // 现在 (*cnt[v])[c] 就是以 v 为根的子树中包含的颜色为 c 的节点个数
}
  1. 编写简单且是 \(O(nlog\;n)\)
vector<int> *vec[N];
int cnt[N];
void dfs(int v, int p, bool keep){
    int mx = -1, bigChild = -1;
    for (auto u : g[v])
       if (u != p && sz[u] > mx)
           mx = sz[u], bigChild = u;
    for (auto u : g[v])
       if (u != p && u != bigChild)
           dfs(u, v, 0);
    if (bigChild != -1)
        dfs(bigChild, v, 1), vec[v] = vec[bigChild];
    else
        vec[v] = new vector<int> ();
    vec[v]->push_back(v);
    cnt[ col[v] ]++;
    for (auto u : g[v])
       if (u != p && u != bigChild)
           for (auto x : *vec[u]){
               cnt[ col[x] ]++;
               vec[v] -> push_back(x);
           }
    // 现在 cnt[c] 就是以 v 为根的子树中包含的颜色为 c 的节点个数
    // 注意:到这一步 *vec[v] 包含了节点 v 的所有子树
    if (keep == 0)
        for (auto u : *vec[v])
            cnt[ col[u] ]--;
}
  1. 重链剖分版本的 \(O(nlog\;n)\).
int cnt[N];
bool big[N];
void add(int v, int p, int x){
    cnt[ col[v] ] += x;
    for (auto u: g[v])
        if (u != p && !big[u])
            add(u, v, x)
}
void dfs(int v, int p, bool keep){
    int mx = -1, bigChild = -1;
    for (auto u : g[v])
       if (u != p && sz[u] > mx)
          mx = sz[u], bigChild = u;
    for (auto u : g[v])
        if (u != p && u != bigChild)
            dfs(u, v, 0);  // 在轻儿子的子树上 dfs,结束后清空 cnt 内容
    if (bigChild != -1)
        dfs(bigChild, v, 1), big[bigChild] = 1;  // 重儿子的 keep 标记为 1,dfs 后不清空 cnt
    add(v, p, 1);
    // 现在 cnt[c] 就是以 v 为根的子树中包含的颜色为 c 的节点个数
    if (bigChild != -1)
        big[bigChild] = 0;
    if (keep == 0)
        add(v, p, -1);
}
  1. 我的发明,也是 \(O(nlog\;n)\)
    这种 Dsu on tree 的实现方法是我(CF上的一位大佬)新发明的。它的代码更容易实现。令 st[v]ft[v] 分别表示节点 v 的 dfs 序的起始时间和结束时间,ver[time] 表示起始时间为 time 的点的编号。
int cnt[N];
void dfs(int v, int p, bool keep){
    int mx = -1, bigChild = -1;
    for (auto u : g[v])
       if(u != p && sz[u] > mx)
          mx = sz[u], bigChild = u;
    for (auto u : g[v])
        if (u != p && u != bigChild)
            dfs(u, v, 0);  // 在轻儿子的子树上 dfs,结束后清空 cnt 内容
    if (bigChild != -1)
        dfs(bigChild, v, 1);  // 重儿子的 keep 标记为 1,dfs 后不清空 cnt
    for (auto u : g[v])
	  if (u != p && u != bigChild)
	    for (int p = st[u]; p < ft[u]; p++)
		  cnt[ col[ ver[p] ] ]++;
    cnt[ col[v] ]++;
    // 现在 cnt[c] 就是以 v 为根的子树中包含的颜色为 c 的节点个数
    if (keep == 0)
        for (int p = st[v]; p < ft[v]; p++)
	      cnt[ col[ ver[p] ] ]--;
}

但是,为什么是 \(O(nlog\;n)\)?你知道 dsu 是 \(O(qlog\;n)\) 的,这段代码使用了相同的方法:将较小的合并到较大的,即启发式合并

(按难度从小到大排序)

600E - Lomsat gelral: heavy-light decomposition style : Link, easy style : Link. I think this is the easiest problem of this technique in CF and it's good to start coding with this problem.

570D - Tree Requests : 17961189 Thanks to Sora233; this problem is also good for start coding.

Sgu507 (SGU is unavailable, read the problem statements here) This problem is also good for the start.

246E - Blood Cousins Return : 15409328

208E - Blood Cousins : 16897324

IOI 2011, Race (See SaYami's comment below).

291E - Tree-String Problem : See bhargav104's comment below.

1009F - Dominant Indices : 40332812 Arpa-Style. Thanks to Tanmoy_Datta.

343D - Water Tree : 15063078 Note that problem is not easy and my code doesn't use this technique (dsu on tree), but AmirAz 's solution to this problem uses this technique : 14904379.

375D - Tree and Queries : 15449102 Again note that problem is not easy 😃).

716E - Digit Tree : 20776957 A hard problem. Also can be solved with centroid decomposition.

741D - Arpa’s letter-marked tree and Mehrdad’s Dokhtar-kosh paths : 22796438 A hard problem. You must be very familiar with Dsu on tree to solve it.

CF600E Lomsat gelral

方法一:递归处理轻儿子子树的添加与删除(202 ms)

#include <iostream>
using namespace std;
const int N = 1e5 + 7;
using LL = long long;
int sz[N], col[N], son[N], ct[N];
struct Edge {
    int v;
    Edge *nx;
    Edge(int _v=0, Edge *_n=nullptr): v(_v), nx(_n){}
} *hd[N];
void dfs1(int u, int fa) {
    sz[u] = 1;
    for (Edge *i = hd[u]; i; i = i->nx) {
        int v = i->v;
        if (v == fa) continue;
        dfs1(v, u);
        sz[u] += sz[v];
        if (!son[u] || sz[son[u]]<sz[v])
            son[u] = v;
    }
}
int hson, mx;
LL ans[N], sum;
void count(int u, int fa, int val) {
    ct[col[u]] += val;
    if (ct[col[u]] > mx) {
        mx = ct[col[u]];
        sum = col[u];
    } else if (ct[col[u]] == mx)
        sum += col[u];
    for (Edge *i = hd[u]; i; i = i->nx) {
        int v = i->v;
        if (v==fa || v==hson) continue;
        count(v, u, val);
    }
}
void dfs2(int u, int fa, bool inc) {
    for (Edge *i = hd[u]; i; i = i->nx) {
        int v = i->v;
        if (v==fa || v==son[u]) continue;
        dfs2(v, u, false);
    }
    if (son[u]) {
        dfs2(son[u], u, true);
        hson = son[u];
    }
    count(u, fa, 1);
    hson = 0, ans[u] = sum;
    if (!inc) {
        count(u, fa, -1);
        sum = mx = 0;
    }
}
int main() {
    int n;
    scanf("%d", &n);
    for (int i = 1; i <= n; ++i)
        scanf("%d", col + i);
    for (int i = 1, u, v; i < n; ++i) {
        scanf("%d%d", &u, &v);
        hd[u] = new Edge(v, hd[u]);
        hd[v] = new Edge(u, hd[v]);
    }
    dfs1(1, 0);
    dfs2(1, 0, false);
    for (int i = 1; i <= n; ++i)
        printf("%lld ", ans[i]);
}

方法2:dfs 序处理轻儿子子树的添加与删除(171 ms)

#include <iostream>
using namespace std;
const int N = 1e5 + 7;
using LL = long long;
int sz[N], col[N], son[N], ct[N], st[N], ed[N], dfn[N], cur;
struct Edge {
    int v;
    Edge *nx;
    Edge(int _v, Edge *_n): v(_v), nx(_n){}
} *hd[N];
void dfs1(int u, int fa) {
    sz[u] = 1, st[u] = ++cur, dfn[cur] = u;
    for (Edge *i = hd[u]; i; i = i->nx) {
        int v = i->v;
        if (v == fa) continue;
        dfs1(v, u);
        sz[u] += sz[v];
        if (!son[u] || sz[son[u]]<sz[v])
            son[u] = v;
    }
	ed[u] = cur;
}
int hson, mx;
LL ans[N], sum;
void add(int u) {
    if (++ct[col[u]] > mx) {
        mx = ct[col[u]];
        sum = col[u];
    } else if (ct[col[u]] == mx)
        sum += col[u];
}
void dfs2(int u, int fa, bool inc) {
    for (Edge *i = hd[u]; i; i = i->nx) {
        int v = i->v;
        if (v==fa || v==son[u]) continue;
        dfs2(v, u, false);
    }
    if (son[u]) dfs2(son[u], u, true);
	for (Edge *e = hd[u]; e; e = e->nx) {
		int v = e->v;
		if (v==fa || v==son[u]) continue;
		for (int i = st[v]; i <= ed[v]; ++i) {
			int vv = dfn[i];
			if (vv==fa || vv==hson) continue;
			add(vv);
		}
	}
    add(u), ans[u] = sum;
    if (!inc) {
        for (int i = st[u]; i <= ed[u]; ++i) {
			int v = dfn[i];
			if (v != fa) --ct[col[v]];
		}
        sum = mx = 0;
    }
}
int main() {
    int n;
    scanf("%d", &n);
    for (int i = 1; i <= n; ++i)
        scanf("%d", col + i);
    for (int i = 1, u, v; i < n; ++i) {
        scanf("%d%d", &u, &v);
        hd[u] = new Edge(v, hd[u]);
        hd[v] = new Edge(u, hd[v]);
    }
    dfs1(1, 0);
    dfs2(1, 0, true);
    for (int i = 1; i <= n; ++i)
        printf("%lld ", ans[i]);
}

注:使用数组的前向星速度会更快,但需要创建的数组会多一些,写起来有些麻烦。

posted @ 2024-03-31 11:00  飞花阁  阅读(57)  评论(0)    收藏  举报
//雪花飘落效果