【最简单的平衡树】Treap

Treap

树是很有用的一种结构,加上不同的约束规则之后,就形成了各种特性鲜明的结构。
最基本的二叉树,加以约束之后,可以形成 BST、AVL、Heap、Splay、R-B Tree …… 等,适用于各种场景。
对于平衡树,种类有很多,有的严格平衡,每次某个子树上任意两个子树的高度差超过1就会进行调整;也有的弱平衡,两子树高度差不超过一倍就不会调整。平衡树很重要,常用于 map、set 等数据结构和数据库等基础设施。但是平衡树大多并不好写,因此一般使用标准库提供的套件来完成工作。这就导致了有些时候,我们想要定制一些操作的时候,难以在封装好的数据结构上进行操作。

本文介绍一种简单好写的平衡树,并给出模板代码,可用于实际解决查元素排名、查排名元素、查前驱、查后继,以及区间操作等。

Treap= Tree + heap

Treap 使用 BST 来提供二分检索特性,使用 Heap 来管理旋转操作,维护二叉树平衡。相比其他平衡树,多用了一个字段来存储堆的权重。堆权随机生成,由随机性来保证二叉树“大概率是平衡的”。

Treap 有多种实现,可以用指针,也可以用数组;可以使用旋转来保证平衡,也可以用分裂合并来做。OI Wiki 解释说无旋 Treap 具有支持序列化的优势,我们在写题的时候用不到这一点,可以直接上数组,减少指针操作。

参考代码1:数组版

这里使用了 OI Wiki 提供的模板代码,并修正了删除不存在元素时 size-- 的错误,另外还加入了 find / count / delall 操作。
注意一点,查询操作只需要传值,插入删除旋转操作传递的是引用,以便旋转操作修改根节点索引。
Treap 本身并不需要 size 字段,这里加入 size 字段是为了提供 查元素排名 和 排名元素 的操作。
为了清晰展示思路,函数用的都是递归操作。代码如下:

#include <cstdio>
#include <algorithm>

#define maxn 100005
#define INF (1 << 30)

struct treap {
    // cnt 是元素重复次数,size 是子树结点数目(计入重复元素),rnd 是堆权
    // l, r 自动初始化为0,作为空值;0 号结点是空结点
    // cnt 默认0个元素,size 默认子树元素为0
    int l[maxn], r[maxn], val[maxn], cnt[maxn], rnd[maxn], size[maxn];
    int sz; // array size, used for insert
    int rt; // tree root index
    int ans;

    void lrotate(int& k)
    {
        int t = r[k];
        r[k] = l[t];
        l[t] = k;
        size[t] = size[k];
        size[k] = size[l[k]] + size[r[k]] + cnt[k];
        k = t;
    }
    void rrotate(int& k)
    {
        int t = l[k];
        l[k] = r[t];
        r[t] = k;
        size[t] = size[k];
        size[k] = size[l[k]] + size[r[k]] + cnt[k];
        k = t;
    }
    void insert(int& k, int x)
    {
        if (!k) { // append to the end
            sz++;
            k = sz;
            val[k] = x;
            cnt[k] = 1;
            size[k] = 1;
            rnd[k] = rand();
            return;
        }
        size[k]++;
        if (val[k] == x) {
            cnt[k]++;
        } else if (val[k] < x) {
            insert(r[k], x);
            if (rnd[r[k]] < rnd[k])
                lrotate(k);
        } else {
            insert(l[k], x);
            if (rnd[l[k]] < rnd[k])
                rrotate(k);
        }
    }

    bool del(int& k, int x)
    {
        if (!k)
            return false;
        if (val[k] == x) {
            if (cnt[k] > 1) {
                cnt[k]--;
                size[k]--;
                return true;
            }
            if (l[k] == 0 || r[k] == 0) { // 元素已调整到链条结点或叶节点
                k = l[k] + r[k];
                return true;
            } else if (rnd[l[k]] < rnd[r[k]]) {
                rrotate(k);
                return del(k, x);
            } else {
                lrotate(k);
                return del(k, x);
            }
        } else if (val[k] < x) {
            bool succ = del(r[k], x);
            if (succ)
                size[k]--;
            return succ;
        } else {
            bool succ = del(l[k], x);
            if (succ)
                size[k]--;
            return succ;
        } // 先把结点转到能删除的位置再删除
    }

    int delall(int& k, int x)
    {
        if (!k)
            return 0;
        if (val[k] == x) {
            if (l[k] == 0 || r[k] == 0) {  // 元素已调整到链条结点或叶节点
                int diff = cnt[k];
                k = l[k] + r[k];
                return diff;
            } else if (rnd[l[k]] < rnd[r[k]]) {
                rrotate(k);
                return delall(k, x);
            } else {
                lrotate(k);
                return delall(k, x);
            }
        } else if (val[k] < x) {
            int diff = delall(r[k], x);
            size[k] -= diff;
            return diff;
        } else {
            int diff = delall(l[k], x);
            size[k] -= diff;
            return diff;
        } // 先把结点转到能删除的位置再删除
    }

