洛谷U644824 简单的平衡树问题 题解 splay tree 模板题

题目链接:https://www.luogu.com.cn/problem/U644824

示例程序:

#include <bits/stdc++.h>
using namespace std;
const int maxn = 4e5 + 5;

struct Node {
    int s[2], p, v, sz;

    void init(int _v, int _p) {
        s[0] = s[1] = 0;
        v = _v;
        p = _p;
        sz = 1;
    }

} tr[maxn];
int root, idx, n, Q, a[maxn];

void push_up(int u) {
    int l = tr[u].s[0], r = tr[u].s[1];
    tr[u].sz = tr[l].sz + tr[r].sz + 1;
}

void f_s(int p, int u, int k) {
    tr[p].s[k] = u;
    tr[u].p = p;
}

void rot(int x) {
    int y = tr[x].p, z = tr[y].p;
    int k = tr[y].s[1] == x;
    f_s(z, x, tr[z].s[1] == y);
    f_s(y, tr[x].s[k^1], k);
    f_s(x, y, k^1);
    push_up(y), push_up(x);
}

void splay(int x, int k) {
    while (tr[x].p != k) {
        int y = tr[x].p, z = tr[y].p;
        if (z != k)
            (tr[y].s[1]==x)^(tr[z].s[1]==y) ? rot(x) : rot(y);
        rot(x);
    }
    if (!k) root = x;
}

int get_k(int k) {
    int u = root;
    while (u) {
        if (tr[tr[u].s[0]].sz >= k) u = tr[u].s[0];
        else if (tr[tr[u].s[0]].sz + 1 == k) return tr[u].v;
        else k -= tr[tr[u].s[0]].sz + 1, u = tr[u].s[1];
    }
    return -1;
}

int get_rnk(int x) {
    int cnt = 1, u = root;
    while (u) {
        if (tr[u].v < x) {
            cnt += 1 + tr[ tr[u].s[0] ].sz;
            u = tr[u].s[1];
        }
        else
            u = tr[u].s[0];
    }
    return cnt;
}

void ins(int v) {
    int u = root, p = 0, k = 0;
    while (u) {
        tr[u].sz++;
        k = tr[u].v < v;
        p = u;
        u = tr[u].s[k];
    }
    u = ++idx;
    tr[u].init(v, p);
    if (p) tr[p].s[k] = u;
    splay(u, 0);
}

void del(int v) {
    int u = root;
    while (u) {
        if (tr[u].v == v) break;
        else u = tr[u].s[tr[u].v < v];
    }
    splay(u, 0);
    int l = tr[u].s[0], r = tr[u].s[1];
    if (!l || !r) {
        root = l + r;
        tr[root].p = 0;
    }
    else {
        while (tr[l].s[1]) l = tr[l].s[1];
        while (tr[r].s[0]) r = tr[r].s[0];
        splay(l, 0);
        splay(r, l);
        tr[r].s[0] = 0;
        push_up(r);
        push_up(l);
    }
}

int get_pre(int x) {
    int ans = -1, u = root;
    while (u) {
        if (tr[u].v < x)
            ans = tr[u].v, u = tr[u].s[1];
        else
            u = tr[u].s[0];
    }
    return ans;
}

int get_suc(int x) {
    int ans = -1, u = root;
    while (u) {
        if (tr[u].v > x)
            ans = tr[u].v, u = tr[u].s[0];
        else
            u = tr[u].s[1];
    }
    return ans;
}

int main() {
    scanf("%d%d", &n, &Q);
    for (int i = 1; i <= n; i++) {
        scanf("%d", a+i);
        ins(a[i]);
    }
    while (Q--) {
        int op, p, k, x;
        scanf("%d", &op);
        if (op == 1) {      // 1 p x
            scanf("%d%d", &p, &x);
            del(a[p]);
            a[p] = x;
            ins(a[p]);
        }
        else if (op == 2) { // 2 x
            scanf("%d", &x);
            printf("%d\n", get_rnk(x));
        }
        else if (op == 3) { // 3 k
            scanf("%d", &k);
            printf("%d\n", get_k(k));
        }
        else if (op == 4) { // 4 x
            scanf("%d", &x);
            printf("%d\n", get_pre(x));
        }
        else {              // 5 x
            scanf("%d", &x);
            printf("%d\n", get_suc(x));
        }
    }
    return 0;
}
posted @ 2025-12-26 04:04  quanjun  阅读(3)  评论(0)    收藏  举报