【模板】普通平衡树

题目连接

【传送门】

splay代码

#include <bits/stdc++.h>
#define N 100005
#define inf 2147483647
using namespace std;
template <typename T>
inline void read(T &x) {
    x = 0; T fl = 1;
    char ch = 0;
    while (ch < '0' || ch > '9') {
        if (ch == '-') fl = -1;
        ch = getchar();
    }
    while (ch >= '0' && ch <= '9') {
        x = (x << 1) + (x << 3) + (ch ^ 48);
        ch = getchar();
    }
    x *= fl;
}
struct Splay {
    int rt, tot;
    struct node {
        int val, fa, cnt, sz, ch[2];
        void init(int x, int ft) {
            fa = ft;
            val = x;
            ch[1] = ch[0] = 0;
            sz = cnt = 1;
        }
    }tr[N];
    Splay() {
        memset(tr, 0, sizeof(tr));
        rt = tot = 0;
    }
    void pushup(int nod) {
        tr[nod].sz = tr[tr[nod].ch[0]].sz + tr[tr[nod].ch[1]].sz + tr[nod].cnt;
    }
    void rotate(int nod) {
        int fa = tr[nod].fa, gf = tr[fa].fa, k = tr[fa].ch[1] == nod;
        tr[gf].ch[tr[gf].ch[1] == fa] = nod;
        tr[nod].fa = gf;
        tr[fa].ch[k] = tr[nod].ch[k ^ 1];
        tr[tr[nod].ch[k ^ 1]].fa = fa;
        tr[nod].ch[k ^ 1] = fa;
        tr[fa].fa = nod;
        pushup(fa); 
        pushup(nod); 
    }
    void splay(int nod, int goal) {
        while (tr[nod].fa != goal) {
            int fa = tr[nod].fa, gf = tr[fa].fa;
            if (gf != goal) {
                if ((tr[gf].ch[0] == fa) ^ (tr[fa].ch[0] == nod)) rotate(nod);
                else rotate(fa);
            }
            rotate(nod);
        }
        if (goal == 0) rt = nod;
    }
    void find(int x)  {
        int u = rt;
        if (!u) return;
        while (tr[u].ch[x > tr[u].val] && x != tr[u].val) {
            u = tr[u].ch[x > tr[u].val];
        }
        splay(u, 0);
    }
    void ins(int x) {
        int u = rt, ft = 0;
        while (u && tr[u].val != x) {
            ft = u;
            u = tr[u].ch[x > tr[u].val];
        }
        if (u) tr[u].cnt ++;
        else {
            u = ++ tot;
            if (ft) tr[ft].ch[x > tr[ft].val] = u;
            tr[u].init(x, ft);
        }
        splay(u, 0);
    }
    int pre(int x)  {
        find(x);
        int u = rt;
        if (tr[u].val < x) return u;
        u = tr[u].ch[0];
        while (tr[u].ch[1]) u = tr[u].ch[1];
        return u;
    }
    int suc(int x) {
        find(x);
        int u = rt;
        if (tr[u].val > x) return u;
        u = tr[u].ch[1];
        while (tr[u].ch[0]) u = tr[u].ch[0];
        return u;
    }
    void del(int x) {
        int lst = pre(x), nxt = suc(x);
        splay(lst, 0); 
        splay(nxt, lst);
        int del = tr[nxt].ch[0];
        if (tr[del].cnt > 1) {
            tr[del].cnt --;
            splay(del, 0);
        }
        else tr[nxt].ch[0] = 0;
    }
    int kth(int x) {
        int u = rt;
        if (tr[u].sz < x) return 0;
        while (1) {
            int lc = tr[u].ch[0];
            if (x > tr[lc].sz + tr[u].cnt) {
                x -= tr[lc].sz + tr[u].cnt;
                u = tr[u].ch[1];
            }
            else {
                if (tr[lc].sz >= x) u = lc;
                else return tr[u].val;
            }
        }
    }
}sl;
int main() {
    int n; read(n);
    sl.ins(-inf);
    sl.ins(inf);
    for (int _t = 1; _t <= n; _t ++) {
        int opt, x; read(opt); read(x);
        if (opt == 1) sl.ins(x);
        if (opt == 2) sl.del(x);
        if (opt == 3) {
            sl.find(x);
            printf("%d\n", sl.tr[sl.tr[sl.rt].ch[0]].sz);
        }
        if (opt == 4) printf("%d\n", sl.kth(x + 1));
        if (opt == 5) printf("%d\n", sl.tr[sl.pre(x)].val);
        if (opt == 6) printf("%d\n", sl.tr[sl.suc(x)].val);
    }
    return 0;
}
posted @ 2019-03-23 09:49 chhokmah 阅读(...) 评论(...) 编辑 收藏