3-3 AVL树(平衡二叉搜索树)

AVL树(平衡二叉搜索树)

AVL树(AVL Tree)是一种自平衡二叉搜索树(Self-Balancing Binary Search Tree),由 Adelson-Velsky 和 Landis 于 1962 年提出。AVL 树中每个节点的左右子树高度差(称为平衡因子,Balance Factor)不超过 1,从而保证了所有操作的最坏时间复杂度为 O(log n)。

AVL 树的核心思想:

  • 在每次插入或删除操作后,检查并修复树的不平衡状态。
  • 通过旋转(Rotation)操作来恢复平衡。
  • 与普通二叉搜索树(BST)相比,AVL 树始终维持平衡,避免退化为链表。

平衡因子与旋转

平衡因子(Balance Factor)

平衡因子定义为节点左子树高度减去右子树高度:

balanceFactor(node) = height(left subtree) - height(right subtree)

AVL 树要求每个节点的平衡因子必须属于 {-1, 0, 1}:

  • 平衡因子 = 0:左右子树等高,完全平衡。
  • 平衡因子 = 1:左子树比右子树高 1,轻微左倾。
  • 平衡因子 = -1:右子树比左子树高 1,轻微右倾。
  • 平衡因子 > 1 或 < -1:不平衡,需要通过旋转修复。

四种旋转类型

当插入或删除导致某个节点的平衡因子超出范围时,需要通过旋转来恢复平衡。共有四种情况:

1. 右旋(LL 情况)

左子树过重,且左子树的左子树过重。

     z                y
    / \              / \
   y   T4    =>    x     z
  / \             / \   / \
 x   T3         T1  T2 T3 T4
/ \
T1 T2

2. 左旋(RR 情况)

右子树过重,且右子树的右子树过重。

  z                  y
 / \                / \
T1   y      =>    z     x
    / \          / \   / \
   T2   x      T1  T2 T3 T4
       / \
      T3  T4

3. 先左旋后右旋(LR 情况)

左子树过重,但左子树的右子树过重。先对左子树左旋,再对根节点右旋。

     z              z           x
    / \            / \         / \
   y   T4  =>    x   T4 =>  y     z
  / \           / \         / \   / \
 T1   x       y   T3      T1 T2 T3 T4
     / \     / \
    T2 T3   T1  T2

4. 先右旋后左旋(RL 情况)

右子树过重,但右子树的左子树过重。先对右子树右旋,再对根节点左旋。

  z            z               x
 / \          / \             / \
T1   y  =>  T1   x     =>  z     y
    / \         / \        / \   / \
   x   T4     T2   y     T1 T2 T3 T4
  / \              / \
 T2  T3           T3  T4

节点定义

AVL 树的节点需要在普通二叉搜索树节点的基础上增加一个 height 字段,用于快速计算平衡因子。

C++ 实现

struct Node 
{
    int data;
    int height;
    Node* left;
    Node* right;

    // constructor
    Node(int val) : data(val), height(1), left(nullptr), right(nullptr) {}
};

C 实现

#include <stdlib.h>

typedef struct Node 
{
    int data;
    int height;
    struct Node* left;
    struct Node* right;
} Node;

// create a new node with given value
Node* createNode(int val) 
{
    Node* node = (Node*)malloc(sizeof(Node));
    node->data = val;
    node->height = 1;
    node->left = NULL;
    node->right = NULL;
    return node;
}

Python 实现

class Node:
    def __init__(self, data):
        self.data = data
        self.height = 1
        self.left = None
        self.right = None

Go 实现

type Node struct {
    Data   int
    Height int
    Left   *Node
    Right  *Node
}

// create a new node with given value
func NewNode(val int) *Node {
    return &Node{Data: val, Height: 1, Left: nil, Right: nil}
}

每个节点存储四个信息:data 为节点值,height 为以该节点为根的子树高度(叶子节点高度为 1),leftright 分别指向左右子节点。


辅助函数

在实现旋转和平衡操作之前,需要定义三个辅助函数:获取高度、计算平衡因子、更新高度。

C++ 实现

// get height of a node (0 for nullptr)
int getHeight(Node* node) 
{
    if (node == nullptr) 
    {
        return 0;
    }
    return node->height;
}

// calculate balance factor of a node
int getBalance(Node* node) 
{
    if (node == nullptr) 
    {
        return 0;
    }
    return getHeight(node->left) - getHeight(node->right);
}

// update height of a node based on children
void updateHeight(Node* node) 
{
    if (node == nullptr) 
    {
        return;
    }
    int leftHeight = getHeight(node->left);
    int rightHeight = getHeight(node->right);
    node->height = 1 + (leftHeight > rightHeight ? leftHeight : rightHeight);
}

C 实现

// get height of a node (0 for NULL)
int getHeight(Node* node) 
{
    if (node == NULL) 
    {
        return 0;
    }
    return node->height;
}

// calculate balance factor of a node
int getBalance(Node* node) 
{
    if (node == NULL) 
    {
        return 0;
    }
    return getHeight(node->left) - getHeight(node->right);
}

// update height of a node based on children
void updateHeight(Node* node) 
{
    if (node == NULL) 
    {
        return;
    }
    int leftHeight = getHeight(node->left);
    int rightHeight = getHeight(node->right);
    node->height = 1 + (leftHeight > rightHeight ? leftHeight : rightHeight);
}

Python 实现

