【数据结构】树上启发式合并
https://codeforces.com/contest/1923/problem/E
这一道题有三种不一样的写法,有一种是用启发式合并(跟树上没啥关系)去优化dp的做法。
int n, k;
int c[200005];
vector<int> G[200005];
ll ans;
map<int, int> cnt[200005];
void dfs (int u, int p) {
cnt[u][c[u]] = 1; // cnt[u][x]表示子树u内部,颜色为x的有效节点的数量
for (int v : G[u]) {
if (v == p) {
continue;
}
dfs (v, u);
if (cnt[u].size() < cnt[v].size()) {
// 启发式合并优化dp,保证一定是小的数组合并到大的里
swap (cnt[u], cnt[v]);
}
for (auto [x, y] : cnt[v]) {
ans += 1LL * cnt[u][x] * y; // 相当于把同为x的答案,从v子树合并到u的已合并部分中
cnt[u][x] += y;
}
cnt[u][c[u]] = 1; // 这道题比较特殊,退出节点v的时候回到了节点u,要把节点u重新堵上
cnt[v].clear();
}
}
void solve() {
RD (n);
RDN (c, n);
for (int i = 1; i <= n; ++i) {
G[i].clear();
cnt[i].clear();
}
for (int i = 1; i <= n - 1; ++i) {
int x, y;
RD (x, y);
G[x].push_back (y);
G[y].push_back (x);
}
ans = 0;
dfs (1, 1);
WT (ans);
}
https://codeforces.com/contest/375/problem/D
这道题也有用树上莫队的解法,这里再给一个树上启发式合并的解法。
树上启发式合并,类似轻重链剖分,先算出每个节点的重儿子,然后计算答案时先递归计算轻儿子的答案,标记clr为true,然后计算重儿子的答案,clr为false,然后把轻儿子的节点暴力插到重儿子里,最后把父节点自己加入。
int n;
vector<int> G[MAXN];
int siz[MAXN], mch[MAXN];
void dfs1(int u, int p) {
// 初始化siz,找到重儿子
siz[u] = 1, mch[u] = 0;
for (int &v : G[u]) {
if (v == p)
continue;
dfs1(v, u);
siz[u] += siz[v];
if (siz[mch[u]] < siz[v])
mch[u] = v;
}
}
由于树上启发式合并并不关心深度,所以没有必要维护深度。
void calc(int u, int p, int skip, int d) {
bit.Add(cnt[c[u]], -1);
cnt[c[u]] += d;
bit.Add(cnt[c[u]], 1);
for (int v : G[u]) {
if (v == p || v == skip)
continue;
calc(v, u, 0, d);
}
}
void dfs2(int u, int p, bool keep) {
// 启发式合并dfs(这个例子不好扩展)
// 先遍历所有的轻儿子,计算轻儿子“内部”的答案
for (int &v : G[u]) {
if (v == p || v == mch[u])
continue;
dfs2(v, u, false);
}
// 如果存在重儿子,则计算重儿子“内部的答案”,将重儿子的答案数组继承给u
if (mch[u])
dfs2(mch[u], u, true);
// 把节点u和其他所有的轻儿子,合并到重儿子的答案中,并计算贡献(合并的写法)
calc(u, p, mch[u], 1);
for (pii &q : Q[u]) {
int id = q.first, k = q.second;
ans[q.first] = bit.Sum(k, n);
}
// 如果不是重儿子,则消除当前节点的所有影响
if (!keep)
calc(u, p, 0, -1);
}
然后是主要的计算过程dfs2,dfs2优先进入所有的轻儿子,并且不keep轻儿子的答案,保持树状数组为空。然后进入重儿子计算并keep重儿子的结果。这里使用一个辅助函数calc,calc的修改值为1时表示向树状数组中添加,然后命令其在添加时skip掉重儿子。计算完毕后树状数组中存着这棵子树对应的状态,然后取出所有的询问进行回答。那之后,假如不keep树状数组,调用calc修改值为-1,并且不跳过重儿子,把整棵子树删除干净。
时间复杂度为 \(O(nlog^2n)\)
https://codeforces.com/gym/102832/problem/F
这里的查询要去重,所以要先计算再查询。而且要注意cache的命中。一次树遍历就统计出所有的信息,把常用的局部值放在数组的低维。
int n, k;
int a[MAXN];
vector<int> G[MAXN];
int siz[MAXN], mch[MAXN];
void dfs1(int u, int p) {
// 维护siz,计算重儿子
siz[u] = 1, mch[u] = 0;
for (int &v : G[u]) {
if (v == p)
continue;
dfs1(v, u);
siz[u] += siz[v];
if (siz[mch[u]] < siz[v])
mch[u] = v;
}
}
int cnt[1 << 20][17][2];
ll ans;
void calc1(int u, int p, int LCA) {
// 单纯计算u子树内的节点合并入cnt中时产生的贡献,并不真正把u节点中的子树合并入cnt中(因为会和“轻儿子内部的答案”重复计算)
int val = a[u] ^ a[LCA];
for (int k = 16; k >= 0; --k) {
int uk = (u >> k) & 1;
ans += (1LL << k) * cnt[val][k][uk ^ 1];
}
for (int &v : G[u]) {
if (v == p)
continue;
calc1(v, u, LCA);
}
}
void calc2(int u, int p) {
// 只是把u节点的子树合并入cnt中,并不计算他们的贡献
int val = a[u];
for (int k = 16; k >= 0; --k) {
int uk = (u >> k) & 1;
++cnt[val][k][uk];
}
for (int &v : G[u]) {
if (v == p)
continue;
calc2(v, u);
}
}
void calc3(int u, int p) {
// 消除u子树的影响,注意这里的memset不是对整个数组进行的
int val = a[u];
memset(cnt[val], 0, sizeof(cnt[val]));
for (int &v : G[u]) {
if (v == p)
continue;
calc3(v, u);
}
}
void dfs2(int u, int p, bool keep) {
// 计算轻儿子内部的贡献
for (int &v : G[u]) {
if (v == p || v == mch[u])
continue;
dfs2(v, u, false);
}
// 计算重儿子内部的贡献
if (mch[u])
dfs2(mch[u], u, true);
// 将u节点和重儿子合并,称为“已合并部分”
int val = a[u];
for (int k = 16; k >= 0; --k) {
int uk = (u >> k) & 1;
ans += (1LL << k) * cnt[0][k][uk ^ 1];
++cnt[val][k][uk];
}
// 将轻儿子逐个添加到已合并部分
for (int &v : G[u]) {
if (v == p || v == mch[u])
continue;
// 只计算轻儿子和已合并部分中间的贡献,不要把轻儿子加入(避免重复计算轻儿子内部的贡献)
calc1(v, u, u);
// 只把轻儿子加入,不计算他们内部和他们和已合并部分之间的贡献
calc2(v, u);
}
if (!keep) {
// 如果是轻儿子,就清空他的贡献
memset(cnt[val], 0, sizeof(cnt[val]));
for (int &v : G[u]) {
if (v == p)
continue;
calc3(v, u);
}
}
}
void solve() {
scanf("%d", &n);
for (int i = 1; i <= n; ++i)
scanf("%d", &a[i]);
for (int i = 1; i <= n; ++i)
G[i].clear();
for (int i = 1; i <= n - 1; ++i) {
int u, v;
scanf("%d%d", &u, &v);
G[u].push_back(v);
G[v].push_back(u);
}
dfs1(1, 0);
dfs2(1, 0, true);
printf("%lld\n", ans);
}