平衡树-有/无旋Treap

有旋Treap

可以在普通二叉搜索树的基础上保持优秀log复杂度

代码里有较为详细的注释

所以,话不多说,放代码qwq

#include<iostream>
using namespace std;
#define lid d[id].l
#define rid d[id].r
int cnt_tree,ans;

struct Treap{
    int l,r; //左孩子,右孩子
    int siz; //大小
    int cnt; //重复元素数量
    int rank; //随机出来的优先度
    int val; //值
}d[100005];

void Update_size(int id){//更新大小
    d[id].siz=d[id].cnt+d[lid].siz+d[rid].siz;
}

void lrotate(int &id){
  /* 左旋:也就是让右子节点变成根节点
   *         A                 C
   *        / \               / \
   *       B  C    ---->     A   E
   *         / \            / \
   *        D   E          B   D
   */
    int t=d[id].r; //记录右孩子
    d[id].r=d[t].l; //A的右孩子改为C的左孩子
    d[t].l=id; //C的左孩子改为A
    d[t].siz=d[id].siz;//传递size
    Update_size(id);//更新A的size
    id=t;//换根
}

void rrotate(int &id){
   /* 右旋:也就是让左子节点变成根节点
    *         A                 C
    *        / \               / \
    *       B  C    <----     A   E
    *         / \            / \
    *        D   E          B   D
    */
    int t=d[id].l;//同上
    d[id].l=d[t].r;
    d[t].r=id;
    d[t].siz=d[id].siz;
    Update_size(id);
    id=t;
}

void insert(int &id,int val){
    if(!id){ //没有点新建点
        id=++cnt_tree;
        d[id].rank=rand();
        d[id].siz=1;
        d[id].cnt=1;
        d[id].val=val;
        return;
    }
    d[id].siz++; //大小++
    if(val==d[id].val){
        d[id].cnt++; //值相同,直接扔进去
    }
    else if(val<d[id].val){
        insert(lid,val); //增加在左孩子里
        if(d[lid].rank<d[id].rank){
            rrotate(id); //为满足小根堆性质(上方优先度低于下方),需要右旋
        }
    }
    else{
        insert(rid,val); //同上
        if(d[rid].rank<d[id].rank){
            lrotate(id);
        }
    }
}
//用bool,0为未删点,1为删点
bool del(int &id,int val){
    if(!val) return false; //如果没有点,不需要修改,所以return false
    if(val==d[id].val){ //如果相等
        if(d[id].cnt>1){ //有重复元素则直接删除
            d[id].cnt--;
            d[id].siz--;
            return true;
        }
        if(lid==0 || rid==0){ //只有一个孩子,或没有孩子,不存在内讧,故直接赋值
            id=lid+rid;
            return true;
        }
        else if(d[lid].rank<d[rid].rank){ //删完还要满足小根堆,故右旋
            rrotate(id);
            return del(id,val); //删点
        }
        else{
            lrotate(id);
            return del(id,val);
        }
    }
    else if(val<d[id].val){
        bool dele=del(lid,val); //点在左孩子,记录是否成功删点
        if(dele) d[id].siz--;
        return dele;
    }
    else{
        bool dele=del(rid,val);
        if(dele) d[id].siz--;
        return dele;
    }
}

int Query_Rank(int id,int val){ //查询排名
    if(!id) return 0; //若 定义排名为比当前数小的数的个数+1 则此处应该为return 1;
    if(val==d[id].val) return d[lid].siz+1; //仅比所有左侧节点大
    else if(val<d[id].val) return Query_Rank(lid,val); //点在左孩子
    else return d[lid].siz+d[id].cnt+Query_Rank(rid,val); //1.比所有左节点大,2.比该节点的所有重复元素大,3.点在右孩子
}

int Query_Num(int id,int val){ //查询排名为val的节点值
    if(!id) return 0;
    if(val<=d[lid].siz) return Query_Num(lid,val); //节点在左孩子
    else if(val>d[lid].siz+d[id].cnt) return Query_Num(rid,val-d[lid].siz-d[id].cnt); //!节点在右孩子,但排名在右孩子里应减小
    else return d[id].val; //不在左,不在右,就只能在自己里了呗awa
}

