平衡树Treap

普通平衡树 Treap

题目描述

您需要写一种数据结构(可参考题目标题),来维护一些数,其中需要提供以下操作:

1.插入 x 数
2.删除 x 数(若有多个相同的数,因只删除一个)
3.查询 x 数的排名(排名定义为比当前数小的数的个数 +1 。若有多个相同的数,因输出最小的排名)
4.查询排名为 x 的数
5.求 x 的前驱(前驱定义为小于 x ,且最大的数)
6.求 x 的后继(后继定义为大于 x ,且最小的数)

输入输出格式

输入格式:

第一行为 n ,表示操作的个数,下面 n 行每行有两个数 opt 和 x , opt 表示操作的序号( 1≤opt≤6 ) 输出格式:

对于操作 3,4,5,6 每行输出一个数,表示对应答案

分析:

平衡树板子题。 注意的是: 1.增加相同的数累计次数。(pushup等)(1——>t[o].sum)(WA无数) 2.注意哨兵不要被赋值(pushup) 3.查询排名,是最小的。 4.查询排名为x的数,就是第k大,注意是k-t[o].sum-t[t[o].ch[0]].size

1.初始化:

其中sum记录该节点val值有几个

const int N=100000+10;
const int inf=0x3f3f3f3f;
struct node{
    int ch[2];
    int val,size,prio;
    int sum;
};
void init(){
    root=0;//认为root开始为0
}

2.发放节点/回收节点

int poolcur;
int delpool[N],delcur;
int newnode(){
    int r=delcur?delpool[delcur--]:++poolcur;
    memset(t+r,0,sizeof(node));
    t[r].size=1;
    t[r].prio=rand();
    t[r].sum=1;
    return r;
}
void delnode(int o){
   delpool[++delcur]=o;
}

3.pushup

注意:

①.t[t[o].ch[1]].size+t[t[o].ch[0]].size+t[o].sum;

②.if(!o) return; 保护哨兵(这里其实不需要了)

void pushup(int o){
    if(!o) return;
    t[o].size=t[t[o].ch[1]].size+t[t[o].ch[0]].size+t[o].sum;
}

4.rotate

注意o=u位置放最后,注意pushup()位置。

void rotate(int &o,int d){
    if(!o) return;
    int u=t[o].ch[d];
    t[o].ch[d]=t[u].ch[d^1];
    t[u].ch[d^1]=o;
    t[u].size=t[o].size;
    pushup(o);
    o=u;
}

5.insert

注意:

①值相等时,++sum与size(调半天)

②注意pushup

③根据优先级rotate

④发现当o为0时,会建边,所以开始root必须为0,而哨兵不能有问题。并且发放节点时,++poolcur而不是poolcur++,保证取到整数节点号。(调半天)

void insert(int &o,int v){
    if(!o){
        o=newnode();
        t[o].val=v;
        return;
    }
    if(t[o].val==v){
        t[o].sum++;
        t[o].size++;
        return;
    }
    int d=t[o].val<v;
    insert(t[o].ch[d],v);
    pushup(o);
    if(t[t[o].ch[d]].prio<t[o].prio) rotate(o,d);
}

6.remove

注意:

①先判断是否找到了删除点。保证找到了,再删除(WA一次)

②先删除一下sum,若不为0,跳过这一步。否则继续操作。注意每次判断sum是否可以再删,否则可能会在后面的递归中将sum删成负数。

③无论如何必须要记得pushup(WA一次)

void remove(int &o,int v){
    if(!o) return;
    if(t[o].val==v){ 
        if(t[o].sum>0) t[o].sum--;
        if(t[o].sum<=0)
        {
        int u=o;
        if(!t[o].ch[0]){
            o=t[o].ch[1];
            delnode(u);
        }
        else if(!t[o].ch[1]){
            o=t[o].ch[0];
            delnode(u);
        }
        else{
            int d=t[t[o].ch[0]].prio<t[t[o].ch[0]].prio;
            rotate(o , d^1);
            remove(t[o].ch[d],v);
        }
        }
    }
    else{
        int d=t[o].val<v;
        remove(t[o].ch[d] , v);
    }
    pushup(o);
}

7.对于操作3找v的排名

注意:

①找到最小的位置,所以1+t[t[o].ch[0]].size 注意左子树都比它小(调半天)

②因为可能不存在v 所以(!o)return 0

③往下找 +t[o].sum而不是1

int rank(int o, int v)//找到v的排名
{
    if(!o) return 0;
    if(t[o].val>v) return rank(t[o].ch[0], v);
    else if(t[o].val==v) return 1+t[t[o].ch[0]].size;
    else return t[t[o].ch[0]].size+t[o].sum+rank(t[o].ch[1],v);
}

8.对于操作4找到第k大

①想往下找必须剩下的d>t[o].sum(WA一次)

