普通平衡树(AVL树)
某考研党复习到平衡树时突然心血来潮想自己实现一下AVL树QAQ。快一年没敲代码了码力下降严重,断断续续写了好久QAQ。写了快300行,不过是全凭自己感觉写的,也算是完成了当年没完成的心愿吧(自己独立写出一种平衡树)。
代码:
#include <bits/stdc++.h>
#define ls(pos) tr[pos].ch[0]
#define rs(pos) tr[pos].ch[1]
#define Fa(pos) tr[pos].fa
using namespace std;
const int maxn = 101010;
struct AVL_tree {
int tot = 0, root = 0;
struct node {
int ch[2], fa;
int bal, val, dep, sz, cnt;
void init() {
ch[0] = ch[1] = fa = bal = val = dep = sz = cnt = 0;
}
};
node tr[maxn];
int Next_pos, pre_pos, Rank, ans;
void init() {
memset(tr, 0, sizeof(node));
}
int creat(int val, int fa) {
int ret = 0;
// if(st.size()) {
// ret = st.top();
// st.pop();
// } else {
ret = ++tot;
// }
tr[ret].val = val;
tr[ret].fa = fa;
tr[ret].sz = 1;
tr[ret].bal = 0;
tr[ret].dep = 1;
tr[ret].cnt = 1;
return ret;
}
int son (int pos) {
if(ls(Fa(pos)) == pos) return 0;
return 1;
}
void rrotate(int pos) {
if(Fa(pos) != 0) {
if(son(pos) == 0) {
ls(Fa(pos)) = ls(pos);
} else {
rs(Fa(pos)) = ls(pos);
}
}
Fa(ls(pos)) = Fa(pos);
Fa(pos) = ls(pos);
ls(pos) = rs(ls(pos));
if(ls(pos)) Fa(ls(pos)) = pos;
rs(Fa(pos)) = pos;
maintain1(pos);
maintain1(Fa(pos));
if(pos == root) {
root = Fa(pos);
}
}
void lrotate(int pos) {
if(Fa(pos) != 0) {
if(son(pos) == 0) {
ls(Fa(pos)) = rs(pos);
} else {
rs(Fa(pos)) = rs(pos);
}
}
Fa(rs(pos)) = Fa(pos);
Fa(pos) = rs(pos);
rs(pos) = ls(rs(pos));
if(rs(pos)) Fa(rs(pos)) = pos;
ls(Fa(pos)) = pos;
maintain1(pos);
maintain1(Fa(pos));
if(pos == root) {
root = Fa(pos);
}
}
void rotate(int pos) {
if(tr[pos].bal > 1) {
if(tr[ls(pos)].bal > 0) {
rrotate(pos);
} else {
lrotate(ls(pos));
rrotate(pos);
}
} else {
if(tr[rs(pos)].bal < 0) {
lrotate(pos);
} else {
rrotate(rs(pos));
lrotate(pos);
}
}
}
void maintain1(int pos) {
if(pos == 0) return;
tr[pos].dep = max(tr[ls(pos)].dep, tr[rs(pos)].dep) + 1;
tr[pos].bal = tr[ls(pos)].dep - tr[rs(pos)].dep;
tr[pos].sz = tr[ls(pos)].sz + tr[rs(pos)].sz + tr[pos].cnt;
}
void maintain(int pos) {
if(pos == 0) return;
maintain1(pos);
if(tr[pos].bal > 1 || tr[pos].bal < -1) {
rotate(pos);
}
}
void insert(int pos, int val) {
if(tr[pos].val == val) {
tr[pos].cnt++;
} else {
if(val > tr[pos].val) {
if(rs(pos) == 0) {
rs(pos) = creat(val, pos);
} else {
insert(rs(pos), val);
}
} else {
if(ls(pos) == 0) {
ls(pos) = creat(val, pos);
} else {
insert(ls(pos), val);
}
}
}
maintain(pos);
}
int find(int pos, int x) {
if(pos == 0) return pos;
if(tr[pos].val == x) {
return pos;
}
if(tr[pos].val > x) return find(ls(pos), x);
else return find(rs(pos), x);
}
void pre(int pos, int x) {
if(tr[pos].val >= x) {
if(ls(pos)) pre(ls(pos), x);
} else {
pre_pos = pos;
if(rs(pos)) pre(rs(pos), x);
}
}
void Next(int pos, int x) {
if(tr[pos].val <= x) {
if(rs(pos)) Next(rs(pos), x);
} else {
Next_pos = pos;
if(ls(pos)) Next(ls(pos), x);
}
}
bool del(int pos) {
bool ret = false;
int s = son(pos), tmp = Fa(pos);
if(!ls(pos) && !rs(pos)) {
if(s == 0) ls(Fa(pos)) = 0;
else rs(Fa(pos)) = 0;
if(root == pos) root = 0;
ret = true;
}
else if(ls(pos) == 0) {
if(s == 0) ls(Fa(pos)) = rs(pos);
else rs(Fa(pos)) = rs(pos);
Fa(rs(pos)) = Fa(pos);
if(pos == root) root = rs(pos);
ret = true;
}
else if(rs(pos) == 0) {
if(s == 0) ls(Fa(pos)) = ls(pos);
else rs(Fa(pos)) = ls(pos);
Fa(ls(pos)) = Fa(pos);
if(pos == root) root = ls(pos);
ret = true;
}
if(ret) {
// st.push(pos);
tr[pos].init();
return true;
}
return false;
}
void rank_of_val(int pos, int val) {
if(!pos) return;
if(tr[pos].val < val) {
Rank += tr[ls(pos)].sz + tr[pos].cnt;
rank_of_val(rs(pos), val);
} else {
rank_of_val(ls(pos), val);
}
}
void val_of_rank(int pos, int remain) {
if(!pos)
return;
if(tr[ls(pos)].sz < remain) {
if(tr[ls(pos)].sz + tr[pos].cnt >= remain) {
ans = tr[pos].val;
return;
} else {
remain -= tr[ls(pos)].sz + tr[pos].cnt;
val_of_rank(rs(pos), remain);
}
} else {
val_of_rank(ls(pos), remain);
}
}
void erase(int pos) {
int t = Fa(pos);
if(tr[pos].cnt == 1) {
if(!del(pos)) {
//int tmp = Next(root, pos);
int tmp = ls(pos);
while(rs(tmp)) tmp = rs(tmp);
t = Fa(tmp);
tr[pos].val = tr[tmp].val;
tr[pos].cnt = tr[tmp].cnt;
del(tmp);
}
} else {
tr[pos].cnt--;
maintain1(pos);
}
while(t) {
maintain(t);
t = Fa(t);
}
}
};
AVL_tree solve;
int main() {
srand(time(0));
int n, x, y;
cin >> n;
solve.init();
for (int i = 1; i <= n; i++) {
cin >> x >> y;
if(x == 1) {
if(solve.root == 0) {
solve.root = solve.creat(y, solve.root);
} else {
solve.insert(solve.root, y);
}
}
else if(x == 2) {
y = solve.find(solve.root, y);
if(y == 0) {
printf("miss\n");
} else {
solve.erase(y);
}
}
else if (x == 3) {
solve.Rank = 0;
solve.rank_of_val(solve.root, y);
printf("%d\n", solve.Rank + 1);
}
else if(x == 4) {
solve.ans = 0;
solve.val_of_rank(solve.root, y);
printf("%d\n", solve.ans);
}
else if (x == 5) {
solve.pre_pos = -1;
solve.pre(solve.root, y);
if(solve.pre_pos == -1) printf("not found\n");
printf("%d\n", solve.tr[solve.pre_pos].val);
}
else if(x == 6) {
solve.Next_pos = -1;
solve.Next(solve.root, y);
if(solve.Next_pos == -1) printf("not found\n");
printf("%d\n", solve.tr[solve.Next_pos].val);
}
// printf("root : %d\n", solve.tr[solve.root].sz);
}
}

浙公网安备 33010602011771号