def get_height(node):
    if node is None:
        return 0
    return node.height

def get_balance(node):
    if node is None:
        return 0
    return get_height(node.left) - get_height(node.right)

def update_height(node):
    if node is None:
        return
    node.height = 1 + max(get_height(node.left), get_height(node.right))

Go 实现

// getHeight returns the height of a node (0 for nil)
func getHeight(node *Node) int {
    if node == nil {
        return 0
    }
    return node.Height
}

// getBalance calculates the balance factor of a node
func getBalance(node *Node) int {
    if node == nil {
        return 0
    }
    return getHeight(node.Left) - getHeight(node.Right)
}

// updateHeight updates the height of a node based on children
func updateHeight(node *Node) {
    if node == nil {
        return
    }
    leftHeight := getHeight(node.Left)
    rightHeight := getHeight(node.Right)
    if leftHeight > rightHeight {
        node.Height = 1 + leftHeight
    } else {
        node.Height = 1 + rightHeight
    }
}
  • getHeight 返回节点高度,空节点返回 0。
  • getBalance 返回平衡因子 = 左子树高度 - 右子树高度。
  • updateHeight 根据子树高度重新计算当前节点高度:height = 1 + max(leftHeight, rightHeight)

旋转操作

旋转是 AVL 树恢复平衡的核心操作。我们实现两种基本旋转:右旋和左旋,复合旋转(LR、RL)由这两个基本旋转组合而成。

右旋(Right Rotation)

右旋以节点 y 为旋转中心,将其左子节点 x 提升为新的根节点。

     y                x
    / \              / \
   x   T3    =>    T1   y
  / \                  / \
 T1  T2               T2  T3

C++ 实现

// right rotation around node y
Node* rightRotate(Node* y) 
{
    Node* x = y->left;
    Node* T2 = x->right;

    // perform rotation
    x->right = y;
    y->left = T2;

    // update heights (y first because it is now lower)
    updateHeight(y);
    updateHeight(x);

    return x;
}

C 实现

// right rotation around node y
Node* rightRotate(Node* y) 
{
    Node* x = y->left;
    Node* T2 = x->right;

    // perform rotation
    x->right = y;
    y->left = T2;

    // update heights (y first because it is now lower)
    updateHeight(y);
    updateHeight(x);

    return x;
}

Python 实现

def right_rotate(y):
    x = y.left
    T2 = x.right

    # perform rotation
    x.right = y
    y.left = T2

    # update heights (y first because it is now lower)
    update_height(y)
    update_height(x)

    return x

Go 实现

// rightRotate performs right rotation around node y
func rightRotate(y *Node) *Node {
    x := y.Left
    T2 := x.Right

    // perform rotation
    x.Right = y
    y.Left = T2

    // update heights (y first because it is now lower)
    updateHeight(y)
    updateHeight(x)

    return x
}

右旋将 y 的左子节点 x 提升为新的子树根,x 的右子树 T2 移交给 y 的左子树。旋转后需要先更新 y 的高度(它现在更低),再更新 x 的高度。


左旋(Left Rotation)

左旋以节点 x 为旋转中心,将其右子节点 y 提升为新的根节点。

   x                  y
  / \                / \
 T1   y      =>    x   T3
     / \          / \
    T2  T3       T1  T2

C++ 实现

// left rotation around node x
Node* leftRotate(Node* x) 
{
    Node* y = x->right;
    Node* T2 = y->left;

    // perform rotation
    y->left = x;
    x->right = T2;

    // update heights (x first because it is now lower)
    updateHeight(x);
    updateHeight(y);

    return y;
}

C 实现

// left rotation around node x
Node* leftRotate(Node* x) 
{
    Node* y = x->right;
    Node* T2 = y->left;

    // perform rotation
    y->left = x;
    x->right = T2;

    // update heights (x first because it is now lower)
    updateHeight(x);
    updateHeight(y);

    return y;
}

Python 实现

def left_rotate(x):
    y = x.right
    T2 = y.left

    # perform rotation
    y.left = x
    x.right = T2

    # update heights (x first because it is now lower)
    update_height(x)
    update_height(y)

    return y

Go 实现

// leftRotate performs left rotation around node x
func leftRotate(x *Node) *Node {
    y := x.Right
    T2 := y.Left

    // perform rotation
    y.Left = x
    x.Right = T2

    // update heights (x first because it is now lower)
    updateHeight(x)
    updateHeight(y)

    return y
}

左旋将 x 的右子节点 y 提升为新的子树根,y 的左子树 T2 移交给 x 的右子树。旋转后同样先更新较低的 x,再更新较高的 y


插入操作

AVL 树的插入分两步:先按照普通二叉搜索树的方式插入节点,然后从插入点向上回溯,检查并修复不平衡。

插入后可能出现四种不平衡情况:

  • LL 型(平衡因子 > 1 且新值 < 左子节点值):右旋一次。
  • RR 型(平衡因子 < -1 且新值 > 右子节点值):左旋一次。
  • LR 型(平衡因子 > 1 且新值 > 左子节点值):先对左子树左旋,再对根节点右旋。
  • RL 型(平衡因子 < -1 且新值 < 右子节点值):先对右子树右旋,再对根节点左旋。

以插入序列 10, 20, 30, 25 为例,演示旋转过程:

插入 10:    10            平衡
插入 20:    10            平衡
              \
              20
