Splay随手记

Splay(伸展树)

注意:Splay只有初始的时候满足BST的结构,翻转之后即不满足,但中序遍历依旧不变。

1.旋转方式:同Treap

2.核心:每操作一个节点,均将该节点旋转到树根

意义:每次操作平均时间复杂度是 \(O(logn)\) 的。(可严格证明:可能用势能分析?

3.核心操作实现(Splay函数)

Splay(x,k):将点x旋转到点k下面

①当其形成一条直线的时候

先转y,再转x(把y向上转,然后把x向上转)

每转一次,x会向上走两格

②当其存在拐点的时候

转x,再转x(就是把x转两次)

只有这样转才能保证 \(O(log n)\) 的复杂度!!!不可以瞎转!!!

写的时候就是一个迭代(递归),不停向上转即可。

4.其余操作

①插入

一般情况:把x插到叶节点,然后选择到根即可

特殊:将一个序列插到y的后面(保证中序遍历在y的后面): 找到y的后继z,将y转到根(Splay(y,0)),将z转到y的下面(Splay(z,y))-> z此刻左子树为空。直接把序列插到z的左子树即可。

②删除

删除一段:删除[l,r],找到前驱l-1,后继r+1。先将l转到根节点(Splay(l-1,0)),再把r+1转到l-1下面(Splay(r+1,l-1)),此时r+1的左子树即为所要删除的段。直接置为空即可。

③其他操作

与其他平衡树无区别

5.Splay如何去维护信息

找第k个数->维护size(每棵子树大小)

翻转区间->维护lazytag(近似于线段树的懒标记)

操作:

pushup():维护值:放到旋转函数的最后

root->size=root->right->size+root->left->size+1

pushdown():下传懒标记:所有递归之前都要下传

swap(root->left,root->right)
下传标记
清空当前标记

6.模版题:文艺平衡树

[link](P3391 【模板】文艺平衡树 - 洛谷)

struct Splay{
    struct NODE{
        int s[2];//左右二字
        int p;//父节点
        int v;//编号
        int flag;//懒标记:有没有翻转
        int size;//大小

        void init(int _v,int _p)
        {
            v=_v,p=_p;
            size=1;
        }
    }tr[N];

    int root=0,idx=0;
    //根节点 动态分配空间

    void pushup(int x)
    {
        tr[x].size=tr[tr[x].s[0]].size+tr[tr[x].s[1]].size+1;
    }//维护值

    void pushdown(int x)
    {
        if(tr[x].flag)
        {
            swap(tr[x].s[0],tr[x].s[1]);//交换左右儿子

            tr[tr[x].s[0]].flag^=1;
            tr[tr[x].s[1]].flag^=1;//下传标记

            tr[x].flag=0;//当前标记清空
        }
    }//下传标记

    void rotate(int x)
    {
        int y=tr[x].p,z=tr[y].p;//找到祖先
        int k=tr[y].s[1]==x;//k表示x是y的左儿子还是右儿子:k=0->x是左儿子

        tr[z].s[tr[z].s[1]==y]=x,tr[x].p=z;
        tr[y].s[k]=tr[x].s[k^1],tr[tr[x].s[k^1]].p=y;
        tr[x].s[k^1]=y,tr[y].p=x;

        pushup(y),pushup(x);
        //先更新y,再更新x:因为y是x的儿子
        //z的信息不用变
    }//把左右旋合在一起写

    void splay(int x,int k)//核心函数
    {
        while(tr[x].p!=k)
        {
            int y=tr[x].p,z=tr[y].p;
            if(z!=k)
                if((tr[y].s[1]==x)^(tr[z].s[1]==y))
                    rotate(x);//折线关系
                else    rotate(y);//直线关系
                rotate(x);
          //  cout<<tr[x].p<<" "<<k<<" ";
        }
        if(!k)  root=x;//如果是的话更新根
    }//保证了Splay的时间复杂度

    void insert(int v)
    {
        int u=root,p=0;
        while(u)    p=u,u=tr[u].s[v>tr[u].v];//快速判断在左/右儿子
        u=++idx;
        if(p)   tr[p].s[v>tr[p].v]=u;//如果有父亲,更新
        tr[u].init(v,p);//初始化新点
        splay(u,0);//为了保证log(n)的时间复杂度,必须转到根
    }//插入

    int get_k(int k)
    {
        int u=root;
        while(true)
        {
            pushdown(u);
            if(tr[tr[u].s[0]].size>=k)    u=tr[u].s[0];//在左子树里找
            else if(tr[tr[u].s[0]].size+1==k)   return u;//这里就是答案
            else k-=tr[tr[u].s[0]].size+1,u=tr[u].s[1];//在右子树里面找
        }
        return -1;//找不到返回-1
    }//找第k个数

    void output(int u)
    {
        pushdown(u);
        if(tr[u].s[0])  output(tr[u].s[0]);//先输出左儿子
        if(tr[u].v>=1 && tr[u].v<=n)    cout<<tr[u].v<<" ";//输出,判断是不是哨兵
        if(tr[u].s[1])  output(tr[u].s[1]);
    }//输出中序遍历

}Tr;

7.简单应用

①:郁闷的出纳员

link

#include<bits/stdc++.h>
#define int long long

using namespace std;

const int N=1e5+100,INF=1e9;

int n,m;

struct Spaly{
    struct NODE{
        int s[2],p,v;
        int size;

        void init(int _v,int _p)
        {
            v=_v,p=_p;
            size=1;
        }
    }tr[N];

    int root=0,idx=0,delta=0;

    void pushup(int x)
    {
        tr[x].size=tr[tr[x].s[0]].size+tr[tr[x].s[1]].size+1;
    }

    void rotate(int x)
    {
        int y=tr[x].p,z=tr[y].p;
        int k=tr[y].s[1]==x;
        tr[z].s[tr[z].s[1]==y]=x,tr[x].p=z;
        tr[y].s[k]=tr[x].s[k^1],tr[tr[x].s[k^1]].p=y;
        tr[x].s[k^1]=y,tr[y].p=x;
        pushup(y),pushup(x);
    }

    void splay(int x,int k)
    {
        while(tr[x].p!=k)
        {
            int y=tr[x].p,z=tr[y].p;
            if(z!=k)
                if((tr[y].s[1]==x)^(tr[z].s[1]==y)) rotate(x);
                else    rotate(y);
            rotate(x);
        }
        if(!k)  root=x;
    }

    int insert(int v)
    {
        int u=root,p=0;
        while(u) p=u,u=tr[u].s[v>tr[u].v];
        u=++idx;
        if(p)   tr[p].s[v>tr[p].v]=u;
        tr[u].init(v,p);
        splay(u,0);
        return u;
    }

    int get(int v)
    {
        int u=root,res;
        while(u)
        {
            if(tr[u].v>=v)  res=u,u=tr[u].s[0];
            else u=tr[u].s[1];
        }
        return res;
    }

    int get_k(int k)
    {
        int u=root;
        while(u)
        {
            if(tr[tr[u].s[0]].size>=k)  u=tr[u].s[0];
            else if(tr[tr[u].s[0]].size+1==k)   return tr[u].v;
            else    k-=tr[tr[u].s[0]].size+1,u=tr[u].s[1];
        }
        return -1;
    }
}Tr;

signed main()
{
    cin>>n>>m;
    int L=Tr.insert(-INF),R=Tr.insert(INF);//建两个哨兵

    int tot=0;
    while(n--)
    {
        char op;
        int k;
        cin>>op>>k;
        if(op=='I')
            if(k>=m)    k-=Tr.delta,Tr.insert(k),tot++;
        if(op=='A') Tr.delta+=k;
        if(op=='S')
        {
            Tr.delta-=k;
            R=Tr.get(m-Tr.delta);
            Tr.splay(R,0),Tr.splay(L,R);
            Tr.tr[L].s[1]=0;
            Tr.pushup(L),Tr.pushup(R);
        }
        if(op=='F')
        {
            if(Tr.tr[Tr.root].size-2<k) cout<<-1<<endl;
            else    cout<<Tr.get_k(Tr.tr[Tr.root].size-k)+Tr.delta<<endl;
        }
    }

    cout<<tot-(Tr.tr[Tr.root].size-2)<<endl;

    return 0;
}

②:永无乡

link

本题涉及到了Splay的合并,强调Splay不可以直接做合并,要用启发式合并,用Splay维护集合大小,合并集合。(启发式合并理论:把节点个数少的合并到节点个数多的,时间复杂度为 \(O(nlog n)\)

关于合并暴力即可。每合并一个元素的时间复杂度为 \(O(log n)\) (插入),启发式合并的时间复杂度是 \(O(nlog n)\) ,总的操作时间复杂度就是 \(O(n log ^2 n)\)

思路很简单,代码很难调。。。

#include<bits/stdc++.h>
#define int long long

using namespace std;

const int N=5e5+100;
//原本的大小是1e5,但是有插入,所以要加上m的大小

int n,m;

struct Splay{
    struct NODE{
        int s[2],p,v,id;
        int size;

        void init(int _v,int _id,int _p)
        {
            v=_v,id=_id,p=_p;
            size=1;
        }
    }tr[N];

    int root[N],idx;

    void pushup(int x)
    {
        tr[x].size=tr[tr[x].s[0]].size+tr[tr[x].s[1]].size+1;
    }

    void rotate(int x)
    {
        int y=tr[x].p,z=tr[y].p;
        int k=tr[y].s[1]==x;
        tr[z].s[tr[z].s[1]==y]=x,tr[x].p=z;
        tr[y].s[k]=tr[x].s[k^1],tr[tr[x].s[k^1]].p=y;
        tr[x].s[k^1]=y,tr[y].p=x;
        pushup(y),pushup(x);
    }

    void splay(int x,int k,int b)
    {
        while(tr[x].p!=k)
        {
            int y=tr[x].p,z=tr[y].p;
            if(z!=k)
                if((tr[y].s[1]==x)^(tr[z].s[1]==y)) rotate(x);
                else    rotate(y);
            rotate(x);
        }
        if(!k)  root[b]=x;
    }

    void insert(int v,int id,int b)
    {
        int u=root[b],p=0;
        while(u) p=u,u=tr[u].s[v>tr[u].v];
        u=++idx;
        if(p)   tr[p].s[v>tr[p].v]=u;
        tr[u].init(v,id,p);
        splay(u,0,b);
    }

    int get_k(int k,int b)
    {
        int u=root[b];
        while(u)
        {
            if(tr[tr[u].s[0]].size>=k)  u=tr[u].s[0];
            else if(tr[tr[u].s[0]].size+1==k)
                return splay(u,0,b),tr[u].id;
            else k-=tr[tr[u].s[0]].size+1,u=tr[u].s[1];
        }
        return -1;
    }

    void merge(int u,int b)
    {
        if(tr[u].s[0])  merge(tr[u].s[0],b);
        if(tr[u].s[1])  merge(tr[u].s[1],b);
        insert(tr[u].v,tr[u].id,b);
    }
}Tr;

int f[N];
int find(int x)
{
    if(f[x]!=x) f[x]=find(f[x]);
    return f[x];
}

signed main()
{
    cin>>n>>m;
    for(int i=1;i<=n;i++)
    {
        f[i]=Tr.root[i]=i;
        int v;
        cin>>v;
        Tr.tr[i].init(v,i,0);
    }
    Tr.idx=n;
    while(m--)
    {
        int a,b;
        cin>>a>>b;
        a=find(a),b=find(b);
        if(a!=b)
        {
            if(Tr.tr[Tr.root[a]].size>Tr.tr[Tr.root[b]].size) swap(a,b);
            Tr.merge(Tr.root[a],b);
            f[a]=b;
        }
    }

    cin>>m;
    while(m--)
    {
        char op;
        int a,b;
        cin>>op>>a>>b;
        if(op=='B')
        {
            a=find(a),b=find(b);
            if(a!=b)
            {
                if(Tr.tr[Tr.root[a]].size>Tr.tr[Tr.root[b]].size)    swap(a,b);
                Tr.merge(Tr.root[a],b);
                f[a]=b;
            }
        }
        else
        {
            a=find(a);
            if(Tr.tr[Tr.root[a]].size<b)    cout<<-1<<endl;
            else    cout<<Tr.get_k(b,a)<<endl;
        }
    }

    return 0;
}

③:维护数列

link

内存回收:

把删掉的点再利用:建一个数组(回收站),把被删的点记录下来(其实是记录可用下标),要用的时候从数组里面找一个即可。提高空间利用率。

#include<bits/stdc++.h>
#define int long long

using namespace std;

const int N=5e5+100,INF=1e18;

int n,m;
int w[N];

struct Splay{
    struct NODE{
        int s[2],p,v;
        int rev,same;
        int size,sum,ms,ls,rs;

        void init(int _v,int _p)
        {
            s[0]=s[1]=0,p=_p,v=_v;
            rev=same=0;
            size=1,sum=ms=v;
            ls=rs=max(v,0ll);
        }
    }tr[N];

    int root;
    int nodes[N],tt;//回收站

    void init_nodes()
    {
        for(int i=1;i<N;i++)
            nodes[++tt]=i;
    }

    void pushup(int x)
    {
        auto &u=tr[x],&l=tr[u.s[0]],&r=tr[u.s[1]];
        u.size=l.size+r.size+1;
        u.sum=l.sum+r.sum+u.v;
        u.ls=max(l.ls,l.sum+u.v+r.ls);
        u.rs=max(r.rs,r.sum+u.v+l.rs);
        u.ms=max({l.ms,r.ms,l.rs+u.v+r.ls});
    }

    void pushdown(int x)
    {
        auto &u=tr[x],&l=tr[u.s[0]],&r=tr[u.s[1]];
        if(u.same)
        {
            u.same=u.rev=0;
            if(u.s[0])  l.same=1,l.v=u.v,l.sum=l.v*l.size;
            if(u.s[1])  r.same=1,r.v=u.v,r.sum=r.v*r.size;
            if(u.v>0)
            {
                if(u.s[0])  l.ms=l.ls=l.rs=l.sum;
                if(u.s[1])  r.ms=r.ls=r.rs=r.sum;
            }
            else
            {
                if(u.s[0])  l.ms=l.v,l.ls=l.rs=0;
                if(u.s[1])  r.ms=r.v,r.ls=r.rs=0;
            }
        }
        else if(u.rev)
        {
            u.rev=0,l.rev^=1,r.rev^=1;
            swap(l.ls,l.rs),swap(r.ls,r.rs);
            swap(l.s[0],l.s[1]),swap(r.s[0],r.s[1]);
        }
    }

    void rotate(int x)
    {
        int y=tr[x].p,z=tr[y].p;
        int k=tr[y].s[1]==x;
        tr[z].s[tr[z].s[1]==y]=x,tr[x].p=z;
        tr[y].s[k]=tr[x].s[k^1],tr[tr[x].s[k^1]].p=y;
        tr[x].s[k^1]=y,tr[y].p=x;
        pushup(y),pushup(x);
    }

    void splay(int x,int k)
    {
        while(tr[x].p!=k)
        {
            int y=tr[x].p,z=tr[y].p;
            if(z!=k)
                if((tr[y].s[1]==x)^(tr[z].s[1]==y)) rotate(x);
                else    rotate(y);
            rotate(x);
        }
        if(!k)  root=x;
    }

    int get_k(int k)
    {
        int u=root;
        while(u)
        {
            pushdown(u);
            if(tr[tr[u].s[0]].size>=k)  u=tr[u].s[0];
            else if(tr[tr[u].s[0]].size+1==k)   return u;
            else k-=tr[tr[u].s[0]].size+1,u=tr[u].s[1];
        }
        return -1;
    }

    int build(int p,int l,int r)
    {
        int mid=l+r>>1;
        int u=nodes[tt--];
        tr[u].init(w[mid],p);
        if(l<mid)   tr[u].s[0]=build(u,l,mid-1);
        if(mid<r)   tr[u].s[1]=build(u,mid+1,r);
        pushup(u);
        return u;
    }

    void dfs(int u)
    {
        if(tr[u].s[0])  dfs(tr[u].s[0]);
        if(tr[u].s[1])  dfs(tr[u].s[1]);
        nodes[++tt]=u;
    }
}Tr;

signed main()
{
    Tr.init_nodes();
    cin>>n>>m;
    Tr.tr[0].ms=w[0]=w[n+1]=-INF;
    for(int i=1;i<=n;i++)   cin>>w[i];
    Tr.root=Tr.build(0,0,n+1);

    char op[30];
    while(m--)
    {
        cin>>op;
        if(!strcmp(op,"INSERT"))
        {
            int posi,tot;
            cin>>posi>>tot;
            for(int i=0;i<tot;i++)  cin>>w[i];
            int l=Tr.get_k(posi+1),r=Tr.get_k(posi+2);
            Tr.splay(l,0),Tr.splay(r,l);
            int u=Tr.build(r,0,tot-1);
            Tr.tr[r].s[0]=u;
            Tr.pushup(r),Tr.pushup(l);
        }
        else if(!strcmp(op,"DELETE"))
        {
            int posi,tot;
            cin>>posi>>tot;
            int l=Tr.get_k(posi),r=Tr.get_k(posi+tot+1);
            Tr.splay(l,0),Tr.splay(r,l);
            Tr.dfs(Tr.tr[r].s[0]);
            Tr.tr[r].s[0]=0;
            Tr.pushup(r),Tr.pushup(l);
        }
        else if(!strcmp(op,"MAKE-SAME"))
        {
            int posi,tot,c;
            cin>>posi>>tot>>c;
            int l=Tr.get_k(posi),r=Tr.get_k(posi+tot+1);
            Tr.splay(l,0),Tr.splay(r,l);
            auto &son=Tr.tr[Tr.tr[r].s[0]];
            son.same=1,son.v=c,son.sum=c*son.size;
            if(c>0) son.ms=son.ls=son.rs=son.sum;
            else    son.ms=c,son.ls=son.rs=0;
            Tr.pushup(r),Tr.pushup(l);
        }
        else if(!strcmp(op,"REVERSE"))
        {
            int posi,tot;
            cin>>posi>>tot;
            int l=Tr.get_k(posi),r=Tr.get_k(posi+tot+1);
            Tr.splay(l,0),Tr.splay(r,l);
            auto &son=Tr.tr[Tr.tr[r].s[0]];
            son.rev^=1;
            swap(son.ls,son.rs);
            swap(son.s[0],son.s[1]);
            Tr.pushup(r),Tr.pushup(l);
        }
        else if(!strcmp(op,"GET-SUM"))
        {
            int posi,tot;
            cin>>posi>>tot;
            int l=Tr.get_k(posi),r=Tr.get_k(posi+tot+1);
            Tr.splay(l,0),Tr.splay(r,l);
            cout<<Tr.tr[Tr.tr[r].s[0]].sum<<endl;
        }
        else    cout<<Tr.tr[Tr.root].ms<<endl;
    }
    return 0;
}

posted @ 2025-05-23 21:16  袍蚤  阅读(19)  评论(0)    收藏  举报