芝士:splay

背景

因为BST本身存在一定的缺陷,

还有毒瘤出题人故意卡

导致BST极容易退化成一条单链,时间复杂度从优秀的\(O(log_n)\)到了\(O(n)\)

各种各样的巨佬就开始优化BST

所以才产生了我们所熟悉(心态爆炸)的各式各样的数据结构

巨佬Sleator和Tarjan就是其中的两位人,用旋转的方案解决了这个缺陷

主要思想

对于一个数列,我们由此生成的BST不止一种,

所以每一种的BST的高度都不一样,导致效率不一样

但是我们可以通过旋转操作使得整个树自平衡,

但是又不改变它本身BST的性质

操作

旋转

思路

我们要做的就是将当前节点的深度-1,只改变当前节点和当前节点的父亲的位置关系,但是又不改变整棵树的BST的性质,

也就是说,我们现在有一棵树,当前节点为x,现在我们想让x的深度-1

成为这样一颗树

一共有四种情况,笔者将其中的一种拿出来,另外3中也都是大同小异,读者可以自己手玩

代码

void rotate(int x)
{
	int y=tre[x].fa;
	int z=tre[y].fa;
	int k=tre[y].ch[1]==x;
	tre[z].ch[tre[z].ch[1]==y]=x;
	tre[x].fa=z;
	tre[y].ch[k]=tre[x].ch[k^1];
	tre[tre[x].ch[k^1]].fa=y;
	tre[x].ch[k^1]=y;
	tre[y].fa=x;
	push_up(y);
	push_up(x);
}

splay

思路

我们发现旋转操作实际上并不能使整棵树自平衡,

而实际上splay的自平衡不是来源于单纯的旋转

而是splay

我们如果我们将当前节点的祖父(父亲的父亲)节点考虑进来,

我们就会发现只有两种情况,

第一种是当前节点两次都旋转一样的方向

第二种是当前节点两次旋转的方向不一样

这里的方向定义为他是父亲节点的哪一个儿子

巨佬Sleator和Tarjan告诉我们要想自平衡,必须要遵守一个策略

如果是情况1,两次旋转都转他自己

如果是情况2,就先转他的父亲,再转他自己

通过势函数中所说的@#$!@#

可以明白splay的均摊时间复杂度为\(O(log_n)\)

在每次操作后都要用splay,因为splay有着上传的作用,并且势函数的时间复杂度的分析也是基于此的

代码

void splay(int x,int goal)
{
	while(tre[x].fa!=goal)
	{
		int y=tre[x].fa;
		int z=tre[y].fa;
		if(z!=goal)
		{
			if((tre[z].ch[0]==y)^(tre[y].ch[0]==x))
				rotate(x);
			else
				rotate(y);
		}
		rotate(x);
	}
	if(goal==0)
		rt=x;
	push_up(x);
}

将某个节点提到根节点

思路

直接在BST中找到这个点

再用splay转上去就好了

代码

void prepare(int x)
{
	int u=rt;
	if(!u)
		return;
	while(x!=tre[u].val&&tre[u].ch[x>tre[u].val])
	{
		if(x>tre[u].val)
		{
			u=tre[u].ch[1];
		}
		else
		{
			u=tre[u].ch[0];
		}
	}
	splay(u,0);
}

初始化

思路

因为蒟蒻笔者写的操作很多都是基于前驱与后继写的

所以最开始先插入极小值和极大值会好写很多

代码

void init()
	{
		rt=1;
		cnt=2;
		tre[0].fa=0;
		tre[0].ch[0]=tre[0].ch[1]=0;
		tre[0].tot=0;
		tre[0].val=0;
		tre[0].siz=0;

		tre[1].fa=0;
		tre[1].ch[0]=0;
		tre[1].ch[1]=2;
		tre[1].tot=1;
		tre[1].val=INT_MIN;
		tre[1].siz=2;

		tre[2].fa=1;
		tre[2].ch[0]=0;
		tre[2].ch[1]=0;
		tre[2].tot=1;
		tre[2].val=INT_MAX;
		tre[2].siz=1;
	}

