Luogu 3369 / BZOJ 3224 - 普通平衡树 - [替罪羊树]

题目链接:

https://www.lydsy.com/JudgeOnline/problem.php?id=3224

https://www.luogu.org/problemnew/show/P3369

Description

您需要写一种数据结构(可参考题目标题),来维护一些数,其中需要提供以下操作:
1. 插入x数
2. 删除x数(若有多个相同的数,因只删除一个)
3. 查询x数的排名(若有多个相同的数,因输出最小的排名)
4. 查询排名为x的数
5. 求x的前驱(前驱定义为小于x,且最大的数)
6. 求x的后继(后继定义为大于x,且最小的数)

Input

第一行为n,表示操作的个数,下面n行每行有两个数opt和x,opt表示操作的序号(1<=opt<=6)

Output

对于操作3,4,5,6每行输出一个数,表示对应答案

Sample Input

10

1 106465

4 1

1 317721

1 460929

1 644985

1 84185

1 89851

6 81968

1 492737

5 493598

Sample Output

106465

84185

492737

HINT

1.n的数据范围:n<=100000

2.每个数的数据范围:[-2e9,2e9]

 

关于替罪羊树:

替罪羊树的主要思想就是将不平衡的树压成一个序列,然后暴力重构成一颗平衡的树。

这里的平衡指的是:对于某个 $0.5 \le \alpha \le 1$ 满足 $size( ls(x) ) \le \alpha \cdot size(x)$ 并且 $size( rs(x) ) \le \alpha \cdot size(x)$。一般 $\alpha$ 取 $0.7 \sim 0.8$。

更加详细的解释和模板请参考替罪羊树(重量平衡树)入门

 

AC代码:

#include<bits/stdc++.h>
using namespace std;

const int maxn=1e5+10;
const double alpha=0.8;
struct Node
{
    Node* ch[2]; //左右子节点
    int key,siz,cov; //key是值,siz是以该节点为根的树的存在的节点数,cover是所有节点数量
    bool ext;
    void pushup() { //更新函数
        siz = ch[0]->siz + ch[1]->siz + ext;
        cov = ch[0]->cov + ch[1]->cov + 1;
    }
    inline bool isbad() { //判断是否要重构
        return alpha*cov+5 < max(ch[0]->cov,ch[1]->cov);
    }
};
struct ScapegoatTree
{
protected:
    Node mem[maxn]; //内存池
    Node *tail,*null,*root; //tail为指向内存池元素的指针
    Node *bak[maxn]; int baksz; //内存回收池

    Node* newnode(int key)
    {
        Node* p=baksz?bak[--baksz]:tail++;
        p->ch[0] = p->ch[1] = null;
        p->siz = p->cov= p->ext = 1;
        p->key = key;
        return p;
    }

    void travel(vector<Node*>& v,Node* p) //中序遍历将一棵树转化成序列
    {
        if(p==null) return;
        travel(v,p->ch[0]);
        if(p->ext) v.push_back(p);
        else bak[baksz++]=p;
        travel(v,p->ch[1]);
    }

    Node* build(vector<Node*>& v,int l,int r)
    {
        if(l>=r) return null;
        int mid=(l+r)>>1;
        Node *p=v[mid];
        p->ch[0] = build(v,l,mid);
        p->ch[1] = build(v,mid+1,r);
        p->pushup();
        return p;
    }

    vector<Node*> cur;
    void rebuild(Node*& p)
    {
        cur.clear();
        travel(cur,p);
        p=build(cur,0,cur.size());
    }

    Node** insert(Node*& p,int val)
    {
        if(p==null)
        {
            p=newnode(val);
            return &null;
        }
        p->siz++, p->cov++;
        Node** res=insert(p->ch[val>=p->key],val);
        if(p->isbad()) res=&p;
        return res;
    }

    void erase(Node*& p,int k)
    {
        p->siz--; //维护siz
        int offset = p->ch[0]->siz + p->ext; //计算左子树的存在的节点总数
        if(p->ext && k==offset) p->ext=0;
        else
        {
            if(k<=offset) erase(p->ch[0],k);
            else erase(p->ch[1],k-offset);
        }
    }

public:
    void init()
    {
        tail=mem;
        null=tail++;
        null->ch[0] = null->ch[1] = null;
        null->key = 0;
        null->siz = null->cov = null->ext = 0;
        root=null; //初始化根节点
        baksz=0; //清空栈
    }
    ScapegoatTree() {
        init();
    }

    void insert(int val)
    {
        Node** res=insert(root,val);
        if(*res!=null) rebuild(*res);
    }

    int getrank(int val)
    {
        Node *p=root;
        int res=1;
        while(p!=null)
        {
            if(val <= p->key) p=p->ch[0];
            else
            {
                res += p->ch[0]->siz + p->ext;
                p = p->ch[1];
            }
        }
        return res;
    }

    int getkth(int k)
    {
        Node *p=root;
        while(p!=null)
        {
            if(p->ch[0]->siz+1==k && p->ext) return p->key;
            if(k <= p->ch[0]->siz) p=p->ch[0];
            else k-=p->ch[0]->siz + p->ext, p=p->ch[1];
        }
    }

    void delval(int val)
    {
        erase(root,getrank(val));
        if(root->siz < alpha * root->cov) rebuild(root);
    }

    void delkth(int k)
    {
        erase(root,k);
        if(root->siz < alpha * root->cov) rebuild(root);
    }
}st;

int main()
{
    int n,opt,x;
    scanf("%d",&n);
    while(n--)
    {
        scanf("%d%d",&opt,&x);
        if(opt==1) st.insert(x);
        if(opt==2) st.delval(x);
        if(opt==3) printf("%d\n",st.getrank(x));
        if(opt==4) printf("%d\n",st.getkth(x));
        if(opt==5) printf("%d\n",st.getkth(st.getrank(x)-1));
        if(opt==6) printf("%d\n",st.getkth(st.getrank(x+1)));
    }
}

 

posted @ 2018-12-02 23:40  Dilthey  阅读(226)  评论(0编辑  收藏  举报