Splay 学习笔记
最近准备学习 LCT,因此先学习了 Splay。
前置知识
核心操作
基础操作
#define fa(x) t[x].fa
#define ls(x) t[x].ch[0]
#define rs(x) t[x].ch[1]
int k,rt;//节点数,根
struct tree
{
int ch[2],fa,val,sz;//左右儿子,父亲,值,子树大小
}t[N];
bool dir(int x)//判断x是它父亲的左儿子还是右儿子
{
return x==rs(fa(x));
}
int newnode(int v)//新建节点
{
t[++k].val=v;
t[k].sz=1;
return k;
}
void pushup(int x)//合并儿子信息
{
t[x].sz=t[ls(x)].sz+t[rs(x)].sz+1;
}
旋转操作
旋转操作的本质是把指定节点上移一个位置,并保证树的中序遍历(即二叉搜索树的性质)不变。
旋转分为右旋(Zig) 和左旋(Zag),分别用于处理指定节点是左儿子和右儿子的情况。如下图,由上到下为右旋,由下到上为左旋。


代码按照旋转的定义模拟即可,需要注意的是必须保证 \(0\) 号节点的所有属性都为 \(0\)。
void rotate(int x)
{
int y=fa(x),z=fa(y);
bool f=dir(x);
t[y].ch[f]=t[x].ch[!f];
t[x].ch[!f]=y;
if(z)//判断0号节点
t[z].ch[dir(y)]=x;
if(t[y].ch[f])
fa(t[y].ch[f])=y;
fa(y)=x;
fa(x)=z;
pushup(y);//先更新儿子再更新父亲
pushup(x);
}
Splay 操作
\(Splay(x)\) 的作用是把点 \(x\) 一路旋到根 \(rt\) 上,其由三种类型组成:
Zig / Zag
这种操作仅发生在 \(fa(x)=rt\) 时,将 \(x\) 旋转一次即可。
Zig-Zig / Zag-Zag
当 \(x\) 和 \(fa(x)\) 同为它们父亲的左儿子或右儿子时,先将 \(fa(x)\) 旋转一次,再将 \(x\) 旋转一次。下图为对 \(3\) 号节点进行的一次 Zig-Zig 操作。

Zig-Zag / Zag-Zig
当 \(x\) 和 \(fa(x)\) 相对于父亲是不同方向的儿子时,连续将 \(x\) 旋转两次。下图为对 \(3\) 号节点进行的一次 Zig-Zag 操作。