前驱&后继

思路

暴力,反正时间复杂度也是\(O(log_n)\)

代码

int next(int x,int f)
{
	prepare(x);
	int u=rt;
	if(tre[u].val>x&&f)
		return u;
	if(tre[u].val<x&&!f)
		return u;
	u=tre[u].ch[f];
	while(tre[u].ch[!f])
		u=tre[u].ch[!f];
	return u;
}

插入

思路

通过旋转插入这个点的前驱和后继,

将插入的点的位置确定下来

代码

void insert(int x)
{
	int p=next(x,0);
	int s=next(x,1);
	splay(p,0);
	splay(s,p);
	if(tre[s].ch[0]==0)
	{
		tre[s].ch[0]=newnode(x,s);
		splay(cnt,0);
	}
	else
	{
		tre[tre[s].ch[0]].tot++;
		tre[tre[s].ch[0]].siz++;
		splay(tre[s].ch[0],0);	
	}
}

删除

思路

同插入一样的思路

代码

void delet(int x)
{
	int p=next(x,0);
	int s=next(x,1);
	splay(s,0);
	splay(p,s);
	if(tre[tre[p].ch[1]].tot>1)
	{
		tre[tre[p].ch[1]].tot--;
		tre[tre[p].ch[1]].siz--;
		splay(tre[p].ch[1],0);
	}
	else
	{
		tre[p].ch[1]=0;
		splay(p,0);
	}
}

排名

思路

将要询问的点转到根节点

之后查询儿子的大小就行了

代码

int rank(int x)
{
	prepare(x);
	return tre[tre[rt].ch[0]].siz;
}

第k大

思路

暴力

代码

int kth(int x)
{
	x++;
	int u=rt;
	while(1)
	{
		if(tre[tre[u].ch[0]].siz<x&&x<=tre[tre[u].ch[0]].siz+tre[u].tot)
		{
			return tre[u].val;
		}
		if(tre[tre[u].ch[0]].siz+tre[u].tot<x)
		{
			x=x-tre[tre[u].ch[0]].siz-tre[u].tot;
			u=tre[u].ch[1];
		}
		else
			u=tre[u].ch[0];
	}
}

板子

这里以洛谷P3369为例