插入 30:    10            不平衡! (10的平衡因子 = -2)
              \           => 左旋(RR型)
              20
                \
                30

左旋后:     20            平衡
           /  \
          10   30

插入 25:    20            不平衡! (10的平衡因子 = -1, 20的平衡因子 = -2)
           /  \          => 先右旋再左旋(RL型)
          10   30
               /
              25

RL旋转后:   20            平衡
           /  \
          10   25
                \
                30

C++ 实现

// insert a value into the AVL tree
Node* insert(Node* node, int data) 
{
    // standard BST insertion
    if (node == nullptr) 
    {
        return new Node(data);
    }

    if (data < node->data) 
    {
        node->left = insert(node->left, data);
    } 
    else if (data > node->data) 
    {
        node->right = insert(node->right, data);
    } 
    else 
    {
        // duplicate values not allowed
        return node;
    }

    // update height of current node
    updateHeight(node);

    // check balance and rebalance if needed
    int balance = getBalance(node);

    // LL case: left-heavy, left child is left-heavy
    if (balance > 1 && data < node->left->data) 
    {
        return rightRotate(node);
    }

    // RR case: right-heavy, right child is right-heavy
    if (balance < -1 && data > node->right->data) 
    {
        return leftRotate(node);
    }

    // LR case: left-heavy, left child is right-heavy
    if (balance > 1 && data > node->left->data) 
    {
        node->left = leftRotate(node->left);
        return rightRotate(node);
    }

    // RL case: right-heavy, right child is left-heavy
    if (balance < -1 && data < node->right->data) 
    {
        node->right = rightRotate(node->right);
        return leftRotate(node);
    }

    return node;
}

C 实现

// insert a value into the AVL tree
Node* insert(Node* node, int data) 
{
    // standard BST insertion
    if (node == NULL) 
    {
        return createNode(data);
    }

    if (data < node->data) 
    {
        node->left = insert(node->left, data);
    } 
    else if (data > node->data) 
    {
        node->right = insert(node->right, data);
    } 
    else 
    {
        // duplicate values not allowed
        return node;
    }

    // update height of current node
    updateHeight(node);

    // check balance and rebalance if needed
    int balance = getBalance(node);

    // LL case
    if (balance > 1 && data < node->left->data) 
    {
        return rightRotate(node);
    }

    // RR case
    if (balance < -1 && data > node->right->data) 
    {
        return leftRotate(node);
    }

    // LR case
    if (balance > 1 && data > node->left->data) 
    {
        node->left = leftRotate(node->left);
        return rightRotate(node);
    }

    // RL case
    if (balance < -1 && data < node->right->data) 
    {
        node->right = rightRotate(node->right);
        return leftRotate(node);
    }

    return node;
}

Python 实现

def insert(node, data):
    # standard BST insertion
    if node is None:
        return Node(data)

    if data < node.data:
        node.left = insert(node.left, data)
    elif data > node.data:
        node.right = insert(node.right, data)
    else:
        # duplicate values not allowed
        return node

    # update height of current node
    update_height(node)

    # check balance and rebalance if needed
    balance = get_balance(node)

    # LL case
    if balance > 1 and data < node.left.data:
        return right_rotate(node)

    # RR case
    if balance < -1 and data > node.right.data:
        return left_rotate(node)

    # LR case
    if balance > 1 and data > node.left.data:
        node.left = left_rotate(node.left)
        return right_rotate(node)

    # RL case
    if balance < -1 and data < node.right.data:
        node.right = right_rotate(node.right)
        return left_rotate(node)

    return node

Go 实现

// insert a value into the AVL tree
func insert(node *Node, data int) *Node {
    // standard BST insertion
    if node == nil {
        return NewNode(data)
    }

    if data < node.Data {
        node.Left = insert(node.Left, data)
    } else if data > node.Data {
        node.Right = insert(node.Right, data)
    } else {
        // duplicate values not allowed
        return node
    }

    // update height of current node
    updateHeight(node)

    // check balance and rebalance if needed
    balance := getBalance(node)

    // LL case
    if balance > 1 && data < node.Left.Data {
        return rightRotate(node)
    }

    // RR case
    if balance < -1 && data > node.Right.Data {
        return leftRotate(node)
    }

    // LR case
    if balance > 1 && data > node.Left.Data {
        node.Left = leftRotate(node.Left)
        return rightRotate(node)
    }

    // RL case
    if balance < -1 && data < node.Right.Data {
        node.Right = rightRotate(node.Right)
        return leftRotate(node)
    }

    return node
}

插入操作先按普通 BST 方式递归插入节点,然后更新高度,再根据平衡因子判断属于四种不平衡情况中的哪一种,执行对应的旋转操作恢复平衡。


删除操作

AVL 树的删除也分两步:先按普通 BST 方式删除节点,然后回溯检查并修复不平衡。删除节点后可能出现的不平衡情况与插入类似,但判断条件略有不同——删除时需要看子节点的平衡因子来决定旋转类型。

删除节点的三种情况:

  • 叶子节点:直接删除。
  • 只有一个子节点:用子节点替换被删除节点。
  • 有两个子节点:用中序后继(右子树最小值)替换被删除节点,然后删除后继节点。

C++ 实现

// find node with minimum value in a subtree
Node* minValueNode(Node* node) 
{
    Node* current = node;
    while (current->left != nullptr) 
    {
        current = current->left;
    }
    return current;
}

