[BZOJ3224]普通平衡树
题目描述 Description
|
您需要写一种数据结构(可参考题目标题),来维护一些数,其中需要提供以下操作:
1. 插入x数 2. 删除x数(若有多个相同的数,因只删除一个) 3. 查询x数的排名(若有多个相同的数,因输出最小的排名) 4. 查询排名为x的数 5. 求x的前驱(前驱定义为小于x,且最大的数) 6. 求x的后继(后继定义为大于x,且最小的数) |
输入描述 Input Description
|
第一行为n,表示操作的个数,下面n行每行有两个数opt和x,opt表示操作的序号(1<=opt<=6) |
输出描述 Output Description
|
对于操作3,4,5,6每行输出一个数,表示对应答案
|
样例输入 Sample Input
|
10
1 106465 4 1 1 317721 1 460929 1 644985 1 84185 1 89851 6 81968 1 492737 5 493598 |
样例输出 Sample Output
|
106465
84185 492737 |
数据范围及提示 Data Size & Hint
|
1.n的数据范围:n<=100000
2.每个数的数据范围:[-2e9,2e9]
|
最裸的平衡树题。用treap来实现,由于是可重复的,我这里默认如果键值相同就往左子树插,然后一些经典操作例如什么找前驱后继,排名什么的上面都写了,希望当作以后的模板。
#include<iostream> #include<algorithm> #include<cstdio> #include<queue> #include<cmath> #include<cstring> #include<ctime> using namespace std; typedef long long LL; #define mem(a,b) memset(a,b,sizeof(a)) inline int read() { int x=0,f=1;char c=getchar(); while(!isdigit(c)){if(c=='-')f=-1;c=getchar();} while(isdigit(c)){x=x*10+c-'0';c=getchar();} return x*f; } const int maxn=100010,oo=2147483647; struct node { int val,rnd; int cmp(int v)const { if(val==v)return -1; return v<val ? 0 : 1; } }ns[maxn]; int tot,rt,ch[2][maxn],size[maxn]; int New(int v) { int o=++tot;ns[o].val=v;ns[o].rnd=rand(); ch[0][o]=ch[1][o]=0;return o; } void del(int o){ns[o].val=ns[o].rnd=ch[0][o]=ch[1][o]=size[o]=0;o=0;} void maintain(int o) { if(!o)return; size[o]=1; if(ch[0][o])size[o]+=size[ch[0][o]]; if(ch[1][o])size[o]+=size[ch[1][o]]; } void rotate(int &o,int d) { int k=ch[d^1][o];ch[d^1][o]=ch[d][k];ch[d][k]=o; maintain(o);o=k;return maintain(o); } void insert(int &o,int v) { if(!o){o=New(v);return maintain(o);} int d= v<=ns[o].val ? 0 : 1; insert(ch[d][o],v); if(ns[ch[d][o]].rnd>ns[o].rnd)rotate(o,d^1); return maintain(o); } void remove(int &o,int v) { if(!o)return; int d=ns[o].cmp(v); if(d==-1) { int t=o; if(ch[0][o] && ch[1][o]) { int d2= ns[ch[0][o]].rnd>ns[ch[1][o]].rnd ? 1 : 0; rotate(o,d2);remove(ch[d2][o],v); } else if(!ch[0][o])o=ch[1][o]; else o=ch[0][o]; del(t); } else remove(ch[d][o],v); return maintain(o); } int kth(int o,int k) { if(!o || k<0 || k>size[o])return 0; int s=size[ch[0][o]]; if(k==s+1)return ns[o].val; else if(k<=s)return kth(ch[0][o],k); else return kth(ch[1][o],k-s-1); } int rank(int o,int v) { if(!o)return 0; if(v>ns[o].val)return size[ch[0][o]]+1+rank(ch[1][o],v); return rank(ch[0][o],v); } int lower(int o,int v) { if(!o)return -oo; if(ns[o].val<v)return max(ns[o].val,lower(ch[1][o],v)); else return lower(ch[0][o],v); } int upper(int o,int v) { if(!o)return oo; if(ns[o].val>v)return min(ns[o].val,upper(ch[0][o],v)); else return upper(ch[1][o],v); } int n,tp,x; int main() { n=read(); while(n--) { tp=read();x=read(); if(tp==1)insert(rt,x); else if(tp==2)remove(rt,x); else if(tp==3)printf("%d\n",rank(rt,x)+1); else if(tp==4)printf("%d\n",kth(rt,x)); else if(tp==5)printf("%d\n",lower(rt,x)); else if(tp==6)printf("%d\n",upper(rt,x)); } return 0; }