现代C++实现AVL树
现代C++实现AVL树
性质
- 空树是一个AVL树。
- 如果T是一个AVL树,那么其左右子树也是一个AVL树,且\(|h(ls)-h(rs)|\le1\),h是其左右子树的高度
- 树高为\(O(\log {n})\)
平衡因子:左子树高度 - 右子树高度
原理
插入操作
与 BST(二叉搜索树)中类似,先进行一次失败的查找来确定插入的位置,插入节点后根据平衡因子来决定是否需要调整.
删除操作
删除和 BST 类似,将结点与后继交换后再删除.
删除会导致树高以及平衡因子变化,这时需要沿着被删除结点到根的路径来调整这种变化.
维护
即保证平衡,也是AVL的难点。
注意我们的操作必须保证中序遍历顺序不变,以维持BST的性质。
旋转
懒得画图,其实这个操作非常简单,难点在于分类讨论。简单说就是将高度高了的儿子放到上面,根据保证中序遍历顺序不变的原则,考虑一下应该怎么移动。
右旋
A B
/ / \
B D A
/ \ /
D E E
显然我们可以发现一个问题,就是万一没有D呢?
A B
/ \
B A
\ /
E E
先左旋后右旋
我们希望的是它得有左儿子,注意到可以通过一次左旋做到。
A A E
/ / / \
B E B A
\ /
E B
还有两种对称的情况,不做过多解说。
总结
| 失衡节点的平衡因子 | 相应子节点的平衡因子 | 应采用的旋转方法 |
|---|---|---|
| + | 非负(说明有左儿子) | 右旋 |
| + | 负 | 先左旋后右旋 |
| - | 非正 | 左旋 |
| - | 正 | 先右旋后左旋 |
代码实现
比较困难的地方有多重集的删除处理,我们需要先找到真正删除的为位置,再rebbalance。
写得很构思,但意外得跑得不算很慢,
#include <cstddef>
#include <functional>
#include <iostream>
#include <memory>
#include <optional>
// #include <stdexcept>
#include <vector>
template <typename Key, typename Compare = std::less<Key>>
class AVL {
private:
struct Node {
Key data_;
std::unique_ptr<Node> left_, right_;
Node* parent_;
size_t count_;
size_t height_;
size_t size_;
Node(const Key& data, Node* p = nullptr)
: data_(data),
left_(nullptr),
right_(nullptr),
parent_(p),
count_(1),
height_(1),
size_(1) {}
};
std::unique_ptr<Node> root_;
Compare cmp_;
// 中序遍历
static void inorder_impl(const std::unique_ptr<Node>& n,
const std::function<void(const Key&)>& f) {
if (!n) return;
inorder_impl(n->left_, f);
f(n->data_);
inorder_impl(n->right_, f);
}
static size_t height(Node* n) {
if (n) {
return n->height_;
} else {
return 0;
}
}
static void update_height(Node* n) {
if (n) n->height_ = 1 + std::max(height(n->left_.get()), height(n->right_.get()));
}
static void update_Info(Node* n) {
if (!n) return;
n->height_ = 1 + std::max(height(n->left_.get()), height(n->right_.get()));
n->size_ =
n->count_ + (n->left_ ? n->left_->size_ : 0) + (n->right_ ? n->right_->size_ : 0);
}
static int balance_factor(Node* n) {
if (n) {
return (int)height(n->left_.get()) - (int)height(n->right_.get());
} else {
return 0;
}
}
// 左旋
void rotate_left(std::unique_ptr<Node>& rootRef) {
auto r = std::move(rootRef->right_);
r->parent_ = rootRef->parent_;
rootRef->right_ = std::move(r->left_);
if (rootRef->right_) rootRef->right_->parent_ = rootRef.get();
rootRef->parent_ = r.get();
r->left_ = std::move(rootRef);
rootRef = std::move(r);
update_Info(rootRef->left_.get());
update_Info(rootRef.get());
}
// 右旋
void rotate_right(std::unique_ptr<Node>& rootRef) {
auto r = std::move(rootRef->left_);
r->parent_ = rootRef->parent_;
rootRef->left_ = std::move(r->right_);
if (rootRef->left_) rootRef->left_->parent_ = rootRef.get();
rootRef->parent_ = r.get();
r->right_ = std::move(rootRef);
rootRef = std::move(r);
update_Info(rootRef->right_.get());
update_Info(rootRef.get());
}
// 维护平衡树
void rebalanceNode(std::unique_ptr<Node>& nodeRef) {
if (!nodeRef) return;
update_Info(nodeRef.get());
int bf = balance_factor(nodeRef.get());
if (bf > 1) { // 左偏
update_height(nodeRef->left_.get()); // 更新高度
int leftBf = balance_factor(nodeRef->left_.get());
if (leftBf < 0) { // 左儿子没有左儿子,左儿子左旋
rotate_left(nodeRef->left_);
}
rotate_right(nodeRef);
} else if (bf < -1) { // 右偏
update_height(nodeRef->right_.get()); // 更新高度
int rightBf = balance_factor(nodeRef->right_.get());
if (rightBf > 0) { // 右儿子没有右儿子,右儿子右旋
rotate_right(nodeRef->right_);
}
rotate_left(nodeRef);
}
}
public:
AVL(Compare cmp = Compare()) : root_(nullptr), cmp_(cmp) {}
// 插入
void insert(const Key& data) {
std::unique_ptr<Node>* p = &root_;
std::vector<std::unique_ptr<Node>*> path;
path.reserve(64);
Node* parentNode = nullptr;
while (*p) {
path.push_back(p);
if (cmp_(data, (*p)->data_)) {
parentNode = (*p).get();
p = &((*p)->left_);
} else if (cmp_((*p)->data_, data)) {
parentNode = (*p).get();
p = &((*p)->right_);
} else {
(*p)->count_++;
for (auto it = path.rbegin(); it != path.rend(); ++it) rebalanceNode(*(*it));
return;
}
}
*p = std::make_unique<Node>(data, parentNode);
for (auto it = path.rbegin(); it != path.rend(); ++it) rebalanceNode(*(*it));
}
// 删除
void erase(const Key& data) {
std::unique_ptr<Node>* p = &root_;
std::vector<std::unique_ptr<Node>*> path;
path.reserve(64);
while (*p) {
path.push_back(p);
if (cmp_(data, (*p)->data_)) {
p = &((*p)->left_);
} else if (cmp_((*p)->data_, data)) {
p = &((*p)->right_);
} else {
break; // 找到节点
}
}
if (!*p) return;
// 如果存在多个相同元素
if ((*p)->count_ > 1) {
(*p)->count_--;
for (auto it = path.rbegin(); it != path.rend(); ++it) update_Info((*it)->get());
return;
}
std::unique_ptr<Node>* target = p;
if ((*target)->left_ && (*target)->right_) {
std::unique_ptr<Node>* cur = &((*target)->right_);
while ((*cur)->left_) {
path.push_back(cur);
cur = &((*cur)->left_);
}
// 移动
(*target)->data_ = (*cur)->data_;
(*target)->count_ = (*cur)->count_;
// 真正要删除的目标
target = cur;
}
if (!(*target)->left_ && !(*target)->right_) {
target->reset();
} else if (!(*target)->left_) {
auto r = std::move((*target)->right_);
r->parent_ = (*target)->parent_;
*target = std::move(r);
} else {
auto l = std::move((*target)->left_);
l->parent_ = (*target)->parent_;
*target = std::move(l);
}
for (auto it = path.rbegin(); it != path.rend(); ++it) rebalanceNode(*(*it));
}
// 中序遍历接口
void inorder(const std::function<void(const Key&)>& f) const { inorder_impl(root_, f); }
void print_inorder() const {
inorder([](const Key& k) { std::cout << k << ' '; });
std::cout << '\n';
}
// 比data小的数
int count_less(const Key& data) const {
auto p = root_.get();
size_t count = 0;
while (p) {
if (cmp_(data, p->data_)) {
p = (p->left_).get();
} else if (cmp_(p->data_, data)) {
count += p->count_;
if (p->left_) count += (p->left_)->size_;
p = (p->right_).get();
} else {
if (p->left_) count += (p->left_)->size_;
return count;
}
}
return count;
}
size_t rank(const Key& data) const { return 1 + count_less(data); }
std::optional<Key> kth(size_t k) const {
if (k <= 0 || !root_ || k > root_->size_) return std::nullopt;
auto p = root_.get();
if (k > p->size_) return std::nullopt;
while (p) {
size_t left_size = p->left_ ? p->left_->size_ : 0;
if (left_size >= k) {
p = p->left_.get();
} else if ((left_size + p->count_) >= k) {
return p->data_;
} else {
k -= (left_size + p->count_);
p = p->right_.get();
}
}
return std::nullopt;
}
// 最大的小于x的数
std::optional<Key> predecessor(const Key& x) const {
auto p = root_.get();
std::optional<Key> ans = std::nullopt;
while (p) {
if (cmp_(p->data_, x)) {
ans = p->data_;
p = (p->right_).get();
} else {
p = (p->left_).get();
}
}
return ans;
}
// 最小的大于x的数
std::optional<Key> successor(const Key& x) const {
auto p = root_.get();
std::optional<Key> ans = std::nullopt;
while (p) {
if (cmp_(x, p->data_)) {
ans = p->data_;
p = (p->left_).get();
} else {
p = (p->right_).get();
}
}
return ans;
}
};

浙公网安备 33010602011771号