#include<iostream>
#include<climits>
using namespace std;
struct Splay
{	
	#define MAXN 100005
	struct node
	{
		int fa;
		int ch[2];
		int tot;//相同键值的数的个数
		int val;//键值
		int siz;//个数
	}tre[MAXN];
	int rt;
	int cnt;
	#undef MAXN
	void init()
	{
		rt=1;
		cnt=2;
		tre[0].fa=0;
		tre[0].ch[0]=tre[0].ch[1]=0;
		tre[0].tot=0;
		tre[0].val=0;
		tre[0].siz=0;

		tre[1].fa=0;
		tre[1].ch[0]=0;
		tre[1].ch[1]=2;
		tre[1].tot=1;
		tre[1].val=INT_MIN;
		tre[1].siz=2;

		tre[2].fa=1;
		tre[2].ch[0]=0;
		tre[2].ch[1]=0;
		tre[2].tot=1;
		tre[2].val=INT_MAX;
		tre[2].siz=1;
	}
	int newnode(int x,int fa)
	{
		++cnt;
		tre[cnt].fa=fa;
		tre[cnt].ch[0]=tre[cnt].ch[1]=0;
		tre[cnt].tot=1;
		tre[cnt].val=x;
		tre[cnt].siz=1;
		return cnt;
	}
	void push_up(int k)
	{
		tre[k].siz=tre[tre[k].ch[0]].siz+tre[tre[k].ch[1]].siz+tre[k].tot;
	}
	void rotate(int x)
	{
		int y=tre[x].fa;
		int z=tre[y].fa;
		int k=tre[y].ch[1]==x;
		tre[z].ch[tre[z].ch[1]==y]=x;
		tre[x].fa=z;
		tre[y].ch[k]=tre[x].ch[k^1];
		tre[tre[x].ch[k^1]].fa=y;
		tre[x].ch[k^1]=y;
		tre[y].fa=x;
		push_up(y);
		push_up(x);
	}
	void splay(int x,int goal)
	{
		while(tre[x].fa!=goal)
		{
			int y=tre[x].fa;
			int z=tre[y].fa;
			if(z!=goal)
			{
				if((tre[z].ch[0]==y)^(tre[y].ch[0]==x))
					rotate(x);
				else
					rotate(y);
			}
			rotate(x);
		}
		if(goal==0)
			rt=x;
		push_up(x);
	}
	void prepare(int x)
	{
		int u=rt;
		if(!u)
			return;
		while(x!=tre[u].val&&tre[u].ch[x>tre[u].val])
		{
			if(x>tre[u].val)
			{
				u=tre[u].ch[1];
			}
			else
			{
				u=tre[u].ch[0];
			}
		}
		splay(u,0);
	}
	int next(int x,int f)
	{
		prepare(x);
		int u=rt;
		if(tre[u].val>x&&f)
			return u;
		if(tre[u].val<x&&!f)
			return u;
		u=tre[u].ch[f];
		while(tre[u].ch[!f])
			u=tre[u].ch[!f];
		return u;
	}
	void insert(int x)
	{
		int p=next(x,0);
		int s=next(x,1);
		splay(p,0);
		splay(s,p);
		if(tre[s].ch[0]==0)
		{
			tre[s].ch[0]=newnode(x,s);
			splay(cnt,0);
		}
		else
		{
			tre[tre[s].ch[0]].tot++;
			tre[tre[s].ch[0]].siz++;
			splay(tre[s].ch[0],0);	
		}
	}
	void delet(int x)
	{
		int p=next(x,0);
		int s=next(x,1);
		splay(s,0);
		splay(p,s);
		if(tre[tre[p].ch[1]].tot>1)
		{
			tre[tre[p].ch[1]].tot--;
			tre[tre[p].ch[1]].siz--;
			splay(tre[p].ch[1],0);
		}
		else
		{
			tre[p].ch[1]=0;
			splay(p,0);
		}
	}
	int rank(int x)
	{
		prepare(x);
		return tre[tre[rt].ch[0]].siz;
	}
	int kth(int x)
	{
		x++;
		int u=rt;
		while(1)
		{
			if(tre[tre[u].ch[0]].siz<x&&x<=tre[tre[u].ch[0]].siz+tre[u].tot)
			{
				return tre[u].val;
			}
			if(tre[tre[u].ch[0]].siz+tre[u].tot<x)
			{
				x=x-tre[tre[u].ch[0]].siz-tre[u].tot;
				u=tre[u].ch[1];
			}
			else
				u=tre[u].ch[0];
		}
	}
}tre;
int n;
int opt,x;
int main()
{
	ios::sync_with_stdio(false);
	tre.init();
	cin>>n;
	for(int i=1;i<=n;i++)
	{
		cin>>opt>>x;
		if(opt==1)
		{
			tre.insert(x);
		}
		if(opt==2)
		{
			tre.delet(x);
		}
		if(opt==3)
		{
			cout<<tre.rank(x)<<'\n';
		}
		if(opt==4)
		{
			cout<<tre.kth(x)<<'\n';	
		}
		if(opt==5)
		{
			cout<<tre.tre[tre.next(x,0)].val<<'\n';	
		}
		if(opt==6)
		{
			cout<<tre.tre[tre.next(x,1)].val<<'\n';
		}
	}
	return 0;
}
posted @ 2020-01-09 15:53  loney_s  阅读(229)  评论(0)    收藏  举报