bzoj 3196: Tyvj 1730 二逼平衡树

Time Limit: 10 Sec  Memory Limit: 128 MB
Submit: 2167  Solved: 907
[Submit][Status][Discuss]

Description

您需要写一种数据结构(可参考题目标题),来维护一个有序数列,其中需要提供以下操作:
1.查询k在区间内的排名
2.查询区间内排名为k的值
3.修改某一位值上的数值
4.查询k在区间内的前驱(前驱定义为小于x,且最大的数)
5.查询k在区间内的后继(后继定义为大于x,且最小的数)

Input

第一行两个数 n,m 表示长度为n的有序序列和m个操作
第二行有n个数,表示有序序列
下面有m行,opt表示操作标号
若opt=1 则为操作1,之后有三个数l,r,k 表示查询k在区间[l,r]的排名
若opt=2 则为操作2,之后有三个数l,r,k 表示查询区间[l,r]内排名为k的数
若opt=3 则为操作3,之后有两个数pos,k 表示将pos位置的数修改为k
若opt=4 则为操作4,之后有三个数l,r,k 表示查询区间[l,r]内k的前驱
若opt=5 则为操作5,之后有三个数l,r,k 表示查询区间[l,r]内k的后继

Output

对于操作1,2,4,5各输出一行,表示查询结果

Sample Input

9 6
4 2 2 1 9 4 0 1 1
2 1 4 3
3 4 10
2 1 4 3
1 2 5 9
4 3 9 5
5 2 8 5

Sample Output

2
4
3
4
9

HINT

1.n和m的数据范围:n,m<=50000

2.序列中每个数的数据范围:[0,1e8]

3.虽然原题没有,但事实上5操作的k可能为负数
 