    int find(int k, int x) {
        if (!k) return 0;
        if (val[k] == x) {
            return k;
        } else if (x < val[k]) {
            return find(l[k], x);
        } else {
            return find(r[k], x);
        }
    }
    int count(int k, int x) {
        k = find(k, x);
        if (!k) return 0;
        return cnt[k];
    }

    // 查元素排名:x是第几小的数
    int queryrank(int k, int x)
    {
        if (!k)
            return 0;
        if (val[k] == x)
            return size[l[k]] + 1;
        else if (x > val[k]) {
            return size[l[k]] + cnt[k] + queryrank(r[k], x);
        } else
            return queryrank(l[k], x);
    }
    // 查排名元素:第x小数
    int querynum(int k, int x)
    {
        if (!k)
            return 0; // 返回空
        if (x <= size[l[k]])
            return querynum(l[k], x);
        else if (x > size[l[k]] + cnt[k])
            return querynum(r[k], x - size[l[k]] - cnt[k]);
        else
            return val[k];
    }
    // 查前驱:刚好比x小的元素
    void querypre(int k, int x)
    {
        if (!k)
            return;
        if (val[k] < x)
            ans = k, querypre(r[k], x);
        else
            querypre(l[k], x);
    }
    // 查后继:刚好比x大的元素
    void querysub(int k, int x)
    {
        if (!k)
            return;
        if (val[k] > x)
            ans = k, querysub(l[k], x);
        else
            querysub(r[k], x);
    }
} T;

int main()
{
    srand(123);
    int n;
    scanf("%d", &n);
    int opt, x;
    for (int i = 1; i <= n; i++) {
        scanf("%d%d", &opt, &x);
        switch (opt) {
        case 1:
            T.insert(T.rt, x);
            break;
        case 2:
            printf("del %d is %d\n", x, T.del(T.rt, x));
            break;
        case 3:
            printf("delall %d count %d\n", x, T.delall(T.rt, x));
            break;
        case 4:
            printf("rank of %d is %d\n", x, T.queryrank(T.rt, x));
            break;
        case 5:
            printf("value of rank %d is %d\n", x, T.querynum(T.rt, x));
            break;
        case 6:
            T.ans = 0;
            T.querypre(T.rt, x);
            printf("previous value of %d is %d\n", x, T.val[T.ans]);
            break;
        case 7:
            T.ans = 0;
            T.querysub(T.rt, x);
            printf("successor value of %d is %d\n", x, T.val[T.ans]);
            break;
        default:
            printf("invalid opt %d\n", opt);
        }
    }
    return 0;
}

注意,这份代码也并不完美,存在一个问题:
添加元素的时候在数组末尾添加,删除元素并不会回收所占数组位置,形成一个个“空洞”,造成空间浪费。
每次删除的时候,都主动用最后一个元素来填补空洞是不可接受的,这会让 delete 操作的时间复杂度上升到 O(N);可以接收的解决办法是记录空洞个数,当“”空洞率”达到一定阈值之后启动整理,一次性压缩所有空洞。
当然,写题的时候是不需要考虑这么多的。

参考代码2:指针版

指针版解决了内存泄漏的问题,经过设计也并不复杂。这里首先参考了 程序员小灰的文章,据说投稿人只有13岁,值得鼓励。当然,也许小灰并没有对投稿进行过审核就在各大平台发了出来,代码中问题多多,包括但不限于:

  • left_rotate 给的是 right_rotate 的代码;
  • lson 和 rson 写成数组形式时,仍然出现了 rson 的字眼;
  • query_rank 和 query_value 都没有针对非法值的处理,运行时会崩溃;
  • query_rank 和 query_value 的结果竟然对应不上;
  • query_value 写错了左右子孩子;
  • query_value 非递归版进左子树的条件少了一个等于号;
  • ...

可以看到,一些错误编译的时候就会出现,一些错误运行的时候会崩掉,非常明显。可以下结论,小灰根本没有运行过投稿人给出的代码,这样搞着实不大行啊……

指针实现的关键点,同样是引用的使用。具体来讲就是 typedef Node* Tree;,这让代码可读性明显提升。

以下是修正过的代码:

#include <cstdio>
#include <algorithm>

using namespace std;

#define Inf 0x3f3f3f3f

typedef struct Node {
    Node(int v) {
        val = v, cnt = size = 1, fac = rand();
        lson = rson = nullptr;
    }
    //  值  个数  子树大小 优先级
    int val, cnt, size, fac;
    Node *lson, *rson;
    // 更新当前子树大小
    inline void push_up() {
        size = cnt;
        if (lson != nullptr) size += lson->size;
        if (rson != nullptr) size += rson->size;
    }
}* Tree;

inline int size(Tree t) { return t == nullptr ? 0 : t->size; }

inline void right_rotate(Tree &a) {
    Tree b = a->lson;
    a->lson = b->rson, b->rson = a, a = b;
    a->rson->push_up(), a->push_up();
}

inline void left_rotate(Tree &a) {
    Tree b = a->rson;
    a->rson = b->lson, b->lson = a, a = b;
    a->lson->push_up(), a->push_up();
}

