关于Treap平衡树的一点总结

算法原理

Treap

一种好用的数据结构,支持插入(\(insert\)),删除(\(remove\)),查前驱(\(pre\))后继(\(suf\)),查树的排名(\(get rank by val\)),据排名查数(\(getvalbyrank\))。

前置知识:BST二叉查找树

二叉树,点带权,左子树上的点都比根小,右子树上的点都比根大。

但是,如果插入的是一个单调的序列,每次对树进行检索操作都是\(O(n)\) ,总复杂度变成\(O(n^2)\) ,不能接受。

随机权值&左旋右旋

这是Treap的核心操作。通过对原\(BST\)树进行旋转操作,使树的高度减小,形状更平衡

那么,什么样的操作才是合理的旋转操作呢?由于普通的\(BST\)在随机的数据下是趋近平衡的,所以我们给每个点一个随机权值,满足大根堆性质,尽量平衡数。

右旋zig

\(p\)的左子节点绕着\(p\)向右旋转。

inline void zig(int &p){
  int q=t[p].l; //q为原左儿子
  t[p].l=t[q].r;//左儿子的右儿子变成左儿子
  t[q].r=p;
  p=q;
  return;
}

左旋同理。

代码详解

Build

初始状态设为一个\(INF\) ,一个\(-INF\) ,防止溢出。

根节点编号设为\(1\) ,初始权值为\(INF\)

记得先新建\(-INF\) ,因为我是初始化根和根的右节点

我就是因为忘了\(build\)才多调了半小时

inline void build(){
	New(-INF),New(INF);
	rt=1,t[1].r=2;
	update(rt);
	return;
}

Get Rank By Value

inline int getrank(int p,int val){
	if(!p) return 0;
	if(val==t[p].val) return t[t[p].l].sz+1;//找到了
	if(val<t[p].val) return getrank(t[p].l,val);//往左子树找
	if(val>t[p].val) return getrank(t[p].r,val)+t[p].cnt+t[t[p].l].sz;// 往右子树找
}

Get Value By Rank

inline int getval(int p,int rank){
	if(!p) return INF;
	if(t[t[p].l].sz>=rank) return getval(t[p].l,rank);// 左子树的大小已经大于等于了,那这个排名的数肯定在左子树
	if(t[t[p].l].sz+t[p].cnt>=rank) return t[p].val;
  // 左子树的大小+父节点的大小才大于等于这个排名,那就是它了
	return getval(t[p].r,rank-t[t[p].l].sz-t[p].cnt);
  // 如果还小了,那就是在右子树上找排名为(rank-左子树大小-父节点大小)的数
}

Insert

如果以前没有这个点-->新建

如果以前有这个值了-->\(cnt++\)


inline void insert(int &p,int val){
	if(!p){
		p=New(val);
		return;
	}
	if(val==t[p].val){
		t[p].cnt++;
		update(p);
		return;
	}
	if(val<t[p].val){
		insert(t[p].l,val);
		if(t[p].data<t[t[p].l].data) zig(p);
	}
  // 找值,找对地方后进行旋转
	if(val>t[p].val){
		insert(t[p].r,val);
		if(t[p].data<t[t[p].r].data) zag(p);
	}
	update(p);
	return;
}

Remove

inline void remove(int &p,int val){
	if(!p) return;
	if(val==t[p].val){
		if(t[p].cnt>1){
			t[p].cnt--,update(p);
			return;
          //找到了
		}
		if(t[p].l||t[p].r){
			if(!t[p].r||t[t[p].l].data>t[t[p].r].data) zig(p),remove(t[p].r,val);
			else zag(p),remove(t[p].l,val);
			update(p);
		}
		else p=0;
		return;
	}
	val<t[p].val?remove(t[p].l,val):remove(t[p].r,val);
  //递归查找要删的数
	update(p);
	return;
}

Pre&Suf

\(pre\):左子树上一直往右找

\(suf\):右子树上一直往左找

注意边界条件。

Code

