Splay 学习笔记

最近准备学习 LCT,因此先学习了 Splay。

前置知识

二叉搜索树

核心操作

基础操作

#define fa(x) t[x].fa
#define ls(x) t[x].ch[0]
#define rs(x) t[x].ch[1]
int k,rt;//节点数,根
struct tree
{
    int ch[2],fa,val,sz;//左右儿子,父亲,值,子树大小
}t[N];
bool dir(int x)//判断x是它父亲的左儿子还是右儿子
{
    return x==rs(fa(x));
}
int newnode(int v)//新建节点
{
    t[++k].val=v;
    t[k].sz=1;
    return k;
}
void pushup(int x)//合并儿子信息
{
    t[x].sz=t[ls(x)].sz+t[rs(x)].sz+1;
}

旋转操作

旋转操作的本质是把指定节点上移一个位置,并保证树的中序遍历(即二叉搜索树的性质)不变。
旋转分为右旋(Zig) 和左旋(Zag),分别用于处理指定节点是左儿子和右儿子的情况。如下图,由上到下为右旋,由下到上为左旋。


代码按照旋转的定义模拟即可,需要注意的是必须保证 \(0\) 号节点的所有属性都为 \(0\)

void rotate(int x)
{
    int y=fa(x),z=fa(y);
    bool f=dir(x);
    t[y].ch[f]=t[x].ch[!f];
    t[x].ch[!f]=y;
    if(z)//判断0号节点
        t[z].ch[dir(y)]=x;
    if(t[y].ch[f])
        fa(t[y].ch[f])=y;
    fa(y)=x;
    fa(x)=z;
    pushup(y);//先更新儿子再更新父亲
    pushup(x);
}

Splay 操作

\(Splay(x)\) 的作用是把点 \(x\) 一路旋到根 \(rt\) 上,其由三种类型组成:

Zig / Zag

这种操作仅发生在 \(fa(x)=rt\) 时,将 \(x\) 旋转一次即可。

Zig-Zig / Zag-Zag

\(x\)\(fa(x)\) 同为它们父亲的左儿子或右儿子时,先将 \(fa(x)\) 旋转一次,再将 \(x\) 旋转一次。下图为对 \(3\) 号节点进行的一次 Zig-Zig 操作。

Zig-Zag / Zag-Zig

\(x\)\(fa(x)\) 相对于父亲是不同方向的儿子时,连续将 \(x\) 旋转两次。下图为对 \(3\) 号节点进行的一次 Zig-Zag 操作。

而 Splay 操作则就是这三种操作的组合。代码如下,为了便于理解(其实是我不会用三目运算符),这里使用较为复杂的 \(if/else\) 实现。

//在常规的平衡树操作中只需要旋转到树根,但是部分操作有旋转到其他祖先的要求,所以这里有一个z表示要旋转到的位置
void splay(int x,int &z=rt)
{
    int w=fa(z);//x和z的父亲相等,则表示到位置了
    while(fa(x)!=w && fa(fa(x))!=w)
    {
        if(dir(fa(x))==dir(x))
            rotate(fa(x));//Zig-Zig / Zag-Zag
        else
            rotate(x);//Zig-Zag / Zag-Zig
        rotate(x);
    }
    if(fa(x)!=w)
        rotate(x);//最后可能有一次Zig / Zag
    z=x;
}

时间复杂度

单次均摊复杂度是 \(O(\log n)\) 的,我不会证,想看证明可以去 oi-wiki

维护集合操作

需要注意的是,所有操作结束后都应进行 Splay 操作以保证时间复杂度。

插入

从根一直找到应该插入的位置。

void insert(int v)
{
    int x=rt,y=0;//y是x的父亲
    while(x)
    {
        y=x;
        x=t[x].ch[t[x].val<v];
    }
    x=newnode(v);
    fa(x)=y;
    t[y].ch[t[y].val<v]=x;
    splay(x);
}

删除

最复杂的操作,有不同的实现方法,这里采用的方法是找到要删除的节点后将其转到根上操作。