// delete a value from the AVL tree
Node* deleteNode(Node* root, int data) 
{
    // standard BST delete
    if (root == nullptr) 
    {
        return root;
    }

    if (data < root->data) 
    {
        root->left = deleteNode(root->left, data);
    } 
    else if (data > root->data) 
    {
        root->right = deleteNode(root->right, data);
    } 
    else 
    {
        // node found, perform deletion
        if (root->left == nullptr || root->right == nullptr) 
        {
            Node* temp = root->left ? root->left : root->right;

            // no child case
            if (temp == nullptr) 
            {
                temp = root;
                root = nullptr;
            } 
            else 
            {
                // one child case: copy child contents
                *root = *temp;
            }
            delete temp;
        } 
        else 
        {
            // two children: get inorder successor
            Node* temp = minValueNode(root->right);
            root->data = temp->data;
            root->right = deleteNode(root->right, temp->data);
        }
    }

    // if tree had only one node
    if (root == nullptr) 
    {
        return root;
    }

    // update height
    updateHeight(root);

    // check balance
    int balance = getBalance(root);

    // LL case
    if (balance > 1 && getBalance(root->left) >= 0) 
    {
        return rightRotate(root);
    }

    // LR case
    if (balance > 1 && getBalance(root->left) < 0) 
    {
        root->left = leftRotate(root->left);
        return rightRotate(root);
    }

    // RR case
    if (balance < -1 && getBalance(root->right) <= 0) 
    {
        return leftRotate(root);
    }

    // RL case
    if (balance < -1 && getBalance(root->right) > 0) 
    {
        root->right = rightRotate(root->right);
        return leftRotate(root);
    }

    return root;
}

C 实现

// find node with minimum value in a subtree
Node* minValueNode(Node* node) 
{
    Node* current = node;
    while (current->left != NULL) 
    {
        current = current->left;
    }
    return current;
}

// delete a value from the AVL tree
Node* deleteNode(Node* root, int data) 
{
    // standard BST delete
    if (root == NULL) 
    {
        return root;
    }

    if (data < root->data) 
    {
        root->left = deleteNode(root->left, data);
    } 
    else if (data > root->data) 
    {
        root->right = deleteNode(root->right, data);
    } 
    else 
    {
        // node found, perform deletion
        if (root->left == NULL || root->right == NULL) 
        {
            Node* temp = root->left ? root->left : root->right;

            // no child case
            if (temp == NULL) 
            {
                temp = root;
                root = NULL;
            } 
            else 
            {
                // one child case: copy child contents
                *root = *temp;
            }
            free(temp);
        } 
        else 
        {
            // two children: get inorder successor
            Node* temp = minValueNode(root->right);
            root->data = temp->data;
            root->right = deleteNode(root->right, temp->data);
        }
    }

    // if tree had only one node
    if (root == NULL) 
    {
        return root;
    }

    // update height
    updateHeight(root);

    // check balance
    int balance = getBalance(root);

    // LL case
    if (balance > 1 && getBalance(root->left) >= 0) 
    {
        return rightRotate(root);
    }

    // LR case
    if (balance > 1 && getBalance(root->left) < 0) 
    {
        root->left = leftRotate(root->left);
        return rightRotate(root);
    }

    // RR case
    if (balance < -1 && getBalance(root->right) <= 0) 
    {
        return leftRotate(root);
    }

    // RL case
    if (balance < -1 && getBalance(root->right) > 0) 
    {
        root->right = rightRotate(root->right);
        return leftRotate(root);
    }

    return root;
}

Python 实现

def min_value_node(node):
    current = node
    while current.left is not None:
        current = current.left
    return current

def delete_node(root, data):
    # standard BST delete
    if root is None:
        return root

    if data < root.data:
        root.left = delete_node(root.left, data)
    elif data > root.data:
        root.right = delete_node(root.right, data)
    else:
        # node found, perform deletion
        if root.left is None:
            temp = root.right
            root = None
            return temp
        elif root.right is None:
            temp = root.left
            root = None
            return temp

        # two children: get inorder successor
        temp = min_value_node(root.right)
        root.data = temp.data
        root.right = delete_node(root.right, temp.data)

    # if tree had only one node
    if root is None:
        return root

    # update height
    update_height(root)

    # check balance
    balance = get_balance(root)

    # LL case
    if balance > 1 and get_balance(root.left) >= 0:
        return right_rotate(root)

    # LR case
    if balance > 1 and get_balance(root.left) < 0:
        root.left = left_rotate(root.left)
        return right_rotate(root)

    # RR case
    if balance < -1 and get_balance(root.right) <= 0:
        return left_rotate(root)

    # RL case
    if balance < -1 and get_balance(root.right) > 0:
        root.right = right_rotate(root.right)
        return left_rotate(root)

    return root

Go 实现

// minValueNode finds the node with minimum value in a subtree
func minValueNode(node *Node) *Node {
    current := node
    for current.Left != nil {
        current = current.Left
    }
    return current
}

