平衡树模板——splay

/*
在splay中
0不能算作是根节点,只能说是一个标记点
如果谁的父亲是0,那么谁就是根节点
*/

#include <bits/stdc++.h>
using namespace std;
const int M=1e5+5;
const int inf=1e9;

#define t tr
#define size siz

int cnt=0,root=0;

struct splay {
    int ch[2],siz,cnt,val,fa;
}tr[M];

int get(int x) {
    return tr[tr[x].fa].ch[1]==x;
}

void up(int x) {
    tr[x].siz=tr[tr[x].ch[0]].siz+tr[tr[x].ch[1]].siz+tr[x].cnt;
}

void rotate(int x) {
    int y=tr[x].fa,z=tr[y].fa;
    int d1=get(x),d2=get(y);
    int son=tr[x].ch[d1^1];
    tr[y].ch[d1]=son;tr[son].fa=y;
    tr[z].ch[d2]=x;tr[x].fa=z;
    tr[x].ch[d1^1]=y;tr[y].fa=x;
    up(y);up(x);
}

void splay(int x,int goal) {
    while(tr[x].fa!=goal) {
        int y=tr[x].fa,z=tr[y].fa;
        int d1=get(x),d2=get(y);
        if(z!=goal) {
            if(d1==d2)rotate(y);
            else rotate(x);
        }
        rotate(x);
    }
    if(goal==0)root=x;
}

int find(int val) {
    int node=root;
    while(tr[node].val!=val&&tr[node].ch[tr[node].val<val])node=tr[node].ch[tr[node].val<val];
    return node;
}


void insert(int val) {
    int node=root,fa=0;
    while(tr[node].val!=val&&node)
        fa=node,node=tr[node].ch[tr[node].val<val];
    if(node)tr[node].cnt++;
    else {
        node=++cnt;
        if(fa)tr[fa].ch[tr[fa].val<val]=node;
        tr[node].siz=tr[node].cnt=1;
        tr[node].fa=fa,tr[node].val=val;
    }
    splay(node,0);
}

int pre(int val,int k) {
    splay(find(val),0);
    int node=root;
    if(k==0&&tr[node].val<val)return node;
    if(k==1&&tr[node].val>val)return node;
    node = tr[node].ch[k];
    while(tr[node].ch[k^1])node=tr[node].ch[k^1];
    return node;
}

void del(int val){
    int last = pre(val,0), next = pre(val,1);
    splay(last , 0); splay(next , last);
    if(t[t[next].ch[0]].cnt > 1){
	    t[t[next].ch[0]].cnt--;
	    splay(t[next].ch[0] , 0);
    }
    else t[next].ch[0] = 0;
}

int kth(int k){
    int node = root;
    if(t[node].size < k) return inf;
    while(1){
	    int son = t[node].ch[0];
	    if(k <= t[son].size) node = son;
	    else if(k > t[son].size+t[node].cnt){
	       k -= t[son].size+t[node].cnt;
	        node = t[node].ch[1];
		}
	    else return t[node].val;
    }
}

int get_rank(int val){
    splay(find(val) , 0);
    return t[t[root].ch[0]].size;
}


int main() {
    insert(-inf);insert(inf);
    int q;cin>>q;
    while(q--) {
        int op,x;
        cin>>op>>x;
        if(op==1)insert(x);
        if(op==2)del(x);
        if(op==3)cout<<get_rank(x)<<endl;
        if(op==4)cout<<kth(x+1)<<endl;
        if(op==5)cout<<tr[pre(x,0)].val<<endl;
        if(op==6)cout<<tr[pre(x,1)].val<<endl;
    }
    return 0;
}
posted @ 2023-04-14 23:31  basicecho  阅读(28)  评论(0)    收藏  举报