void erase(int v)
{
    int x=rt,y=0;
    while(t[x].val!=v && x)
    {
        y=x;
        x=t[x].ch[t[x].val<v];
    }
    if(!x)//找不到节点,直接退出
    {
        splay(y);
        return;
    }
    splay(x);
    if(!ls(x) || !rs(x))//如果要删除节点只有一个儿子,将儿子设为根即可
    {
        rt=ls(x)+rs(x);
        fa(ls(x)+rs(x))=0;
        return;
    }
    int p=rt=ls(x);
    fa(p)=0;
    while(rs(p))
        p=rs(p);
    rs(p)=rs(x);//将右儿子接在左子树中最大的节点下面
    fa(rs(x))=p;
    pushup(p);//改变了结构,要额外pushup一次
    splay(p);
}

查询排名

在树上搜索的时候统计比 \(v\) 小的节点数量。

int getrnk(int v)
{
    int x=rt,y=0,ans=1;
    while(x)
    {
        y=x;
        if(t[x].val<v)
        {
            ans+=t[ls(x)].sz+1;
            x=rs(x);
        }
        else
            x=ls(x);
    }
    splay(y);
    return ans;
}

查询第 k 大值

类似线段树二分。

int getkth(int v)
{
    int x=rt;
    while(1)
    {
        int now=t[ls(x)].sz+1;
        if(now==v)
            break;
        if(now<v)
        {
            v-=now;
            x=rs(x);
        }
        else
            x=ls(x);
    }
    splay(x);
    return t[x].val;
}

查询前驱后继

查询前驱类似于查询排名,只是改为纪录比 \(v\) 小的节点数值;查询后继就是前驱的做法反过来。

int getpre(int v)//前驱
{
    int x=rt,y=0,ans=0;
    while(x)
    {
        y=x;
        if(t[x].val<v)
        {
            ans=t[x].val;
            x=rs(x);
        }
        else
            x=ls(x);
    }
    splay(y);
    return ans;
}
int getnxt(int v)//后继
{
    int x=rt,y=0,ans=0;
    while(x)
    {
        y=x;
        if(t[x].val>v)
        {
            ans=t[x].val;
            x=ls(x);
        }
        else
            x=rs(x);
    }
    splay(y);
    return ans;
}

完整代码

例题:P3369 【模板】普通平衡树

#include <bits/stdc++.h>
using namespace std;
const int N=1e5+5;
#define fa(x) t[x].fa
#define ls(x) t[x].ch[0]
#define rs(x) t[x].ch[1]

int k,rt,n;
struct tree
{
    int ch[2],fa,val,sz;
}t[N];

bool dir(int x)
{
    return x==rs(fa(x));
}

int newnode(int v)
{
    t[++k].val=v;
    t[k].sz=1;
    return k;
}

void pushup(int x)
{
    t[x].sz=t[ls(x)].sz+t[rs(x)].sz+1;
}

void rotate(int x)
{
    int y=fa(x),z=fa(y);
    bool f=dir(x);
    t[y].ch[f]=t[x].ch[!f];
    t[x].ch[!f]=y;
    if(z)
        t[z].ch[dir(y)]=x;
    if(t[y].ch[f])
        fa(t[y].ch[f])=y;
    fa(y)=x;
    fa(x)=z;
    pushup(y);
    pushup(x);
}

void splay(int x,int &z=rt)
{
    int w=fa(z);
    while(fa(x)!=w && fa(fa(x))!=w)
    {
        if(dir(fa(x))==dir(x))
            rotate(fa(x));
        else
            rotate(x);
        rotate(x);
    }
    if(fa(x)!=w)
        rotate(x);
    z=x;
}

void insert(int v)
{
    int x=rt,y=0;
    while(x)
    {
        y=x;
        x=t[x].ch[t[x].val<v];
    }
    x=newnode(v);
    fa(x)=y;
    t[y].ch[t[y].val<v]=x;
    splay(x);
}

