BZOJ3224 洛谷3369 Tyvj 1728 普通平衡树 splay

欢迎访问~原文出处——博客园-zhouzhendong

去博客园看该题解


题目传送门 - BZOJ3224


题意概括

  您需要写一种数据结构(可参考题目标题),来维护一些数,其中需要提供以下操作:
1. 插入x数
2. 删除x数(若有多个相同的数,因只删除一个)
3. 查询x数的排名(若有多个相同的数,因输出最小的排名)
4. 查询排名为x的数
5. 求x的前驱(前驱定义为小于x,且最大的数)
6. 求x的后继(后继定义为大于x,且最小的数)


题解

  splay模板题。

  


代码

 UPD(2018-03-20):省选一试前夜,抱佛脚打板子,发现之前的代码真丑。注意,在splay中预先放一个非常大的数可以有效的排除平衡树空的特殊情况。我在细节处理上也增加了大量改进,代码量爆缩,运行效率略高了一些,老的那份代码在后面。

#include <bits/stdc++.h>
using namespace std;
const int N=100005;
int n,root=1,size=1,val[N],cnt[N],son[N][2],fa[N],tot[N];
int wson(int x){
	return son[fa[x]][1]==x;
}
void pushup(int x){
	tot[x]=cnt[x]+tot[son[x][0]]+tot[son[x][1]];
}
void rotate(int x){
	if (!x)
		return;
	int y=fa[x],z=fa[y],L=wson(x),R=L^1;
	if (z)
		son[z][wson(y)]=x;
	fa[x]=z,fa[y]=x,fa[son[x][R]]=y;
	son[y][L]=son[x][R],son[x][R]=y;
	pushup(y),pushup(x);
}
void splay(int x,int k){
	if (!x)
		return;
	if (!k)
		root=x;
	for (int y=fa[x];fa[x]!=k;rotate(x),y=fa[x])
		if (fa[y]!=k)
			rotate(wson(x)==wson(y)?y:x);
}
int find(int x,int v){
	return val[x]==v?x:find(son[x][v>val[x]],v);
}
int findkth(int x,int k){
	if (k<=tot[son[x][0]])
		return findkth(son[x][0],k);
	k-=tot[son[x][0]];
	if (k<=cnt[x])
		return x;
	k-=cnt[x];
	return findkth(son[x][1],k);
}
int findnxt(int x,int v){
	if (!x)
		return 0;
	if (val[x]<=v)
		return findnxt(son[x][1],v);
	else {
		int res=findnxt(son[x][0],v);
		return res?res:x;
	}
}
int findpre(int x,int v){
	if (!x)
		return 0;
	if (val[x]>=v)
		return findpre(son[x][0],v);
	else {
		int res=findpre(son[x][1],v);
		return res?res:x;
	}
}
void insert(int &x,int pre,int v){
	if (!x){
		x=++size;
		val[x]=v,cnt[x]=tot[x]=1,fa[x]=pre;
		splay(x,0);
		return;
	}
	tot[x]++;
	if (val[x]==v){
		cnt[x]++;
		return;
	}
	insert(son[x][v>val[x]],x,v);
}
void Insert(int v){insert(root,0,v);}
void Delete(int v){
	int x;
	splay(x=find(root,v),0);
	if (--cnt[x])
		return;
	splay(findnxt(root,v),root);
	root=son[x][1];
	son[root][0]=son[x][0];
	fa[son[x][0]]=root;
	fa[root]=son[x][0]=son[x][1]=0;
	pushup(root);
}
int Rank(int v){
	splay(find(root,v),0);
	return tot[son[root][0]]+1;
}
int main(){
	val[1]=2147483647;
	cnt[1]=tot[1]=1;
	scanf("%d",&n);
	while (n--){
		int opt,x;
		scanf("%d%d",&opt,&x);
		if (opt==1) Insert(x);
		if (opt==2) Delete(x);
		if (opt==3) printf("%d\n",Rank(x));
		if (opt==4) printf("%d\n",val[findkth(root,x)]);
		if (opt==5) printf("%d\n",val[findpre(root,x)]);
		if (opt==6) printf("%d\n",val[findnxt(root,x)]);
	}
	return 0;
}

  

 

