dsu on tree 小记
要求:
- 静态,离线
- 一个点的信息从子树合并过来
过程:首先我们需要知道,对于每个点,可以选一个子树的信息作为初值再合并,显然重儿子最好。先求重儿子,每次先处理完轻儿子(不保留),之后遍历重儿子的信息并保留,再将轻儿子的信息合并。时间复杂度 \(\mathcal{O}(n \log n)\)。
常数优化:考虑树剖一遍,然后按照树剖得到的 dfn 倒序枚举,那么如果当前点 \(u\) 是重链顶,其父亲的 dfn 一定不在其前一个,其贡献不保留,然后对于 \(u\) 的答案,直接继承后面的即可。
例题:
树上数颜色
实现 1:
namespace Loop1st {
int n, m, col[N], cnt[N], L[N], R[N], sz[N], son[N], rk[N], idx, ans[N], tot;
vector<int>e[N];
void dfs1(int u, int fa) {
L[u] = ++idx; rk[idx] = u; sz[u] = 1;
for (int v : e[u]) if (v != fa) {
dfs1(v, u);
if (sz[v] > sz[son[u]]) son[u] = v;
}
R[u] = idx;
}
void add(int u) {
cnt[col[u]]++;
if (cnt[col[u]] == 1) tot++;
}
void del(int u) {
cnt[col[u]]--;
if (!cnt[col[u]]) tot--;
}
void dfs2(int u, int fa, bool fl) {
for (int v : e[u]) if (v != fa && v != son[u]) {
dfs2(v, u, 0);
}
if (son[u]) dfs2(son[u], u, 1);
for (int v : e[u]) if (v != fa && v != son[u]) {
for (int i = L[v]; i <= R[v]; i++) add(rk[i]);
}
add(u);
ans[u] = tot;
if (!fl) {
for (int i = L[u]; i <= R[u]; i++) del(rk[i]);
}
}
void main() {
cin >> n;
for (int i = 1, u, v; i < n; i++) {
cin >> u >> v;
e[u].push_back(v);
e[v].push_back(u);
}
for (int i = 1; i <= n; i++) cin >> col[i];
dfs1(1, 0);
dfs2(1, 0, 0);
cin >> m;
while (m--) {
int u; cin >> u;
cout << ans[u] << '\n';
}
}
}
实现 2:
namespace Loop1st {
int n, m, col[N], cnt[N], L[N], R[N], sz[N], son[N], rk[N], idx, ans[N], tot;
vector<int>e[N];
void dfs1(int u, int fa) {
L[u] = ++idx; rk[idx] = u; sz[u] = 1;
for (int v : e[u]) if (v != fa) {
dfs1(v, u);
if (sz[v] > sz[son[u]]) son[u] = v;
}
R[u] = idx;
}
void add(int u) {
cnt[col[u]]++;
if (cnt[col[u]] == 1) tot++;
}
void del(int u) {
cnt[col[u]]--;
if (!cnt[col[u]]) tot--;
}
void dfs2(int u, int fa) {
for (int v : e[u]) if (v != fa && v != son[u]) {
dfs2(v, u);
for (int i = L[v]; i <= R[v]; i++) del(rk[i]);
}
if (son[u]) dfs2(son[u], u);
for (int v : e[u]) if (v != fa && v != son[u]) {
for (int i = L[v]; i <= R[v]; i++) add(rk[i]);
}
add(u);
ans[u] = tot;
}
void main() {
cin >> n;
for (int i = 1, u, v; i < n; i++) {
cin >> u >> v;
e[u].push_back(v);
e[v].push_back(u);
}
for (int i = 1; i <= n; i++) cin >> col[i];
dfs1(1, 0);
dfs2(1, 0);
cin >> m;
while (m--) {
int u; cin >> u;
cout << ans[u] << '\n';
}
}
}
优化版:
namespace Loop1st {
int n, m, col[N], cnt[N], fa[N], top[N], L[N], R[N], sz[N], son[N], rk[N], idx, ans[N], tot;
vector<int>e[N];
void dfs1(int u) {
sz[u] = 1;
for (int v : e[u]) if (v != fa[u]) {
fa[v] = u;
dfs1(v);
if (sz[v] > sz[son[u]]) son[u] = v;
}
}
void dfs2(int u, int tp) {
L[u] = ++idx; rk[idx] = u; top[u] = tp;
if (son[u]) dfs2(son[u], tp);
for (int v : e[u]) if (v != fa[u] && v != son[u]) dfs2(v, v);
R[u] = idx;
}
void add(int u) {
cnt[col[u]]++;
if (cnt[col[u]] == 1) tot++;
}
void del(int u) {
cnt[col[u]]--;
if (!cnt[col[u]]) tot--;
}
void main() {
cin >> n;
for (int i = 1, u, v; i < n; i++) {
cin >> u >> v;
e[u].push_back(v);
e[v].push_back(u);
}
for (int i = 1; i <= n; i++) cin >> col[i];
dfs1(1);
dfs2(1, 0);
for (int t = n; t; t--) {
int u = rk[t];
for (int v : e[u]) if (v != fa[u] && v != son[u]) {
for (int i = L[v]; i <= R[v]; i++) add(rk[i]);
}
add(u);
ans[u] = tot;
if (u == top[u]) {
for (int i = L[u]; i <= R[u]; i++) del(rk[i]);
}
}
cin >> m;
while (m--) {
int u; cin >> u;
cout << ans[u] << '\n';
}
}
}

浙公网安备 33010602011771号