void Query_Pre(int id,int val){ //查询前驱
    if(!id) return; //类似二分查找
    if(val>d[id].val){ //!别弄反了 
        ans=id;
        Query_Pre(rid,val);
    }
    else{
        Query_Pre(lid,val);
    }
}

void Query_Sub(int id,int val){
    if(!id) return;
    if(val<d[id].val){ //!别弄反了
        ans=id;
        Query_Sub(lid,val);
    }
    else{
        Query_Sub(rid,val);
    }
}

int main(){
    ios::sync_with_stdio(false),cin.tie(0),cout.tie(0);
    int n,root=0;
    cin>>n;
    for(int i=1;i<=n;i++){
        int op,x;
        cin>>op>>x;
        if(op==1){
            insert(root,x);
        }
        if(op==2){
            del(root,x);
        }
        if(op==3){
            cout<<Query_Rank(root,x)<<"\n";
        }
        if(op==4){
            cout<<Query_Num(root,x)<<"\n";
        }
        if(op==5){
            ans=0;
            Query_Pre(root,x);
            cout<<d[ans].val<<"\n";
        }
        if(op==6){
            ans=0;
            Query_Sub(root,x);
            cout<<d[ans].val<<"\n";
        }
    }
}

无旋Treap(FHQ)

#include <iostream>
#include <cstring>
#include <algorithm>
#include <bitset>
#include <vector>
#include <set>
#include <queue>
#include <cmath>
#include <map>
#include <cassert>
#include <bitset>
#include <cstdio>
#include <climits>
#include <iomanip>

using namespace std;

#define lid ch[id][0]
#define rid ch[id][1]
#define emp emplace_back
#define pb push_back
#define fi first
#define se second
#define IL inline
#define reg register
#define endl '\n'
#define IL inline
#define LF 1
#define RF 0

const int M = 20000 + 7, N = 100000 + 7, mod = 10007, B = 300; // !!!!!!!!

// #define int long long
using pii = pair <int, int>;
using ll = long long;
using ld = long double;
using ull = unsigned long long;
using I = __int128;
using uIt = __uint128_t;

int root;

class FHQ
{
public :
    int rnk[N], val[N], siz[N], ch[N][2], tot;

    int New(int k)
    {
        ++tot;
        rnk[tot] = rand();
        val[tot] = k;
        siz[tot] = 1;
        return tot;
    }

    void PushUp(int id)
    {
        siz[id] = siz[lid] + siz[rid] + 1;
    }

    void Split(int id, int k, int &x, int &y)
    {
        if (!id) return (x = 0, y = 0), void();
        if (k < val[id])
        {
            y = id;
            Split(ch[y][0], k, x, ch[y][0]);
        }
        else
        {
            x = id;
            Split(ch[x][1], k, ch[x][1], y);
        }
        PushUp(id);
    }

    int Merge(int x, int y)
    {
        if (!x || !y) return x | y;
        if (rnk[x] < rnk[y])
        {
            ch[x][1] = Merge(ch[x][1], y);
            PushUp(x);
            return x;
        }
        else
        {
            ch[y][0] = Merge(x, ch[y][0]);
            PushUp(y);
            return y;
        }
    }

    void Insert(int k)
    {
        int x, y, z;
        Split(root, k, x, y);
        z = New(k);
        root = Merge(Merge(x, z), y);
    }

    void Del(int k)
    {
        int x, y, z;
        Split(root, k, x, z);
        Split(x, k - 1, x, y);
        y = Merge(ch[y][0], ch[y][1]);
        root = Merge(Merge(x, y), z);
    }

    int Find_K(int id, int k)
    {
        if (k <= siz[lid]) return Find_K(lid, k);
        if (k == siz[lid] + 1) return id;
        return Find_K(rid, k - siz[lid] - 1);
    }

    int Find_Val(int k)
    {
        int x, y;
        Split(root, k - 1, x, y);
        int p = siz[x] + 1;
        root = Merge(x, y);
        return p;
    }

    int Find_Pre(int k)
    {
        int x, y;
        Split(root, k - 1, x, y);
        int p = Find_K(x, siz[x]);
        root = Merge(x, y);
        return p;
    }

    int Find_Sub(int k)
    {
        int x, y;
        Split(root, k, x, y);
        int p = Find_K(y, 1);
        root = Merge(x, y);
        return p;
    }
}T;

