平衡树 Splay

//平衡树 Splay
//利用双旋 防止树退化成链
//时间比Treap慢log(n)
#include<iostream>
#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<cmath>
#include<algorithm>
using namespace std;
struct uio{
    int son[2],fa,num,tim,siz;//左右儿子,父节点,点值,出现次数,子树大小+自己大小 
}spl[100001];
int n,size,root;//size树的大小 
void clear(int x)//清零 
{
    spl[x].son[0]=spl[x].son[1]=0;
    spl[x].fa=0;
    spl[x].siz=0;
    spl[x].tim=0;
    spl[x].num=0;
}
int get(int x)//判断左右儿子 
{
    return spl[spl[x].fa].son[1]==x;
}
void update(int x)//更新子树大小 
{
    if(x)
    {
        spl[x].siz=spl[x].tim;
        if(spl[x].son[0])
            spl[x].siz+=spl[spl[x].son[0]].siz;
        if(spl[x].son[1])
            spl[x].siz+=spl[spl[x].son[1]].siz;
    }
}
void rotate(int x)
{
    int old=spl[x].fa,oldf=spl[old].fa,which=get(x);
    spl[old].son[which]=spl[x].son[which^1];
    spl[spl[old].son[which]].fa=old;
    spl[old].fa=x;
    spl[x].son[which^1]=old;
    spl[x].fa=oldf;
    if(oldf)
        spl[oldf].son[spl[oldf].son[1]==old]=x;
    update(old);
    update(x);
}
void splay(int x)
{
    for(int f;f=spl[x].fa;rotate(x))
        if(spl[f].fa)
            rotate((get(x)==get(f)? f : x));
    root=x;
}
void insert(int x)
{
    if(!root)//插入第一个节点 
    {
        size++;
        spl[size].son[0]=spl[size].son[1]=0;
        spl[size].fa=0;
        spl[size].num=x;
        spl[size].tim=1;
        spl[size].siz=1;
        root=size;
        return;
    }
    int now=root,f=0;
    while(1)
    {
        if(spl[now].num==x)//已有此节点 
        {
            spl[now].tim++;
            update(now);
            update(f);
            splay(now);
            break;
        }
        f=now;
        now=spl[now].son[x>spl[now].num];
        if(now==0)//无此节点
        {
            size++;
            spl[size].son[0]=spl[size].son[1]=0;
            spl[size].num=x;
            spl[size].tim=1;
            spl[size].siz=1;
            spl[size].fa=f;
            spl[f].son[x>spl[f].num]=size;
            update(f);
            splay(size);
            break;
        } 
    }
}
int get_no(int x)
{
    int now=root,ans=0;
    while(1)
    {
        if(x<spl[now].num)//在左子树 
            now=spl[now].son[0];
        else//在右子树 
        {
            ans+=(spl[now].son[0]? spl[spl[now].son[0]].siz : 0);//判断左子树是否为空 
            if(x==spl[now].num)
            {
                splay(now);
                return ans+1;
            }
            ans+=spl[now].tim;
            now=spl[now].son[1];
        }
    }
}
int get_num(int x)
{
    int now=root;
    while(1)
    {
        if(spl[now].son[0]&&x<=spl[spl[now].son[0]].siz)//在左子树 
            now=spl[now].son[0];
        else
        {
            int temp=(spl[now].son[0]? spl[spl[now].son[0]].siz : 0)+spl[now].tim;
            if(x<=temp)//结果就为此点 
                return spl[now].num;
            //在右子树 
            x-=temp;
            now=spl[now].son[1]; 
        }
    }
}
int pre()//此操作将夹在一次插入和一次删除之间 插入时已将此点旋转到根 因此前驱即为左子树最右子节点
{
    int now=spl[root].son[0];
    while(spl[now].son[1])
        now=spl[now].son[1];
    return now;
}
int nxt()//此操作将夹在一次插入和一次删除之间 插入时已将此点旋转到根 因此后继即为右子树最左子节点
{
    int now=spl[root].son[1];
    while(spl[now].son[0])
        now=spl[now].son[0];
    return now;
}
void del(int x)
{
    int useless=get_no(x);//把x旋转到根 
    if(spl[root].tim>1)//出现次数大于一 
    {
        spl[root].tim--;
        update(root);
        return;
    }
    if(!spl[root].son[0]&&!spl[root].son[1])//无子节点
    {
        clear(root);//直接清零 
        root=0;
        return;
    }
    if(!spl[root].son[0])//仅有右子节点
    {
        int oldroot=root;
        root=spl[root].son[1];//右子节点接到根上 
        spl[root].fa=0;
        clear(oldroot);
        return;
    }
    if(!spl[root].son[1])//仅有左子节点
    {
        int oldroot=root;
        root=spl[root].son[0];//左子节点接到根上 
        spl[root].fa=0;
        clear(oldroot);
        return;
    }
    //左右子节点均有 
    //找到x的前驱,把它旋转到根,将x的右子树接到新根上 
    int leftbig=pre(),oldroot=root;
    splay(leftbig);
    spl[spl[oldroot].son[1]].fa=root;
    spl[root].son[1]=spl[oldroot].son[1];
    clear(oldroot);
    update(root);
}
int main()
{
    scanf("%d",&n);
    for(int i=1;i<=n;i++)
    {
        int u,v;
        scanf("%d%d",&u,&v);
        if(u==1)
            insert(v);
        if(u==2)
            del(v);
        if(u==3)
            printf("%d\n",get_no(v));
        if(u==4)
            printf("%d\n",get_num(v));
        if(u==5)
        {
            insert(v);
            printf("%d\n",spl[pre()].num);
            del(v);
        }
        if(u==6)
        {
            insert(v);
            printf("%d\n",spl[nxt()].num);
            del(v);
        }
    }
    return 0;
}

 

posted @ 2018-07-08 18:40  radishえらい  阅读(139)  评论(0编辑  收藏  举报