【模板】二逼平衡树(树套树)
Description
维护一个可以支持查询区间某值的排名,查询区间某排名的值,修改某值,查询某值在区间内的前驱和后继的数据结构。
区间长度为 \(n\) , \(m\) 次询问。
\(1\leq n\ m\leq5\cdot 10^4\ \ \ \ 0\leq a[i]\leq 10^8\)
Solution
刚学了平衡树,发现硬上平衡树仿佛不行。(全是区间内呀)
正好老师也大概讲了一下树套树是什么玩意。
看到大部分操作其实都都是平衡树的基本操作,只不过都在区间内,然后就很自然地想到了线段树。
于是我们就可以考虑线段树套平衡树(虽说最优的还是树状数组套值域线段树)
直接在线段树每个节点上建一颗平衡树。
那么对于这五个操作,除了第三个,其他的直接在线段树上找到合法区间硬上平衡树就行。
第三个操作的话从根节点开始找,碰上一个节点更新哪里平衡树对应的值(删掉 \(a[pos]\) ,加入 \(k\) ),直到找到最后 \(l=r=pos\) 时,更新后,把 \(a[pos]\) 改成 \(k\) 就行了。
#include<bits/stdc++.h>
#define ls(i) spl[i].ch[0]
#define rs(i) spl[i].ch[1]
#define reg register
using namespace std;
typedef long long ll;
const int N=5e4+10;
const int INF=2147483647;
int n,m,a[N],tot;
struct Splay{int ch[2],fa,cnt,val,siz;}spl[N<<6];
struct Seg_tree{int lt,rt,root;}seg[N<<2];
inline int read(){
int s=0,w=1;
char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')w=-1;ch=getchar();}
while(ch>='0'&&ch<='9') s=s*10+ch-'0',ch=getchar();
return s*w;
}
//splay start
inline void pushup(int now){
spl[now].siz=spl[now].cnt+spl[ls(now)].siz+spl[rs(now)].siz;
}
inline void rotate(int now){
int nxt=spl[now].fa;
int nnt=spl[nxt].fa;
int k1=rs(nxt)==now;
int k2=rs(nnt)==nxt;
int pre=spl[now].ch[k1^1];
spl[nnt].ch[k2]=now; spl[now].fa=nnt;
spl[nxt].ch[k1]=pre; spl[pre].fa=nxt;
spl[now].ch[k1^1]=nxt; spl[nxt].fa=now;
pushup(nxt);pushup(now);
}//虽然肯定会跑得慢,但真tm好看
inline void splay(int now,int S,int it){
while(spl[now].fa!=S){
int nxt=spl[now].fa;
int nnt=spl[nxt].fa;
int k1=rs(nxt)==now;
int k2=rs(nnt)==nxt;
if(nnt!=S)(k1^k2)?rotate(now):rotate(nxt);
rotate(now);
}
if(!S)seg[it].root=now;
}
inline void insert(int now,int it){
int u=seg[it].root,fth=0;
while(u&&spl[u].val!=now){
fth=u;
u=spl[u].ch[now>spl[u].val];
}
if(u)++spl[u].cnt;
else {
u=++tot;ls(u)=rs(u)=0;
if(fth)spl[fth].ch[now>spl[fth].val]=u;
spl[u].cnt=spl[u].siz=1;
spl[u].fa=fth;spl[u].val=now;
}
splay(u,0,it);
}
inline void find(int now,int it){
int u=seg[it].root;
if(!u)return ;
while(spl[u].ch[now>spl[u].val]&&now!=spl[u].val){
u=spl[u].ch[now>spl[u].val];
}
splay(u,0,it);
}
inline int pre_nxt(int now,int k,int it){
find(now,it);
int u=seg[it].root;
if(spl[u].val<now&&(!k))return u;
if(spl[u].val>now&&k)return u;
u=spl[u].ch[k];
while(spl[u].ch[k^1])u=spl[u].ch[k^1];
return u;
}
inline void delet(int now,int it){
int pre=pre_nxt(now,0,it);
int nxt=pre_nxt(now,1,it);
splay(pre,0,it);splay(nxt,pre,it);
int del=ls(nxt);
if(spl[del].cnt>1){
--spl[del].cnt;
splay(del,0,it);
}
else {
ls(nxt)=0;//del
}
pushup(pre);
}
//splay end
/*inline int kth_find(int now,int it){
int u=seg[it].root;
if(spl[u].siz<now)return 0;
while(1){
if(now>spl[ls(u)].siz+spl[u].cnt){
now-=spl[ls(u)].siz+spl[u].cnt;
u=rs(u);
}
else if(now<=spl[ls(u)].siz)u=ls(u);
else return spl[u].val;
}
}*/
//seg_tree start
inline void sbuild(int lt,int rt,int it){
insert(-INF,it);insert(INF,it);
if(lt==rt)return ;
int mid=(lt+rt)>>1;
sbuild(lt,mid,it<<1);
sbuild(mid+1,rt,it<<1|1);
}
inline void sinsert(int lt,int rt,int pos,int val,int it){
insert(val,it);
if(lt==rt)return ;
int mid=(lt+rt)>>1;
if(mid>=pos)sinsert(lt,mid,pos,val,it<<1);
else sinsert(mid+1,rt,pos,val,it<<1|1);
}
inline int sfind(int lt,int rt,int LT,int RT,int pos,int it){
if(RT<lt||rt<LT)return 0;
if(LT<=lt&&rt<=RT){
find(pos,it);
int u=seg[it].root;
if(spl[u].val>=pos)return spl[ls(u)].siz-1;
else return spl[ls(u)].siz-1+spl[u].cnt;
}
int mid=(lt+rt)>>1;
int ans1=sfind(lt,mid,LT,RT,pos,it<<1);
int ans2=sfind(mid+1,rt,LT,RT,pos,it<<1|1);
return ans1+ans2;
}
inline int skth_find(int lt,int rt,int LT,int RT,int val){
int mid,now,ans;
while(lt<=rt){
mid=(lt+rt)>>1;
now=sfind(1,n,LT,RT,mid,1)+1;
if(now>val)rt=mid-1;
else lt=mid+1,ans=mid;
}
return ans;
}
inline void smodify(int lt,int rt,int pos,int val,int it){
delet(a[pos],it);insert(val,it);
if(lt==rt&&rt==pos){
a[pos]=val;
return ;
}
int mid=(lt+rt)>>1;
if(mid>=pos)smodify(lt,mid,pos,val,it<<1);
else smodify(mid+1,rt,pos,val,it<<1|1);
}
inline int spre_nxt(int sig,int lt,int rt,int LT,int RT,int pos,int it){
if(RT<lt||rt<LT){
if(!sig)return -INF;
return INF;
}
if(LT<=lt&&rt<=RT){
int u=pre_nxt(pos,sig,it);
return spl[u].val;
}
int mid=(lt+rt)>>1;
int ans1=spre_nxt(sig,lt,mid,LT,RT,pos,it<<1);
int ans2=spre_nxt(sig,mid+1,rt,LT,RT,pos,it<<1|1);
if(!sig)return max(ans1,ans2);
return min(ans1,ans2);
}
//seg_tree end
int main(){
n=read();m=read();
sbuild(1,n,1);
for(int i=1;i<=n;++i){
a[i]=read();
sinsert(1,n,i,a[i],1);
}
while(m--){
int opt=read(),l=read(),r=read(),k;
if(opt==1){
k=read();
printf("%d\n",sfind(1,n,l,r,k,1)+1);
}
else if(opt==2){
k=read();
printf("%d\n",skth_find(0,1e8,l,r,k));
}
else if(opt==3){
smodify(1,n,l,r,1);
}
else if(opt==4){
k=read();
printf("%d\n",spre_nxt(0,1,n,l,r,k,1));
}
else if(opt==5){
k=read();
printf("%d\n",spre_nxt(1,1,n,l,r,k,1));
}
}
return 0;
}
(这玩意考场上打得出来??

浙公网安备 33010602011771号