Splay学习笔记 & P3369 【模板】普通平衡树

传送门


Splay

细节是真他妈多,写了一天,写吐了。

后悔没先学treap了。

需要实现以下函数:

  • void Init():重置整棵树(删除了整棵树的时候用)
  • int New(int val,int fa):新建一个节点,权值为val,父亲为fa,返回节点编号
  • void Delete(int x):清空节点x的信息
  • void update(int x):更新x节点的siz大小(类似线段树的push up)
  • void rotate(int x):把节点x旋转到x父亲的位置,左旋右旋可以写在一起
  • void splay(int x,int goal):把x通过不断旋转,直到父亲是goal
  • bool find(int val):找到值为val的节点,并旋转到根,返回是否成功找到
  • int pre():找到第一个小于根节点val的节点的编号
  • int nxt():找到第一个大于根节点val的节点的编号
  • void insert(int val):插入值为val的点
  • void del(int val):删除值为val的点,注意若整棵树就一个节点需要Init()一下
  • int getrk(int val):返回值为val的节点的排名
  • int getval(int x):返回排名为x的节点权值

Q:要注意啥?
A:呵呵,真的没啥要注意的,就注意别写挂了行了/kx/kx/kx/kx

AC代码

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<queue>
#include<stack>
#include<map>
#include<vector>
using namespace std;
template<class T>inline void read(T &x)
{
    x=0;register char c=getchar();register bool f=0;
    while(!isdigit(c))f^=c=='-',c=getchar();
    while(isdigit(c))x=(x<<3)+(x<<1)+(c^48),c=getchar();
    if(f)x=-x;
}
template<class T>inline void print(T x)
{
    if(x<0)putchar('-'),x=-x;
    if(x>9)print(x/10);
    putchar('0'+x%10);
}
const int maxn=1e5+5;
int n,rt,cnt;
struct node{
	int fa,son[2],siz,num,val;
}tr[maxn];
inline void Init(){
	tr[rt].fa=tr[rt].son[0]=tr[rt].son[1]=tr[rt].siz=tr[rt].num=0;
	rt=0;
}
inline int New(int val,int fa){
	cnt++;
	tr[cnt].fa=fa;
	tr[cnt].val=val;
	tr[cnt].num=tr[cnt].siz=1;
	return cnt;
}
inline void Delete(int x){
	tr[x].fa=tr[x].son[0]=tr[x].son[1]=tr[x].siz=tr[x].num=0;
}
inline void update(int x){
	if(!x) return;
	tr[x].siz=tr[x].num;
	if(tr[x].son[0]) tr[x].siz+=tr[tr[x].son[0]].siz;
	if(tr[x].son[1]) tr[x].siz+=tr[tr[x].son[1]].siz;
}
void rotate(int x){
	int y=tr[x].fa,z=tr[y].fa;
	int c=(tr[y].son[1]==x);
	tr[y].son[c]=tr[x].son[!c];
	tr[tr[x].son[!c]].fa=y;
	tr[x].son[!c]=y;
	if(z) tr[z].son[tr[z].son[1]==y]=x;
	tr[y].fa=x;
	tr[x].fa=z;
	update(y);
	update(x);
}
void splay(int x,int goal){
	while(tr[x].fa!=goal){
		int y=tr[x].fa,z=tr[y].fa;
		if(z!=goal) ((tr[y].son[1]==x)^(tr[z].son[1]==y))?rotate(x):rotate(y);
		rotate(x);
	}
	if(!goal) rt=x;
}
bool find(int val){
	int x=rt;
	while(1){
		if(tr[x].val==val) return splay(x,0),1;
		if(tr[x].son[tr[x].val<val]) x=tr[x].son[tr[x].val<val];
		else return 0;
	}
}
int pre(){
	int x=tr[rt].son[0];
	while(tr[x].son[1]) x=tr[x].son[1];
	return x;
}
int nxt(){
	int x=tr[rt].son[1];
	while(tr[x].son[0]) x=tr[x].son[0];
	return x;
}
void insert(int val){
	if(!rt){
		rt=New(val,0);
		return;
	}
	if(find(val)){
		tr[rt].num++;
		tr[rt].siz++;
		return;
	}
	int x=rt;
	while(1){
		if(tr[x].son[tr[x].val<val]) x=tr[x].son[tr[x].val<val];
		else{
			tr[x].son[tr[x].val<val]=New(val,x);
			update(x);
			splay(tr[x].son[tr[x].val<val],0);
			return;
		}
	}
}
void del(int val){
	find(val);
	int x=rt;
	tr[rt].num--;
	tr[rt].siz--;
	if(tr[rt].num) return;
	if(!tr[rt].son[0]&&!tr[rt].son[1]){
		Init();
		return;
	}
	if(!tr[rt].son[0]){
		rt=tr[rt].son[1];
		tr[rt].fa=0;
		Delete(x);
		return;
	}
	if(!tr[rt].son[1]){
		rt=tr[rt].son[0];
		tr[rt].fa=0;
		Delete(x);
		return;
	}
	find(tr[pre()].val);
	splay(x,rt);
	if(tr[x].son[1]){
		tr[rt].son[1]=tr[x].son[1];
		tr[tr[x].son[1]].fa=rt;	
	}
	Delete(x);
}
int getrk(int val){
	int x=rt,res=0;
	while(1){
		if(val==tr[x].val){
			res+=tr[x].son[0]?tr[tr[x].son[0]].siz+1:1;
			splay(x,0);
			return res;
		}
		if(val<tr[x].val){
			if(!tr[x].son[0]) return res;
			x=tr[x].son[0];
		}else{
			res+=tr[x].num;
			if(tr[x].son[0]) res+=tr[tr[x].son[0]].siz;
			if(!tr[x].son[1]) return res;
			x=tr[x].son[1];
		}
	}
}
int getval(int tot){
	int x=rt;
	while(1){
		if(tr[x].son[0]&&tot<=tr[tr[x].son[0]].siz) x=tr[x].son[0];
		else{
			tot-=tr[tr[x].son[0]].siz;
			if(tot<=tr[x].num) return tr[x].val;
			tot-=tr[x].num;
			x=tr[x].son[1];
		}
	}
}
int main(){
	read(n);
	for(int i=1;i<=n;i++){
		int op,x;
		read(op);read(x);
		if(op==1) insert(x);
		if(op==2) del(x);
		if(op==3) print(getrk(x)),puts("");
		if(op==4) print(getval(x)),puts("");
		if(op==5) insert(x),print(tr[pre()].val),puts(""),del(x);
		if(op==6) insert(x),print(tr[nxt()].val),puts(""),del(x);
	}
	return 0;
}
posted @ 2021-11-02 21:33  尹昱钦  阅读(40)  评论(0编辑  收藏  举报