typedef struct node *pos;
typedef struct node *AvlT;
struct node
{
int val;
AvlT l;
AvlT r;
int h;
};
static int Height(pos p)
{
if (!p) return -1;
return p -> h;
}
static pos R_rotate(pos k2)
{
pos k1;
k1 = k2 -> l;
k2 -> l = k1 -> r;
k1 -> r = k2;
k2 -> h = max(Height(k2 -> l), Height(k2 -> r)) + 1;
k1 -> h = max(Height(k1 -> l), k2 -> h) + 1;
return k1;
}
static pos L_rotate(pos k2)
{
pos k1;
k1 = k2 -> r;
k2 -> r = k1 -> l;
k1 -> l = k2;
k2 -> h = max(Height(k2 -> l), Height(k2 -> r)) + 1;
k1 -> h = max(Height(k1 -> l), k2 -> h) + 1;
return k1;
}
static pos RL_rotate(pos k)
{
k -> r = R_rotate(k -> r); // 右旋变成 右-右型
return L_rotate(k);
}
static pos LR_rotate(pos k)
{
k -> l = L_rotate(k -> l); // 左旋变成 左-左型
return R_rotate(k);
}
AvlT inst(int x, AvlT t)
{
if (!t)
{
t = (struct node*) malloc(sizeof (struct node));
t -> val = x;
t -> h = 0; // 记得set树叶节点高度为0
t -> l = t -> r = NULL;
return t;
}
if (x < t -> val)
{
t -> l = inst(x, t -> l);
if (Height(t -> l) - Height(t -> r) == 2)
if (x < t -> l -> val) // ll 右旋
t = R_rotate(t);
else // lr 左右旋
t = LR_rotate(t);
}
else if (x > t -> val)
{
t -> r = inst(x, t -> r);
if (Height(t -> r) - Height(t -> l) == 2)
if (x > t -> r -> val) // rr 左旋
t = L_rotate(t);
else // rl 右左旋
t = RL_rotate(t);
}
t -> h = max(Height(t -> l), Height(t -> r)) + 1;
return t;
}