032.有序表之AVL树

模板

luogu P3369

luogu P5076

P5076需把INT_MIN换成-2147483647

支持API :

  • add(num) 添加 keynum 的节点,自动维护平衡

  • remove(num) 删除keynum的节点,自动维护平衡,若不存在则无事发生

  • rank(num) 获取keynum的节点的排名

定义 key为num的节点 的排名为集合中 key小于num 的节点的个数 +1

  • index(x) 获取排名为x的节点的key

  • pre(num) 获取keynum的节点的前驱,若不存在返回INT_MIN

  • post(num) 获取keynum的节点的后继,若不存在返回INT_MAX

  • clean() 初始化AVL

#include<iostream>
#include<climits>
using namespace std;
const int N=1e5+5;
struct AVL{
private:
    int cnt=0;
    int head=0;
    int key[N];
    int height[N];
    int left[N];
    int right[N];
    int size[N];
    int count[N];
    void up(int i){
        height[i]=max(height[left[i]],height[right[i]])+1;
        size[i]=count[i]+size[left[i]]+size[right[i]];
    }
    int leftRotate(int i){
        int r=right[i];
        right[i]=left[r];
        left[r]=i;
        up(i);
        up(r);
        return r;
    }
    int rightRotate(int i){
        int l=left[i];
        left[i]=right[l];
        right[l]=i;
        up(i);
        up(l);
        return l;
    }
    int maintain(int i){
        int lh=height[left[i]];
        int rh=height[right[i]];
        if(lh-rh>1){
            if(height[left[left[i]]]>=height[right[left[i]]]){
                i=rightRotate(i);
            }
            else{
                left[i]=leftRotate(left[i]);
                i=rightRotate(i);
            }
        }
        else if(rh-lh>1){
            if(height[right[right[i]]]>=height[left[right[i]]]){
                i=leftRotate(i);
            }
            else{
                right[i]=rightRotate(right[i]);
                i=leftRotate(i);
            }
        }
        return i;
    }
    int add(int i,int num){
        if(i==0){
            key[++cnt]=num;
            count[cnt]=size[cnt]=height[cnt]=1;
            return cnt;
        }
        if(key[i]==num){
            count[i]++;
        }
        else if(key[i]>num){
            left[i]=add(left[i],num);
        }
        else{
            right[i]=add(right[i],num);
        }
        up(i);
        return maintain(i);
    }
    int remove(int i,int num){
        if(key[i]<num){
            right[i]=remove(right[i],num);
        }
        else if(key[i]>num){
            left[i]=remove(left[i],num);
        }
        else{
            if(count[i]>1){
                count[i]--;
            }
            else{
                if(left[i]==0&&right[i]==0)return 0;
                else if(right[i]==0){
                    i=left[i];
                }
                else if(left[i]==0){
                    i=right[i];
                }
                else{
                    int mostLeft=right[i];
                    while(left[mostLeft]){
                        mostLeft=left[mostLeft];
                    }
                    right[i]=removeMostLeft(right[i],mostLeft);
                    left[mostLeft]=left[i];
                    right[mostLeft]=right[i];
                    i=mostLeft;
                }
            }
        }
        up(i);
        return maintain(i);
    }
    int removeMostLeft(int i,int mostLeft){
        if(i==mostLeft){
            return right[i];
        }
        else{
            left[i]=removeMostLeft(left[i],mostLeft);
            up(i);
            return maintain(i);
        }
    }
    int small(int i,int num){
        if(i==0)return 0;
        if(key[i]>=num){
            return small(left[i],num);
        }
        else{
            return size[left[i]]+count[i]+small(right[i],num);
        }
    }
    int index(int i,int x){
        if(size[left[i]]>=x)return index(left[i],x);
        else if(size[left[i]]+count[i]<x){
            return index(right[i],x-size[left[i]]-count[i]);
        }
        return key[i];
    }
    int pre(int i,int num){
        if(i==0)return INT_MIN;
        if(key[i]>=num){
            return pre(left[i],num);
        }
        else{
            return max(key[i],pre(right[i],num));
        }
    }
    int post(int i,int num){
        if(i==0)return INT_MAX;
        if(key[i]<=num){
            return post(right[i],num);
        }
        else{
            return min(key[i],post(left[i],num));
        }
    }
public:
    void add(int num){
        head=add(head,num);
    }
    void remove(int num){
        if(rank(num)!=rank(num+1)){
            head=remove(head,num);
        }
    }
    int rank(int num){
        return small(head,num)+1;
    }
    int index(int x){
        return index(head,x);
    }
    int pre(int num){
        return pre(head,num);
    }
    int post(int num){
        return post(head,num);
    }
    void clean(){
        for(int i=1;i<=cnt;++i){
            key[i]=0;
            height[i]=0;
            left[i]=0;
            right[i]=0;
            count[i]=0;
            size[i]=0;
        }
        cnt=0;
        head=0;
    }
}avl;
posted @ 2026-01-02 19:48  射杀百头  阅读(6)  评论(0)    收藏  举报