#include <cstring>
#include <algorithm>
#include <cstdio>
#include <cmath>
#include <cstdlib>
using namespace std;
const int N=100005;
struct Splay{
	int fa[N],lc[N],rc[N],size[N],cnt[N],val[N],root,tot;
	void clear(){//splay³õʼ»¯ 
		root=tot=0,memset(size,0,sizeof size),memset(cnt,0,sizeof cnt);
	}
	void update(int x){//¼ÆËã(¸üÐÂ)ijһ¸ö½ÚµãµÄsize 
		if (x){
			size[x]=cnt[x];
			size[x]+=lc[x]?size[lc[x]]:0;
			size[x]+=rc[x]?size[rc[x]]:0;
		}
	}
	void zig(int x){//ÓÒÐý 
		if (!fa[x])
			return;
		int y=fa[x],z=fa[y];
		if (z)
			if (lc[z]==y)
				lc[z]=x;
			else
				rc[z]=x;
		fa[x]=z;
		fa[y]=x;
		fa[rc[x]]=y;
		lc[y]=rc[x];
		rc[x]=y;
		update(y);
		update(x);
	}
	void zag(int x){//×óÐý 
		if (!fa[x])
			return;
		int y=fa[x],z=fa[y];
		if (z)
			if (lc[z]==y)
				lc[z]=x;
			else
				rc[z]=x;
		fa[x]=z;
		fa[y]=x;
		fa[lc[x]]=y;
		rc[y]=lc[x];
		lc[x]=y;
		update(y);
		update(x);
	}
	void splay(int rt,int x){//splayÖ÷¹ý³Ì(ÔÚÒÔrtΪ¸ùµÄ×ÓÊ÷ÖÐsplay£¬×îÖÕx³ÉΪµ±Ç°×ÓÊ÷µÄÐÂ×æ×Ú)
		rt=fa[rt];
		int y,z;
		while (fa[x]!=rt){
			y=fa[x];
			z=fa[y];
			if (z&&z!=rt){
				if (lc[z]==y&&lc[y]==x)
					zig(y),zig(x);
				else if (lc[z]==y&&rc[y]==x)
					zag(x),zig(x);
				else if (rc[z]==y&&lc[y]==x)
					zig(x),zag(x);
				else
					zag(y),zag(x);
			}
			else if (lc[y]==x)
				zig(x);
			else
				zag(x);
		}
		if (!fa[x])
			root=x;
	}
	void splay(int x){
		splay(root,x);
	}
	void ins(int &k,int key,int p){//ÔÚkÕâ¸ö×ÓÊ÷ÖвåÈëֵΪkeyµÄÊý£¬µ±Ç°½ÚµãµÄ¸¸Ç×Ϊp 
		if (!k){
			size[k=++tot]=cnt[k]=1;
			lc[k]=rc[k]=0;
			val[k]=key;
			fa[k]=p;
			splay(k);
			return;
		}
		size[k]++;
		if (key==val[k]){
			cnt[k]++;
			return;
		}
		if (key<val[k])
			ins(lc[k],key,k);
		else
			ins(rc[k],key,k);
	}
	void ins(int key){//ÔÚÕû¸ösplayÀïÃæ²åÈëָΪkeyµÄÊý 
		if (!root){
			size[root=++tot]=cnt[root]=1;
			lc[root]=rc[root]=0;
			val[root]=key;
			fa[root]=0;
			return;
		}
		ins(root,key,0);
	}
	int find(int k,int key){//ÔÚÒÔkΪ¸ùµÄsplayÀïÃæ²éÕÒֵΪkeyµÄÊýµÄ´æ´¢±àºÅ  
		int tmp=k;
		while (k)
			if (val[k]==key)
				break;
			else
				k=key<val[k]?lc[k]:rc[k];
		if (k)
			splay(k);
		return k;
	}
	int find(int key){//splayÀïÃæ²éÕÒֵΪkeyµÄÊýµÄ´æ´¢±àºÅ  
		return find(root,key);
	}
	void del(int key){//ɾ³ýֵΪkeyµÄÊý£¨µ«ÊÇtot²»¼õ£¬Òâζ×ÅÓÐ1¸ö¿Õ¼ä»áÀË·Ñ£© 
		int x=find(key),ls=lc[x],rs=rc[x];
		if (!x)
			return;
		if (--cnt[x])
			return;
		if (!ls&&!rs){
			clear();
			return;
		}
		if (!ls)
			root=rs,fa[rs]=0;
		else if (!rs)
			root=ls,fa[ls]=0;
		else {
			int lson=getmax(ls);
			swap(lson,ls);
			fa[lson]=0;
			splay(ls);
			rc[ls]=rs;
			fa[rs]=ls;
			update(ls);
		}
	}
	int getmin(int k){//ÔÚ¸ùΪkµÄsplayÖÐÕÒ×îСֵµÄ´æ´¢±àºÅ 
		return lc[k]?getmin(lc[k]):k;
	}
	int getmax(int k){//ÔÚ¸ùΪkµÄsplayÖÐÕÒ×î´óÖµµÄ´æ´¢±àºÅ 
		return rc[k]?getmax(rc[k]):k;
	}
	int findkth(int k){//Ñ°ÕÒµÚkСµÄ 
		int t=root;
		while (t){
			if (size[lc[t]]<k&&k<=size[lc[t]]+cnt[t])
				break;
			if (k<=size[lc[t]])
				t=lc[t];
			else
				k-=size[lc[t]]+cnt[t],t=rc[t];
		}
		splay(t);
		return t;
	}
	int findpre(int k,int key){//ÔÚ¸ùΪkµÄsplayÖÐÕÒkeyµÄÇ°Çý£¨¼´Öµ<keyµÄ×î´óÖµ£©µÄ´æ´¢±àºÅ 
		if (!k)
			return 0;
		if (key<=val[k])
			return findpre(lc[k],key);
		else {
			int tmp=findpre(rc[k],key);
			return tmp?tmp:k;
		}
	}
	int findpre(int key){return findpre(root,key);}//ÕÒkeyµÄÇ°Çý£¨¼´Öµ<keyµÄ×î´óÖµ£©µÄ´æ´¢±àºÅ 
	int findnxt(int k,int key){//ÔÚ¸ùΪkµÄsplayÖÐÕÒkeyµÄºó¼Ì£¨¼´Öµ>keyµÄ×îСֵ£©µÄ´æ´¢±àºÅ 
		if (!k)
			return 0;
		if (key>=val[k])
			return findnxt(rc[k],key);
		else {
			int tmp=findnxt(lc[k],key);
			return tmp?tmp:k;
		}
	}
	int findnxt(int key){return findnxt(root,key);}//ÕÒkeyµÄºó¼Ì£¨¼´Öµ>keyµÄ×îСֵ£©µÄ´æ´¢±àºÅ 
}s;
int n,opt,x;
int main(){
	scanf("%d",&n);
	s.clear();
	for (int i=1;i<=n;i++){
		scanf("%d%d",&opt,&x);
		if (opt==1)	s.ins(x);
		if (opt==2)	s.del(x);
		if (opt==3){//ÇóֵΪxµÄÃû´Î 
			s.splay(s.find(x));
			printf("%d\n",s.size[s.lc[s.root]]+1);
		}
		if (opt==4)	printf("%d\n",s.val[s.findkth(x)]);
		if (opt==5)	printf("%d\n",s.val[s.findpre(x)]);
		if (opt==6)	printf("%d\n",s.val[s.findnxt(x)]);
	}
	return 0;
}

 

  

 

posted @ 2017-11-21 23:20  zzd233  阅读(408)  评论(0编辑  收藏  举报