SBT平衡树

SizeBalancedTree是通过维护节点大小来保持平衡的二叉搜索树,平衡的条件是任意节点的size不小于其兄弟节点的所有子节点的size。

重点在于时间复杂度均摊O(1)的maintain函数。

对于每个节点,若左子树的左子树的节点数大于右子树的节点,需要右旋根来平衡。

若左子树的右子树的节点数大于右子树的节点数,需要先将左子树平衡,再右旋根来平衡。

若右子树的右子树的节点数大于左子树的节点数,需要左旋根来平衡。

若右子树的左子树的节点数大于左子树的节点数,需要先平衡右子树,再左旋根来平衡。

由于SBT以来size进行平衡,所以一半不维护自身副本数。

struct SizeBalancedTree{
    struct tree{
        int ch[2],v,w;
    }t[N];
    int root,tot;
    unsigned long long inf=1e10;
    #define l(p) (t[p].ch[0])
    #define r(p) (t[p].ch[1])
    #define ch(p,x) (t[p].ch[x])
    #define v(p) (t[p].v)
    #define w(p) (t[p].w)
    inline int create(int x){
        t[++tot]={0,0,x,1};
        return tot;
    }
    inline void rotate(int&p,int d){
        int x=ch(p,d^1);
        ch(p,d^1)=ch(x,d);
        ch(x,d)=p;
        w(x)=w(p);/*新根替换原根,新根的size为原根的size*/
        w(p)=w(l(p))+w(r(p))+1;/*更新原根的size*/
        p=x;
    }
    void maintain(int&p,int d){/*d=0左旋,d=1右旋*/
        if(w(ch(ch(p,d),d))>w(ch(p,d^1)))rotate(p,d^1);/*假设d=0,左子树的左子树的size>右子树的size,需要右旋根节点*/
        else if(w(ch(ch(p,d),d^1))>w(ch(p,d^1))){/*假设d=0,左子树的右子树的size>右子树的size*/
            rotate(ch(p,d),d);/*先左旋左儿子平衡左子树*/
            rotate(p,d^1);/*再右旋根节点*/
        }
        else return;
        maintain(l(p),0);/*先递归平衡子树*/
        maintain(r(p),1);
        maintain(p,0);/*再递归平衡自身*/
        maintain(p,1);
    }
    void insert(int&p,int x){
        if(!p){/*空节点则创建*/
            p=create(x);
            return;
        }
        w(p)++;/*新建时一路增加根节点size不再pushup*/
        int d=v(p)<x;
        insert(ch(p,d),x);
        maintain(p,d);
    }
    int del(int&p,int x){
        w(p)--;/*删除时一路减少根节点size不再pushup*/
        int d=v(p)<x;
        if(x==v(p)||!ch(p,d)){/*找到一个权值x的点或x应该在当前点的子树方向但是该方向为空*/
            int v=v(p);
            if(l(p)&&r(p))v(p)=del(l(p),v(p)+1);/*用左子树中的前驱替代*/
            else p=l(p)|r(p);/*将一个非空的子树拉上来替代*/
            return v;
        }
        return del(ch(p,d),x);
    }
/*
递归版本
    int rank(int p,int x){
        if(!p)return 1;
        if(x<=v(p))return rank(l(p),x);
        else return rank(r(p),x)+w(l(p))+1;
    }
    int kth(int p,int x){
        if(x==w(l(p))+1)return v(p);
        else if(x<=w(l(p)))return kth(l(p),x);
        else return kth(r(p),x-w(l(p))-1);
    }
*/
    inline int rank(int x){
        int p=root,re=1;
        while(p){
            if(x<=v(p))p=l(p);
            else re+=w(l(p))+1,p=r(p);
        }
        return re;
    }
    inline int kth(int x){
        int p=root;
        while(w(l(p))+1!=x){
            if(x<=w(l(p)))p=l(p);
            else x-=w(l(p))+1,p=r(p);
        }
        return v(p);
    }
    int pre(int p,int x){
        if(!p)return -inf;
        if(v(p)<x)return max(v(p),pre(r(p),x));
        else return pre(l(p),x);
    }
    int suc(int p,int x){
        if(!p)return inf;
        if(v(p)>x)return min(v(p),suc(l(p),x));
        else return suc(r(p),x);
    }
    inline void insert(int x){
        insert(root,x);
    }
    inline void del(int x){
        del(root,x);
    }
    inline int pre(int x){
        return pre(root,x);
    }
    inline int suc(int x){
        return suc(root,x);
    }
};
posted @ 2022-11-14 17:52  半步蒟蒻  阅读(75)  评论(0)    收藏  举报