splay

阅读前请先了解二叉搜索树。(有 左儿子 \(<=\) 自己 \(<=\) 右儿子 的性质的一棵二叉树) splay的用途即在于保持其平衡,从而实现一次操作均摊 \(O(\log n)\)


用途:

维护各种奇怪的操作。如支持区间翻转等。本来用途是当更强大的 set/multiset 使用。

核心操作: splay()

该操作可以将一个点转到根。

转的方式如下,其中 \(a\) 是要被转的节点:

前两种应该是好理解的,大家手模几组小的例子即可。我们重点分析,为何我们要先转 \(f\) 呢?

我们不妨构造一条如下的链:

这也太不平衡了!那先转父亲呢?

所以我们要特判这种情况,并先转父亲。

实现

有了 splay 操作,我们基本做完了。只需要在二叉搜索树的代码基础上稍作修改。

我觉得删除是最抽象的操作,所以细说一下。原理是先将要删的节点转到根,然后我们找到他的左子树里的最大值。
再把左子树里的这个最大值转到根。

把这个最大值转到根的最后一步之前树长这样:

再转一步就变成

然后我们只要把要删的点的父亲和右儿子连起来就可以了。

下面给出代码。

#include<cstdio>
using namespace std;
const int mn=1e5+5;
struct SPLAY
{
    int rev[mn],tp,ncnt;
    int alloc()
    {
        if(tp)return rev[tp--];
        return ++ncnt;
    }
    struct node
    {
        int ch[2],fa;
        int sz,v;
        node(int x=0){v=x;sz=1;fa=ch[0]=ch[1]=0;}
    }a[mn];
    int root;
    bool isr(int x){return (a[a[x].fa].ch[1]==x);}
    void upd(int x){a[x].sz=1+a[a[x].ch[0]].sz+a[a[x].ch[1]].sz;}
    int rnk(int x){return a[a[x].ch[0]].sz+1;}
    void rot(int x)
    {
        bool k=isr(x);
        int y=a[x].fa,z=a[y].fa,w=a[x].ch[!k];
        if(z)a[z].ch[isr(y)]=x;
        else root=x;
        a[x].fa=z;
        a[y].fa=x;
        a[x].ch[!k]=y;
        a[y].ch[k]=w;
        if(w)a[w].fa=y;
        upd(x);upd(y);
    }
    void splay(int x)
    {
        while(root!=x)
        {
            if(a[x].fa!=root)rot((isr(a[x].fa)^isr(x))?x:a[x].fa);
            rot(x);
        }
    }
    void ins(int v)
    {
        int x=alloc();
        a[x]=node(v);
        if(!root)return void(root=x);
        int t=root;
        while(a[t].ch[a[t].v<v])t=a[t].ch[a[t].v<v];
        a[t].ch[a[t].v<v]=x;
        a[x].fa=t;
        upd(t);
        splay(x);
    }
    void del(int v)
    {
        int t=root;
        while(a[t].v!=v)t=a[t].ch[a[t].v<v];
        splay(t);
        int o=a[t].ch[0];
        while(a[o].ch[1])o=a[o].ch[1];
        splay(o);
        a[o].ch[1]=a[t].ch[1];
        a[a[t].ch[1]].fa=o;
        rev[++tp]=t;
    }
    int rank(int v)
    {
        int t=root,res=0,lst=root;
        while(t)
        {
            lst=t;
            if(a[t].v<v)res+=rnk(t);
            t=a[t].ch[a[t].v<v];
        }
        splay(lst);
        return res;
    }
    int kth(int k)
    {
        ++k;
        int t=root;
        while(rnk(t)!=k)
        {
            perr("%d %d %d\n",t,rnk(t),k);
            if(rnk(t)>k)t=a[t].ch[0];
            else
            {
                k-=rnk(t);
                t=a[t].ch[1];
            }
        }
        splay(t);
        return a[t].v;
    }
    int pre(int v)
    {
        int t=root,res=0,lst=root;
        while(t)
        {
            lst=t;
            if(a[t].v<v)res=a[t].v;
            t=a[t].ch[a[t].v<v];
        }
        splay(lst);
        return res;
    }
    int nxt(int v)
    {
        int t=root,res=0,lst=root;
        while(t)
        {
            lst=t;
            if(a[t].v>v)res=a[t].v;
            t=a[t].ch[a[t].v<=v];//注意这里与别的地方不同。当a[t].v==v时我们应去右子节点。
        }
        splay(lst);
        return res;
    }
    SPLAY()
    {
        root=ncnt=tp=0;
        ins(0x3f3f3f3f);
        ins(-0x3f3f3f3f);
        a[0].sz=0;
    }
}spl;
int n;
int main()
{
    int op,x;
    scanf("%d",&n);
    for(int i=1;i<=n;i++)
    {
        scanf("%d %d",&op,&x);
        if(op==1)spl.ins(x);
        if(op==2)spl.del(x);
        if(op==3)printf("%d\n",spl.rank(x));
        if(op==4)printf("%d\n",spl.kth(x));
        if(op==5)printf("%d\n",spl.pre(x));
        if(op==6)printf("%d\n",spl.nxt(x));
    }
}
posted @ 2025-02-25 22:02  ikusiad  阅读(15)  评论(0)    收藏  举报