题解:
  第一次写树套树,写了半天调了半天。。。
  首先,询问操作如果不加区间的话很容易看出来是平衡树,伸展树等数据结构就可以维护的。但是现在有区间,关系到区间就容易想到是线段树。所以就愉快地线段树套splay了。
  树套树的思路还是很好理解的,相当于建一棵线段树,线段树中每个节点的信息用splay来维护,splay是很多棵splay,不同的节点组成不同的splay来维护不同的线段树区间。假设有以root[x]为根的splay来维护[l,r],那么此splay中的节点就是线段树[l,r]里的所有值。
  下面的代码中findkth(),有两个,第一个用/* */括起来的是可能会WA的,得到的答案会比标准答案小一,但是WA的概率相当低,我并没有看出来为什么错,求解释啊。
  1 #include<iostream>
  2 #include<cstdio>
  3 #include<cstdlib>
  4 #include<cstring>
  5 #include<cmath>
  6 #include<algorithm>
  7 #include<queue>
  8 #include<vector>
  9 using namespace std;
 10 const int inf=1e9;
 11 const int maxn=500005,maxm=5000005;
 12 int N,M,ANS,a[maxn];
 13 int Siz,root[maxn],lc[maxm],rc[maxm],key[maxm],siz[maxm],fa[maxm],cnt[maxn];
 14 inline void update(int x){
 15     siz[x]=siz[lc[x]]+siz[rc[x]]+1;
 16 }
 17 inline void r_rotate(int x){
 18     int y=fa[x];
 19     lc[y]=rc[x];
 20     if(rc[x]) fa[rc[x]]=y;
 21     fa[x]=fa[y];
 22     if(y==lc[fa[y]]) lc[fa[y]]=x;
 23     else rc[fa[y]]=x;
 24     fa[y]=x; rc[x]=y;
 25     update(y); update(x);
 26 }
 27 inline void l_rotate(int x){
 28     int y=fa[x];
 29     rc[y]=lc[x];
 30     if(lc[x]) fa[lc[x]]=y;
 31     fa[x]=fa[y];
 32     if(y==lc[fa[y]]) lc[fa[y]]=x;
 33     else rc[fa[y]]=x;
 34     fa[y]=x; lc[x]=y;
 35     update(y); update(x);
 36 }
 37 inline void splay(int rt,int x,int s){//在维护线段树节点标号为 rt 的 splay 中 把 x 节点旋至 s 的下面 
 38     int p;
 39     while(fa[x]!=s){
 40         int p=fa[x];
 41         if(fa[p]==s){
 42             if(x==lc[p]) r_rotate(x);
 43             else l_rotate(x);
 44             break;
 45         }
 46         else if(x==lc[p]){
 47             if(p==lc[fa[p]]) r_rotate(x),r_rotate(x);
 48             else r_rotate(x),l_rotate(x);
 49         }
 50         else if(x==rc[p]){
 51             if(p==rc[fa[p]]) l_rotate(x),l_rotate(x);
 52             else l_rotate(x),r_rotate(x);
 53         }
 54     }
 55     if(s==0) root[rt]=x;
 56 }
 57 inline void insert(int rt,int v){//在维护线段树节点标号为rt的splay中插入权值为v的点 
 58     if(root[rt]==0){
 59         root[rt]=++Siz;
 60         siz[root[rt]]=1; key[root[rt]]=v;
 61         return ;
 62     }
 63     int tmp,x=root[rt];
 64     while(x!=0){
 65         tmp=x;
 66         if(v<=key[x]) siz[x]++,x=lc[x];
 67         else siz[x]++,x=rc[x];
 68     }
 69     if(v<=key[tmp]){
 70          lc[tmp]=++Siz;
 71           fa[Siz]=tmp; key[Siz]=v; siz[Siz]=1;
 72     }
 73     else{
 74         rc[tmp]=++Siz;
 75         fa[Siz]=tmp; key[Siz]=v; siz[Siz]=1;
 76     }
 77     splay(rt,Siz,0);
 78 }
 79 inline void build(int rt,int l,int r,int x,int num){//在线段树里的标号是rt维护序列 l~r  在序列中是第x个数字 数字大小是num 
 80     insert(rt,num);
 81     if(l==r) return ;
 82     int mid=(l+r)>>1;
 83     if(x<=mid) build(rt<<1,l,mid,x,num);
 84     else build(rt<<1|1,mid+1,r,x,num);
 85 }
 86 
 87 inline int ask_rank(int x,int v){//在以x为根的splay中找 v 的排名 
 88     if(x==0) return 0;
 89     if(v<key[x]) return ask_rank(lc[x],v);
 90     else return siz[lc[x]]+1+ask_rank(rc[x],v);
 91 }
 92 
 93 inline int get_rank(int rt,int l,int r,int x,int y,int num){
 94     if(l==x&&r==y){
 95         return ask_rank(root[rt],num);
 96     }
 97     int mid=(l+r)>>1,ans=0;
 98     if(mid>=y) return get_rank(rt<<1,l,mid,x,y,num);//待询问区间完全存在于左子树 
 99     else if(mid<x) return get_rank(rt<<1|1,mid+1,r,x,y,num);
100     else{
101         ans+=get_rank(rt<<1,l,mid,x,mid,num);
102         ans+=get_rank(rt<<1|1,mid+1,r,mid+1,y,num);
103         return ans;
104     }
105 }
106 
107 inline int findkth(int l,int r,int x,int y,int k){//二分区间在[l,r]之间的数字,线段树中是序列的[x,y]位,要找第 k 小 
108     /*
109     while(l+1<r){
110         int mid=(l+r)>>1;
111         if(get_rank(1,1,N,x,y,mid-1)+1<=k) l=mid;
112         else r=mid-1;
113     }
114     if(get_rank(1,1,N,x,y,r-1)+1==k) return r;
115     return l;
116     */
117     int ans;
118     while(l<=r){
119         int mid=(l+r)>>1;
120         if(get_rank(1,1,N,x,y,mid-1)+1<=k){
121             l=mid+1; ans=mid;
122         }
123         else r=mid-1;
124     }
125     return ans;
126 }
127 inline int findv(int rt,int v){
128     int x=root[rt];
129     while(x!=0){
130         if(v<key[x]) x=lc[x];
131         else if(v>key[x]) x=rc[x];
132         else{
133             splay(rt,x,0);
134             return x;
135         }
136     }
137     return -1;
138 }
139 inline int getmax(int x){
140     int tmp;
141     while(x!=0){
142         tmp=x; x=rc[x];
143     }
144     return tmp;
145 }
146 inline int getmin(int x){
147     int tmp;
148     while(x!=0){
149         tmp=x; x=lc[x];
150     }
151     return tmp;
152 }
153 inline void Delete(int rt,int num){
154     int x=findv(rt,num);
155     if(x==-1) return ;
156     int pp=getmax(lc[x]),nn=getmin(rc[x]);
157     if(lc[x]==0||rc[x]==0){
158         if(lc[x]==0&&rc[x]==0){
159             root[rt]=0; return ;    
160         }
161         else if(lc[x]==0){
162             fa[rc[x]]=0; root[rt]=rc[x];
163             siz[x]=1;
164             return ;
165         }
166         else if(rc[x]==0){
167             fa[lc[x]]=0; root[rt]=lc[x];
168             siz[x]=1;
169             return ;
170         }
171     }
172     splay(rt,pp,0);
173     splay(rt,nn,root[rt]);
174     fa[lc[nn]]=0; siz[lc[nn]]=1; lc[nn]=0;
175     update(nn); update(pp);
176 }
177 inline void change(int rt,int l,int r,int pos,int num,int last){//修改操作 
178     Delete(rt,last); insert(rt,num);
179     if(l==r) return ;
180     int mid=(l+r)>>1;
181     if(pos<=mid) change(rt<<1,l,mid,pos,num,last);
182     else change(rt<<1|1,mid+1,r,pos,num,last);
183 }
184 inline void pre(int x,int num){
185     if(x==0) return ;
186     if(key[x]<num){
187         ANS=max(ANS,key[x]);
188         pre(rc[x],num);
189     }
190     else pre(lc[x],num);
191 }
192 inline void get_pre(int rt,int l,int r,int x,int y,int num){
193     if(l==x&&r==y){
194         pre(root[rt],num);
195         return ;
196     }
197     int mid=(l+r)>>1;
198     if(mid>=y) get_pre(rt<<1,l,mid,x,y,num);
199     else if(mid<x) get_pre(rt<<1|1,mid+1,r,x,y,num);
200     else{
201         get_pre(rt<<1,l,mid,x,mid,num);
202         get_pre(rt<<1|1,mid+1,r,mid+1,y,num);
203     }
204 }
205 inline void succ(int x,int num){
206     if(x==0) return ;
207     if(key[x]>num){
208         ANS=min(ANS,key[x]);
209         succ(lc[x],num);
210     }
211     else succ(rc[x],num);
212 }
213 inline void get_succ(int rt,int l,int r,int x,int y,int num){
214     if(l==x&&r==y){
215         succ(root[rt],num);
216         return ;
217     }
218     int mid=(l+r)>>1;
219     if(mid>=y) get_succ(rt<<1,l,mid,x,y,num);
220     else if(mid<x) get_succ(rt<<1|1,mid+1,r,x,y,num);
221     else{
222         get_succ(rt<<1,l,mid,x,mid,num);
223         get_succ(rt<<1|1,mid+1,r,mid+1,y,num);
224     }
225 }
226 
227 inline void insertinf(int x){
228     if(x>=50) return ;
229     else{
230         insert(x,-inf); insert(x,inf);
231         insertinf(x<<1); insertinf(x<<1|1);
232     }
233 }
234 int main(){
235 //    freopen("psh.in","r",stdin);
236 //    freopen("psh.out","w",stdout);
237     scanf("%d%d",&N,&M);
238     for(int i=1;i<=N;i++) scanf("%d",&a[i]);
239     for(int i=1;i<=N;i++) 
240         build(1,1,N,i,a[i]);
241     while(M--){
242         int f; scanf("%d",&f);
243         int x,y,k,pos;
244         switch(f){
245             case 1:scanf("%d%d%d",&x,&y,&k);//表示查询 k 在区间 [l,r] 的排名 
246                    printf("%d\n",get_rank(1,1,N,x,y,k-1)+1);
247                    break;
248             case 2:scanf("%d%d%d",&x,&y,&k);//查询区间内排名为 k 的值
249                    printf("%d\n",findkth(0,inf,x,y,k));
250                    break;
251             case 3:scanf("%d%d",&pos,&k);//把序列中的第pos位数改成k 
252                    change(1,1,N,pos,k,a[pos]);
253                    a[pos]=k;
254                    break;
255             case 4:scanf("%d%d%d",&x,&y,&k);
256                    ANS=0; get_pre(1,1,N,x,y,k);
257                    printf("%d\n",ANS);
258                    break;
259             case 5:scanf("%d%d%d",&x,&y,&k);
260                    ANS=inf; get_succ(1,1,N,x,y,k);
261                    printf("%d\n",ANS);
262                    break;
263         }
264     }
265     return 0;
266 }

 

posted @ 2016-04-08 21:49  CXCXCXC  阅读(356)  评论(1编辑  收藏  举报