平衡树

普通平衡树

  1. 插入数值x.
  2. 删除数值x(若有多个相同的数,则只删除一个)
  3. 查询数值x的排名(若有多个相同的数,应输出最小排名)
  4. 查询排名为x的的数值
  5. 求数值x的前驱
  6. 求数值x的后继
#include<bits/stdc++.h>
using namespace std;
const int N=2e5+10,INF=1e8;
int n,root,idx;
struct Node{
    int l,r;
    int key,val;
    int cnt,size;//当前节点key重复个数和子孙结点个数
}tr[N];
//更新父节点信息
void pushup(int p){
    tr[p].size=tr[tr[p].l].size+tr[tr[p].r].size+tr[p].cnt;
}
//创造叶子结点
int get_node(int key){
    tr[++idx].key=key;
    tr[idx].val=rand();
    tr[idx].cnt=tr[idx].size=1;
    return idx;
}
//初始化平衡树
void build(){
    get_node(-INF);get_node(INF);
    root=1,tr[1].r=2;//INF在-INF右边
    pushup(root);
}
//右旋
void zig(int &p){
    int q=tr[p].l;
    tr[p].l=tr[q].r;
    tr[q].r=p;
    p=q;
    pushup(tr[p].r);
    pushup(p);
}
//左旋
void zag(int &p){
    int q=tr[p].r;
    tr[p].r=tr[q].l;
    tr[q].l=p;
    p=q;
    pushup(tr[p].l);
    pushup(p);
}

void inserts(int &p,int key){
    if(!p)p=get_node(key);
    else if(tr[p].key==key)tr[p].cnt++;
    else if(tr[p].key>key){
        inserts(tr[p].l,key);
        if(tr[tr[p].l].val>tr[p].val)zig(p);
    }
    else {
        inserts(tr[p].r,key);
        if(tr[tr[p].r].val>tr[p].val)zag(p);
    }
    pushup(p);
}

void remove(int &p,int key){
    if(!p)return ;
    if(tr[p].key==key){
        if(tr[p].cnt>1)tr[p].cnt--;
        else if(tr[p].l||tr[p].r){
            if(!tr[p].r||tr[tr[p].l].val>tr[tr[p].r].val){
                //只存在左儿子或者左val大于右val
                zig(p);
                remove(tr[p].r,key);
            }
            else {
                //存在右儿子且左val小于右val
                zag(p);
                remove(tr[p].l,key);
            }
        }
        else p=0;
    }
    else if(tr[p].key>key)remove(tr[p].l,key);
    else remove(tr[p].r,key);
    pushup(p);
}

int get_rank_bykey(int p,int key){
    if(!p)return 0;
    if(tr[p].key==key)return tr[tr[p].l].size+1;//左子树size一
    if(tr[p].key>key)return get_rank_bykey(tr[p].l,key);//大了,去左子树找
    return tr[tr[p].l].size+tr[p].cnt+get_rank_bykey(tr[p].r,key);
}

int get_key_byrank(int p,int rank){
    if(!p)return INF;
    if(tr[tr[p].l].size>=rank)return get_key_byrank(tr[p].l,rank);
    if(tr[tr[p].l].size+tr[p].cnt>=rank)return tr[p].key;
    return get_key_byrank(tr[p].r,rank-tr[tr[p].l].size-tr[p].cnt);
}
//找到严格小于key的最大数
int get_pre(int p,int key){
    if(!p)return -INF;
    if(tr[p].key>=key)return get_pre(tr[p].l,key);
    return max(tr[p].key,get_pre(tr[p].r,key));
}
//找到严格大于key的最小数
int get_next(int p,int key){
    if(!p)return INF;
    if(tr[p].key<=key)return get_next(tr[p].r,key);
    return min(tr[p].key,get_next(tr[p].l,key));
}

int main(){
    build();
    int n;cin>>n;
    while(n--){
        int op,x;
        cin>>op>>x;
        if(op==1)inserts(root,x);
        if(op==2)remove(root,x);
        if(op==3)cout<<get_rank_bykey(root,x)-1<<endl;
        if(op==4)cout<<get_key_byrank(root,x+1)<<endl;
        if(op==5)cout<<get_pre(root,x)<<endl;
        if(op==6)cout<<get_next(root,x)<<endl;
    }
    return 0;
}

 

posted @ 2022-11-18 16:55  Dengpc  阅读(51)  评论(0)    收藏  举报