BZOJ3224/洛谷P3391 - 普通平衡树(Splay)

BZOJ链接 洛谷链接

题意简述

模板题啦~

代码

//普通平衡树(Splay)
#include <cstdio>
int const N=1e5+10;
int rt,ndCnt;
int ch[N][2],fa[N],val[N],siz[N],cnt[N];
int wh(int p) {return p==ch[fa[p]][1];}
void create(int p,int v) {ch[p][0]=ch[p][1]=fa[p]=0; val[p]=v; cnt[p]=1,siz[p]=1;}
void update(int p) {if(p) siz[p]=cnt[p]+siz[ch[p][0]]+siz[ch[p][1]]; else siz[p]=0;}
void rotate(int p)
{
    int q=fa[p],r=fa[q],w=wh(p);
    fa[ch[q][w]=ch[p][w^1]]=q;
    fa[ch[p][w^1]=q]=p;
    fa[p]=r; if(r) ch[r][q==ch[r][1]]=p;
    update(p),update(q);
}
void splay(int p)
{
    for(int q;q=fa[p];rotate(p)) if(fa[q]) rotate(wh(p)==wh(q)?q:p);  
    update(rt=p);
}
int find(int v)
{
    int p=rt,res=0;
    while(true)
        if(val[p]>v) p=ch[p][0];
        else
        {
            res+=siz[ch[p][0]];
            if(val[p]==v) {splay(p); return res+1;}
            res+=cnt[p],p=ch[p][1];
        }
}
int rank(int x)
{
    int p=rt;
    while(true)
        if(siz[ch[p][0]]>=x) p=ch[p][0];
        else
        {
            int sizL=siz[ch[p][0]]+cnt[p];
            if(sizL>=x) return val[p];
            x-=sizL,p=ch[p][1];
        }
}
int pre(int v)
{
    int p=rt,res;
    while(p)
        if(val[p]<v) res=val[p],p=ch[p][1];
        else p=ch[p][0];
    return res;
}
int nxt(int v)
{
    int p=rt,res;
    while(p)
        if(val[p]>v) res=val[p],p=ch[p][0];
        else p=ch[p][1];
    return res;
}
void ins(int v)
{  
    if(rt==0) {create(rt=++ndCnt,v); return;}  
    int p=rt,q=0;  
    while(true)
    {
         if(val[p]==v) {cnt[p]++; update(p),update(q); splay(p); break;}
         q=p; p=ch[p][val[p]<v];  
         if(p==0)
         {  
              create(++ndCnt,v);
              fa[ch[q][val[q]<v]=ndCnt]=q;
              update(q); splay(ndCnt);
              break;  
         }  
    }  
}
void del(int v)
{
    int p=rt;
    while(true)
        if(val[p]==v) break;
        else p=ch[p][val[p]<v];
    splay(p);
    if(cnt[p]>1) {cnt[p]--; return;}
    int sCnt=(ch[p][0]>0)+(ch[p][1]>0);
    if(sCnt==0) rt=0;
    if(sCnt==1) fa[rt=ch[p][ch[p][0]==0]]=0; 
    if(sCnt==2)
    {
        int q=ch[p][0];
        while(ch[q][1]) q=ch[q][1];
        splay(q);
        fa[ch[q][1]=ch[p][1]]=q;
        update(q);
    }
    ch[p][0]=ch[p][1]=0; fa[p]=-1;
}
int main()
{
    int q; scanf("%d",&q);
    while(q--)
    {
        int opt,x;
        scanf("%d%d",&opt,&x);
        //printf("\n-------- %d %d Begin.\n",opt,x);
        if(opt==1) ins(x);
        if(opt==2) del(x);
        if(opt==3) printf("%d\n",find(x));
        if(opt==4) printf("%d\n",rank(x));
        if(opt==5) printf("%d\n",pre(x));
        if(opt==6) printf("%d\n",nxt(x));
        //for(int i=1;i<=ndCnt;i++) printf("%d fa=%d son=%d,%d val=%d siz=%d cnt=%d\n",i,fa[i],ch[i][0],ch[i][1],val[i],siz[i],cnt[i]);
    }
    return 0;
}

