Atcoder ABC421F. Erase between X and Y 题解 splay tree

题目链接:https://atcoder.jp/contests/abc421/tasks/abc421_f

题目大意:

给你一个初始值包含一个 \(0\) 的序列,然后依次进行 \(Q\) 次操作,第 \(i\) 次操作可以是如下两种类型之一:

  • 1 x:将数字 \(i\) 插入到数字 \(x\) 的后面(保证执行此操作时序列中存在数字 \(x\)
  • 2 x y:将数字 \(x\)\(y\) 之间所有数字(不包括 \(x\)\(y\))都删除,同时输出这一次操作删除的所有数字之和(保证执行此操作时序列中存在数字 \(x\)\(y\)

解题思路:

基本上就是 splay tree 的基本操作了。

使用 splay tree 的中序遍历序列来表示这个序列。

每个节点维护一个 sum 信息,他表示以该结点为根的所有(节点对应的)数字之和。

初始时只有数字 \(0\) 对应的节点。

对于 1 x 操作:

先将 \(x\) splay 到根节点,设 \(x\) 的右儿子为 \(y\),则:

  • \(x\) 的右儿子变成 \(i\)
  • \(i\) 的左儿子变成 \(y\)

或者:

  • \(i\) 的左儿子变成 \(x\)
  • \(i\) 的右儿子变成 \(y\)
  • \(x\) 的右儿子变为空(因为 \(y\) 移动到 \(i\) 的右儿子去了)
  • 更新根节点为 \(i\)

以上两种操作 二选一 执行即可,都可以保证中序遍历序列就是 \(x\) 后面插入 \(i\) 的效果。

对于 2 x y 操作:

先将 \(x\) splay 到根节点,再将 \(y\) splay 到是 \(x\) 的儿子节点。这也就是说,此时 \(x\) 是根节点,\(y\)\(x\) 的左儿子节点或右儿子节点(这主要取决于序列里 \(y\)\(x\) 的前面还是后面)。

\(y\)\(x\) 的左儿子时:

  • \(z\)\(y\) 的右儿子,则输出 \(z\) 维护的 sum(就是答案),然后删除以 \(z\) 为根的子树。

\(y\)\(x\) 的右儿子时:可以按照上述逻辑一样分析,如下

  • \(z\)\(y\) 的左儿子,则输出 \(z\) 维护的 sum(就是答案),然后删除以 \(z\) 为根的子树。

具体实现时,因为我一般用 \(0\) 表示空节点,而本题中数字是 \(0\) 是存在的,所以来一个 \(+1\) 的偏移,具体来说,就是令节点 \(x+1\) 来表示数字 \(x\) 的信息即可。

示例程序:

#include <bits/stdc++.h>
using namespace std;
const int maxn = 5e5 + 5;

struct Node {
    int s[2], p, key;
    long long sum;

    void init(int _key, int _p) {
        s[0] = s[1] = 0;
        p = _p;
        sum = key = _key;
    }

} tr[maxn];
int root;

void push_up(int x) {
    auto &u = tr[x], &l = tr[u.s[0]], &r = tr[u.s[1]];
    u.sum = l.sum + r.sum + u.key;
}

void f_s(int p, int u, int k) {
    tr[p].s[k] = u;
    tr[u].p = p;
}

void rot(int x) {
    int y = tr[x].p, z = tr[y].p;
    int k = tr[y].s[1] == x;
    f_s(z, x, tr[z].s[1] == y);
    f_s(y, tr[x].s[k^1], k);
    f_s(x, y, k^1);
    push_up(y), push_up(x);
}

void splay(int x, int k) {
    while (tr[x].p != k) {
        int y = tr[x].p, z = tr[y].p;
        if (z != k)
            (tr[y].s[1]==x)^(tr[z].s[1]==y) ? rot(x) : rot(y);
        rot(x);
    }
    if (!k) root = x;
}

int Q, op, x, y;

void test_dfs(int u) {
    if (!u) return;
    test_dfs(tr[u].s[0]);
    if (u == root) printf("*");
    printf("%d, ", tr[u].key);
    test_dfs(tr[u].s[1]);
}

int main() {
    tr[1].init(0, 0); root = 1;
    scanf("%d", &Q);
    for (int i = 1; i <= Q; i++) {
        scanf("%d%d", &op, &x);
        if (op == 1) {
            splay(x+1, 0);
            tr[i+1].init(i, 0);
            f_s(i+1, x+1, 0);
            f_s(i+1, tr[x+1].s[1], 1);
            tr[x+1].s[1] = 0;
            push_up(x+1);
            push_up(i+1);
            root = i+1;
        }
        else { // op == 2
            scanf("%d", &y);
            splay(x+1, 0);
            splay(y+1, x+1);
            long long res;
            if (tr[x+1].s[0] == y+1) {
                int p = tr[y+1].s[1];
                res = tr[p].sum;
                tr[y+1].s[1] = 0;
                push_up(y+1);
                push_up(x+1);
            }
            else {
                int p = tr[y+1].s[0];
                res = tr[p].sum;
                tr[y+1].s[0] = 0;
                push_up(y+1);
                push_up(x+1);
            }
            printf("%lld\n", res);
        }
    }
    return 0;
}
posted @ 2025-09-01 18:49  quanjun  阅读(46)  评论(0)    收藏  举报