AcWing2476. 树套树 题解 线段树 套 平衡树

题目链接:https://www.acwing.com/problem/content/description/2478/

线段树 套 Splay Tree。

示例程序:

#include <bits/stdc++.h>
using namespace std;
const int maxn = 5e4 + 5, maxm = 6e6, inf = 1e8 + 1;

struct SplayTree {

    struct Node {
        int s[2], v, p, sz;

        Node() {}
        Node(int _v, int _p) {
            s[0] = s[1] = 0;
            v = _v;
            p = _p;
            sz = 1;
        }
    } tr[maxm];

    int idx;

    void push_up(int u) {
        int l = tr[u].s[0], r = tr[u].s[1];
        tr[u].sz = tr[l].sz + tr[r].sz + 1;
    }

    void f_s(int p, int u, int k) {
        tr[p].s[k] = u;
        tr[u].p = p;
    }

    void rot(int x) {
        int y = tr[x].p, z = tr[y].p;
        int k = tr[y].s[1] == x;
        f_s(z, x, tr[z].s[1] == y);
        f_s(y, tr[x].s[k^1], k);
        f_s(x, y, k^1);
        push_up(y), push_up(x);
    }

    void splay(int &root, int x, int k) {
        while (tr[x].p != k) {
            int y = tr[x].p, z = tr[y].p;
            if (z != k)
                (tr[y].s[1]==x)^(tr[z].s[1]==y) ? rot(x) : rot(y);
            rot(x);
        }
        if (!k) root = x;
    }

//    int get_k(int root, int k) {
//        int u = root;
//        while (u) {
//            if (tr[tr[u].s[0]].sz >= k) u = tr[u].s[0];
//            else if (tr[tr[u].s[0]].sz + 1 == k) return tr[u].v;
//            else k -= tr[tr[u].s[0]].sz + 1, u = tr[u].s[1];
//        }
//        return -1;
//    }

    int get_rnk(int root, int x) {
        int cnt = 0, u = root;
        while (u) {
            if (tr[u].v < x) {
                cnt += 1 + tr[ tr[u].s[0] ].sz;
                u = tr[u].s[1];
            }
            else
                u = tr[u].s[0];
        }
        return cnt - 1; // 减去一个哨兵节点
    }

    void ins(int &root, int v) {
        int u = root, p = 0, k = 0;
        while (u) {
            tr[u].sz++;
            k = tr[u].v < v;
            p = u;
            u = tr[u].s[k];
        }
        u = ++idx;
        tr[u] = Node(v, p);
        if (p) tr[p].s[k] = u;
        splay(root, u, 0);
    }

    void del(int &root, int v) {
        int u = root;
        while (u) {
            if (tr[u].v == v) break;
            else u = tr[u].s[tr[u].v < v];
        }
        assert(tr[u].v == v);
        splay(root, u, 0);
        int l = tr[u].s[0], r = tr[u].s[1];
        if (!l || !r) {
            root = l + r;
            tr[root].p = 0;
        }
        else {
            while (tr[l].s[1]) l = tr[l].s[1];
            while (tr[r].s[0]) r = tr[r].s[0];
            splay(root, l, 0);
            splay(root, r, l);
            tr[r].s[0] = 0;
            push_up(r);
            push_up(l);
        }
    }

    int get_pre(int root, int x) {
        int ans = -inf, u = root;
        while (u) {
            if (tr[u].v < x)
                ans = tr[u].v, u = tr[u].s[1];
            else
                u = tr[u].s[0];
        }
        return ans;
    }

    int get_suc(int root, int x) {
        int ans = inf, u = root;
        while (u) {
            if (tr[u].v > x)
                ans = tr[u].v, u = tr[u].s[0];
            else
                u = tr[u].s[1];
        }
        return ans;
    }

} splay_t;

int n, m, a[maxn];

struct SegmentTree {
    int tr[maxn<<2];

    #define lson l, mid, u<<1
    #define rson mid+1, r, u<<1|1
    void build(int l, int r, int u) {
        splay_t.ins(tr[u], -inf);
        splay_t.ins(tr[u], inf);
        for (int i = l; i <= r; i++)
            splay_t.ins(tr[u], a[i]);
        if (l == r) return;
        int mid = l + r >> 1;
        build(lson);
        build(rson);
    }

    // 查询 x 在 a[L..R] 中的排名
    int get_rnk(int L, int R, int x, int l, int r, int u) {
        if (L <= l && r <= R) {
            int tmp = splay_t.get_rnk(tr[u], x);
            return splay_t.get_rnk(tr[u], x);
        }
        int res = 0, mid = l + r >> 1;
        if (L <= mid) res += get_rnk(L, R, x, lson);
        if (R > mid) res += get_rnk(L, R, x, rson);
        return res;
    }

    // 删除 x,插入 y
    void update(int p, int x, int y, int l, int r, int u) {
        splay_t.del(tr[u], x);
        splay_t.ins(tr[u], y);
        if (l == r) return;
        int mid = l + r >> 1;
        (p <= mid) ? update(p, x, y, lson) : update(p, x, y, rson);
    }

    // 查询前驱
    int get_pre(int L, int R, int x, int l, int r, int u) {
        if (L <= l && r <= R)
            return splay_t.get_pre(tr[u], x);
        int res = -inf, mid = l + r >> 1;
        if (L <= mid) res = max(res, get_pre(L, R, x, lson));
        if (R > mid) res = max(res, get_pre(L, R, x, rson));
        return res;
    }

    // 查询后继
    int get_suc(int L, int R, int x, int l, int r, int u) {
        if (L <= l && r <= R)
            return splay_t.get_suc(tr[u], x);
        int res = inf, mid = l + r >> 1;
        if (L <= mid) res = min(res, get_suc(L, R, x, lson));
        if (R > mid) res = min(res, get_suc(L, R, x, rson));
        return res;
    }

} seg_t;

int main() {
    scanf("%d%d", &n, &m);
    for (int i = 1; i <= n; i++)
        scanf("%d", a+i);
    seg_t.build(1, n, 1);
    while (m--) {
        int op, l, r, x, p, k;
        scanf("%d", &op);
        if (op == 1) {  // 1 l r x
            scanf("%d%d%d", &l, &r, &x);
            int ans = seg_t.get_rnk(l, r, x, 1, n, 1) + 1;
            printf("%d\n", ans);
        }
        else if (op == 2) { // 2 l r k
            scanf("%d%d%d", &l, &r, &k);
            int L = 0, R = inf, res = -1;
            while (L <= R) {
                int mid = L + R >> 1;
                int cnt = seg_t.get_rnk(l, r, mid, 1, n, 1);
                if (cnt <= k-1)
                    res = mid, L = mid + 1;
                else if (cnt > k-1)
                    R = mid - 1;
            }
            printf("%d\n", res);
        }
        else if (op == 3) { // 3 p x
            scanf("%d%d", &p, &x);
            seg_t.update(p, a[p], x, 1, n, 1);
            a[p] = x;
        }
        else if (op == 4) { // 4 l r x
            scanf("%d%d%d", &l, &r, &x);
            int ans = seg_t.get_pre(l, r, x, 1, n, 1);
            printf("%d\n", ans);
        }
        else {  // 5 l r x
            scanf("%d%d%d", &l, &r, &x);
            int ans = seg_t.get_suc(l, r, x, 1, n, 1);
            printf("%d\n", ans);
        }
    }
    return 0;
}
posted @ 2025-12-31 16:42  quanjun  阅读(1)  评论(0)    收藏  举报