浅谈两种树上启发式合并的写法
我将用两道例题来说明树上启发式合并的两种写法。
我先说明一种比较好写但是不是特别通用的写法。,因为这个空间复杂度令人堪忧。
树上数颜色
题目概述
给一棵根为 \(1\) 的树,每次询问子树颜色种类数。
数据范围:\(1\leq n,q\leq 10^5\)。
分析
挺模板的一道题目。
一般这种你代码想写得比较顺畅就直接用线段树合并就行了。
但是有一种方法比这个还要简单。
我们考虑是怎么合并上去的,我们肯定是小的合并到大的上面去,这样每次合并的总数不超过原本就在这个集合的数量,总而言之就是 \(\mathcal{O}(\log n)\)。
那么如果当前节点的存下来的桶需要被合并到另外一个上面但是我又想通过这个节点去访问怎么办呢?其实很简单记录一个 \(id\) 即可。
考虑到只道题目是询问子树的颜色种类数,直接用 \(set\)(一般都是这样的)。
代码
时间复杂度 \(\mathcal{O}(n\log n)\),空间复杂度 \(\mathcal{O}(n\log^2 n)\)。
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <stdlib.h>
#include <vector>
#include <set>
#define int long long
#define N 100005
using namespace std;
int id[N],n,c[N];
vector<int> g[N];
set<int> st[N];
int ans[N];
void dfs(int cur,int fa) {
for (auto i : g[cur])
if (i != fa) {
dfs(i,cur);
if (st[id[cur]].size() < st[id[i]].size()) swap(id[cur],id[i]);
for (auto j :st[id[i]]) st[id[cur]].insert(j);
}
st[id[cur]].insert(c[cur]);
ans[cur] = st[id[cur]].size();
}
signed main(){
cin >> n;
for (int i = 1;i <= n;i ++) id[i] = i;
for (int i = 1;i < n;i ++) {
int u,v;
scanf("%lld%lld",&u,&v);
g[u].push_back(v);
g[v].push_back(u);
}
for (int i = 1;i <= n;i ++) scanf("%lld",&c[i]);
dfs(1,0);
int m;
cin >> m;
for (int i = 1;i <= m;i ++) {
int x;
scanf("%lld",&x);
printf("%lld\n",ans[x]);
}
return 0;
}
P3201 [HNOI2009] 梦幻布丁
题目概述
\(n\) 个布丁摆成一行,进行 \(m\) 次操作。每次将某个颜色的布丁全部变成另一种颜色的,然后再询问当前一共有多少段颜色。
例如,颜色分别为 \(1,2,2,1\) 的四个布丁一共有 \(3\) 段颜色.
询问:
- 操作 \(1\):后有两个整数 \(x, y\),表示将颜色 \(x\) 的布丁全部变成颜色 \(y\)。
- 操作 \(2\),表示一次询问。
分析
这不是树上启发式合并了,这是启发式合并,没有树上。
考虑用 \(set_i\) 记录颜色 \(i\) 的各个区间 \(l,r\)。
你可以根据势能分析法,或者说感性理解,这个跟树上合并差不多,而且是像那种 Ynoi 的不会种类数单调不增的,可以直接小的合并到大的上面。
合并的时候判断能不能合并左右区间即可。
代码
时间复杂度 \(\mathcal{O}(n\log^2 n)\)。需要 \(lower\text{_}bound\) 查找。
#include <iostream>
#include <cstdio>
#include <cstring>
#include <stdlib.h>
#include <algorithm>
#include <vector>
#include <set>
#define int long long
#define N 100006
#define M 1000006
#define PII pair<int,int>
using namespace std;
int n,m,a[N];
set<PII> st[M];
int id[M];
signed main(){
cin >> n >> m;
for (int i = 1;i <= n;i ++) scanf("%lld",&a[i]);
for (int i = 1;i <= n;i ++) {
int l = i;
int r = i;
while(a[r] == a[l]) r ++;
st[a[i]].insert({l,r - 1});
i = r - 1;
}
for (int i = 1;i <= 1e6;i ++) id[i] = i;
int cnt = 0;
for (int i = 1;i <= 1e6;i ++) cnt += st[i].size();
// cout << cnt;
for (int op,x,y;m --;) {
scanf("%lld",&op);
if (op == 1) {
scanf("%lld%lld",&x,&y);
if (x == y) continue;
if (st[id[x]].empty()) continue;
if (st[id[y]].empty()) {
swap(id[x],id[y]);
continue;
}
// swap(x,y);//x<----y
int xx = id[x],yy = id[y];
if (st[id[x]].size() < st[id[y]].size()) swap(id[x],id[y]);
int fx = id[x], fy = id[y];
for (auto i : st[fx]) {
int l = i.first, r = i.second;
auto it = st[fy].lower_bound({l, 0});
if (it != st[fy].begin()) {
it--;
if (it->second == l - 1) {
l = it->first;
cnt--;
st[fy].erase(it);
}
}
it = st[fy].lower_bound({r + 1, 0});
if (it != st[fy].end() && it->first == r + 1) {
r = it->second;
cnt--;
st[fy].erase(it);
}
st[fy].insert({l, r});
}
st[fx].clear();
}
else printf("%lld\n",cnt);
}
return 0;
}
听说有 \(\mathcal{O}(n)\) 或者 \(\mathcal{O}(n\log n)\) 做法,还没看。
51Nod 1513 树上的回文
题目概述
给你一颗有 \(n\) 个点的树,根为 \(1\),每个点上面有一个字母,然后给你 \(q\) 次询问,询问的内容是 \(x,y\) 表示以 \(x\) 为根的子树中深度为 \(y\) 的节点(在 \(n\) 个节点的树上的深度)上面的字母全部取出来重新排列能否得到一个回文串。
分析
我们先考虑回文串的判断:
- 对于奇数情况,只能存在一种字母的个数是奇数。
- 对于偶数情况,所有种类的字母个数都应该是偶数。
我们既要存储 \(dep\) 还要存储某个字符出现的个数,好像可以用 \(set\),可以设 \(set_i\) 表示字符 \(i\) 的个数,我不知道空间行不行,我没打,当时没有想到这个所以我就没有这么写。
其实还可以开 \(26\) 颗线段树。
虽然好像可以用 \(set_i\) 写,但是我们还是讲一讲这种经典写法吧。
首先找出每个节点的重儿子(在儿子中所含节点数最多的)。
维护一个桶
要统计答案的时候,先遍历轻儿子,计算所有轻儿子的答案,然后把他们的桶给删掉。
然后遍历重儿子计算答案,不删桶。
最后再把轻儿子的贡献加上,计算当前节点的答案。
前面那个把桶删掉可以考虑写在最后。
代码
时间复杂度 \(\mathcal{O}(n|\Sigma|\log n)\),空间复杂度 \(\mathcal{O}(n|\Sigma|)\)。
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <stdlib.h>
#include <vector>
#include <set>
// #define int long long
#define N 500005
#define M 27
#define PII pair<int,int>
using namespace std;
int n,m;
vector<int> g[N];
int deps[N][30],a[N],dep[N];
vector<PII> ned[N];
bool ans[N];
int son[N],sz[N];
void dfs0(int cur,int fa) {
sz[cur] = 1;
dep[cur] = dep[fa] + 1;
for (auto i : g[cur]) {
dfs0(i,cur);
sz[cur] += sz[i];
if (sz[son[cur]] < sz[i]) son[cur] = i;
}
}
void add(int x,int y) {
deps[x][y] ++;
}
void del(int x,int y) {
deps[x][y] --;
}
void change(int cur,bool flag) {//add or del
if (flag) add(dep[cur],a[cur]);
else del(dep[cur],a[cur]);
for (auto i : g[cur]) change(i,flag);
}
void dfs(int cur,bool flag) {
for (auto i : g[cur])
if (i != son[cur]) dfs(i,0);
if (son[cur]) dfs(son[cur],1);
add(dep[cur],a[cur]);
for (auto i : g[cur])
if (i != son[cur]) change(i,1);
for (auto i : ned[cur]) {
int cnt = 0;
for (int j = 0;j < 26;j ++) cnt += deps[i.first][j] & 1;
ans[i.second] = (cnt <= 1);
}
if (!flag) change(cur,0);
}
signed main(){
cin >> n >> m;
for (int i = 2;i <= n;i ++) {
int x;
scanf("%d",&x);
g[x].push_back(i);
}
for (int i = 1;i <= n;i ++) {
char x;
cin >> x;
a[i] = x - 'a';
}
for (int i = 1;i <= m;i ++) {
int u,w;
scanf("%d%d",&u,&w);
ned[u].push_back({w,i});
}
dfs0(1,0);
dfs(1,0);
for (int i = 1;i <= m;i ++)
if (ans[i]) puts("Yes");
else puts("No");
return 0;
}
写了一下,用 \(set\) 的代码,空间会炸,MLE(不知道对不对,只过了样例):
#include <iostream>
#include <cstdio>
#include <cstring>
#include <stdlib.h>
#include <algorithm>
#include <vector>
#include <set>
#define int long long
#define N 500005
#define M 30
#define PII pair<int,int>
using namespace std;
multiset<int> st[M][N];
int n,m;
vector<int> g[N];
int dep[N],a[N];
vector<PII> ned[N];
bool ans[N];
void dfs0(int cur,int fa) {
dep[cur] = dep[fa] + 1;
for (auto i : g[cur]) dfs0(i,cur);
}
int id[N];
void dfs(int cur) {
int cnt1 = 0;
st[a[cur]][cur].insert(dep[cur]);
cnt1 ++;
for (auto i : g[cur]) {
dfs(i);
int cnt2 = 0;
for (int j = 0;j < 26;j ++) cnt2 += st[j][id[i]].size();
if (cnt1 < cnt2) swap(id[i],id[cur]);
for (int j = 0;j < 26;j ++)
for (auto k : st[j][id[i]])
st[j][id[cur]].insert(k);
cnt1 += cnt2;
}
for (auto i : ned[cur]) {
int cnt = 0;
for (int j = 0;j < 26;j ++) cnt += st[j][id[cur]].count(i.first) & 1;
ans[i.second] = (cnt <= 1);
}
}
signed main(){
cin >> n >> m;
for (int i = 2;i <= n;i ++) {
int fa;
scanf("%lld",&fa);
g[fa].push_back(i);
}
for (int i = 1;i <= n;i ++) {
char x;
cin >> x;
a[i] = x - 'a';
id[i] = i;
}
for (int i = 1;i <= m;i ++) {
int x,y;
scanf("%lld%lld",&x,&y);
ned[x].push_back({y,i});
}
dfs0(1,0);
dfs(1);
for (int i = 1;i <= m;i ++)
if (ans[i]) puts("Yes");
else puts("No");
return 0;
}

浙公网安备 33010602011771号