Version2

//普通平衡树
#include <cstdio>
#include <algorithm>
using namespace std;
inline char gc()
{
    static char now[1<<16],*S,*T;
    if(S==T) {T=(S=now)+fread(now,1,1<<16,stdin); if(S==T) return EOF;}
    return *S++;
}
inline int read()
{
    int x=0,f=1; char ch=gc();
    while(ch<'0'||'9'<ch) {if(ch=='-') f=-1; ch=gc();}
    while('0'<=ch&&ch<='9') x=x*10+ch-'0',ch=gc();
    return x*f;
}
int const N=1e5+10;
int const INF=0x7FFFFFFF;
int rt,ndCnt;
int fa[N],ch[N][2],val[N],cnt[N],siz[N];
int wh(int p) {return p==ch[fa[p]][1];}
void update(int p) {siz[p]=siz[ch[p][0]]+cnt[p]+siz[ch[p][1]];}
void create(int p,int x) {fa[p]=ch[p][0]=ch[p][1]=0; val[p]=x,cnt[p]=siz[p]=1;}
void rotate(int p)
{
    int q=fa[p],r=fa[q],w=wh(p);
    fa[ch[q][w]=ch[p][w^1]]=q;
    fa[p]=r; if(r) ch[r][wh(q)]=p;
    fa[ch[p][w^1]=q]=p;
    update(p),update(q);
}
void splay(int p,int &k)
{
    int r=fa[k];
    for(int q;(q=fa[p])!=r;rotate(p)) if(fa[q]!=r) rotate(wh(p)==wh(q)?q:p);
    update(k=p);
}
void ins(int x)
{
    int p=rt,q=0;
    while(p&&val[p]!=x) q=p,p=ch[p][val[p]<x];
    if(p) {splay(p,rt),cnt[rt]++,siz[rt]++; return;}
    create(p=++ndCnt,x);
    fa[p]=q; if(q) ch[q][val[q]<x]=p,update(q);
    splay(p,rt);
}
void del(int x)
{
    int p=rt; while(val[p]!=x) p=ch[p][val[p]<x];
    splay(p,rt); if(cnt[p]>1) {cnt[rt]--,siz[rt]--; return;}
    if(ch[p][0]*ch[p][1]==0) {fa[rt=ch[p][ch[p][1]>0]]=0; return;}
    int q=ch[p][0]; while(ch[q][1]) q=ch[q][1];
    splay(q,rt); fa[ch[rt][1]=ch[p][1]]=rt;
}
int find(int x)
{
    int p,res=0;
    for(p=rt;val[p]!=x;p=ch[p][val[p]<x])
        if(val[p]<x) res+=siz[ch[p][0]]+cnt[p];
    return res+siz[ch[p][0]]+1;
}
int rank(int x)
{
    int p=rt,res=0;
    while(true)
    {
        int sizL=siz[ch[p][0]]+cnt[p];
        if(x<=siz[ch[p][0]]) p=ch[p][0];
        else if(x<=sizL) return val[p];
        else x-=sizL,p=ch[p][1];
    }
}
int pre(int x)
{
    int res=-INF;
    for(int p=rt;p;p=ch[p][val[p]<x])
        if(val[p]<x) res=max(res,val[p]);
    return res;
}
int nxt(int x)
{
    int res=INF;
    for(int p=rt;p;p=ch[p][val[p]<=x])
        if(val[p]>x) res=min(res,val[p]);
    return res; 
}
int main()
{
    int T=read();
    while(T--)
    {
        int opt=read(),x=read();
        if(opt==1) ins(x);
        else if(opt==2) del(x);
        else if(opt==3) printf("%d\n",find(x));
        else if(opt==4) printf("%d\n",rank(x));
        else if(opt==5) printf("%d\n",pre(x));
        else if(opt==6) printf("%d\n",nxt(x));
    }
    return 0;
}

posted @ 2017-11-30 17:13  VisJiao  阅读(168)  评论(0编辑  收藏  举报