Splay算法

前言

看了 \(ying-xue\) 大佬的博客,讲的太好了,以至于我都没有动力写这篇博客了T_T,所以想要学习 Splay 的请移步ying-xue cat的博客

这里仅记载我的一些理解和想法

算法理解

本质上还是维护一棵二叉搜索树

代码

#include<bits/stdc++.h>
using namespace std;
const int N=1e5+5;
struct tree{
    int s[2],siz,fa,key;
    tree(){s[0]=s[1]=siz=fa=key=0;}
}tr[N];
#define ls(x) tr[x].s[0]
#define rs(x) tr[x].s[1]
#define fa(x) tr[x].fa
int rt,idx,T;
int newnode(int key){
    tr[++idx].key=key;
    tr[idx].siz=1;
    return idx;
}
void update(int x){
    tr[x].siz=tr[ls(x)].siz+tr[rs(x)].siz+1;
}
void clear(int x){
    ls(x)=rs(x)=fa(x)=tr[x].siz=tr[x].key=0;
}
bool get(int x){
    return x==rs(fa(x));
}
void rotate(int x){
	int y=fa(x),z=fa(y),xy=get(x),yz=get(y);//xy是x的方向,yz是y的方向 
	if(tr[x].s[xy^1])  fa(tr[x].s[xy^1])=y;tr[y].s[xy]=tr[x].s[xy^1];//将x的反方向子树移到y上 
	fa(y)=x;tr[x].s[xy^1]=y;//将y移到x上
	if(z)  tr[z].s[yz]=x;fa(x)=z;//将x移到z上
	update(y);//先更新y因为它是x的子树 
	update(x);
}
void splay(int x){
	for(int f=fa(x);f=fa(x);){//一直旋到x没有根为止 
		if(fa(f)){
			if(get(x)==get(f))  rotate(f);//同向先转父亲 
			else  rotate(x); //不同向x转两次 
		}
		rotate(x);
	}
	rt=x;//把根设为x 
}
void ins(int key){
	int now=rt,f=0;
	while(now){//得走到一个空节点加入 
		f=now;
		now=tr[now].s[key>tr[f].key]; //往左/右走 
	} 
	now=newnode(key);
	fa(now)=f;tr[f].s[key>tr[f].key]=now;
	splay(now);
}
void del(int key){
	int now=rt,p=0;
	while(tr[now].key!=key&&now){//找到now的位置 
		p=now;
		now=tr[now].s[key>tr[now].key];
	} 
	if(!now){//如果没找到 
		splay(p);
		return;
	}
	splay(now);
	int cur=ls(now);
	if(!cur){//如果没有左子树 
		rt=rs(now);fa(rs(now))=0;clear(now);//把右子树作为根 
		return;
	}
	while(rs(cur))  cur=rs(cur);//找到左子树的最右边 
	rs(cur)=rs(now);fa(rs(now))=cur;//将右子树接到左子树下边 
	fa(ls(now))=0;clear(now);
	update(cur);splay(cur);
} 
int pre(int key){
	int now=rt,ans=0,p=rt;//p一定要初始化 
	while(now){
		p=now; 
		if(tr[now].key>=key)  now=ls(now);
		else  ans=tr[now].key,now=rs(now);
	}
	splay(p);
	return ans;
} 
int nxt(int key){
	int now=rt,ans=0,p=rt;
	while(now){
		p=now;
		if(tr[now].key<=key)  now=rs(now);
		else  ans=tr[now].key,now=ls(now);
	}
	splay(p);
	return ans;
}
int kth(int rk){
	int now=rt;
	while(now){
		int sz=tr[ls(now)].siz+1;
		if(sz>rk)  now=ls(now);
		else if(sz==rk)  break;
		else  rk-=sz,now=rs(now); 
	}
	splay(now);
	return tr[now].key;
}
int rnk(int key){
	int res=1,now=rt,p=rt;
	while(now){
		p=now;
		if(tr[now].key<key)  res+=tr[ls(now)].siz+1,now=rs(now);
		else  now=ls(now);
	} 
	splay(p);
	return res;
}
int op,x;
int main()
{
    scanf("%d",&T);
    while(T--)
    {
    	scanf("%d%d",&op,&x);
        if(op==1) ins(x);
        if(op==2) del(x);
        if(op==3) printf("%d\n",rnk(x));
        if(op==4) printf("%d\n",kth(x));
        if(op==5) printf("%d\n",pre(x));
        if(op==6) printf("%d\n",nxt(x));
    }
    return 0;
}
posted @ 2025-05-09 21:33  daydreamer_zcxnb  阅读(23)  评论(0)    收藏  举报