bzoj 3224(平衡树)

传送门

题意:

请你完成66种最基础的平衡树的操作。

题目分析:

什么也别说了,平衡树的板子题。收获到了一个写起来比较舒服的SplaySplay的板子。

代码:

#include <bits/stdc++.h>
using namespace std;
const int maxn=200005;

int ch[maxn][2],par[maxn],val[maxn],cnt[maxn],size[maxn],ncnt,root;
bool chk(int x){
    return ch[par[x]][1]==x;
}
void push_up(int x){
    size[x]=size[ch[x][0]]+size[ch[x][1]]+cnt[x];
}
void rotate(int x){
    int y=par[x],z=par[y],k=chk(x),w=ch[x][k^1];
    ch[y][k]=w; par[w]=y;
    ch[z][chk(y)]=x;par[x]=z;
    ch[x][k^1]=y; par[y]=x;
    push_up(y);
    push_up(x);
}

void splay(int x,int goal=0){
    while(par[x]!=goal){
        int y=par[x],z=par[y];
        if(z!=goal){
            if(chk(x)==chk(y)) rotate(y);
            else rotate(x);

        }
        rotate(x);
    }
    if(!goal) root=x;
}

void insert(int x){
    int cur=root,p=0;
    while(cur&&val[cur]!=x){
        p=cur;
        cur=ch[cur][x>val[cur]];
    }
    if(cur){
        cnt[cur]++;
    }
    else{
        cur=++ncnt;
        if(p) ch[p][x>val[p]]=cur;
        ch[cur][0]=ch[cur][1]=0;
        par[cur]=p;val[cur]=x;
        cnt[cur]=size[cur]=1;
    }
    splay(cur);
}


void find(int x){
    int cur=root;
    while(ch[cur][x>val[cur]]&&x!=val[cur]){
        cur=ch[cur][x>val[cur]];
    }
    splay(cur);
}

int kth(int k){
    int cur=root;
    while(1){
        if(ch[cur][0]&&k<=size[ch[cur][0]]){
            cur=ch[cur][0];
        }else if(k>size[ch[cur][0]]+cnt[cur]){
            k-=size[ch[cur][0]]+cnt[cur];
            cur=ch[cur][1];
        }else return cur;
    }
}
int pre(int x){
    find(x);
    if(val[root]<x) return root;
    int cur=ch[root][0];
    while(ch[cur][1]) cur=ch[cur][1];
    return cur;
}

int succ(int x){
    find(x);
    if(val[root]>x) return root;
    int cur=ch[root][1];
    while(ch[cur][0]) cur=ch[cur][0];
    return cur;
}

void remove(int x){
    int last=pre(x),next=succ(x);
    splay(last);
    splay(next,last);
    int del=ch[next][0];
    if(cnt[del]>1){
        cnt[del]--;
        splay(del);
    }
    else ch[next][0]=0,push_up(next),push_up(root);
}
int n,op,x;

int main()
{
    scanf("%d",&n);
    insert(0x3f3f3f3f);
    insert(0xcfcfcfcf);
    while(n--){
        scanf("%d%d",&op,&x);
        if(op==1) insert(x);
        if(op==2) remove(x);
        if(op==3){
            find(x);
            printf("%d\n",size[ch[root][0]]);
        }
        if(op==4){
            printf("%d\n",val[kth(x+1)]);
        }
        if(op==5){
            printf("%d\n",val[pre(x)]);
        }
        if(op==6){
            printf("%d\n",val[succ(x)]);
        }
    }
}

posted @ 2019-03-21 22:37  ChenJr  阅读(145)  评论(0编辑  收藏  举报