【二逼平衡树】树套树模板,线段树套平衡树
树套树入门?第一份代码达到7k,orz orz orz,我的splay总是会断owo。
LOJ106
code:
这是一道模板题。
您需要写一种数据结构(可参考题目标题),来维护一个有序数列,其中需要提供以下操作:
- 查询 x 在区间内的排名;
- 查询区间内排名为 k 的值;
- 修改某一位置上的数值;
- 查询 x 在区间内的前趋(前趋定义为小于 x,且最大的数);
- 查询 x 在区间内的后继(后继定义为大于 x ,且最小的数)。
输入格式
第一行两个数 n,m ,表示长度为 n 的有序序列和 m个操作。
第二行有 n 个数,表示有序序列。
下面有 m 行,每行第一个数表示操作类型:
- 之后有三个数 l,r,x 表示查询 x 在区间 [l,r]的排名;
- 之后有三个数 l,r,k 表示查询区间 [l,r]内排名为 k 的数;
- 之后有两个数 pos,x 表示将 pos \mathrm{pos} pos 位置的数修改为 x;
- 之后有三个数 l,r,x 表示查询区间 [l,r] 内 x 的前趋;
- 之后有三个数 l,r,x 表示查询区间 [l,r] 内 x 的后继。
输出格式
对于操作 1,2,4,5 各输出一行,表示查询结果。
样例
样例输入
9 6
4 2 2 1 9 4 0 1 1
2 1 4 3
3 4 10
2 1 4 3
1 2 5 9
4 3 9 5
5 2 8 5
样例输出
2
4
3
4
9
数据范围与提示
1≤n,m≤5×104,−108≤k,x≤108
题目很显然,就是一道毒瘤数据结构。
做法1(本蒟蒻的做法):树套树,线段树套平衡树(splay),对每个位置建线段树,线段树每个结点对应的l到r建立一颗平衡树。由于层数Logn,建树时间O(n log^2 n)
对于询问,我们把l到r区间内所有对应包含的线段树结点对应splay的root结点提取出来,然后
1,对每个splay找比他小的个数总和+1即可。 O(n log^2 n)
2.我们考虑二分,对mid找比他小的数总和然后调整即可,注意小心二分写法(蒟蒻太久没有手写过二分查找了)O(n log^3 n)
3,对包含这个位置所有的平衡树进行del后insert , O(nlog^2 n)
4.对所有区间内平衡树找前驱,取个max就好
5.同4,找所有后继,取个min就好
空间复杂度在于splay,由于线段树logn层,每层n个数,空间(nlogn)
时间 瓶颈在于操作2,时间(nlog^3n)
做法2:(考虑怎么优化操作2)
来自Oblack大佬%%%
同样是线段树套平衡树,只是将线段树维护的转化为数值(动态开点),然后对于每个线段树结点开平衡树维护这个结点的数值对应的位置。可以发现这样的话操作2就转化为了nlog^2n。
当然我们转过头来看原nlog^3n做法,我们会发现其实根本没法达到这么高的复杂度。因为对于一个区间得到的点也就是最多n个,我们均摊过来,每个平衡树内的结点并不会多,若结点多,那么对应区间需要扫的平衡树结点个数也就会相应减少。
数组不要开小了,血的教训,线段树一定空间要开两倍!!!
//by newuser #include<stdio.h> #include<bits/stdc++.h> #define zig(x) zigzag(x,1) #define zag(x) zigzag(x,2) using namespace std; const int maxn = 50005; const int spmaxn = 2500005; int n,m,MAX; int A[maxn]; int sta[maxn],statop; namespace spa { int rt[maxn*2],cnt[spmaxn],ls[spmaxn],rs[spmaxn],dat[spmaxn],siz[spmaxn],fa[spmaxn],tot; inline void putup(int x) { siz[x] = siz[ls[x]] + siz[rs[x]] + cnt[x]; } inline void zigzag(int x,int knd) { int y=fa[x],z=fa[y]; if(z) { if(ls[z]==y) ls[z]=x; else rs[z]=x; } fa[x]=z; fa[y]=x; if(knd==1) { ls[y]=rs[x]; fa[ls[y]]=y; rs[x]=y; } else { rs[y]=ls[x]; fa[rs[y]]=y; ls[x]=y; } putup(y); putup(x); } inline void splay(int x,int root) { int y,z; while(fa[x]) { y=fa[x]; z=fa[y]; if(z) { if(ls[z]==y) { if(ls[y]==x) { zig(y); zig(x); } else { zag(x); zig(x); } } else { if(rs[y]==x) { zag(y); zag(x); } else { zig(x); zag(x); } } } else { if(ls[y]==x) zig(x); else zag(x); } } rt[root]=x; } inline void insert(int x,int root) { if(!rt[root]) { rt[root] = ++tot; cnt[tot]=siz[tot]=1; dat[tot]=x; ls[tot]=rs[tot]=fa[tot]=0; return ; } ++tot; int p = rt[root]; while(p) { siz[p]++; if(x<dat[p]) { if(!ls[p]) { ls[p]=tot; break; } p=ls[p]; } else if(x>dat[p]) { if(!rs[p]) { rs[p]=tot; break; } p=rs[p]; } else { --tot; cnt[p]++; return; } } dat[tot]=x; cnt[tot]=siz[tot]=1; fa[tot] = p; ls[tot]=rs[tot]=0; splay(tot,root); } inline int getmax(int p) { while(rs[p]) p = rs[p]; return p; } inline void del(int x,int root) { splay(x,root); --cnt[x]; siz[x]--; if(cnt[x]) return; int ll=ls[x]; int rr=rs[x]; fa[ll]=fa[rr]=ls[x]=rs[x]=siz[x]=0; if(!ll) { rt[root]=rr; return; } if(!rr) { rt[root]=ll; return; } ll=getmax(ll); splay(ll,root); fa[ll]=0; fa[rr]=ll; rs[ll]=rr; putup(rr); putup(ll); return; } inline int fi(int x,int root) { int p=rt[root]; while(p) { if(x==dat[p]) return p; else if(x<dat[p]) p = ls[p]; else if(x>dat[p]) p = rs[p]; } splay(p,root); return p; } inline int getsmaller(int x,int root) { int p = rt[root],ans=0; while(p) { if(dat[p]>x) p = ls[p]; else if(dat[p]<x) { ans+=siz[ls[p]]+cnt[p]; p = rs[p]; } else { ans+=siz[ls[p]]; break; } } if(p) splay(p,root); return ans; } inline int getqianqu(int x,int root) { int p = rt[root],ans=-0x3f3f3f3f; while(p) { if(dat[p]>=x) { p = ls[p]; } else { if(dat[p]>ans) ans = dat[p]; p = rs[p]; } } if(p) splay(p,root); return ans; } inline int gethouji(int x,int root) { int p = rt[root],ans=0x3f3f3f3f; while(p) { if(dat[p]<=x) { p=rs[p]; } else { if(dat[p]<ans) ans = dat[p]; p = ls[p]; } } if(p) splay(p,root); return ans; } } namespace seg { int ls[maxn*2],rs[maxn*2],dy[maxn*2],tot; int maketree(int l,int r) { int p = ++tot; if(l<r) { int mid = (l+r)>>1; ls[p] = maketree(l,mid); rs[p] = maketree(mid+1,r); } for(int i=l;i<=r;i++) spa::insert(A[i],p); return p; } inline void getqujian(int p,int l,int r,int x,int y) { if(x<=l&&r<=y) { sta[++statop] = p; return; } int mid = (l+r)>>1; if(x>mid) getqujian(rs[p],mid+1,r,x,y); else if(y<=mid) getqujian(ls[p],l,mid,x,y); else getqujian(ls[p],l,mid,x,y),getqujian(rs[p],mid+1,r,x,y); } inline void getbaohan(int p,int l,int r,int k) { sta[++statop] = p; if(l<r) { int mid = (l+r)>>1; if(k<=mid) getbaohan(ls[p],l,mid,k); else getbaohan(rs[p],mid+1,r,k); } return; } } inline int solve1() { int l,r,k; scanf("%d%d%d",&l,&r,&k); statop=0; seg::getqujian(1,1,n,l,r); int ans = 1; for(int i=1;i<=statop;i++) { ans+=spa::getsmaller(k,sta[i]); } return ans; } inline int ggg(int mid) { int tmp=0; for(int i=1;i<=statop;i++) { tmp+=spa::getsmaller(mid,sta[i]); } return tmp; } inline int solve2() { int l,r,k; scanf("%d%d%d",&l,&r,&k); int L=0;int R=MAX; int mid; statop=0; seg::getqujian(1,1,n,l,r); while(L<=R) { mid = (L+R)>>1; int aha = ggg(mid); if(aha<k) L = mid+1; else R=mid-1; } return L-1; } inline void solve3() { int pos,k; scanf("%d%d",&pos,&k); statop=0; seg::getbaohan(1,1,n,pos); for(int i=1;i<=statop;i++) { int x = spa::fi(A[pos],sta[i]); spa::del(x,sta[i]); spa::insert(k,sta[i]); } A[pos] = k; MAX = max(MAX,k); } inline int solve4() { int l,r,k; scanf("%d%d%d",&l,&r,&k); statop=0; seg::getqujian(1,1,n,l,r); int ans = -0x3f3f3f3f; for(int i=1;i<=statop;i++) { ans = max(ans,spa::getqianqu(k,sta[i])); } return ans; } inline int solve5() { int l,r,k; scanf("%d%d%d",&l,&r,&k); statop=0; seg::getqujian(1,1,n,l,r); int ans = 0x3f3f3f3f; for(int i=1;i<=statop;i++) { ans = min(ans,spa::gethouji(k,sta[i])); } return ans; } int main() { scanf("%d%d",&n,&m); for(int i=1;i<=n;i++) scanf("%d",&A[i]),MAX=max(MAX,A[i]); seg::maketree(1,n); for(int i=1;i<=m;i++) { int opt; scanf("%d",&opt); if(opt==1) printf("%d\n",solve1()); else if(opt==2) printf("%d\n",solve2()); else if(opt==3) solve3(); else if(opt==4) printf("%d\n",solve4()); else printf("%d\n",solve5()); } }