void erase(int v)
{
    int x=rt,y=0;
    while(t[x].val!=v && x)
    {
        y=x;
        x=t[x].ch[t[x].val<v];
    }
    if(!x)
    {
        splay(y);
        return;
    }
    splay(x);
    if(!ls(x) || !rs(x))
    {
        rt=ls(x)+rs(x);
        fa(ls(x)+rs(x))=0;
        return;
    }
    int p=rt=ls(x);
    fa(p)=0;
    while(rs(p))
        p=rs(p);
    rs(p)=rs(x);
    fa(rs(x))=p;
    pushup(p);
    splay(p);
}

int getrnk(int v)
{
    int x=rt,y=0,ans=1;
    while(x)
    {
        y=x;
        if(t[x].val<v)
        {
            ans+=t[ls(x)].sz+1;
            x=rs(x);
        }
        else
            x=ls(x);
    }
    splay(y);
    return ans;
}

int getkth(int v)
{
    int x=rt;
    while(1)
    {
        int now=t[ls(x)].sz+1;
        if(now==v)
            break;
        if(now<v)
        {
            v-=now;
            x=rs(x);
        }
        else
            x=ls(x);
    }
    splay(x);
    return t[x].val;
}

int getpre(int v)
{
    int x=rt,y=0,ans=0;
    while(x)
    {
        y=x;
        if(t[x].val<v)
        {
            ans=t[x].val;
            x=rs(x);
        }
        else
            x=ls(x);
    }
    splay(y);
    return ans;
}

int getnxt(int v)
{
    int x=rt,y=0,ans=0;
    while(x)
    {
        y=x;
        if(t[x].val>v)
        {
            ans=t[x].val;
            x=ls(x);
        }
        else
            x=rs(x);
    }
    splay(y);
    return ans;
}

int main()
{
    scanf("%d",&n);
    while(n--)
    {
        int op,x;
        scanf("%d%d",&op,&x);
        if(op==1)
            insert(x);
        else if(op==2)
            erase(x);
        else if(op==3)
            printf("%d\n",getrnk(x));
        else if(op==4)
            printf("%d\n",getkth(x));
        else if(op==5)
            printf("%d\n",getpre(x));
        else
            printf("%d\n",getnxt(x));
    }

    return 0;
}

维护序列操作

平衡树的另一个重要用途就是维护序列。
这里拿一道例题说明:P3391 【模板】文艺平衡树
题意简述:给一个序列,支持多次区间翻转,求最终的序列。

序列操作中平衡树不是依据大小关系,而是依据排列顺序维护元素的,即树的中序遍历就是当前序列,其他的基本操作和上述的没有区别。

而本题要求实现的区间翻转,容易想到可以先把询问的区间集中到一棵子树上,再交换该子树的所有节点的左右儿子来实现。

建树

也可以直接插入 \(n\) 个元素,这里给一种易懂的建树方法。

void build(int &x,int l,int r)
{
    int mid=(l+r)>>1;
    x=newnode(mid);
    if(mid>l)
    {
        build(ls(x),l,mid-1);
        fa(ls(x))=x;
    }
    if(mid<r)
    {
        build(rs(x),mid+1,r);
        fa(rs(x))=x;
    }
    pushup(x);
}

如何把区间集中到一棵子树?

根据序列平衡树的性质可以得到两个简单结论:

  1. 每一棵子树都对应序列上一个区间。
  2. 设某子树对应的区间是 \([l,r]\),根对应的位置是 \(k\),那么如果根有左儿子,左子树对应的区间为 \([l,k-1]\);如果根有右儿子,右子树对应的区间为 \([k+1,r]\)

以上都可以用中序遍历的性质简单证明。
对于要翻转的区间 \([l,r]\),把 \(l-1\) 对应的点 Splay 到根,右子树对应的区间就是 \([l,n]\) 了;此时对于右子树,再把 \(r+1\) 对应的点(肯定在子树中) Splay 到根,这棵右子树的左子树对应的区间就正好是 \([l,r]\),即我们要求的区间了。
可以结合下图帮助理解(图中均用序列中位置代指节点)。

需要注意的是,如果 \(l=1\)\(r=n\),则无法找到 \(l-1\)\(r+1\) 对应的点,一个简单的弥补方法是把 \(0\)\(n+1\) 也作为节点加入树中。

如何交换某子树所有节点的左右儿子?

