现代C++实现AVL树

现代C++实现AVL树

性质

  1. 空树是一个AVL树。
  2. 如果T是一个AVL树,那么其左右子树也是一个AVL树,且\(|h(ls)-h(rs)|\le1\),h是其左右子树的高度
  3. 树高为\(O(\log {n})\)

平衡因子:左子树高度 - 右子树高度

原理

插入操作

与 BST(二叉搜索树)中类似,先进行一次失败的查找来确定插入的位置,插入节点后根据平衡因子来决定是否需要调整.

删除操作

删除和 BST 类似,将结点与后继交换后再删除.

删除会导致树高以及平衡因子变化,这时需要沿着被删除结点到根的路径来调整这种变化.

维护

即保证平衡,也是AVL的难点。

注意我们的操作必须保证中序遍历顺序不变,以维持BST的性质。

旋转

懒得画图,其实这个操作非常简单,难点在于分类讨论。简单说就是将高度高了的儿子放到上面,根据保证中序遍历顺序不变的原则,考虑一下应该怎么移动。

右旋

    A         B
   /         / \
  B         D   A
 / \           /
D   E         E

显然我们可以发现一个问题,就是万一没有D呢?

    A         B             
   /           \
  B             A
   \           /
    E         E

先左旋后右旋

我们希望的是它得有左儿子,注意到可以通过一次左旋做到。

    A         A      E            
   /         /      / \
  B         E      B   A
   \       /
    E     B   

还有两种对称的情况,不做过多解说。

总结

失衡节点的平衡因子 相应子节点的平衡因子 应采用的旋转方法
+ 非负(说明有左儿子) 右旋
+ 先左旋后右旋
- 非正 左旋
- 先右旋后左旋

代码实现

比较困难的地方有多重集的删除处理,我们需要先找到真正删除的为位置,再rebbalance。

写得很构思,但意外得跑得不算很慢,

#include <cstddef>
#include <functional>
#include <iostream>
#include <memory>
#include <optional>
// #include <stdexcept>
#include <vector>

template <typename Key, typename Compare = std::less<Key>>
class AVL {
   private:
    struct Node {
        Key data_;
        std::unique_ptr<Node> left_, right_;
        Node* parent_;
        size_t count_;
        size_t height_;
        size_t size_;

        Node(const Key& data, Node* p = nullptr)
            : data_(data),
              left_(nullptr),
              right_(nullptr),
              parent_(p),
              count_(1),
              height_(1),
              size_(1) {}
    };

    std::unique_ptr<Node> root_;
    Compare cmp_;

    // 中序遍历
    static void inorder_impl(const std::unique_ptr<Node>& n,
                             const std::function<void(const Key&)>& f) {
        if (!n) return;
        inorder_impl(n->left_, f);
        f(n->data_);
        inorder_impl(n->right_, f);
    }

    static size_t height(Node* n) {
        if (n) {
            return n->height_;
        } else {
            return 0;
        }
    }
    static void update_height(Node* n) {
        if (n) n->height_ = 1 + std::max(height(n->left_.get()), height(n->right_.get()));
    }

    static void update_Info(Node* n) {
        if (!n) return;
        n->height_ = 1 + std::max(height(n->left_.get()), height(n->right_.get()));
        n->size_ =
            n->count_ + (n->left_ ? n->left_->size_ : 0) + (n->right_ ? n->right_->size_ : 0);
    }

    static int balance_factor(Node* n) {
        if (n) {
            return (int)height(n->left_.get()) - (int)height(n->right_.get());
        } else {
            return 0;
        }
    }
    // 左旋
    void rotate_left(std::unique_ptr<Node>& rootRef) {
        auto r = std::move(rootRef->right_);
        r->parent_ = rootRef->parent_;
        rootRef->right_ = std::move(r->left_);
        if (rootRef->right_) rootRef->right_->parent_ = rootRef.get();
        rootRef->parent_ = r.get();
        r->left_ = std::move(rootRef);
        rootRef = std::move(r);

        update_Info(rootRef->left_.get());
        update_Info(rootRef.get());
    }

    // 右旋
    void rotate_right(std::unique_ptr<Node>& rootRef) {
        auto r = std::move(rootRef->left_);
        r->parent_ = rootRef->parent_;
        rootRef->left_ = std::move(r->right_);
        if (rootRef->left_) rootRef->left_->parent_ = rootRef.get();
        rootRef->parent_ = r.get();
        r->right_ = std::move(rootRef);
        rootRef = std::move(r);

        update_Info(rootRef->right_.get());
        update_Info(rootRef.get());
    }

    // 维护平衡树
    void rebalanceNode(std::unique_ptr<Node>& nodeRef) {
        if (!nodeRef) return;
        update_Info(nodeRef.get());
        int bf = balance_factor(nodeRef.get());

        if (bf > 1) {                             // 左偏
            update_height(nodeRef->left_.get());  // 更新高度
            int leftBf = balance_factor(nodeRef->left_.get());
            if (leftBf < 0) {  // 左儿子没有左儿子,左儿子左旋
                rotate_left(nodeRef->left_);
            }
            rotate_right(nodeRef);
        } else if (bf < -1) {                      // 右偏
            update_height(nodeRef->right_.get());  // 更新高度
            int rightBf = balance_factor(nodeRef->right_.get());
            if (rightBf > 0) {  // 右儿子没有右儿子,右儿子右旋
                rotate_right(nodeRef->right_);
            }
            rotate_left(nodeRef);
        }
    }