#include<bits/stdc++.h>
#define N (400010)
#define INF (998244353)
using namespace std;
struct xbk{
	int l,r,data,val,cnt,sz;
}t[N];
int n,tot,rt;
inline int read(){
	int w=0;
	bool f=0;
	char ch=getchar();
	while(ch>'9'||ch<'0'){
		if(ch=='-') f=1;
		ch=getchar();
	}
	while(ch>='0'&&ch<='9'){
		w=(w<<3)+(w<<1)+(ch^48);
		ch=getchar();
	}
	return f?-w:w;
}
inline int New(int val){
	t[++tot].val=val;
	t[tot].data=rand();
	t[tot].sz=t[tot].cnt=1;
	return tot;
}
inline void update(int p){
	t[p].sz=t[t[p].l].sz+t[t[p].r].sz+t[p].cnt;
}
inline void build(){
	New(-INF),New(INF);
	rt=1,t[1].r=2;
	update(rt);
	return;
}
inline int getrank(int p,int val){
	if(!p) return 0;
	if(val==t[p].val) return t[t[p].l].sz+1;
	if(val<t[p].val) return getrank(t[p].l,val);
	if(val>t[p].val) return getrank(t[p].r,val)+t[p].cnt+t[t[p].l].sz;
}
inline int getval(int p,int rank){
	if(!p) return INF;
	if(t[t[p].l].sz>=rank) return getval(t[p].l,rank);
	if(t[t[p].l].sz+t[p].cnt>=rank) return t[p].val;
	return getval(t[p].r,rank-t[t[p].l].sz-t[p].cnt);
}
inline void zig(int &p){
	int q=t[p].l;
	t[p].l=t[q].r,t[q].r=p,p=q;
	update(p),update(t[p].r);
}
inline void zag(int &p){
	int q=t[p].r;
	t[p].r=t[q].l,t[q].l=p,p=q;
	update(p),update(t[p].l);
}
inline void insert(int &p,int val){
	if(!p){
		p=New(val);
		return;
	}
	if(val==t[p].val){
		t[p].cnt++;
		update(p);
		return;
	}
	if(val<t[p].val){
		insert(t[p].l,val);
		if(t[p].data<t[t[p].l].data) zig(p);
	}
	if(val>t[p].val){
		insert(t[p].r,val);
		if(t[p].data<t[t[p].r].data) zag(p);
	}
	update(p);
	return;
}
inline int getpre(int val){
	int ans=1;
	int p=rt;
	while(p){
		if(val==t[p].val){
			if(t[p].l){
				p=t[p].l;
				while(t[p].r) p=t[p].r;
				ans=p;
			}
			break;
		}
		if(t[p].val<val&&t[p].val>t[ans].val) ans=p;
		p=val<t[p].val?t[p].l:t[p].r;
	}
	return t[ans].val;
}
inline int getsuf(int val){
	int ans=2;
	int p=rt;
	while(p){
		if(val==t[p].val){
			if(t[p].r){
				p=t[p].r;
				while(t[p].l) p=t[p].l;
				ans=p;
			}
			break;
		}
		if(t[p].val>val&&t[p].val<t[ans].val) ans=p;
		p=val<t[p].val?t[p].l:t[p].r;
	}
	return t[ans].val;
}
inline void remove(int &p,int val){
	if(!p) return;
	if(val==t[p].val){
		if(t[p].cnt>1){
			t[p].cnt--,update(p);
			return;
		}
		if(t[p].l||t[p].r){
			if(!t[p].r||t[t[p].l].data>t[t[p].r].data) zig(p),remove(t[p].r,val);
			else zag(p),remove(t[p].l,val);
			update(p);
		}
		else p=0;
		return;
	}
	val<t[p].val?remove(t[p].l,val):remove(t[p].r,val);
	update(p);
	return;
}
int main(){
	build();
	n=read();
	while(n--){
		int opt=read(),val=read();
		if(opt==1) insert(rt,val);
		if(opt==2) remove(rt,val);
		if(opt==3) printf("%d\n",getrank(rt,val)-1);
		if(opt==4) printf("%d\n",getval(rt,val+1));
		if(opt==5) printf("%d\n",getpre(val));
		if(opt==6) printf("%d\n",getsuf(val));
	}
	return 0;
}

完结撒花❀

posted @ 2021-01-30 16:14  xxbbkk  阅读(113)  评论(0编辑  收藏  举报