树上数颜色

树上数颜色

题目描述

给一棵根为 $1$ 的树,每次询问子树颜色种类数。

输入格式

第一行一个整数 $n$,表示树的结点数。

接下来 $n-1$ 行,每行一条边。

接下来一行 $n$ 个数,表示每个结点的颜色 $c[i]$。

接下来一个数 $m$,表示询问数。

接下来 $m$ 行表示询问的子树。

输出格式

对于每个询问,输出该子树颜色数。

样例 #1

样例输入 #1

5
1 2
1 3
2 4
2 5
1 2 2 3 3
5
1
2
3
4
5

样例输出 #1

3
2
1
1
1

提示

对于前三组数据,有 $1\leq m,c[i]\leq n\leq 100$。

而对于所有数据,有$1\leq m,c[i]\leq n\leq 10^5$。

 

解题思路

  今天学了 dsu on tree,大概记录一下原理。

  dsu on tree 主要用于解决静态(没有修改操作)的子树问题,名字上虽然有并查集(dsu),但实际上与并查集或启发式合并没有太大的关系,反而用到了树链剖分中重儿子的思想。对于节点 $u$,其重儿子就是 $u$ 的所有子节点中子树最大的子节点。如果有多个子树最大的子节点则任取其一。如果没有子节点,则无重儿子。对应的轻儿子就是除重儿子外剩余的所有子节点。

  大部分 dsu on tree 的做法都是通过 dfs 遍历每个节点 $u$ 并进行以下 $3$ 个步骤:

  1. 先对所有轻儿子进行 dfs 求答案,但不记录 dfs 过程中每个节点对 $u$ 的贡献。
  2. 对其重儿子进行 dfs 求答案,并记录 dfs 过程中每个节点对 $u$ 的贡献。
  3. 再次对所有轻儿子进行 dfs,记录 dfs 过程中每个节点对 $u$ 的贡献,从而求出节点 $u$ 的答案。

  在这道题目中,我们用 $\text{cnt}$ 数组来记录颜色的出现次数,并用 $s$ 来维护 $\text{cnt}$ 中出现了多少种不同的颜色。

  看上去第 $3$ 步完全可以合并到第 $1$ 步中,实际上这是不对的,这会导致 $\text{cnt}$ 数组被重复使用,即其他子树的记录会影响到当前子树的答案。而如果对每个节点都开一个 $\text{cnt}$ 数组,则会导致 $O(n^2)$ 的空间复杂度。

  dsu on tree 的时间复杂度是 $O(n \log{n})$,大概的解释是树中任意节点到根的路径中的轻边数量不超过 $\log{n}$ 条,意味着在 dfs 的过程中每个节点会被遍历 $O(\log{n})$ 次,因此所有节点被遍历的次数就是 $O(n \log{n})$。具体证明参见:树上启发式合并

  下面大概讲一下本题的代码实现。

  首先先通过 dfs 求出每个节点 $u$ 的子节点的子树大小,并选择子树大小最大的子节点作为 $u$ 的重儿子 $\text{son}[u]$,时间复杂度为 $O(n)$。

  然后再 dfs 进行 dsu on tree,对于每个节点都按照上面描述的 $3$ 个步骤进行。其中记录贡献的部分就是对 $u$ 的所有轻儿子执行另一个 dfs,并将每个节点的颜色记录到 $\text{cnt}$ 中。最后如果 $u$ 是其父节点的轻儿子,还需要清除 $\text{cnt}$ 中的所有记录。

  AC 代码如下,时间复杂度为 $O(n \log{n})$:

#include <bits/stdc++.h>
using namespace std;

typedef long long LL;

const int N = 1e5 + 5, M = N * 2;

int a[N];
int h[N], e[M], ne[M], idx;
int sz[N], son[N];
int ans[N], cnt[N], s;

void add(int u, int v) {
    e[idx] = v, ne[idx] = h[u], h[u] = idx++;
}