int find(int o,int k){// 找第K小的数
    if(o==0||k==0) return 0;
    int d=k-t[t[o].ch[0]].size;
    if(d<= 0) return find(t[o].ch[0],k);
    else if(d>=1&&d<=t[o].sum) return t[o].val;
    else return find(t[o].ch[1],d-t[o].sum);
}

9.对于5/6找前去后继

①想清楚min/max以及inf/-inf就好

int find_front(int o,int v)//找到小于v的最大值
{
    if(!o) return -inf;
    int d= t[o].val>=v;
    if(!d) return max(t[o].val,find_front(t[o].ch[1],v));
    else return find_front(t[o].ch[0],v);
}

代码纯享:

#include<bits/stdc++.h>
using namespace std;
const int N=100000+10;
const int inf=0x3f3f3f3f;
struct node{
    int ch[2];
    int val,size,prio;
    int sum;
};
node t[N];
int n,root;
int poolcur;
int delpool[N],delcur;
void init(){
    root=0;
}
int newnode(){
    int r=delcur?delpool[delcur--]:++poolcur;
    memset(t+r,0,sizeof(node));
    t[r].size=1;
    t[r].prio=rand();
    t[r].sum=1;
    return r;
}
void delnode(int o){
   delpool[++delcur]=o;
}
void pushup(int o){
    if(!o) return;
    t[o].size=t[t[o].ch[1]].size+t[t[o].ch[0]].size+t[o].sum;
}
void rotate(int &o,int d){
    if(!o) return;
    int u=t[o].ch[d];
    t[o].ch[d]=t[u].ch[d^1];
    t[u].ch[d^1]=o;
    t[u].size=t[o].size;
    pushup(o);
    o=u;
}
void insert(int &o,int v){
    if(!o){
        o=newnode();
        t[o].val=v;
        return;
    }
    if(t[o].val==v){
        t[o].sum++;
        t[o].size++;
        return;
    }
    int d=t[o].val<v;
    insert(t[o].ch[d],v);
    pushup(o);
    if(t[t[o].ch[d]].prio<t[o].prio) rotate(o,d);
}
void remove(int &o,int v){
    if(!o) return;
    if(t[o].val==v){ 
        if(t[o].sum>0) t[o].sum--;
        if(t[o].sum<=0)
        {
        int u=o;
        if(!t[o].ch[0]){
            o=t[o].ch[1];
            delnode(u);
        }
        else if(!t[o].ch[1]){
            o=t[o].ch[0];
            delnode(u);
        }
        else{
            int d=t[t[o].ch[0]].prio<t[t[o].ch[0]].prio;
            rotate(o , d^1);
            remove(t[o].ch[d],v);
        }
        }
    }
    else{
        int d=t[o].val<v;
        remove(t[o].ch[d] , v);
    }
    pushup(o);
}
int rank(int o, int v)//找到v的排名
{
    if(!o) return 0;
    if(t[o].val>v) return rank(t[o].ch[0], v);
    else if(t[o].val==v) return 1+t[t[o].ch[0]].size;
    else return t[t[o].ch[0]].size+t[o].sum+rank(t[o].ch[1],v);
}
int find(int o,int k){// 找第K小的数
    if(o==0||k==0) return 0;
    int d=k-t[t[o].ch[0]].size;
    if(d<= 0) return find(t[o].ch[0],k);
    else if(d>=1&&d<=t[o].sum) return t[o].val;
    else return find(t[o].ch[1],d-t[o].sum);
}
int find_back(int o,int v)//找到大于v的最小值
{
    if(!o) return inf;
    int d= t[o].val<=v; 
    if(!d) return min(t[o].val,find_back(t[o].ch[0],v));
    else return find_back(t[o].ch[1],v);
}
int find_front(int o,int v)//找到小于v的最大值
{
    if(!o) return -inf;
    int d= t[o].val>=v;
    if(!d) return max(t[o].val,find_front(t[o].ch[1],v));
    else return find_front(t[o].ch[0],v);
}
int ans[N];
int cnt;
int main()
{
    scanf("%d",&n);
    init();
    int p,x;
    int has=0;
    srand((unsigned)time(NULL));
    for(int i=1;i<=n;i++)
    {
        scanf("%d%d",&p,&x);
        if(p==1) insert(root,x),has++;
        if(p==2) {remove(root,x);has--;}
        if(p==3) ans[++cnt]=rank(root,x);
        if(p==4) ans[++cnt]=find(root,x);
        if(p==5) ans[++cnt]=find_front(root,x);
        if(p==6) ans[++cnt]=find_back(root,x);
    }
    for(int i=1;i<=cnt;i++)
     printf("%d\n",ans[i]);
    return 0;
}

注意细节,打多了就好

posted @ 2018-05-13 11:59  *Miracle*  阅读(196)  评论(0编辑  收藏  举报