看八股408数据结构中平衡树有感而发,直接手撸了10h Splay终于撸出来了

首先八股中的平衡树是AVL,他在手撸的时候实在是太难写了,因为要找到最小不平衡子树,然后判断四种情况也就是LL,LR,RR,RL,旋转后还得维护各种信息。

但是为了能够现在手撸平衡树,我直接找到了Splay,这个看起来好撸一点,因为旋转无非就两个了,也就是左旋或右旋了,顶多就是Splay的时候再判断是否爷孙三个一条线罢了。

但是毕竟第一次用指针写这样比较复杂的数据结构,还是撸了好久好久,得两天花了10h吧,最终洛谷T了最后一个点,Acwing过了,500多ms,用洛谷的treap似乎是400多ms,差不多,毕竟这个指针他就是慢啊,但是除了判断NULL之外其他的都比数组好写啊。

对于Splay实际上真的是很容易的平衡树了,首先要弄明白的就是两种旋转右旋左旋,这个跟AVL一模一样的,然后就是独特的Splay操作,也就是把一个节点旋转到根节点,这个复杂度大约是\(logN\)的因为保持平衡树状态旋转的嘛,然后根据一些别人总结出来的什么规律,这样也不慢,也很快。

对于平衡树来说,无非就是插入删除了。

Splay操作

(1)旋转: 旋转分为左旋或者右旋,跟AVL的是一样的。不过只需要知道他是左孩子还是右孩子直接转就好了,这可比四种AVL旋转算什么平衡因子要好写多了。
(2)Splay: 实际上就是把一个节点通过旋转给他旋转到根上,这里需要注意一下就是要判断是否是爷孙在一条线上,也就是是否三点一线,如果是就旋转爹,否则就旋转自己。

插入

Splay插入跟AVL一模一样,但是不要忘记最后把插入的节点Splay到根节点上去

删除

Splay删除就比较麻烦了,首先是找到要删除的节点,把他Splay到根上去,然后进行判断。
(1) 如果当前节点还有剩余的个数,也就是有重复值,重复值--,不变。
(2) 如果说这个节点的左孩子或者右孩子是NULL的话,那么直接删除即可,唯一的孩子继承成为根。
(3) 如果左右孩子都存在的话,那么先把根删除,现在就是两颗Splay树了,接着再把左子树的最右节点Splay到左子树根上,这样这个根节点的右孩子一定是NULL然后用它去连接右子树。

查询一个数的排名

BST操作,往右就加上左子树的size,最后+1

查询rank为k的数

BST操作,拿着去跟左子树的size以及节点的cnt比较,然后走着判断就行

查询前驱

因为这个前驱可能不在树里面,先把这个数插入到Splay中,然后把这个节点Splay到根节点,接着去找左子树最右节点最后对这个根节点进行一次del操作

查询后继

因为这个前驱可能不在树里面,先把这个数插入到Splay中,然后把这个节点Splay到根节点,接着去找右子树最左节点最后对这个根节点进行一次del操作

最终代码(手撸的第一个版本,比较粗糙,之后会重新整合一下)