void dfs(int u, int p) {
    sz[u] = 1;
    for (int i = h[u]; i != -1; i = ne[i]) {
        int v = e[i];
        if (v == p) continue;
        dfs(v, u);
        sz[u] += sz[v];
        if (sz[v] > sz[son[u]]) son[u] = v;
    }
}

void modify(int u, int p, int c, int pson) {
    cnt[a[u]] += c;
    if (cnt[a[u]] == 1 && c == 1) s++;
    if (cnt[a[u]] == 0 && c == -1) s--;
    for (int i = h[u]; i != -1; i = ne[i]) {
        int v = e[i];
        if (v == p || v == pson) continue;
        modify(v, u, c, pson);
    }
}

void dfs(int u, int p, int keep) {
    for (int i = h[u]; i != -1; i = ne[i]) {
        int v = e[i];
        if (v == p || v == son[u]) continue;
        dfs(v, u, 0);
    }
    if (son[u]) dfs(son[u], u, 1);
    modify(u, p, 1, son[u]);
    ans[u] = s;
    if (!keep) modify(u, p, -1, -1);
}

int main() {
    int n, m;
    scanf("%d", &n);
    memset(h, -1, sizeof(h));
    for (int i = 0; i < n - 1; i++) {
        int u, v;
        scanf("%d %d", &u, &v);
        add(u, v), add(v, u);
    }
    for (int i = 1; i <= n; i++) {
        scanf("%d", a + i);
    }
    dfs(1, -1);
    dfs(1, -1, 0);
    scanf("%d", &m);
    while (m--) {
        int x;
        scanf("%d", &x);
        printf("%d\n", ans[x]);
    }
    
    return 0;
}

  这题还可以用启发式合并来做(貌似大部分 dsu on tree 的题都可以用启发式合并来实现)。

  对于每个节点 $u$ 都开一个 std::set<int> 用来存储子树 $u$ 中包含的不同颜色,表示为 $\text{st}[u]$。通过 dfs 求出 $u$ 的每个子节点 $v$ 的 $\text{st}[v]$,并将 $\text{st}[v]$ 的元素合并到 $\text{st}[u]$ 中。如果 $\text{st}[v]$ 的大小不超过 $\text{st}[u]$,则直接合并即可。否则需要将两个集合进行互换,再将 $\text{st}[v]$ 合并到 $\text{st}[u]$ 中。这就是启发式合并。

  用 std::set<int> 实现的时间复杂度为 $O(n \log^2{n})$,改为 std::unordered_set<int> 的话就是 $O(n \log{n})$,不过由于容易被卡哈希函数还是建议用 std::set<int>

  AC 代码如下,时间复杂度为 $O(n \log^2{n})$:

#include <bits/stdc++.h>
using namespace std;

typedef long long LL;

const int N = 1e5 + 5, M = N * 2;

int a[N];
int h[N], e[M], ne[M], idx;
set<int> st[N];
int ans[N];

void add(int u, int v) {
    e[idx] = v, ne[idx] = h[u], h[u] = idx++;
}

void dfs(int u, int p) {
    st[u].insert(a[u]);
    for (int i = h[u]; i != -1; i = ne[i]) {
        int v = e[i];
        if (v == p) continue;
        dfs(v, u);
        if (st[v].size() > st[u].size()) st[u].swap(st[v]);
        st[u].insert(st[v].begin(), st[v].end());
        st[v].clear();
    }
    ans[u] = st[u].size();
}

int main() {
    int n, m;
    scanf("%d", &n);
    memset(h, -1, sizeof(h));
    for (int i = 0; i < n - 1; i++) {
        int u, v;
        scanf("%d %d", &u, &v);
        add(u, v), add(v, u);
    }
    for (int i = 1; i <= n; i++) {
        scanf("%d", a + i);
    }
    dfs(1, -1);
    scanf("%d", &m);
    while (m--) {
        int x;
        scanf("%d", &x);
        printf("%d\n", ans[x]);
    }
    
    return 0;
}

 

参考资料

  树上启发式合并:https://oi-wiki.org/graph/dsu-on-tree/

posted @ 2024-04-03 22:47  onlyblues  阅读(9)  评论(0编辑  收藏  举报
Web Analytics