// deleteNode deletes a value from the AVL tree
func deleteNode(root *Node, data int) *Node {
    // standard BST delete
    if root == nil {
        return root
    }

    if data < root.Data {
        root.Left = deleteNode(root.Left, data)
    } else if data > root.Data {
        root.Right = deleteNode(root.Right, data)
    } else {
        // node found, perform deletion
        if root.Left == nil {
            return root.Right
        } else if root.Right == nil {
            return root.Left
        }

        // two children: get inorder successor
        temp := minValueNode(root.Right)
        root.Data = temp.Data
        root.Right = deleteNode(root.Right, temp.Data)
    }

    // if tree had only one node
    if root == nil {
        return root
    }

    // update height
    updateHeight(root)

    // check balance
    balance := getBalance(root)

    // LL case
    if balance > 1 && getBalance(root.Left) >= 0 {
        return rightRotate(root)
    }

    // LR case
    if balance > 1 && getBalance(root.Left) < 0 {
        root.Left = leftRotate(root.Left)
        return rightRotate(root)
    }

    // RR case
    if balance < -1 && getBalance(root.Right) <= 0 {
        return leftRotate(root)
    }

    // RL case
    if balance < -1 && getBalance(root.Right) > 0 {
        root.Right = rightRotate(root.Right)
        return leftRotate(root)
    }

    return root
}

删除操作与插入类似,区别在于:删除有两个子节点的节点时,用中序后继替换后递归删除后继节点;判断旋转类型时,通过子节点的平衡因子来决定(>= 0 为 LL/RR,< 0 / > 0 为 LR/RL)。


完整实现

下面给出包含插入、删除和遍历的完整 AVL 树程序。插入序列为 10, 20, 30, 40, 50, 25,然后删除 30。

C++ 完整实现

#include <iostream>
using namespace std;

struct Node 
{
    int data;
    int height;
    Node* left;
    Node* right;

    Node(int val) : data(val), height(1), left(nullptr), right(nullptr) {}
};

// get height of a node (0 for nullptr)
int getHeight(Node* node) 
{
    if (node == nullptr) return 0;
    return node->height;
}

// calculate balance factor
int getBalance(Node* node) 
{
    if (node == nullptr) return 0;
    return getHeight(node->left) - getHeight(node->right);
}

// update height based on children
void updateHeight(Node* node) 
{
    if (node == nullptr) return;
    node->height = 1 + (getHeight(node->left) > getHeight(node->right) 
                        ? getHeight(node->left) : getHeight(node->right));
}

// right rotation
Node* rightRotate(Node* y) 
{
    Node* x = y->left;
    Node* T2 = x->right;
    x->right = y;
    y->left = T2;
    updateHeight(y);
    updateHeight(x);
    return x;
}

// left rotation
Node* leftRotate(Node* x) 
{
    Node* y = x->right;
    Node* T2 = y->left;
    y->left = x;
    x->right = T2;
    updateHeight(x);
    updateHeight(y);
    return y;
}

// insert a value
Node* insert(Node* node, int data) 
{
    if (node == nullptr) return new Node(data);

    if (data < node->data)
        node->left = insert(node->left, data);
    else if (data > node->data)
        node->right = insert(node->right, data);
    else
        return node;

    updateHeight(node);
    int balance = getBalance(node);

    // LL
    if (balance > 1 && data < node->left->data)
        return rightRotate(node);
    // RR
    if (balance < -1 && data > node->right->data)
        return leftRotate(node);
    // LR
    if (balance > 1 && data > node->left->data) 
    {
        node->left = leftRotate(node->left);
        return rightRotate(node);
    }
    // RL
    if (balance < -1 && data < node->right->data) 
    {
        node->right = rightRotate(node->right);
        return leftRotate(node);
    }

    return node;
}

// find minimum value node
Node* minValueNode(Node* node) 
{
    Node* current = node;
    while (current->left != nullptr)
        current = current->left;
    return current;
}

// delete a value
Node* deleteNode(Node* root, int data) 
{
    if (root == nullptr) return root;

    if (data < root->data)
        root->left = deleteNode(root->left, data);
    else if (data > root->data)
        root->right = deleteNode(root->right, data);
    else 
    {
        if (root->left == nullptr || root->right == nullptr) 
        {
            Node* temp = root->left ? root->left : root->right;
            if (temp == nullptr) 
            {
                temp = root;
                root = nullptr;
            } 
            else 
            {
                *root = *temp;
            }
            delete temp;
        } 
        else 
        {
            Node* temp = minValueNode(root->right);
            root->data = temp->data;
            root->right = deleteNode(root->right, temp->data);
        }
    }

    if (root == nullptr) return root;

    updateHeight(root);
    int balance = getBalance(root);

    // LL
    if (balance > 1 && getBalance(root->left) >= 0)
        return rightRotate(root);
    // LR
    if (balance > 1 && getBalance(root->left) < 0) 
    {
        root->left = leftRotate(root->left);
        return rightRotate(root);
    }
    // RR
    if (balance < -1 && getBalance(root->right) <= 0)
        return leftRotate(root);
    // RL
    if (balance < -1 && getBalance(root->right) > 0) 
    {
        root->right = rightRotate(root->right);
        return leftRotate(root);
    }

    return root;
}

// inorder traversal (left, root, right)
void inorder(Node* root) 
{
    if (root != nullptr) 
    {
        inorder(root->left);
        cout << root->data << " ";
        inorder(root->right);
    }
}

// preorder traversal (root, left, right)
void preorder(Node* root) 
{
    if (root != nullptr) 
    {
        cout << root->data << " ";
        preorder(root->left);
        preorder(root->right);
    }
}

