BZOJ 3224 平衡树模板题

Treap:

//By SiriusRen
#include <cstdio>
#include <algorithm>
using namespace std;
int n,op,xx,ans,size,root;
struct Treap{int ch[2],v,cnt,rnd,sz;}tr[300000];
void Upd(int k){tr[k].sz=tr[k].cnt+tr[tr[k].ch[0]].sz+tr[tr[k].ch[1]].sz;}
void rot(int &k,bool f){int t=tr[k].ch[f];tr[k].ch[f]=tr[t].ch[!f],tr[t].ch[!f]=k,Upd(k),Upd(t),k=t;}
void ins(int &k){
    if(!k){k=++size;tr[k].cnt=tr[k].sz=1;tr[k].rnd=rand(),tr[k].v=xx;return;}
    tr[k].sz++;
    if(tr[k].v==xx){tr[k].cnt++;return;}
    bool f=xx>tr[k].v;ins(tr[k].ch[f]);
    if(tr[tr[k].ch[f]].rnd<tr[k].rnd)rot(k,f);
}
void del(int &k){
    if(tr[k].v==xx){
        if(tr[k].cnt>1)tr[k].sz--,tr[k].cnt--;
        else if(!(tr[k].ch[0]*tr[k].ch[1]))k=max(tr[k].ch[0],tr[k].ch[1]);
        else rot(k,tr[tr[k].ch[0]].rnd>tr[tr[k].ch[1]].rnd),del(k);
    }
    else tr[k].sz--,del(tr[k].ch[xx>tr[k].v]);
}
int get_rank(int k){
    if(tr[k].v==xx)return tr[tr[k].ch[0]].sz+1;
    else if(tr[k].v>xx)return get_rank(tr[k].ch[0]);
    else return get_rank(tr[k].ch[1])+tr[tr[k].ch[0]].sz+tr[k].cnt;
}
int get_kth(int k,int x){
    if(tr[tr[k].ch[0]].sz>=x)return get_kth(tr[k].ch[0],x);
    else if(tr[tr[k].ch[0]].sz+tr[k].cnt<x)return get_kth(tr[k].ch[1],x-tr[tr[k].ch[0]].sz-tr[k].cnt);
    else return tr[k].v;
}
void get(int k){
    if(!k)return;
    if(op==5&&tr[k].v<xx)ans=tr[k].v,get(tr[k].ch[1]);
    else if(op==5&&tr[k].v>=xx)get(tr[k].ch[0]);
    else if(op==6&&tr[k].v>xx)ans=tr[k].v,get(tr[k].ch[0]);
    else get(tr[k].ch[1]);
}
int main(){
    scanf("%d",&n);
    for(int i=1;i<=n;i++){
        scanf("%d%d",&op,&xx);
        if(op==1)ins(root);
        else if(op==2)del(root);
        else if(op==3)printf("%d\n",get_rank(root));
        else if(op==4)printf("%d\n",get_kth(root,xx));
        else get(root),printf("%d\n",ans);
    }
}

Splay:

//By SiriusRen
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
int op,xx,root,n,size;
struct Splay{int ch[2],fa,v,sz,cnt;}tr[300500];
void Upd(int x){tr[x].sz=tr[tr[x].ch[0]].sz+tr[tr[x].ch[1]].sz+tr[x].cnt;}
void rot(int x){
    int y=tr[x].fa,z=tr[y].fa;
    bool f=(tr[y].ch[1]==x);
    tr[y].ch[f]=tr[x].ch[!f];
    if(tr[y].ch[f])tr[tr[y].ch[f]].fa=y;
    tr[x].ch[!f]=y;tr[y].fa=x;tr[x].fa=z;
    if(z)tr[z].ch[tr[z].ch[1]==y]=x;
    Upd(y);
}
void splay(int x,int tp){
    for(int y,z;(y=tr[x].fa)!=tp;rot(x)){
        z=tr[y].fa;
        if(z==tp)continue;
        if((tr[y].ch[0]==x)==(tr[z].ch[0]==y))rot(y);
        else rot(x);
    }
    if(!tp)root=x;
    Upd(x);
}
void insert(int x,int num){
    int y=0;
    while(x&&tr[x].v!=num)y=x,x=tr[x].ch[num>tr[x].v];
    if(x)++tr[x].cnt;
    else{
        x=++size,tr[x].sz=tr[x].cnt=1,tr[x].fa=y,tr[x].v=num;
        if(y)tr[y].ch[num>tr[y].v]=x;
    }
    splay(x,0);
}
void find(int v){
    int x=root;
    while(tr[x].ch[v>tr[x].v]&&tr[x].v!=v)x=tr[x].ch[v>tr[x].v];
    splay(x,0);
}
int next(int v,bool f){
    find(v);
    if((tr[root].v>v&&f)||(tr[root].v<v&&!f))return root;
    int p=tr[root].ch[f];
    while(tr[p].ch[!f])p=tr[p].ch[!f];
    return p;
}
void del(int v){
    int p=next(v,0),s=next(v,1);
    splay(p,0),splay(s,p);
    p=tr[s].ch[0];
    if(tr[p].cnt>1)tr[p].cnt--,splay(p,0);
    else tr[s].ch[0]=0;
}
int kth(int x){
    int y=root,p;
    if(x>tr[root].sz)return 0;
    while(1){
        p=tr[y].ch[0];
        if(tr[p].sz+tr[y].cnt<x)x=x-tr[p].sz-tr[y].cnt,y=tr[y].ch[1];
        else if(tr[p].sz>=x)y=p;
        else return tr[y].v;
    }
}
int main(){
    insert(root,0x3fffffff),insert(root,-0x3fffffff);
    scanf("%d",&n);
    for(int i=1;i<=n;i++){
        scanf("%d%d",&op,&xx);
        if(op==1)insert(root,xx);
        else if(op==2)del(xx);
        else if(op==3)find(xx),printf("%d\n",tr[tr[root].ch[0]].sz);
        else if(op==4)printf("%d\n",kth(xx+1));
        else printf("%d\n",tr[next(xx,op==6)].v);
    }
}

posted @ 2016-12-04 21:36  SiriusRen  阅读(220)  评论(0编辑  收藏  举报