   public:
    AVL(Compare cmp = Compare()) : root_(nullptr), cmp_(cmp) {}
    // 插入
    void insert(const Key& data) {
        std::unique_ptr<Node>* p = &root_;
        std::vector<std::unique_ptr<Node>*> path;
        path.reserve(64);
        Node* parentNode = nullptr;

        while (*p) {
            path.push_back(p);
            if (cmp_(data, (*p)->data_)) {
                parentNode = (*p).get();
                p = &((*p)->left_);
            } else if (cmp_((*p)->data_, data)) {
                parentNode = (*p).get();
                p = &((*p)->right_);
            } else {
                (*p)->count_++;
                for (auto it = path.rbegin(); it != path.rend(); ++it) rebalanceNode(*(*it));
                return;
            }
        }

        *p = std::make_unique<Node>(data, parentNode);

        for (auto it = path.rbegin(); it != path.rend(); ++it) rebalanceNode(*(*it));
    }

    // 删除
    void erase(const Key& data) {
        std::unique_ptr<Node>* p = &root_;
        std::vector<std::unique_ptr<Node>*> path;
        path.reserve(64);
        while (*p) {
            path.push_back(p);
            if (cmp_(data, (*p)->data_)) {
                p = &((*p)->left_);
            } else if (cmp_((*p)->data_, data)) {
                p = &((*p)->right_);
            } else {
                break;  // 找到节点
            }
        }
        if (!*p) return;

        // 如果存在多个相同元素
        if ((*p)->count_ > 1) {
            (*p)->count_--;
            for (auto it = path.rbegin(); it != path.rend(); ++it) update_Info((*it)->get());
            return;
        }

        std::unique_ptr<Node>* target = p;
        if ((*target)->left_ && (*target)->right_) {
            std::unique_ptr<Node>* cur = &((*target)->right_);
            while ((*cur)->left_) {
                path.push_back(cur);
                cur = &((*cur)->left_);
            }

            // 移动
            (*target)->data_ = (*cur)->data_;
            (*target)->count_ = (*cur)->count_;

            // 真正要删除的目标
            target = cur;
        }

        if (!(*target)->left_ && !(*target)->right_) {
            target->reset();
        } else if (!(*target)->left_) {
            auto r = std::move((*target)->right_);
            r->parent_ = (*target)->parent_;
            *target = std::move(r);
        } else {
            auto l = std::move((*target)->left_);
            l->parent_ = (*target)->parent_;
            *target = std::move(l);
        }
        for (auto it = path.rbegin(); it != path.rend(); ++it) rebalanceNode(*(*it));
    }

    // 中序遍历接口
    void inorder(const std::function<void(const Key&)>& f) const { inorder_impl(root_, f); }

    void print_inorder() const {
        inorder([](const Key& k) { std::cout << k << ' '; });
        std::cout << '\n';
    }

    // 比data小的数
    int count_less(const Key& data) const {
        auto p = root_.get();
        size_t count = 0;

        while (p) {
            if (cmp_(data, p->data_)) {
                p = (p->left_).get();
            } else if (cmp_(p->data_, data)) {
                count += p->count_;
                if (p->left_) count += (p->left_)->size_;
                p = (p->right_).get();
            } else {
                if (p->left_) count += (p->left_)->size_;
                return count;
            }
        }
        return count;
    }

    size_t rank(const Key& data) const { return 1 + count_less(data); }

    std::optional<Key> kth(size_t k) const {
        if (k <= 0 || !root_ || k > root_->size_) return std::nullopt;
        auto p = root_.get();
        if (k > p->size_) return std::nullopt;
        while (p) {
            size_t left_size = p->left_ ? p->left_->size_ : 0;
            if (left_size >= k) {
                p = p->left_.get();
            } else if ((left_size + p->count_) >= k) {
                return p->data_;
            } else {
                k -= (left_size + p->count_);
                p = p->right_.get();
            }
        }
        return std::nullopt;
    }
    // 最大的小于x的数
    std::optional<Key> predecessor(const Key& x) const {
        auto p = root_.get();
        std::optional<Key> ans = std::nullopt;
        while (p) {
            if (cmp_(p->data_, x)) {
                ans = p->data_;
                p = (p->right_).get();
            } else {
                p = (p->left_).get();
            }
        }
        return ans;
    }

    // 最小的大于x的数
    std::optional<Key> successor(const Key& x) const {
        auto p = root_.get();
        std::optional<Key> ans = std::nullopt;
        while (p) {
            if (cmp_(x, p->data_)) {
                ans = p->data_;
                p = (p->left_).get();
            } else {
                p = (p->right_).get();
            }
        }
        return ans;
    }
};
posted @ 2026-02-13 22:30  _lull  阅读(1)  评论(0)    收藏  举报