int main() 
{
    Node* root = nullptr;

    // insert sequence: 10, 20, 30, 40, 50, 25
    int values[] = {10, 20, 30, 40, 50, 25};
    cout << "Inserting: ";
    for (int v : values) 
    {
        cout << v << " ";
        root = insert(root, v);
    }
    cout << "\n\n";

    cout << "Inorder traversal: ";
    inorder(root);
    cout << "\n";

    cout << "Preorder traversal: ";
    preorder(root);
    cout << "\n\n";

    // delete 30
    cout << "Deleting 30...\n";
    root = deleteNode(root, 30);

    cout << "Inorder traversal: ";
    inorder(root);
    cout << "\n";

    cout << "Preorder traversal: ";
    preorder(root);
    cout << "\n";

    return 0;
}

运行该程序将输出

Inserting: 10 20 30 40 50 25 

Inorder traversal: 10 20 25 30 40 50 
Preorder traversal: 30 20 10 25 40 50 

Deleting 30...
Inorder traversal: 10 20 25 40 50 
Preorder traversal: 25 20 10 40 50 

C 完整实现

#include <stdio.h>
#include <stdlib.h>

typedef struct Node 
{
    int data;
    int height;
    struct Node* left;
    struct Node* right;
} Node;

// create a new node
Node* createNode(int val) 
{
    Node* node = (Node*)malloc(sizeof(Node));
    node->data = val;
    node->height = 1;
    node->left = NULL;
    node->right = NULL;
    return node;
}

// get height of a node
int getHeight(Node* node) 
{
    if (node == NULL) return 0;
    return node->height;
}

// calculate balance factor
int getBalance(Node* node) 
{
    if (node == NULL) return 0;
    return getHeight(node->left) - getHeight(node->right);
}

// update height based on children
void updateHeight(Node* node) 
{
    if (node == NULL) return;
    int lh = getHeight(node->left);
    int rh = getHeight(node->right);
    node->height = 1 + (lh > rh ? lh : rh);
}

// right rotation
Node* rightRotate(Node* y) 
{
    Node* x = y->left;
    Node* T2 = x->right;
    x->right = y;
    y->left = T2;
    updateHeight(y);
    updateHeight(x);
    return x;
}

// left rotation
Node* leftRotate(Node* x) 
{
    Node* y = x->right;
    Node* T2 = y->left;
    y->left = x;
    x->right = T2;
    updateHeight(x);
    updateHeight(y);
    return y;
}

// insert a value
Node* insert(Node* node, int data) 
{
    if (node == NULL) return createNode(data);

    if (data < node->data)
        node->left = insert(node->left, data);
    else if (data > node->data)
        node->right = insert(node->right, data);
    else
        return node;

    updateHeight(node);
    int balance = getBalance(node);

    // LL
    if (balance > 1 && data < node->left->data)
        return rightRotate(node);
    // RR
    if (balance < -1 && data > node->right->data)
        return leftRotate(node);
    // LR
    if (balance > 1 && data > node->left->data) 
    {
        node->left = leftRotate(node->left);
        return rightRotate(node);
    }
    // RL
    if (balance < -1 && data < node->right->data) 
    {
        node->right = rightRotate(node->right);
        return leftRotate(node);
    }

    return node;
}

// find minimum value node
Node* minValueNode(Node* node) 
{
    Node* current = node;
    while (current->left != NULL)
        current = current->left;
    return current;
}

// delete a value
Node* deleteNode(Node* root, int data) 
{
    if (root == NULL) return root;

    if (data < root->data)
        root->left = deleteNode(root->left, data);
    else if (data > root->data)
        root->right = deleteNode(root->right, data);
    else 
    {
        if (root->left == NULL || root->right == NULL) 
        {
            Node* temp = root->left ? root->left : root->right;
            if (temp == NULL) 
            {
                temp = root;
                root = NULL;
            } 
            else 
            {
                *root = *temp;
            }
            free(temp);
        } 
        else 
        {
            Node* temp = minValueNode(root->right);
            root->data = temp->data;
            root->right = deleteNode(root->right, temp->data);
        }
    }

    if (root == NULL) return root;

    updateHeight(root);
    int balance = getBalance(root);

    // LL
    if (balance > 1 && getBalance(root->left) >= 0)
        return rightRotate(root);
    // LR
    if (balance > 1 && getBalance(root->left) < 0) 
    {
        root->left = leftRotate(root->left);
        return rightRotate(root);
    }
    // RR
    if (balance < -1 && getBalance(root->right) <= 0)
        return leftRotate(root);
    // RL
    if (balance < -1 && getBalance(root->right) > 0) 
    {
        root->right = rightRotate(root->right);
        return leftRotate(root);
    }

    return root;
}

// inorder traversal
void inorder(Node* root) 
{
    if (root != NULL) 
    {
        inorder(root->left);
        printf("%d ", root->data);
        inorder(root->right);
    }
}

// preorder traversal
void preorder(Node* root) 
{
    if (root != NULL) 
    {
        printf("%d ", root->data);
        preorder(root->left);
        preorder(root->right);
    }
}

int main() 
{
    Node* root = NULL;

    // insert sequence: 10, 20, 30, 40, 50, 25
    int values[] = {10, 20, 30, 40, 50, 25};
    int len = sizeof(values) / sizeof(values[0]);

    printf("Inserting: ");
    for (int i = 0; i < len; i++) 
    {
        printf("%d ", values[i]);
        root = insert(root, values[i]);
    }
    printf("\n\n");

    printf("Inorder traversal: ");
    inorder(root);
    printf("\n");

    printf("Preorder traversal: ");
    preorder(root);
    printf("\n\n");

    // delete 30
    printf("Deleting 30...\n");
    root = deleteNode(root, 30);

    printf("Inorder traversal: ");
    inorder(root);
    printf("\n");

    printf("Preorder traversal: ");
    preorder(root);
    printf("\n");

    return 0;
}