直接一个个交换的复杂度肯定是过高的,所以这里引入线段树中的懒标记思想,即对每个节点维护懒标记,在遍历儿子节点时下传。
想必各位都很熟悉线段树,这里就不仔细讲了,细节见代码。

struct tree
{
    int ch[2],fa,sz,val,tag;
}t[N];
void change(int x)//修改一个区间
{
    if(!x)
        return;
    t[x].tag^=1;
    swap(ls(x),rs(x));
}
void pushdown(int x)//下传标记
{
    if(!t[x].tag)
        return;
    t[x].tag=0;
    change(ls(x));
    change(rs(x));
}

查找

即找到要 Splay 的位置,和集合操作中的找第 k 大基本一致,只是要记得下传标记,以及不能 Splay(需要确保 \(r+1\)\(l-1\) 的子树里)。

翻转

很简单,没什么好说的。

void reverse(int l,int r)
{
    int x=getkth(l-1);
    splay(x);
    int y=getkth(r+1);
    splay(y,rs(x));
    change(ls(y));
}

完整代码

#include <bits/stdc++.h>
using namespace std;
const int N=1e5+5;
#define ls(x) t[x].ch[0]
#define rs(x) t[x].ch[1]
#define fa(x) t[x].fa

int rt,k,n,q;
struct tree
{
    int ch[2],fa,sz,val,tag;
}t[N];

int newnode(int v)
{
    t[++k].val=v;
    t[k].sz=1;
    return k;
}

bool dir(int x)
{
    return x==rs(fa(x));
}

void pushup(int x)
{
    t[x].sz=t[ls(x)].sz+t[rs(x)].sz+1;
}

void change(int x)
{
    if(!x)
        return;
    t[x].tag^=1;
    swap(ls(x),rs(x));
}

void pushdown(int x)
{
    if(!t[x].tag)
        return;
    t[x].tag=0;
    change(ls(x));
    change(rs(x));
}

void rotate(int x)
{
    int y=fa(x),z=fa(y);
    bool f=dir(x);
    t[y].ch[f]=t[x].ch[!f];
    t[x].ch[!f]=y;
    if(z)
        t[z].ch[dir(y)]=x;
    if(t[y].ch[f])
        fa(t[y].ch[f])=y;
    fa(y)=x;
    fa(x)=z;
    pushup(y);
    pushup(x);
}

void splay(int x,int &z=rt)
{
    int w=fa(z);
    while(fa(x)!=w && fa(fa(x))!=w)
    {
        if(dir(fa(x))==dir(x))
            rotate(fa(x));
        else
            rotate(x);
        rotate(x);
    }
    if(fa(x)!=w)
        rotate(x);
    z=x;
}

void build(int &x,int l,int r)
{
    int mid=(l+r)>>1;
    x=newnode(mid);
    if(mid>l)
    {
        build(ls(x),l,mid-1);
        fa(ls(x))=x;
    }
    if(mid<r)
    {
        build(rs(x),mid+1,r);
        fa(rs(x))=x;
    }
    pushup(x);
}

int getkth(int v)
{
    v++;
    int x=rt;
    while(1)
    {
        pushdown(x);//记得下传标记!
        int now=t[ls(x)].sz+1;
        if(now==v)
            break;
        if(now<v)
        {
            v-=now;
            x=rs(x);
        }
        else
            x=ls(x);
    }//不能Splay
    return x;
}

void reverse(int l,int r)
{
    int x=getkth(l-1);
    splay(x);
    int y=getkth(r+1);
    splay(y,rs(x));
    change(ls(y));
}

void output(int x)//输出答案
{
    if(!x)
        return;
    pushdown(x);
    output(ls(x));
    if(t[x].val>=1 && t[x].val<=n)
        printf("%d ",t[x].val);
    output(rs(x));
}

int main( void )
{
    scanf("%d%d",&n,&q);
    build(rt,0,n+1);
    while(q--)
    {
        int x,y;
        scanf("%d%d",&x,&y);
        reverse(x,y);
    }
    output(rt);
    return 0;
}
posted @ 2026-05-21 17:01  ShanLing3  阅读(18)  评论(0)    收藏  举报
//雪花飘落效果