而 Splay 操作则就是这三种操作的组合。代码如下,为了便于理解(其实是我不会用三目运算符),这里使用较为复杂的 \(if/else\) 实现。
//在常规的平衡树操作中只需要旋转到树根,但是部分操作有旋转到其他祖先的要求,所以这里有一个z表示要旋转到的位置
void splay(int x,int &z=rt)
{
int w=fa(z);//x和z的父亲相等,则表示到位置了
while(fa(x)!=w && fa(fa(x))!=w)
{
if(dir(fa(x))==dir(x))
rotate(fa(x));//Zig-Zig / Zag-Zag
else
rotate(x);//Zig-Zag / Zag-Zig
rotate(x);
}
if(fa(x)!=w)
rotate(x);//最后可能有一次Zig / Zag
z=x;
}
时间复杂度
单次均摊复杂度是 \(O(\log n)\) 的,我不会证,想看证明可以去 oi-wiki。
维护集合操作
需要注意的是,所有操作结束后都应进行 Splay 操作以保证时间复杂度。
插入
从根一直找到应该插入的位置。
void insert(int v)
{
int x=rt,y=0;//y是x的父亲
while(x)
{
y=x;
x=t[x].ch[t[x].val<v];
}
x=newnode(v);
fa(x)=y;
t[y].ch[t[y].val<v]=x;
splay(x);
}
删除
最复杂的操作,有不同的实现方法,这里采用的方法是找到要删除的节点后将其转到根上操作。
void erase(int v)
{
int x=rt,y=0;
while(t[x].val!=v && x)
{
y=x;
x=t[x].ch[t[x].val<v];
}
if(!x)//找不到节点,直接退出
{
splay(y);
return;
}
splay(x);
if(!ls(x) || !rs(x))//如果要删除节点只有一个儿子,将儿子设为根即可
{
rt=ls(x)+rs(x);
fa(ls(x)+rs(x))=0;
return;
}
int p=rt=ls(x);
fa(p)=0;
while(rs(p))
p=rs(p);
rs(p)=rs(x);//将右儿子接在左子树中最大的节点下面
fa(rs(x))=p;
pushup(p);//改变了结构,要额外pushup一次
splay(p);
}
查询排名
在树上搜索的时候统计比 \(v\) 小的节点数量。
int getrnk(int v)
{
int x=rt,y=0,ans=1;
while(x)
{
y=x;
if(t[x].val<v)
{
ans+=t[ls(x)].sz+1;
x=rs(x);
}
else
x=ls(x);
}
splay(y);
return ans;
}
查询第 k 大值
类似线段树二分。
int getkth(int v)
{
int x=rt;
while(1)
{
int now=t[ls(x)].sz+1;
if(now==v)
break;
if(now<v)
{
v-=now;
x=rs(x);
}
else
x=ls(x);
}
splay(x);
return t[x].val;
}
查询前驱后继
查询前驱类似于查询排名,只是改为纪录比 \(v\) 小的节点数值;查询后继就是前驱的做法反过来。
int getpre(int v)//前驱
{
int x=rt,y=0,ans=0;
while(x)
{
y=x;
if(t[x].val<v)
{
ans=t[x].val;
x=rs(x);
}
else
x=ls(x);
}
splay(y);
return ans;
}
int getnxt(int v)//后继
{
int x=rt,y=0,ans=0;
while(x)
{
y=x;
if(t[x].val>v)
{
ans=t[x].val;
x=ls(x);
}
else
x=rs(x);
}
splay(y);
return ans;
}
完整代码
#include <bits/stdc++.h>
using namespace std;
const int N=1e5+5;
#define fa(x) t[x].fa
#define ls(x) t[x].ch[0]
#define rs(x) t[x].ch[1]
int k,rt,n;
struct tree
{
int ch[2],fa,val,sz;
}t[N];
bool dir(int x)
{
return x==rs(fa(x));
}
int newnode(int v)
{
t[++k].val=v;
t[k].sz=1;
return k;
}
void pushup(int x)
{
t[x].sz=t[ls(x)].sz+t[rs(x)].sz+1;
}
void rotate(int x)
{
int y=fa(x),z=fa(y);
bool f=dir(x);
t[y].ch[f]=t[x].ch[!f];
t[x].ch[!f]=y;
if(z)
t[z].ch[dir(y)]=x;
if(t[y].ch[f])
fa(t[y].ch[f])=y;
fa(y)=x;
fa(x)=z;
pushup(y);
pushup(x);
}
void splay(int x,int &z=rt)
{
int w=fa(z);
while(fa(x)!=w && fa(fa(x))!=w)
{
if(dir(fa(x))==dir(x))
rotate(fa(x));
else
rotate(x);
rotate(x);
}
if(fa(x)!=w)
rotate(x);
z=x;
}
void insert(int v)
{
int x=rt,y=0;
while(x)
{
y=x;
x=t[x].ch[t[x].val<v];
}
x=newnode(v);
fa(x)=y;
t[y].ch[t[y].val<v]=x;
splay(x);
}
void erase(int v)
{
int x=rt,y=0;
while(t[x].val!=v && x)
{
y=x;
x=t[x].ch[t[x].val<v];
}
if(!x)
{
splay(y);
return;
}
splay(x);
if(!ls(x) || !rs(x))
{
rt=ls(x)+rs(x);
fa(ls(x)+rs(x))=0;
return;
}
int p=rt=ls(x);
fa(p)=0;
while(rs(p))
p=rs(p);
rs(p)=rs(x);
fa(rs(x))=p;
pushup(p);
splay(p);
}
int getrnk(int v)
{
int x=rt,y=0,ans=1;
while(x)
{
y=x;
if(t[x].val<v)
{
ans+=t[ls(x)].sz+1;
x=rs(x);
}
else
x=ls(x);
}
splay(y);
return ans;
}
int getkth(int v)
{
int x=rt;
while(1)
{
int now=t[ls(x)].sz+1;
if(now==v)
break;
if(now<v)
{
v-=now;
x=rs(x);
}
else
x=ls(x);
}
splay(x);
return t[x].val;
}
int getpre(int v)
{
int x=rt,y=0,ans=0;
while(x)
{
y=x;
if(t[x].val<v)
{
ans=t[x].val;
x=rs(x);
}
else
x=ls(x);
}
splay(y);
return ans;
}
int getnxt(int v)
{
int x=rt,y=0,ans=0;
while(x)
{
y=x;
if(t[x].val>v)
{
ans=t[x].val;
x=ls(x);
}
else
x=rs(x);
}
splay(y);
return ans;
}
int main()
{
scanf("%d",&n);
while(n--)
{
int op,x;
scanf("%d%d",&op,&x);
if(op==1)
insert(x);
else if(op==2)
erase(x);
else if(op==3)
printf("%d\n",getrnk(x));
else if(op==4)
printf("%d\n",getkth(x));
else if(op==5)
printf("%d\n",getpre(x));
else
printf("%d\n",getnxt(x));
}
return 0;
}
维护序列操作
平衡树的另一个重要用途就是维护序列。
这里拿一道例题说明:P3391 【模板】文艺平衡树
题意简述:给一个序列,支持多次区间翻转,求最终的序列。
序列操作中平衡树不是依据大小关系,而是依据排列顺序维护元素的,即树的中序遍历就是当前序列,其他的基本操作和上述的没有区别。
而本题要求实现的区间翻转,容易想到可以先把询问的区间集中到一棵子树上,再交换该子树的所有节点的左右儿子来实现。
建树
也可以直接插入 \(n\) 个元素,这里给一种易懂的建树方法。
void build(int &x,int l,int r)
{
int mid=(l+r)>>1;
x=newnode(mid);
if(mid>l)
{
build(ls(x),l,mid-1);
fa(ls(x))=x;
}
if(mid<r)
{
build(rs(x),mid+1,r);
fa(rs(x))=x;
}
pushup(x);
}
如何把区间集中到一棵子树?
根据序列平衡树的性质可以得到两个简单结论:
- 每一棵子树都对应序列上一个区间。
- 设某子树对应的区间是 \([l,r]\),根对应的位置是 \(k\),那么如果根有左儿子,左子树对应的区间为 \([l,k-1]\);如果根有右儿子,右子树对应的区间为 \([k+1,r]\)。
以上都可以用中序遍历的性质简单证明。
对于要翻转的区间 \([l,r]\),把 \(l-1\) 对应的点 Splay 到根,右子树对应的区间就是 \([l,n]\) 了;此时对于右子树,再把 \(r+1\) 对应的点(肯定在子树中) Splay 到根,这棵右子树的左子树对应的区间就正好是 \([l,r]\),即我们要求的区间了。
可以结合下图帮助理解(图中均用序列中位置代指节点)。