运行该程序将输出

Inserting: 10 20 30 40 50 25 

Inorder traversal: 10 20 25 30 40 50 
Preorder traversal: 30 20 10 25 40 50 

Deleting 30...
Inorder traversal: 10 20 25 40 50 
Preorder traversal: 25 20 10 40 50 

Python 完整实现

class Node:
    def __init__(self, data):
        self.data = data
        self.height = 1
        self.left = None
        self.right = None

def get_height(node):
    if node is None:
        return 0
    return node.height

def get_balance(node):
    if node is None:
        return 0
    return get_height(node.left) - get_height(node.right)

def update_height(node):
    if node is None:
        return
    node.height = 1 + max(get_height(node.left), get_height(node.right))

def right_rotate(y):
    x = y.left
    T2 = x.right
    x.right = y
    y.left = T2
    update_height(y)
    update_height(x)
    return x

def left_rotate(x):
    y = x.right
    T2 = y.left
    y.left = x
    x.right = T2
    update_height(x)
    update_height(y)
    return y

def insert(node, data):
    if node is None:
        return Node(data)

    if data < node.data:
        node.left = insert(node.left, data)
    elif data > node.data:
        node.right = insert(node.right, data)
    else:
        return node

    update_height(node)
    balance = get_balance(node)

    # LL
    if balance > 1 and data < node.left.data:
        return right_rotate(node)
    # RR
    if balance < -1 and data > node.right.data:
        return left_rotate(node)
    # LR
    if balance > 1 and data > node.left.data:
        node.left = left_rotate(node.left)
        return right_rotate(node)
    # RL
    if balance < -1 and data < node.right.data:
        node.right = right_rotate(node.right)
        return left_rotate(node)

    return node

def min_value_node(node):
    current = node
    while current.left is not None:
        current = current.left
    return current

def delete_node(root, data):
    if root is None:
        return root

    if data < root.data:
        root.left = delete_node(root.left, data)
    elif data > root.data:
        root.right = delete_node(root.right, data)
    else:
        if root.left is None:
            temp = root.right
            root = None
            return temp
        elif root.right is None:
            temp = root.left
            root = None
            return temp

        temp = min_value_node(root.right)
        root.data = temp.data
        root.right = delete_node(root.right, temp.data)

    if root is None:
        return root

    update_height(root)
    balance = get_balance(root)

    # LL
    if balance > 1 and get_balance(root.left) >= 0:
        return right_rotate(root)
    # LR
    if balance > 1 and get_balance(root.left) < 0:
        root.left = left_rotate(root.left)
        return right_rotate(root)
    # RR
    if balance < -1 and get_balance(root.right) <= 0:
        return left_rotate(root)
    # RL
    if balance < -1 and get_balance(root.right) > 0:
        root.right = right_rotate(root.right)
        return left_rotate(root)

    return root

def inorder(root):
    if root is not None:
        inorder(root.left)
        print(root.data, end=" ")
        inorder(root.right)

def preorder(root):
    if root is not None:
        print(root.data, end=" ")
        preorder(root.left)
        preorder(root.right)

# main
root = None
values = [10, 20, 30, 40, 50, 25]

print("Inserting:", " ".join(str(v) for v in values))
for v in values:
    root = insert(root, v)

print("\nInorder traversal: ", end="")
inorder(root)
print("\nPreorder traversal:", end=" ")
preorder(root)
print()

# delete 30
print("\nDeleting 30...")
root = delete_node(root, 30)

print("Inorder traversal: ", end="")
inorder(root)
print("\nPreorder traversal:", end=" ")
preorder(root)
print()

Go 完整实现

package main

import "fmt"

type Node struct {
    Data   int
    Height int
    Left   *Node
    Right  *Node
}

func NewNode(val int) *Node {
    return &Node{Data: val, Height: 1}
}

// get height of a node (0 for nil)
func getHeight(node *Node) int {
    if node == nil {
        return 0
    }
    return node.Height
}

// calculate balance factor
func getBalance(node *Node) int {
    if node == nil {
        return 0
    }
    return getHeight(node.Left) - getHeight(node.Right)
}

// update height based on children
func updateHeight(node *Node) {
    if node == nil {
        return
    }
    lh := getHeight(node.Left)
    rh := getHeight(node.Right)
    if lh > rh {
        node.Height = 1 + lh
    } else {
        node.Height = 1 + rh
    }
}

// right rotation
func rightRotate(y *Node) *Node {
    x := y.Left
    T2 := x.Right
    x.Right = y
    y.Left = T2
    updateHeight(y)
    updateHeight(x)
    return x
}

// left rotation
func leftRotate(x *Node) *Node {
    y := x.Right
    T2 := y.Left
    y.Left = x
    x.Right = T2
    updateHeight(x)
    updateHeight(y)
    return y
}

