平衡二叉树板子(转载)

#include <iostream>
#include <cstdio>
#define MAXN 100010
using namespace std;
int root, tot;
struct Splay
{
    int fa;
    int ch[2];
    int val;
    int cnt;
    int size;
} t[MAXN];
void maintain(int x)
{
    t[x].size = t[t[x].ch[0]].size + t[t[x].ch[1]].size + t[x].cnt;
}
bool get(int x)
{
    return x == t[t[x].fa].ch[1];
}
void clear(int x)
{
    t[x].ch[0] = t[x].ch[1] = t[x].fa = t[x].val = t[x].cnt = t[x].size = 0;
}
void rotate(int x)
{
    int y = t[x].fa, z = t[y].fa, chk = get(x);
    t[y].ch[chk] = t[x].ch[chk ^ 1];
    if (t[x].ch[chk ^ 1])
        t[t[x].ch[chk ^ 1]].fa = y;
    t[x].ch[chk ^ 1] = y;
    t[y].fa = x;
    t[x].fa = z;
    if (z)
        t[z].ch[y == t[z].ch[1]] = x;
    maintain(y);
    maintain(x);
}
void splay(int x)
{
    for (int f = t[x].fa; f = t[x].fa, f; rotate(x))
        if (t[f].fa)
            rotate(get(x) == get(f) ? f : x);
    root = x;
}
void insert(int k)
{
    if (!root)
    {
        t[++tot].val = k;
        t[tot].cnt++;
        root = tot;
        maintain(root);
        return;
    }
    int cur = root, f = 0;
    while (1)
    {
        if (t[cur].val == k)
        {
            t[cur].cnt++;
            maintain(cur);
            maintain(f);
            splay(cur);
            break;
        }
        f = cur;
        cur = t[f].ch[t[f].val < k];
        if (!cur)
        {
            t[++tot].val = k;
            t[tot].cnt++;
            t[tot].fa = f;
            t[f].ch[t[f].val < k] = tot;
            maintain(tot);
            maintain(f);
            splay(tot);
            break;
        }
    }
}
int rnk(int k)
{
    int res = 0, cur = root;
    while (1)
    {
        if (k < t[cur].val)
            cur = t[cur].ch[0];
        else
        {
            res += t[t[cur].ch[0]].size;
            if (k == t[cur].val)
            {
                splay(cur);
                return res + 1;
            }
            res += t[cur].cnt;
            cur = t[cur].ch[1];
        }
    }
}
int kth(int k)
{
    int cur = root;
    while (1)
    {
        if (t[cur].ch[0] && k <= t[t[cur].ch[0]].size)
            cur = t[cur].ch[0];
        else
        {
            k -= t[t[cur].ch[0]].size + t[cur].cnt;
            if (k <= 0)
            {
                splay(cur);
                return t[cur].val;
            }
            cur = t[cur].ch[1];
        }
    }
}
int pre()
{
    int cur = t[root].ch[0];
    if (!cur)
        return cur;
    while (t[cur].ch[1])
        cur = t[cur].ch[1];
    splay(cur);
    return cur;
}
int nxt()
{
    int cur = t[root].ch[1];
    if (!cur)
        return cur;
    while (t[cur].ch[0])
        cur = t[cur].ch[0];
    splay(cur);
    return cur;
}
void del(int k)
{
    rnk(k);
    if (t[root].cnt > 1)
    {
        t[root].cnt--;
        maintain(root);
        return;
    }
    if (!t[root].ch[0] && !t[root].ch[1])
    {
        clear(root);
        root = 0;
        return;
    }
    if (!t[root].ch[0])
    {
        int cur = root;
        root = t[root].ch[1];
        t[root].fa = 0;
        clear(cur);
        return;
    }
    if (!t[root].ch[1])
    {
        int cur = root;
        root = t[root].ch[0];
        t[root].fa = 0;
        clear(cur);
        return;
    }
    int cur = root;
    int x = pre();
    t[t[cur].ch[1]].fa = root;
    t[root].ch[1] = t[cur].ch[1];
    clear(cur);
    maintain(root);
}
int n, op, x;
int main()
{
    scanf("%d", &n);
    while (n--)
    {
        scanf("%d%d", &op, &x);
        if (op == 1)
            insert(x);
        else if (op == 2)
            del(x);
        else if (op == 3)
            printf("%d\n", rnk(x));
        else if (op == 4)
            printf("%d\n", kth(x));
        else if (op == 5)
        {
            insert(x);
            printf("%d\n", t[pre()].val);
            del(x);
        }
        else
        {
            insert(x);
            printf("%d\n", t[nxt()].val);
            del(x);
        }
    }
    return 0;
}
posted @ 2022-10-06 20:07  梦歌  阅读(24)  评论(0)    收藏  举报