【二逼平衡树】树套树模板,线段树套平衡树
树套树入门?第一份代码达到7k,orz orz orz,我的splay总是会断owo。
LOJ106
code:
这是一道模板题。
您需要写一种数据结构(可参考题目标题),来维护一个有序数列,其中需要提供以下操作:
- 查询 x 在区间内的排名;
- 查询区间内排名为 k 的值;
- 修改某一位置上的数值;
- 查询 x 在区间内的前趋(前趋定义为小于 x,且最大的数);
- 查询 x 在区间内的后继(后继定义为大于 x ,且最小的数)。
输入格式
第一行两个数 n,m ,表示长度为 n 的有序序列和 m个操作。
第二行有 n 个数,表示有序序列。
下面有 m 行,每行第一个数表示操作类型:
- 之后有三个数 l,r,x 表示查询 x 在区间 [l,r]的排名;
- 之后有三个数 l,r,k 表示查询区间 [l,r]内排名为 k 的数;
- 之后有两个数 pos,x 表示将 pos \mathrm{pos} pos 位置的数修改为 x;
- 之后有三个数 l,r,x 表示查询区间 [l,r] 内 x 的前趋;
- 之后有三个数 l,r,x 表示查询区间 [l,r] 内 x 的后继。
输出格式
对于操作 1,2,4,5 各输出一行,表示查询结果。
样例
样例输入
9 6
4 2 2 1 9 4 0 1 1
2 1 4 3
3 4 10
2 1 4 3
1 2 5 9
4 3 9 5
5 2 8 5
样例输出
2
4
3
4
9
数据范围与提示
1≤n,m≤5×104,−108≤k,x≤108
题目很显然,就是一道毒瘤数据结构。
做法1(本蒟蒻的做法):树套树,线段树套平衡树(splay),对每个位置建线段树,线段树每个结点对应的l到r建立一颗平衡树。由于层数Logn,建树时间O(n log^2 n)
对于询问,我们把l到r区间内所有对应包含的线段树结点对应splay的root结点提取出来,然后
1,对每个splay找比他小的个数总和+1即可。 O(n log^2 n)
2.我们考虑二分,对mid找比他小的数总和然后调整即可,注意小心二分写法(蒟蒻太久没有手写过二分查找了)O(n log^3 n)
3,对包含这个位置所有的平衡树进行del后insert , O(nlog^2 n)
4.对所有区间内平衡树找前驱,取个max就好
5.同4,找所有后继,取个min就好
空间复杂度在于splay,由于线段树logn层,每层n个数,空间(nlogn)
时间 瓶颈在于操作2,时间(nlog^3n)
做法2:(考虑怎么优化操作2)
来自Oblack大佬%%%
同样是线段树套平衡树,只是将线段树维护的转化为数值(动态开点),然后对于每个线段树结点开平衡树维护这个结点的数值对应的位置。可以发现这样的话操作2就转化为了nlog^2n。
当然我们转过头来看原nlog^3n做法,我们会发现其实根本没法达到这么高的复杂度。因为对于一个区间得到的点也就是最多n个,我们均摊过来,每个平衡树内的结点并不会多,若结点多,那么对应区间需要扫的平衡树结点个数也就会相应减少。
数组不要开小了,血的教训,线段树一定空间要开两倍!!!
//by newuser
#include<stdio.h>
#include<bits/stdc++.h>
#define zig(x) zigzag(x,1)
#define zag(x) zigzag(x,2)
using namespace std;
const int maxn = 50005;
const int spmaxn = 2500005;
int n,m,MAX;
int A[maxn];
int sta[maxn],statop;
namespace spa
{
int rt[maxn*2],cnt[spmaxn],ls[spmaxn],rs[spmaxn],dat[spmaxn],siz[spmaxn],fa[spmaxn],tot;
inline void putup(int x) { siz[x] = siz[ls[x]] + siz[rs[x]] + cnt[x]; }
inline void zigzag(int x,int knd)
{
int y=fa[x],z=fa[y];
if(z)
{
if(ls[z]==y) ls[z]=x;
else rs[z]=x;
}
fa[x]=z; fa[y]=x;
if(knd==1)
{
ls[y]=rs[x];
fa[ls[y]]=y;
rs[x]=y;
}
else
{
rs[y]=ls[x];
fa[rs[y]]=y;
ls[x]=y;
}
putup(y); putup(x);
}
inline void splay(int x,int root)
{
int y,z;
while(fa[x])
{
y=fa[x]; z=fa[y];
if(z)
{
if(ls[z]==y)
{
if(ls[y]==x) { zig(y); zig(x); }
else { zag(x); zig(x); }
}
else
{
if(rs[y]==x) { zag(y); zag(x); }
else { zig(x); zag(x); }
}
}
else
{
if(ls[y]==x) zig(x);
else zag(x);
}
}
rt[root]=x;
}
inline void insert(int x,int root)
{
if(!rt[root]) { rt[root] = ++tot; cnt[tot]=siz[tot]=1; dat[tot]=x; ls[tot]=rs[tot]=fa[tot]=0; return ; }
++tot; int p = rt[root];
while(p)
{
siz[p]++;
if(x<dat[p])
{
if(!ls[p])
{
ls[p]=tot;
break;
}
p=ls[p];
}
else if(x>dat[p])
{
if(!rs[p])
{
rs[p]=tot;
break;
}
p=rs[p];
}
else { --tot; cnt[p]++; return; }
}
dat[tot]=x; cnt[tot]=siz[tot]=1;
fa[tot] = p; ls[tot]=rs[tot]=0;
splay(tot,root);
}
inline int getmax(int p)
{
while(rs[p]) p = rs[p];
return p;
}
inline void del(int x,int root)
{
splay(x,root);
--cnt[x];
siz[x]--;
if(cnt[x]) return;
int ll=ls[x]; int rr=rs[x];
fa[ll]=fa[rr]=ls[x]=rs[x]=siz[x]=0;
if(!ll) { rt[root]=rr; return; }
if(!rr) { rt[root]=ll; return; }
ll=getmax(ll);
splay(ll,root);
fa[ll]=0;
fa[rr]=ll; rs[ll]=rr;
putup(rr); putup(ll);
return;
}
inline int fi(int x,int root)
{
int p=rt[root];
while(p)
{
if(x==dat[p]) return p;
else if(x<dat[p]) p = ls[p];
else if(x>dat[p]) p = rs[p];
}
splay(p,root);
return p;
}
inline int getsmaller(int x,int root)
{
int p = rt[root],ans=0;
while(p)
{
if(dat[p]>x) p = ls[p];
else if(dat[p]<x)
{
ans+=siz[ls[p]]+cnt[p];
p = rs[p];
}
else
{
ans+=siz[ls[p]];
break;
}
}
if(p) splay(p,root);
return ans;
}
inline int getqianqu(int x,int root)
{
int p = rt[root],ans=-0x3f3f3f3f;
while(p)
{
if(dat[p]>=x)
{
p = ls[p];
}
else
{
if(dat[p]>ans) ans = dat[p];
p = rs[p];
}
}
if(p) splay(p,root);
return ans;
}
inline int gethouji(int x,int root)
{
int p = rt[root],ans=0x3f3f3f3f;
while(p)
{
if(dat[p]<=x)
{
p=rs[p];
}
else
{
if(dat[p]<ans) ans = dat[p];
p = ls[p];
}
}
if(p) splay(p,root);
return ans;
}
}
namespace seg
{
int ls[maxn*2],rs[maxn*2],dy[maxn*2],tot;
int maketree(int l,int r)
{
int p = ++tot;
if(l<r)
{
int mid = (l+r)>>1;
ls[p] = maketree(l,mid);
rs[p] = maketree(mid+1,r);
}
for(int i=l;i<=r;i++) spa::insert(A[i],p);
return p;
}
inline void getqujian(int p,int l,int r,int x,int y)
{
if(x<=l&&r<=y)
{
sta[++statop] = p;
return;
}
int mid = (l+r)>>1;
if(x>mid) getqujian(rs[p],mid+1,r,x,y);
else if(y<=mid) getqujian(ls[p],l,mid,x,y);
else getqujian(ls[p],l,mid,x,y),getqujian(rs[p],mid+1,r,x,y);
}
inline void getbaohan(int p,int l,int r,int k)
{
sta[++statop] = p;
if(l<r)
{
int mid = (l+r)>>1;
if(k<=mid) getbaohan(ls[p],l,mid,k);
else getbaohan(rs[p],mid+1,r,k);
}
return;
}
}
inline int solve1()
{
int l,r,k;
scanf("%d%d%d",&l,&r,&k);
statop=0;
seg::getqujian(1,1,n,l,r);
int ans = 1;
for(int i=1;i<=statop;i++)
{
ans+=spa::getsmaller(k,sta[i]);
}
return ans;
}
inline int ggg(int mid)
{
int tmp=0;
for(int i=1;i<=statop;i++)
{
tmp+=spa::getsmaller(mid,sta[i]);
}
return tmp;
}
inline int solve2()
{
int l,r,k;
scanf("%d%d%d",&l,&r,&k);
int L=0;int R=MAX; int mid;
statop=0;
seg::getqujian(1,1,n,l,r);
while(L<=R)
{
mid = (L+R)>>1;
int aha = ggg(mid);
if(aha<k) L = mid+1;
else R=mid-1;
}
return L-1;
}
inline void solve3()
{
int pos,k;
scanf("%d%d",&pos,&k);
statop=0;
seg::getbaohan(1,1,n,pos);
for(int i=1;i<=statop;i++)
{ int x = spa::fi(A[pos],sta[i]);
spa::del(x,sta[i]);
spa::insert(k,sta[i]);
}
A[pos] = k;
MAX = max(MAX,k);
}
inline int solve4()
{
int l,r,k;
scanf("%d%d%d",&l,&r,&k);
statop=0;
seg::getqujian(1,1,n,l,r);
int ans = -0x3f3f3f3f;
for(int i=1;i<=statop;i++)
{
ans = max(ans,spa::getqianqu(k,sta[i]));
}
return ans;
}
inline int solve5()
{
int l,r,k;
scanf("%d%d%d",&l,&r,&k);
statop=0;
seg::getqujian(1,1,n,l,r);
int ans = 0x3f3f3f3f;
for(int i=1;i<=statop;i++)
{
ans = min(ans,spa::gethouji(k,sta[i]));
}
return ans;
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++) scanf("%d",&A[i]),MAX=max(MAX,A[i]);
seg::maketree(1,n);
for(int i=1;i<=m;i++)
{
int opt;
scanf("%d",&opt);
if(opt==1) printf("%d\n",solve1());
else if(opt==2) printf("%d\n",solve2());
else if(opt==3) solve3();
else if(opt==4) printf("%d\n",solve4());
else printf("%d\n",solve5());
}
}

浙公网安备 33010602011771号