需要注意的是,如果 \(l=1\) 或 \(r=n\),则无法找到 \(l-1\) 或 \(r+1\) 对应的点,一个简单的弥补方法是把 \(0\) 和 \(n+1\) 也作为节点加入树中。
如何交换某子树所有节点的左右儿子?
直接一个个交换的复杂度肯定是过高的,所以这里引入线段树中的懒标记思想,即对每个节点维护懒标记,在遍历儿子节点时下传。
想必各位都很熟悉线段树,这里就不仔细讲了,细节见代码。
struct tree
{
int ch[2],fa,sz,val,tag;
}t[N];
void change(int x)//修改一个区间
{
if(!x)
return;
t[x].tag^=1;
swap(ls(x),rs(x));
}
void pushdown(int x)//下传标记
{
if(!t[x].tag)
return;
t[x].tag=0;
change(ls(x));
change(rs(x));
}
查找
即找到要 Splay 的位置,和集合操作中的找第 k 大基本一致,只是要记得下传标记,以及不能 Splay(需要确保 \(r+1\) 在 \(l-1\) 的子树里)。
翻转
很简单,没什么好说的。
void reverse(int l,int r)
{
int x=getkth(l-1);
splay(x);
int y=getkth(r+1);
splay(y,rs(x));
change(ls(y));
}
完整代码
#include <bits/stdc++.h>
using namespace std;
const int N=1e5+5;
#define ls(x) t[x].ch[0]
#define rs(x) t[x].ch[1]
#define fa(x) t[x].fa
int rt,k,n,q;
struct tree
{
int ch[2],fa,sz,val,tag;
}t[N];
int newnode(int v)
{
t[++k].val=v;
t[k].sz=1;
return k;
}
bool dir(int x)
{
return x==rs(fa(x));
}
void pushup(int x)
{
t[x].sz=t[ls(x)].sz+t[rs(x)].sz+1;
}
void change(int x)
{
if(!x)
return;
t[x].tag^=1;
swap(ls(x),rs(x));
}
void pushdown(int x)
{
if(!t[x].tag)
return;
t[x].tag=0;
change(ls(x));
change(rs(x));
}
void rotate(int x)
{
int y=fa(x),z=fa(y);
bool f=dir(x);
t[y].ch[f]=t[x].ch[!f];
t[x].ch[!f]=y;
if(z)
t[z].ch[dir(y)]=x;
if(t[y].ch[f])
fa(t[y].ch[f])=y;
fa(y)=x;
fa(x)=z;
pushup(y);
pushup(x);
}
void splay(int x,int &z=rt)
{
int w=fa(z);
while(fa(x)!=w && fa(fa(x))!=w)
{
if(dir(fa(x))==dir(x))
rotate(fa(x));
else
rotate(x);
rotate(x);
}
if(fa(x)!=w)
rotate(x);
z=x;
}
void build(int &x,int l,int r)
{
int mid=(l+r)>>1;
x=newnode(mid);
if(mid>l)
{
build(ls(x),l,mid-1);
fa(ls(x))=x;
}
if(mid<r)
{
build(rs(x),mid+1,r);
fa(rs(x))=x;
}
pushup(x);
}
int getkth(int v)
{
v++;
int x=rt;
while(1)
{
pushdown(x);//记得下传标记!
int now=t[ls(x)].sz+1;
if(now==v)
break;
if(now<v)
{
v-=now;
x=rs(x);
}
else
x=ls(x);
}//不能Splay
return x;
}
void reverse(int l,int r)
{
int x=getkth(l-1);
splay(x);
int y=getkth(r+1);
splay(y,rs(x));
change(ls(y));
}
void output(int x)//输出答案
{
if(!x)
return;
pushdown(x);
output(ls(x));
if(t[x].val>=1 && t[x].val<=n)
printf("%d ",t[x].val);
output(rs(x));
}
int main( void )
{
scanf("%d%d",&n,&q);
build(rt,0,n+1);
while(q--)
{
int x,y;
scanf("%d%d",&x,&y);
reverse(x,y);
}
output(rt);
return 0;
}

浙公网安备 33010602011771号