// // 也可以将lson和rson写成一个数组son,然后将左旋和右旋写成一个函数:
// // 注意左旋、右旋中的代码传的参数a需要传引用,因为最后a也要更新
// inline void rotate(Tree &a, int f) {
//     Tree b = a->son[f^1];
//     a->son[f^1] = b->son[f], b->son[f] = a, a = b;
//     a->son[f]->push_up(), a->push_up();
// }

inline void insert(Tree &rt, int val) {
    if (rt == nullptr) {
        rt = new Node(val);
        return;
    }
    if (val == rt->val) {
        rt->cnt++; // 已经有这个点了
    } else if (val < rt->val) {
        insert(rt->lson, val);
        if (rt->fac < rt->lson->fac) right_rotate(rt);
    } else {
        insert(rt->rson, val);
        if (rt->fac < rt->rson->fac) left_rotate(rt);
    }
    rt->push_up();
}


inline void del(Tree &rt, int val) {
    if (rt == nullptr) return; // 没找到
    if (val == rt->val) {
        if (rt->cnt > 1) {
            rt->cnt--, rt->push_up();
            return;
        }
        if (rt->lson == nullptr && rt->rson == nullptr) {
            delete rt; rt = nullptr;
            return;
        } // 叶结点
        else {
            if (rt->rson == nullptr || (rt->lson != nullptr && rt->lson->fac > rt->rson->fac)) {
                right_rotate(rt), del(rt->rson, val);
            } // 右子树小,右旋删除
            else {
                left_rotate(rt), del(rt->lson, val);
            } // 左子树小,左旋删除
        }
    } 
    else if (val < rt->val) del(rt->lson, val);
    else del(rt->rson, val);
    rt->push_up();
}

// 询问有多少个数小于等于val(也就相当于查询排名)
inline int query_rank(Tree p, int val) {
    int rank = 1; // 有效值的排名,从1开始
    while (p != nullptr) {
        if (val == p->val) return rank + size(p->lson);
        else if (val < p->val) p = p->lson;
        else rank += size(p->lson) + p->cnt, p = p->rson;
    }
    // return rank;
    return 0; // 排名为0代表无效值
}

// // query_rank 的递归实现有一个缺点,那就是对于无效值,其返回值并不固定,无法据此判断当前值是否存在于这棵树之中
// inline int query_rank(Tree p, int val) {
//     if (p == nullptr) return 0; // 排名为0代表无效值
//     if (val == p->val) return 1 + size(p->lson); // 有效值的排名,从1开始
//     if (val < p->val) return query_rank(p->lson, val);
//     return size(p->lson) + p->cnt + query_rank(p->rson, val);
// }

#define INVALID_RANK 0x7f7f7f7f

// // query_value 递归和非递归实现,均能正常工作
// inline int query_value(Tree p, int rank) {
//     // if (rank < 0 || rank > size(p)) return INVALID_RANK;
//     while (p != nullptr && rank) {
//         if (rank <= size(p->lson)) p = p->lson;
//         else if (rank <= size(p->lson) + p->cnt) return p->val;
//         else rank -= size(p->lson) + p->cnt, p = p->rson;
//     }
//     return INVALID_RANK;
// }

inline int query_value(Tree p, int rank) {
    if (p == nullptr) return INVALID_RANK; // printf("rank %d not exist!\n", rank); 
    if (rank <= size(p->lson)) return query_value(p->lson, rank);
    if (rank <= size(p->lson) + p->cnt) return p->val;
    return query_value(p->rson, rank - size(p->lson) - p->cnt);
}

// 返回比val小的最大的有效值
inline int query_prev(Tree p, int val) {
    int pre = -Inf;
    while (p != nullptr) {
        if (p->val < val) pre = p->val, p = p->rson;
        else p = p->lson;
    }
    return pre;
}

// 返回比val大的最小的有效值
inline int query_next(Tree p, int val) {
    int suf = Inf;
    while (p != nullptr) {
        if (p->val > val) suf = p->val, p = p->lson;
        else p = p->rson;
    }
    return suf;
}


// ----------------------------------------------------------------------------

int main()
{
    srand(123);
    int n;
    scanf("%d", &n);
    Tree rt = nullptr; // 必须初始化,不能出现悬空指针
    int opt, x;
    for (int i = 1; i <= n; i++) {
        scanf("%d%d", &opt, &x);
        switch (opt) {
        case 1:
            insert(rt, x);
            break;
        case 2:
            del(rt, x);
            // printf("del %d is %d\n", x, del(rt, x));
            break;
        case 3:
            // printf("delall %d count %d\n", x, delall(rt, x));
            break;
        case 4:
            printf("rank of %d is %d\n", x, query_rank(rt, x));
            break;
        case 5:
            printf("value of rank %d is %d\n", x, query_value(rt, x));
            break;
        case 6:
            printf("previous value of %d is %d\n", x, query_prev(rt, x));
            break;
        case 7:
            printf("successor value of %d is %d\n", x, query_next(rt, x));
            break;
        default:
            printf("invalid opt %d\n", opt);
        }
    }
    return 0;
}

参考文献

posted @ 2021-05-08 21:52  与MPI做斗争  阅读(191)  评论(0编辑  收藏  举报