// insert a value
func insert(node *Node, data int) *Node {
    if node == nil {
        return NewNode(data)
    }

    if data < node.Data {
        node.Left = insert(node.Left, data)
    } else if data > node.Data {
        node.Right = insert(node.Right, data)
    } else {
        return node
    }

    updateHeight(node)
    balance := getBalance(node)

    // LL
    if balance > 1 && data < node.Left.Data {
        return rightRotate(node)
    }
    // RR
    if balance < -1 && data > node.Right.Data {
        return leftRotate(node)
    }
    // LR
    if balance > 1 && data > node.Left.Data {
        node.Left = leftRotate(node.Left)
        return rightRotate(node)
    }
    // RL
    if balance < -1 && data < node.Right.Data {
        node.Right = rightRotate(node.Right)
        return leftRotate(node)
    }

    return node
}

// find minimum value node
func minValueNode(node *Node) *Node {
    current := node
    for current.Left != nil {
        current = current.Left
    }
    return current
}

// delete a value
func deleteNode(root *Node, data int) *Node {
    if root == nil {
        return root
    }

    if data < root.Data {
        root.Left = deleteNode(root.Left, data)
    } else if data > root.Data {
        root.Right = deleteNode(root.Right, data)
    } else {
        if root.Left == nil {
            return root.Right
        } else if root.Right == nil {
            return root.Left
        }

        temp := minValueNode(root.Right)
        root.Data = temp.Data
        root.Right = deleteNode(root.Right, temp.Data)
    }

    if root == nil {
        return root
    }

    updateHeight(root)
    balance := getBalance(root)

    // LL
    if balance > 1 && getBalance(root.Left) >= 0 {
        return rightRotate(root)
    }
    // LR
    if balance > 1 && getBalance(root.Left) < 0 {
        root.Left = leftRotate(root.Left)
        return rightRotate(root)
    }
    // RR
    if balance < -1 && getBalance(root.Right) <= 0 {
        return leftRotate(root)
    }
    // RL
    if balance < -1 && getBalance(root.Right) > 0 {
        root.Right = rightRotate(root.Right)
        return leftRotate(root)
    }

    return root
}

// inorder traversal (left, root, right)
func inorder(root *Node) {
    if root != nil {
        inorder(root.Left)
        fmt.Print(root.Data, " ")
        inorder(root.Right)
    }
}

// preorder traversal (root, left, right)
func preorder(root *Node) {
    if root != nil {
        fmt.Print(root.Data, " ")
        preorder(root.Left)
        preorder(root.Right)
    }
}

func main() {
    var root *Node

    // insert sequence: 10, 20, 30, 40, 50, 25
    values := []int{10, 20, 30, 40, 50, 25}

    fmt.Print("Inserting: ")
    for _, v := range values {
        fmt.Print(v, " ")
        root = insert(root, v)
    }
    fmt.Println("\n")

    fmt.Print("Inorder traversal: ")
    inorder(root)
    fmt.Println()

    fmt.Print("Preorder traversal: ")
    preorder(root)
    fmt.Println("\n")

    // delete 30
    fmt.Println("Deleting 30...")
    root = deleteNode(root, 30)

    fmt.Print("Inorder traversal: ")
    inorder(root)
    fmt.Println()

    fmt.Print("Preorder traversal: ")
    preorder(root)
    fmt.Println()
}

运行该程序将输出

Inserting: 10 20 30 40 50 25

Inorder traversal: 10 20 25 30 40 50
Preorder traversal: 30 20 10 25 40 50

Deleting 30...
Inorder traversal: 10 20 25 40 50
Preorder traversal: 25 20 10 40 50

插入序列 10, 20, 30, 40, 50, 25 后,中序遍历得到有序序列 10 20 25 30 40 50,前序遍历为 30 20 10 25 40 50,可以看到根节点为 30,树保持平衡。删除 30 后,前序遍历变为 25 20 10 40 50,25 成为新根节点,树仍然保持平衡。


AVL树的性质

时间复杂度

由于 AVL 树始终保持平衡(高度为 O(log n)),所有基本操作的最坏时间复杂度均为 O(log n)。

操作 平均时间复杂度 最坏时间复杂度 说明
查找(Search) O(log n) O(log n) 与普通 BST 相同算法,但树高有保证
插入(Insert) O(log n) O(log n) BST 插入 + 至多 O(log n) 次旋转
删除(Delete) O(log n) O(log n) BST 删除 + 至多 O(log n) 次旋转
空间复杂度 O(n) O(n) 每个节点额外存储一个 height 字段

AVL 树与普通 BST 对比

特性 普通 BST AVL 树
查找最坏情况 O(n)(退化为链表) O(log n)
插入最坏情况 O(n) O(log n)
删除最坏情况 O(n) O(log n)
额外空间 每节点多一个 height 字段
插入/删除开销 无旋转 需要旋转维持平衡
适用场景 数据随机分布,很少删除 频繁查找,需要稳定性能

平衡因子约束

对于 AVL 树中的任意节点 node:
    -1 <= balanceFactor(node) <= 1

其中:
    balanceFactor(node) = height(left subtree) - height(right subtree)

AVL 树通过严格的平衡因子约束,保证树的高度始终接近 log(n),从而确保操作的高效性。这使得 AVL 树特别适用于查找密集型场景,相比红黑树(Red-Black Tree)的宽松平衡条件,AVL 树提供更严格的平衡,查找效率略高,但插入和删除时可能需要更多旋转操作。

posted @ 2026-04-16 11:11  游翔  阅读(11)  评论(0)    收藏  举报