二逼平衡树 题解(树套树)
我 想 扇 死 自 己
void up(int x) { if(x) { size[x]=cnt[x];//我TM这行忘了 if(son[x][0])size[x]+=size[son[x][0]]; if(son[x][1])size[x]+=size[son[x][1]]; } }
4个小时!调一道模板!我敲里码!
上道splay刚因为细节打错浪费了3个小时时间,这次就又**重现了
不多说了,先把splay抄上10遍,手写!
-----------以下是正经题解----------------
第一道树套树:线段树套splay
对于线段树的每一段区间建splay维护这段的信息
在合并时:
排名相加;
前驱取max;
后继取min;
比较麻烦的是查询数值,需要二分答案.
以数值为值域进行二分,不断询问mid的排名来缩小范围。
#include<cstdio> #include<iostream> #include<algorithm> #include<cstring> using namespace std; const int N=4000005,inf=1e9; int n,m,a[N]; int root[N],son[N][3],fa[N],key[N],size[N],type,cnt[N]; void clear(int x) { if(!x)return ; fa[x]=cnt[x]=son[x][0]=son[x][1]=size[x]=key[x]=0; } int pre(int k) { int now=son[root[k]][0]; while(son[now][1])now=son[now][1]; return now; } bool judge(int x) { return son[fa[x]][1]==x; } void up(int x) { if(x) { size[x]=cnt[x]; if(son[x][0])size[x]+=size[son[x][0]]; if(son[x][1])size[x]+=size[son[x][1]]; } } void rotate(int x) { int old=fa[x],oldf=fa[old],lr=judge(x); son[old][lr]=son[x][lr^1]; fa[son[old][lr]]=old; son[x][lr^1]=old; fa[old]=x; fa[x]=oldf; if(oldf)son[oldf][son[oldf][1]==old]=x; up(old);up(x); } void splay(int k,int x) { for(int f;f=fa[x];rotate(x)) if(fa[f])rotate(judge(x)==judge(f)?f:x); root[k]=x; } void ins(int k,int x) { if(!root[k]) { type++; key[type]=x; root[k]=type; cnt[type]=size[type]=1; fa[type]=son[type][0]=son[type][1]=0; return ; } int now=root[k],f=0; while(1) { if(x==key[now]) { cnt[now]++; up(now); up(f); splay(k,now); return ; } f=now;now=son[now][key[now]<x]; if(!now) { type++; size[type]=cnt[type]=1; son[type][0]=son[type][1]=0; son[f][x>key[f]]=type; fa[type]=f; key[type]=x; up(f);splay(k,type); return ; } } } int getrank(int k,int x) { int now=root[k],ans=0; while(1) { if(!now)return ans; if(x==key[now])return (son[now][0]?size[son[now][0]]:0)+ans; else if(x>key[now]) { ans+=(son[now][0]?size[son[now][0]]:0)+cnt[now]; now=son[now][1]; } else if(x<key[now])now=son[now][0]; } } int findpos(int k,int x) { int now=root[k]; while(1) { if(x==key[now])return now; else if(x<key[now])now=son[now][0]; else now=son[now][1]; } } int findpre(int k,int x) { int now=root[k],ans=0; while(now) { if(key[now]<x) { if(ans<key[now])ans=key[now]; now=son[now][1]; } else now=son[now][0]; } return ans; } int findnxt(int k,int x) { int now=root[k],ans=inf; while(now) { if(key[now]>x) { if(ans>key[now])ans=key[now]; now=son[now][0]; } else now=son[now][1]; } return ans; } void del(int k,int x) { int now=findpos(k,x); splay(k,now); if(cnt[root[k]]>1) { cnt[root[k]]--; up(root[k]); return ; } else if(!son[root[k]][0]&&(!son[root[k]][1])) { clear(root[k]); root[k]=0; return ; } int old=root[k]; if(son[root[k]][0]*son[root[k]][1]==0) { if(!son[root[k]][0])root[k]=son[root[k]][1]; else root[k]=son[root[k]][0]; fa[root[k]]=0; clear(old); return ; } int L=pre(k); splay(k,L); son[root[k]][1]=son[old][1]; fa[son[old][1]]=root[k]; clear(old); up(root[k]); } #define ls(k) k<<1 #define rs(k) k<<1|1 void update(int k,int l,int r,int pos,int val) { ins(k,val); if(l==r)return ; int mid=l+r>>1; if(pos<=mid)update(ls(k),l,mid,pos,val); else update(rs(k),mid+1,r,pos,val); return ; } int rank(int k,int l,int r,int L,int R,int val) { if(l>=L&&r<=R) { int res=getrank(k,val); return res; } int mid=l+r>>1,res=0; if(L<=mid)res+=rank(ls(k),l,mid,L,R,val); if(R>mid)res+=rank(rs(k),mid+1,r,L,R,val); return res; } void modify(int k,int l,int r,int pos,int val) { del(k,a[pos]); ins(k,val); if(l==r)return ; int mid=l+r>>1; if(pos<=mid)modify(ls(k),l,mid,pos,val); else modify(rs(k),mid+1,r,pos,val); } int getpre(int k,int l,int r,int L,int R,int val) { if(l>=L&&r<=R)return findpre(k,val); int mid=l+r>>1,res=0; if(L<=mid)res=max(res,getpre(ls(k),l,mid,L,R,val)); if(R>mid)res=max(res,getpre(rs(k),mid+1,r,L,R,val)); return res; } int getnxt(int k,int l,int r,int L,int R,int val) { if(l>=L&&r<=R)return findnxt(k,val); int mid=l+r>>1,res=inf; if(L<=mid)res=min(res,getnxt(ls(k),l,mid,L,R,val)); if(R>mid)res=min(res,getnxt(rs(k),mid+1,r,L,R,val)); return res; } inline int read() { int x=0,f=1;char ch=getchar(); while(ch<'0'||ch>'9') {if(ch=='-')f=-1;ch=getchar();} while(ch>='0'&&ch<='9') {x=(x<<1)+(x<<3)+ch-'0';ch=getchar();} return x*f; } int main() { n=read();m=read(); int op,maxx=0; for(int i=1;i<=n;i++) { a[i]=read(); update(1,1,n,i,a[i]); maxx=max(maxx,a[i]); } while(m--) { op=read(); if(op==1) { int l=read(),r=read(),val=read(); printf("%d\n",rank(1,1,n,l,r,val)+1); } else if(op==2) { int l=read(),r=read(),val=read(); int L=0,R=maxx+1; while(L!=R) { int mid=L+R>>1; int res=rank(1,1,n,l,r,mid); //cout<<"***"<<res<<endl; if(res<val)L=mid+1; else R=mid; } printf("%d\n",L-1); } else if(op==3) { int pos=read(),val=read();modify(1,1,n,pos,val); a[pos]=val; maxx=max(maxx,val); } else if(op==4) { int l=read(),r=read(),val=read(); printf("%d\n",getpre(1,1,n,l,r,val)); } else if(op==5) { int l=read(),r=read(),val=read(); printf("%d\n",getnxt(1,1,n,l,r,val)); } } return 0; }
好了。
从上面那段简短而狗屁不通的“题解”和几乎是抄来的代码可以看出来,是什么让当时的我那么垃圾。
不求甚解、生搬硬套、懒于思考、依赖题解。
装模作样打个Splay,考场上没板子真的写得出来?
如果像本题一样,把普通平衡树的操作放到区间上,显然是无法只用平衡树维护的。解决区间问题最有力的武器就是线段树,所以考虑线段树套平衡树解决。
对每个线段树区间建一棵平衡树。建树时直接把所有区间都插入该区间的所有元素,单点修改时把沿路的所有线段树区间上的平衡树都进行改动(删除再插入)。
对于剩下的查询操作,求排名显然可以转化为所有区间小于该数的元素个数之和+1,即$( \sum (每个区间求排名结果-1)) +1$,前驱应当是所有区间结果的最大值,同理后继就是最小值。
但用相同的方式求K大是不太可行的,考虑牺牲一下时间复杂度进行二分答案,每次二分出一个数check它的排名即可。这样的话是3个$log$。
平衡树使用的是替罪羊树,一是确实好写且容易封装,二是动态开点删点可以避免内存超限。这样就可以直接粗暴地扔到结构体里而不用像Splay一样使用$root[]$数组了。
上面瞎写的东西我没有删。给自己和大家一个警示以及反面典型。
#include<cstdio>
#include<iostream>
#include<cstring>
#include<vector>
using namespace std;
int read()
{
int x=0,f=1;char ch=getchar();
while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
while(isdigit(ch))x=x*10+ch-'0',ch=getchar();
return x*f;
}
const int N=1e5+5,inf=2147483647;
const double al=0.7;
int n,m,a[N];
struct Scapegoat
{
struct node
{
node *l,*r;
int val,size,cnt;
bool del;
bool bad()
{
return l->cnt>al*cnt+5||r->cnt>al*cnt+5;
}
void up()
{
size=!del+l->size+r->size;
cnt=1+l->cnt+r->cnt;
}
};
node *null,**badtag;
void dfs(node *k,vector<node*> &v)
{
if(k==null)return ;
dfs(k->l,v);
if(!k->del)v.push_back(k);
dfs(k->r,v);
if(k->del)delete k;
}
node *build(vector<node*> &v,int l,int r)
{
if(l>=r)return null;
int mid=l+r>>1;
node *k=v[mid];
k->l=build(v,l,mid);
k->r=build(v,mid+1,r);
k->up();
return k;
}
void rebuild(node* &k)
{
vector<node*> v;
dfs(k,v);
k=build(v,0,v.size());
}
void insert(int x,node* &k)
{
if(k==null)
{
k=new node;
k->l=k->r=null;
k->del=0;
k->size=k->cnt=1;
k->val=x;
return ;
}
++k->size;++k->cnt;
if(x>=k->val)insert(x,k->r);
else insert(x,k->l);
if(k->bad())badtag=&k;
else if(badtag!=&null)
k->cnt-=(*badtag)->cnt-(*badtag)->size;
}
void ins(int x,node* &k)
{
badtag=&null;
insert(x,k);
if(badtag!=&null)rebuild(*badtag);
}
int getrk(node *now,int x)
{
int ans=1;
while(now!=null)
{
if(now->val>=x)now=now->l;
else
{
ans+=now->l->size+!now->del;
now=now->r;
}
}
return ans;
}
int kth(node *now,int x)
{
while(now!=null)
{
if(!now->del&&now->l->size+1==x)
return now->val;
if(now->l->size>=x)now=now->l;
else
{
x-=now->l->size+!now->del;
now=now->r;
}
}
return -1;
}
void erase(node *k,int rk)
{
if(!k->del&&rk==k->l->size+1)
{
k->del=1;
--k->size;
return ;
}
--k->size;
if(rk<=k->l->size+!k->del)erase(k->l,rk);
else erase(k->r,rk-k->l->size-!k->del);
}
node* root;
Scapegoat()
{
null=new node;
root=null;
}
}s[N<<3];
#define ls(k) (k)<<1
#define rs(k) (k)<<1|1
void build(int k,int l,int r)
{
for(int i=l;i<=r;i++)
s[k].ins(a[i],s[k].root);
if(l==r)return ;
int mid=l+r>>1;
build(ls(k),l,mid);
build(rs(k),mid+1,r);
}
int askrk(int k,int l,int r,int L,int R,int val)
{
if(L<=l&&R>=r)return s[k].getrk(s[k].root,val)-1;
int mid=l+r>>1,res=0;
if(L<=mid)res+=askrk(ls(k),l,mid,L,R,val);
if(R>mid)res+=askrk(rs(k),mid+1,r,L,R,val);
return res;
}
void update(int k,int l,int r,int pos,int val)
{
s[k].erase(s[k].root,s[k].getrk(s[k].root,a[pos]));
s[k].ins(val,s[k].root);
if(l==r)return ;
int mid=l+r>>1;
if(pos<=mid)update(ls(k),l,mid,pos,val);
else update(rs(k),mid+1,r,pos,val);
}
int askpre(int k,int l,int r,int L,int R,int val)
{
if(L<=l&&R>=r)return s[k].kth(s[k].root,s[k].getrk(s[k].root,val)-1);
int res=-inf,mid=l+r>>1;
if(L<=mid)
{
int ret=askpre(ls(k),l,mid,L,R,val);
if(ret==-1)res=max(res,-inf);
else res=max(res,ret);
}
if(R>mid)
{
int ret=askpre(rs(k),mid+1,r,L,R,val);
if(ret==-1)res=max(res,-inf);
else res=max(res,ret);
}
return res;
}
int asknxt(int k,int l,int r,int L,int R,int val)
{
if(L<=l&&R>=r)return s[k].kth(s[k].root,s[k].getrk(s[k].root,val+1));
int res=inf,mid=l+r>>1;
if(L<=mid)
{
int ret=asknxt(ls(k),l,mid,L,R,val);
if(ret==-1)res=min(res,inf);
else res=min(res,ret);
}
if(R>mid)
{
int ret=asknxt(rs(k),mid+1,r,L,R,val);
if(ret==-1)res=min(res,inf);
else res=min(res,ret);
}
return res;
}
int askth(int L,int R,int val)
{
int l=0,r=1e8,res;
while(l<=r)
{
int mid=l+r>>1;
if(askrk(1,1,n,L,R,mid)+1<=val)res=mid,l=mid+1;
else r=mid-1;
}
return res;
}
int main()
{
n=read();m=read();
for(int i=1;i<=n;i++)
a[i]=read();
build(1,1,n);
while(m--)
{
int op=read();
if(op==1){int l=read(),r=read(),K=read();printf("%d\n",askrk(1,1,n,l,r,K)+1);}
if(op==2){int l=read(),r=read(),K=read();printf("%d\n",askth(l,r,K));}
if(op==3){int pos=read(),K=read();update(1,1,n,pos,K);a[pos]=K;}
if(op==4){int l=read(),r=read(),K=read();printf("%d\n",askpre(1,1,n,l,r,K));}
if(op==5){int l=read(),r=read(),K=read();printf("%d\n",asknxt(1,1,n,l,r,K));}
}
return 0;
}
兴许青竹早凋,碧梧已僵,人事本难防。

浙公网安备 33010602011771号