splay

其实一直都不是很会splay

#include<bits/stdc++.h>
#define F(i0,i1,i2) for(int i0=i1;i0<=i2;++i0)
#define N 100005
using namespace std;
inline int rd(){
    int x=0,f=0;char ch=getchar();
    while(!isdigit(ch)){if(ch=='-')f=1;ch=getchar(); }
    while(isdigit(ch)){x=(x<<3)+(x<<1)+ch-48;ch=getchar();}
    return f?-x:x;
}
struct Splay{int f,c[2],v,ct,num; }t[N];
int tot,rot;
inline void push_up(int p){t[p].ct=t[t[p].c[0]].ct+t[t[p].c[1]].ct+t[p].num;}
inline bool get(int x){return t[t[x].f].c[1]==x;}
inline void rotate(int x){
    int y=t[x].f,z=t[y].f,k=get(x);//x 是不是右节点 
    t[y].c[k]=t[x].c[k^1];
    if(t[y].c[k])t[t[y].c[k]].f=y;
    t[x].c[k^1]=y; 
    if(z)t[z].c[get(y)]=x;
    t[y].f=x;
    t[x].f=z;
    push_up(x);
    push_up(y);
}
inline void splay(int x,int goal){
    while(t[x].f!=goal){
        int y=t[x].f;
        if(t[y].f!=goal)
            rotate(get(x)==get(y)?y:x);
        rotate(x);
    }
    if(!goal)rot=x;
}
void ins(int val){
    int u=rot,f=0;
    while(u&&t[u].v!=val)
        f=u,u=t[u].c[t[u].v<val];
    if(u)t[u].num++,t[u].ct++;
    else {
        t[u=++tot]={f,{0,0},val,1,1};
        if(f)t[f].c[t[f].v<val]=u;
    }
    splay(u,0);
}
void fnd(int val){
    int x=rot;
    if(!x)return;
    while(t[x].c[t[x].v<val]&&t[x].v!=val)x=t[x].c[t[x].v<val];
    splay(x,0);
}
int kth(int k){
    if(k>t[rot].ct) return 0;
    int x=rot;
    while(1){
        if(t[t[x].c[0]].ct>=k)x=t[x].c[0];
        else {
            int num=t[t[x].c[0]].ct+t[x].num;
            if(k<=num)return t[x].v;
            else k-=num,x=t[x].c[1];
        }
    }
}
int nxt(int val,int op){ //0 前驱 1 后继 
    fnd(val);
    if((!op&&t[rot].v<val)||(op&&t[rot].v>val))return rot;
    int x=t[rot].c[op];
    while(t[x].c[op^1])x=t[x].c[op^1];
    return x;
}
void del(int val){
    int pre=nxt(val,0),suc=nxt(val,1);
    splay(pre,0);
    splay(suc,pre);
    int x=t[suc].c[0];
    if(t[x].num>1){
        t[x].num--,t[x].ct--;
        splay(x,0);
    }
    else {
        t[suc].c[0]=0;
        push_up(pre);
        push_up(suc);
    }

}
signed main(){
    int n=rd();
    ins(INT_MAX);
    ins(INT_MIN);
    while(n--){
        int op=rd(),x=rd();
        if(op==1)ins(x);
        else if(op==2)del(x);
        else if(op==3){
            fnd(x);
            cout<<t[t[rot].c[0]].ct<<endl;
        }
        else if(op==4)cout<<kth(x+1)<<endl;
        else if(op==5)cout<<t[nxt(x,0)].v<<endl;
        else cout<<t[nxt(x,1)].v<<endl;
    }
    return 0;
}
posted @ 2023-09-09 17:09  ussumer  阅读(39)  评论(0)    收藏  举报