Atcoder ABC421F. Erase between X and Y 题解 splay tree
题目链接:https://atcoder.jp/contests/abc421/tasks/abc421_f
题目大意:
给你一个初始值包含一个 \(0\) 的序列,然后依次进行 \(Q\) 次操作,第 \(i\) 次操作可以是如下两种类型之一:
1 x
:将数字 \(i\) 插入到数字 \(x\) 的后面(保证执行此操作时序列中存在数字 \(x\))2 x y
:将数字 \(x\) 和 \(y\) 之间所有数字(不包括 \(x\) 和 \(y\))都删除,同时输出这一次操作删除的所有数字之和(保证执行此操作时序列中存在数字 \(x\) 和 \(y\))
解题思路:
基本上就是 splay tree 的基本操作了。
使用 splay tree 的中序遍历序列来表示这个序列。
每个节点维护一个 sum 信息,他表示以该结点为根的所有(节点对应的)数字之和。
初始时只有数字 \(0\) 对应的节点。
对于 1 x
操作:
先将 \(x\) splay 到根节点,设 \(x\) 的右儿子为 \(y\),则:
- 将 \(x\) 的右儿子变成 \(i\);
- 将 \(i\) 的左儿子变成 \(y\)。
或者:
- 将 \(i\) 的左儿子变成 \(x\);
- 将 \(i\) 的右儿子变成 \(y\);
- 将 \(x\) 的右儿子变为空(因为 \(y\) 移动到 \(i\) 的右儿子去了)
- 更新根节点为 \(i\)。
以上两种操作 二选一 执行即可,都可以保证中序遍历序列就是 \(x\) 后面插入 \(i\) 的效果。
对于 2 x y
操作:
先将 \(x\) splay 到根节点,再将 \(y\) splay 到是 \(x\) 的儿子节点。这也就是说,此时 \(x\) 是根节点,\(y\) 是 \(x\) 的左儿子节点或右儿子节点(这主要取决于序列里 \(y\) 在 \(x\) 的前面还是后面)。
当 \(y\) 是 \(x\) 的左儿子时:
- 设 \(z\) 是 \(y\) 的右儿子,则输出 \(z\) 维护的 sum(就是答案),然后删除以 \(z\) 为根的子树。
当 \(y\) 是 \(x\) 的右儿子时:可以按照上述逻辑一样分析,如下
- 设 \(z\) 是 \(y\) 的左儿子,则输出 \(z\) 维护的 sum(就是答案),然后删除以 \(z\) 为根的子树。
具体实现时,因为我一般用 \(0\) 表示空节点,而本题中数字是 \(0\) 是存在的,所以来一个 \(+1\) 的偏移,具体来说,就是令节点 \(x+1\) 来表示数字 \(x\) 的信息即可。
示例程序:
#include <bits/stdc++.h>
using namespace std;
const int maxn = 5e5 + 5;
struct Node {
int s[2], p, key;
long long sum;
void init(int _key, int _p) {
s[0] = s[1] = 0;
p = _p;
sum = key = _key;
}
} tr[maxn];
int root;
void push_up(int x) {
auto &u = tr[x], &l = tr[u.s[0]], &r = tr[u.s[1]];
u.sum = l.sum + r.sum + u.key;
}
void f_s(int p, int u, int k) {
tr[p].s[k] = u;
tr[u].p = p;
}
void rot(int x) {
int y = tr[x].p, z = tr[y].p;
int k = tr[y].s[1] == x;
f_s(z, x, tr[z].s[1] == y);
f_s(y, tr[x].s[k^1], k);
f_s(x, y, k^1);
push_up(y), push_up(x);
}
void splay(int x, int k) {
while (tr[x].p != k) {
int y = tr[x].p, z = tr[y].p;
if (z != k)
(tr[y].s[1]==x)^(tr[z].s[1]==y) ? rot(x) : rot(y);
rot(x);
}
if (!k) root = x;
}
int Q, op, x, y;
void test_dfs(int u) {
if (!u) return;
test_dfs(tr[u].s[0]);
if (u == root) printf("*");
printf("%d, ", tr[u].key);
test_dfs(tr[u].s[1]);
}
int main() {
tr[1].init(0, 0); root = 1;
scanf("%d", &Q);
for (int i = 1; i <= Q; i++) {
scanf("%d%d", &op, &x);
if (op == 1) {
splay(x+1, 0);
tr[i+1].init(i, 0);
f_s(i+1, x+1, 0);
f_s(i+1, tr[x+1].s[1], 1);
tr[x+1].s[1] = 0;
push_up(x+1);
push_up(i+1);
root = i+1;
}
else { // op == 2
scanf("%d", &y);
splay(x+1, 0);
splay(y+1, x+1);
long long res;
if (tr[x+1].s[0] == y+1) {
int p = tr[y+1].s[1];
res = tr[p].sum;
tr[y+1].s[1] = 0;
push_up(y+1);
push_up(x+1);
}
else {
int p = tr[y+1].s[0];
res = tr[p].sum;
tr[y+1].s[0] = 0;
push_up(y+1);
push_up(x+1);
}
printf("%lld\n", res);
}
}
return 0;
}