int main()
{
    // freopen("data.in", "r", stdin); freopen("data.out", "w", stdout);
    ios :: sync_with_stdio(false), cin.tie(0), cout.tie(0);

    int n; cin >> n;
    while (n--)
    {
        int op, x; cin >> op >> x;
        if (op == 1) T.Insert(x);
        if (op == 2) T.Del(x);
        if (op == 3) cout << T.Find_Val(x) << '\n';
        if (op == 4) cout << T.val[T.Find_K(root, x)] << '\n';
        if (op == 5) cout << T.val[T.Find_Pre(x)] << '\n';
        if (op == 6) cout << T.val[T.Find_Sub(x)] << '\n';
    }

    return 0;
}

维护区间信息

#include <iostream>
#include <cstring>
#include <algorithm>
#include <bitset>
#include <vector>
#include <set>
#include <queue>
#include <cmath>
#include <map>
#include <cassert>
#include <bitset>
#include <cstdio>
#include <climits>
#include <iomanip>

using namespace std;

#define lid ch[id][0]
#define rid ch[id][1]
#define emp emplace_back
#define pb push_back
#define fi first
#define se second
#define IL inline
#define reg register
#define endl '\n'
#define IL inline
#define LF 1
#define RF 0

const int M = 20000 + 7, N = 100000 + 7, mod = 10007, B = 300; // !!!!!!!!

// #define int long long
using pii = pair <int, int>;
using ll = long long;
using ld = long double;
using ull = unsigned long long;
using I = __int128;
using uIt = __uint128_t;

int root;

class FHQ
{
public :
    int tag[N], siz[N], rnk[N], tot, ch[N][2], val[N];

    void PushUp(int id) {siz[id] = siz[lid] + siz[rid] + 1;}

    void PushDown(int id)
    {
        if (tag[id])
        {
            swap(lid, rid);
            if (lid) tag[lid] ^= 1;
            if (rid) tag[rid] ^= 1;
            tag[id] = 0;
        }
    }

    void Split(int id, int k, int &x, int &y)
    {
        if (!id) return (x = y = 0), void();
        PushDown(id);
        if (k <= siz[ch[id][0]])
        {
            y = id;
            Split(ch[y][0], k, x, ch[y][0]);
            PushUp(y);
        }
        else
        {
            x = id;
            Split(ch[x][1], k - siz[ch[id][0]] - 1, ch[x][1], y);
            PushUp(x);
        }
    }

    int Merge(int x, int y)
    {
        if (!x || !y) return x | y;
        PushDown(x), PushDown(y);
        if (rnk[x] > rnk[y])
        {
            ch[x][1] = Merge(ch[x][1], y);
            PushUp(x);
            return x;
        }
        else
        {
            ch[y][0] = Merge(x, ch[y][0]);
            PushUp(y);
            return y;
        }
    }

    void Reverse(int l, int r)
    {
        int x, y, z;
        Split(root, l - 1, x, y);
        Split(y, r - l + 1, y, z);
        tag[y] ^= 1;
        root = Merge(Merge(x, y), z);
    }

    void print(int id)
    {
        PushDown(id);
        if (lid) print(lid);
        cout << val[id] << ' ';
        if (rid) print(rid);
    }
}T;


int rnk[N], sta[N], top;

int main()
{
    srand(time(NULL));
    // freopen("data.in", "r", stdin); freopen("data.out", "w", stdout);
    ios :: sync_with_stdio(false), cin.tie(0), cout.tie(0);

    int n; cin >> n;
    for (int i = 1; i <= n; i++) T.rnk[i] = rand(), T.val[i] = i, T.siz[i] = 1;
    for (int i = 1; i <= n; i++)
    {
        int k = top;
        while (k && T.rnk[sta[k]] < T.rnk[i]) k--;
        if (k) T.ch[sta[k]][1] = i;
        if (k < top) T.ch[i][0] = sta[k + 1];
        sta[++k] = i;
        top = k;
    }
    root = sta[1];
    int m; cin >> m;
    for (int i = 1; i <= m; i++)
    {
        int l, r; cin >> l >> r;
        T.Reverse(l, r);
    }
    T.print(root);

    return 0;
}
posted @ 2024-08-23 18:07  QEDQEDQED  阅读(47)  评论(1)    收藏  举报