#include <bits/stdc++.h>
using namespace std;
struct Splay
{
	Splay *l=NULL,*r=NULL,*parent=NULL;
	int sz=0,cnt=0,num=0;
	Splay(int n)
	{
		num=n;
		cnt=1;
	}
}*root=NULL;
void up(Splay *p)
{
	p->sz=p->cnt;
	if(p->l!=NULL) p->sz+=p->l->sz;
	if(p->r!=NULL) p->sz+=p->r->sz;
}
int get(Splay *p)
{
	return p->parent==NULL?-1:p->parent->r==p;
}
void zig(Splay *p)
{
	int now=get(p->parent);
	p->parent->l=p->r;
	if(p->r!=NULL) p->r->parent=p->parent;
	p->r=p->parent;
	if(p->parent->parent!=NULL) 
	{
		p->parent=p->parent->parent;
		if(!now) p->parent->l=p;
		else p->parent->r=p;
	}
	else p->parent=NULL;
	p->r->parent=p;
	up(p->r);up(p);
}
void zag(Splay *p)
{
	int now=get(p->parent);
	p->parent->r=p->l;
	if(p->l!=NULL) p->l->parent=p->parent;
	p->l=p->parent;
	if(p->parent->parent!=NULL) 
	{
		p->parent=p->parent->parent;
		if(!now) p->parent->l=p;
		else p->parent->r=p;
	}
	else p->parent=NULL;
	p->l->parent=p;
	up(p->l);up(p); 
}
void rotate(Splay *p)
{
	if(!get(p)) zig(p);
	else zag(p);
}
Splay *splay(Splay* p)
{
	while(p->parent!=NULL) rotate(get(p)==get(p->parent)?p->parent:p);
	return p;
}
Splay *ins(int num)
{
	Splay *parent=NULL;
	Splay *p=root;
	while(p!=NULL)
	{
		parent=p;
		if(p->num==num) 
		{
			p->cnt++;
			up(p);
			if(p->parent!=NULL) up(p->parent);
			return splay(p);
		}
		else if(p->num>num) p=p->l;
		else p=p->r;	
	}
	p=new Splay(num);
	p->parent=parent;
	if(p->parent!=NULL)
	{
		if(p->num>p->parent->num) p->parent->r=p;
		else p->parent->l=p;
	}
	up(p);
	if(p->parent!=NULL) up(p->parent);
	return splay(p);	
}
int find(int x)
{
	Splay *p=root;
	int ans=0;
	while(p!=NULL)
	{
		if(p->num==x) 
		{
			if(p->l!=NULL)ans+=p->l->sz;
			break;
		}
		else if(p->num>x) p=p->l;
		else ans+=(p->l==NULL?0:p->l->sz)+p->cnt,p=p->r;
	}
	return ans+1;
}
Splay *del(int x)
{
	Splay *p=root;
	int f=0;
	while(1)
	{
		if(p->num==x) p=splay(p),f=1;
		else if(p->num>x) p=p->l;
		else p=p->r;
		if(f) break;
	}
	if(p->cnt>1) 
	{
		p->cnt--;up(p);return p;
	}
	Splay *now=p;
	if(p->l==NULL) 
	{
		if(p->r==NULL) return NULL;
		else p=p->r,p->parent=NULL,delete(now);return p;	
	} 
	if(p->r==NULL)
	{
		if(p->l==NULL) return NULL;
		else p=p->l,p->parent=NULL,delete(now);return p;	
	}
	Splay *r=p->r;
	p=p->l;delete(now);
	while(p->r!=NULL) p=p->r;
	p=splay(p);
	p->r=r;
	p->r->parent=p;
	up(p);
	return p;
}
int srank(int x)
{
	Splay *p=root;
	while(1)
	{
		int ls=0; 
		if(p->l!=NULL) ls+=p->l->sz;
		if(x>ls+p->cnt) x-=ls+p->cnt,p=p->r;
		else if(x>ls&&x<=p->cnt+ls) return p->num;
		else p=p->l;
	} 
}
int pre(int x)
{
	root=ins(x);
	Splay *p=root->l;
	while(p->r!=NULL) p=p->r;
	root=del(x);
	return p->num;
}
int after(int x)
{
	root=ins(x);
	Splay *p=root->r;
	while(p->l!=NULL) p=p->l;
	root=del(x);
	return p->num;
}
void solve()
{
	int op,x;
	scanf("%d%d",&op,&x);
	if(op==1) root=ins(x);
	if(op==2) root=del(x);
	if(op==3) printf("%d\n",find(x));
	if(op==4) printf("%d\n",srank(x));
	if(op==5) printf("%d\n",pre(x));
	if(op==6) printf("%d\n",after(x));
}
int main()
{
	int T;
	scanf("%d",&T);
	while(T--) solve();
}
posted @ 2021-07-15 13:35  baccano!  阅读(183)  评论(0编辑  收藏  举报